Skip to content

Commit

Permalink
Adds LIT app server code
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Sep 11, 2024
1 parent 14e700b commit cf1b3e6
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 3 deletions.
6 changes: 6 additions & 0 deletions lit_nlp/examples/gcp/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import enum

class LlmHTTPEndpoints(enum.Enum):
GENERATE = 'predict'
SALIENCE = 'salience'
TOKENIZE = 'tokenize'
106 changes: 106 additions & 0 deletions lit_nlp/examples/gcp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Wrapper for connetecting to LLMs on GCP via the model_server HTTP API."""

import enum

from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.api.types import Spec
from lit_nlp.examples.gcp import constants as lit_gcp_constants
from lit_nlp.examples.prompt_debugging import constants as pd_constants
from lit_nlp.examples.prompt_debugging import utils as pd_utils
from lit_nlp.lib import serialize
import requests

"""
Plan for this module:
From GitHub:
* Rebase to include cl/672527408 and the CL described above
* Define an enum to track HTTP endpoints across Python modules
* Adopt HTTP endpoint enum across model_server.py and LlmOverHTTP
* Adopt model_specs.py in LlmOverHTTP, using HTTP endpoint enum for
conditional additions
"""

_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints


class LlmOverHTTP(lit_model.BatchedRemoteModel):

def __init__(
self,
base_url: str,
endpoint: str | _LlmHTTPEndpoints,
max_concurrent_requests: int = 4,
max_qps: int | float = 25
):
super().__init__(max_concurrent_requests, max_qps)
self.endpoint = _LlmHTTPEndpoints(endpoint)
self.url = f'{base_url}/{self.endpoint.value}'

def input_spec(self) -> lit_types.Spec:
input_spec = pd_constants.INPUT_SPEC

if self.endpoint == _LlmHTTPEndpoints.SALIENCE:
input_spec |= pd_constants.INPUT_SPEC_SALIENCE

return input_spec

def output_spec(self) -> lit_types.Spec:
if self.endpoint == _LlmHTTPEndpoints.GENERATE:
return (
pd_constants.OUTPUT_SPEC_GENERATION
| pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS
)
elif self.endpoint == _LlmHTTPEndpoints.SALIENCE:
return pd_constants.OUTPUT_SPEC_SALIENCE
else:
return pd_constants.OUTPUT_SPEC_TOKENIZER

def predict_minibatch(
self, inputs: list[lit_types.JsonDict]
) -> list[lit_types.JsonDict]:
"""Run prediction on a batch of inputs.
Subclass should implement this.
Args:
inputs: sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
response = requests.post(
self.url, data=serialize.to_json(list(inputs), simple=True)
)

if not (200 <= response.status_code < 300):
raise RuntimeError()

outputs = serialize.from_json(response.text)
return outputs


def initialize_model_group_for_salience(
name: str, base_url: str, *args, **kw
) -> dict[str, lit_model.Model]:
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'."""
salience_name, tokenizer_name = pd_utils.generate_model_group_names(name)

generation_model = LlmOverHTTP(
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.GENERATE, **kw
)
salience_model = LlmOverHTTP(
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.SALIENCE, **kw
)
tokenizer_model = LlmOverHTTP(
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.TOKENIZE, **kw
)

return {
name: generation_model,
salience_name: salience_model,
tokenizer_name: tokenizer_model,
}
9 changes: 6 additions & 3 deletions lit_nlp/examples/gcp/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from absl import app
from lit_nlp import dev_server
from lit_nlp.examples.gcp import constants as lit_gcp_constants
from lit_nlp.examples.prompt_debugging import models as pd_models
from lit_nlp.examples.prompt_debugging import utils as pd_utils
from lit_nlp.lib import serialize
Expand All @@ -19,6 +20,8 @@
DEFAULT_BATCH_SIZE = 1
DEFAULT_MODELS = 'gemma_1.1_2b_IT:gemma_1.1_instruct_2b_en'

_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints


def get_wsgi_app() -> wsgi_app.App:
"""Return WSGI app for an LLM server."""
Expand Down Expand Up @@ -60,9 +63,9 @@ def _handler(app: wsgi_app.App, request, unused_environ):
sal_name, tok_name = pd_utils.generate_model_group_names(gen_name)

handlers = {
'/predict': models[gen_name].predict,
'/salience': models[sal_name].predict,
'/tokenize': models[tok_name].predict,
f'/{_LlmHTTPEndpoints.GENERATE.value}': models[gen_name].predict,
f'/{_LlmHTTPEndpoints.SALIENCE.value}': models[sal_name].predict,
f'/{_LlmHTTPEndpoints.TOKENIZE.value}': models[tok_name].predict,
}

wrapped_handlers = {
Expand Down
78 changes: 78 additions & 0 deletions lit_nlp/examples/gcp/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Server for sequence salience with a left-to-right language model."""

from collections.abc import Mapping, Sequence
import sys
from typing import Optional

from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.examples.gcp import model as lit_gcp_model
from lit_nlp.examples.prompt_debugging import datasets as pd_datasets
from lit_nlp.examples.prompt_debugging import layouts as pd_layouts


_FLAGS = flags.FLAGS

_SPLASH_SCREEN_DOC = """
# Language Model Salience
To begin, select an example, then click the segment(s) (tokens, words, etc.)
of the output that you would like to explain. Preceding segments(s) will be
highlighted according to their importance to the selected target segment(s),
with darker colors indicating a greater influence (salience) of that segment on
the model's likelihood of the target segment.
"""


def init_llm_on_gcp(
name: str, base_url: str, *args, **kw
) -> Mapping[str, lit_model.Model]:
return lit_gcp_model.initialize_model_group_for_salience(
name=name, base_url=base_url, *args, **kw
)


def get_wsgi_app() -> Optional[dev_server.LitServerType]:
"""Return WSGI app for container-hosted demos."""
_FLAGS.set_default("server_type", "external")
_FLAGS.set_default("demo_mode", True)
_FLAGS.set_default("page_title", "LM Prompt Debugging")
_FLAGS.set_default("default_layout", pd_layouts.THREE_PANEL)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = flags.FLAGS(sys.argv, known_only=True)
if unused:
logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused)
return main([])


def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

lit_demo = dev_server.Server(
models={},
datasets={},
layouts=pd_layouts.PROMPT_DEBUGGING_LAYOUTS,
model_loaders={
'LLM on GCP': (init_llm_on_gcp, {
'name': lit_types.String(),
'base_url': lit_types.String(),
'max_concurrent_requests': lit_types.Integer(default=1),
'max_qps': lit_types.Scalar(default=25),
})
},
dataset_loaders=pd_datasets.get_dataset_loaders(),
onboard_start_doc=_SPLASH_SCREEN_DOC,
**server_flags.get_flags(),
)
return lit_demo.serve()


if __name__ == "__main__":
app.run(main)

0 comments on commit cf1b3e6

Please sign in to comment.