Skip to content

Commit

Permalink
Adding LIT App server for LIT LLMs on GCP
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Sep 17, 2024
1 parent 20fc14d commit dc96cff
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 46 deletions.
59 changes: 54 additions & 5 deletions lit_nlp/examples/gcp/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@
# Use the official lightweight Python image.
# https://hub.docker.com/_/python

FROM nvidia/cuda:12.5.1-base-ubuntu22.04 AS base
# ---- LIT on GCP Base Images ----

FROM python:3.11-slim AS lit-gcp-app-server-base

# Update Ubuntu packages and install basic utils
RUN apt-get update
RUN apt-get install -y wget curl gnupg2 gcc g++ git

# Copy local code to the container image.
ENV APP_HOME /app
WORKDIR $APP_HOME

COPY ./lit_nlp/examples/gcp/server_gunicorn_config.py ./gunicorn_config.py



FROM nvidia/cuda:12.5.1-base-ubuntu22.04 AS lit-gcp-model-server-base
ENV DEBIAN_FRONTEND=noninteractive
ENV LANG C.UTF-8

Expand Down Expand Up @@ -71,9 +87,42 @@ RUN rm -rf /var/lib/apt/lists/*



# ---- LIT on GCP from source ----
# ---- LIT on GCP Development Images ----

FROM lit-gcp-app-server-base AS lit-gcp-app-server-dev

# Install yarn
RUN curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | apt-key add -
RUN echo "deb https://dl.yarnpkg.com/debian/ stable main" | \
tee /etc/apt/sources.list.d/yarn.list
RUN apt update && apt -y install yarn

# Set up python environment with production dependencies
# This step is slow as it installs many packages.
COPY requirements_core.txt ./
COPY lit_nlp/examples/prompt_debugging/requirements.txt \
lit_nlp/examples/prompt_debugging/requirements.txt
COPY lit_nlp/examples/gcp/requirements.txt \
lit_nlp/examples/gcp/requirements.txt
RUN python -m pip install -r lit_nlp/examples/gcp/requirements.txt

# Copy the rest of the lit_nlp package
COPY . ./

# Build front-end with yarn
WORKDIR $APP_HOME/lit_nlp/client
ENV NODE_OPTIONS "--openssl-legacy-provider"
RUN yarn && yarn build && rm -rf node_modules/*

# Run LIT server
# Note that the config file supports configuring the LIT demo that is launched
# via the DEMO_NAME and DEMO_PORT environment variables.
WORKDIR $APP_HOME
ENTRYPOINT ["gunicorn", "--config=gunicorn_config.py"]



FROM base AS lit-gcp-model-server-dev
FROM lit-gcp-model-server-base AS lit-gcp-model-server-dev
ENV APP_HOME /app
WORKDIR $APP_HOME

Expand All @@ -89,15 +138,15 @@ RUN echo "deb https://dl.yarnpkg.com/debian/ stable main" | \
tee /etc/apt/sources.list.d/yarn.list
RUN apt update && apt -y install yarn

# TODO(b/353980272): Replace the default config with the GCP-specific config
COPY ./lit_nlp/examples/gcp/model_server_gunicorn_config.py ./

# TODO(b/353980272): Replace this with a requirements file specific to the GCP
# exmaple, this should include the core lit-nlp package.
COPY requirements_core.txt ./
COPY lit_nlp/examples/prompt_debugging/requirements.txt \
lit_nlp/examples/prompt_debugging/requirements.txt
COPY lit_nlp/examples/gcp/requirements.txt lit_nlp/examples/gcp/requirements.txt
COPY lit_nlp/examples/gcp/requirements.txt \
lit_nlp/examples/gcp/requirements.txt
RUN python -m pip install -r lit_nlp/examples/gcp/requirements.txt

# Copy the rest of the lit_nlp package
Expand Down
38 changes: 23 additions & 15 deletions lit_nlp/examples/gcp/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Wrapper for connetecting to LLMs on GCP via the model_server HTTP API."""

import enum

from lit_nlp import app as lit_app
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.api.types import Spec
from lit_nlp.examples.gcp import constants as lit_gcp_constants
from lit_nlp.examples.prompt_debugging import constants as pd_constants
from lit_nlp.examples.prompt_debugging import utils as pd_utils
Expand All @@ -14,18 +12,19 @@
"""
Plan for this module:
From GitHub:
* Rebase to include cl/672527408 and the CL described above
* Define an enum to track HTTP endpoints across Python modules
* Adopt HTTP endpoint enum across model_server.py and LlmOverHTTP
* Adopt model_specs.py in LlmOverHTTP, using HTTP endpoint enum for
conditional additions
"""

_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints

LLM_ON_GCP_INIT_SPEC: lit_types.Spec = {
# Note that `new_name` is not actually passed to LlmOverHTTP but the
# `/create_model` API will validate the config with a `new_name` in it.
'new_name': lit_types.String(required=False),
'base_url': lit_types.String(),
'max_concurrent_requests': lit_types.Integer(default=1),
'max_qps': lit_types.Scalar(default=25),
}


class LlmOverHTTP(lit_model.BatchedRemoteModel):

Expand Down Expand Up @@ -84,10 +83,10 @@ def predict_minibatch(


def initialize_model_group_for_salience(
name: str, base_url: str, *args, **kw
) -> dict[str, lit_model.Model]:
new_name: str, base_url: str, *args, **kw
) -> lit_model.ModelMap:
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'."""
salience_name, tokenizer_name = pd_utils.generate_model_group_names(name)
salience_name, tokenizer_name = pd_utils.generate_model_group_names(new_name)

generation_model = LlmOverHTTP(
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.GENERATE, **kw
Expand All @@ -100,7 +99,16 @@ def initialize_model_group_for_salience(
)

return {
name: generation_model,
new_name: generation_model,
salience_name: salience_model,
tokenizer_name: tokenizer_model,
}


def get_model_loaders() -> lit_app.ModelLoadersMap:
return {
'LLM Over HTTP': (
initialize_model_group_for_salience,
LLM_ON_GCP_INIT_SPEC
)
}
13 changes: 7 additions & 6 deletions lit_nlp/examples/gcp/model_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.examples.gcp import constants as lit_gcp_constants
from lit_nlp.examples.gcp import model_server
from lit_nlp.examples.prompt_debugging import utils as pd_utils
import webtest
Expand Down Expand Up @@ -41,22 +42,22 @@ def setUpClass(cls):

@parameterized.named_parameters(
dict(
testcase_name='predict',
endpoint='/predict',
testcase_name=lit_gcp_constants.LlmHTTPEndpoints.GENERATE.value,
endpoint=f'/{lit_gcp_constants.LlmHTTPEndpoints.GENERATE.value}',
expected=[{'response': 'test output text'}],
),
dict(
testcase_name='salience',
endpoint='/salience',
testcase_name=lit_gcp_constants.LlmHTTPEndpoints.SALIENCE.value,
endpoint=f'/{lit_gcp_constants.LlmHTTPEndpoints.SALIENCE.value}',
expected=[{
'tokens': ['test', 'output', 'text'],
'grad_l2': [0.1234, 0.3456, 0.5678],
'grad_dot_input': [0.1234, -0.3456, 0.5678],
}],
),
dict(
testcase_name='tokenize',
endpoint='/tokenize',
testcase_name=lit_gcp_constants.LlmHTTPEndpoints.TOKENIZE,
endpoint=f'/{lit_gcp_constants.LlmHTTPEndpoints.TOKENIZE}',
expected=[{'tokens': ['test', 'output', 'text']}],
),
)
Expand Down
23 changes: 3 additions & 20 deletions lit_nlp/examples/gcp/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Server for sequence salience with a left-to-right language model."""

from collections.abc import Mapping, Sequence
from collections.abc import Sequence
import sys
from typing import Optional

Expand All @@ -9,8 +9,6 @@
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.examples.gcp import model as lit_gcp_model
from lit_nlp.examples.prompt_debugging import datasets as pd_datasets
from lit_nlp.examples.prompt_debugging import layouts as pd_layouts
Expand All @@ -29,18 +27,10 @@
"""


def init_llm_on_gcp(
name: str, base_url: str, *args, **kw
) -> Mapping[str, lit_model.Model]:
return lit_gcp_model.initialize_model_group_for_salience(
name=name, base_url=base_url, *args, **kw
)


def get_wsgi_app() -> Optional[dev_server.LitServerType]:
"""Return WSGI app for container-hosted demos."""
_FLAGS.set_default("server_type", "external")
_FLAGS.set_default("demo_mode", True)
_FLAGS.set_default("demo_mode", False)
_FLAGS.set_default("page_title", "LM Prompt Debugging")
_FLAGS.set_default("default_layout", pd_layouts.THREE_PANEL)
# Parse flags without calling app.run(main), to avoid conflict with
Expand All @@ -59,14 +49,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
models={},
datasets={},
layouts=pd_layouts.PROMPT_DEBUGGING_LAYOUTS,
model_loaders={
'LLM on GCP': (init_llm_on_gcp, {
'name': lit_types.String(),
'base_url': lit_types.String(),
'max_concurrent_requests': lit_types.Integer(default=1),
'max_qps': lit_types.Scalar(default=25),
})
},
model_loaders=lit_gcp_model.get_model_loaders(),
dataset_loaders=pd_datasets.get_dataset_loaders(),
onboard_start_doc=_SPLASH_SCREEN_DOC,
**server_flags.get_flags(),
Expand Down
25 changes: 25 additions & 0 deletions lit_nlp/examples/gcp/server_gunicorn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""gunicorn configuration for cloud-hosted demos."""

import os

_PORT = os.getenv('PORT', '5432')

bind = f'0.0.0.0:{_PORT}'
timeout = 3600
threads = 8
worker_class = 'gthread'
wsgi_app = f'lit_nlp.examples.gcp.server:get_wsgi_app()'

0 comments on commit dc96cff

Please sign in to comment.