Skip to content

Commit

Permalink
LIT: Load multiple model wrappers with shared model via the /create_m…
Browse files Browse the repository at this point in the history
…odel API

PiperOrigin-RevId: 673859638
  • Loading branch information
RyanMullins authored and LIT team committed Sep 16, 2024
1 parent e190a71 commit ebeac49
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 65 deletions.
5 changes: 4 additions & 1 deletion lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Base classes for LIT models."""
import abc
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, Mapping
import inspect
import itertools
import multiprocessing.pool # for ThreadPool
Expand Down Expand Up @@ -203,6 +203,9 @@ def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
pass


ModelMap = Mapping[str, Model]


class ModelWrapper(Model):
"""Wrapper for a LIT model.
Expand Down
19 changes: 12 additions & 7 deletions lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,9 +1047,12 @@ def is_param_optional(parameter: inspect.Parameter) -> bool:

# Otherwise, attempt to infer a type from the Paramater object.
if param.annotation is param.empty and param.default is param.empty:
raise TypeError(f"Unable to infer a type for parameter '{param.name}' "
f"of '{func.__name__}'. Please add a type hint or "
"default value, or implement a Spec literal.")
fn_name = getattr(func, "__name__", repr(func))
raise TypeError(
f"Unable to infer a type for parameter '{param.name}' of '{fn_name}'."
" Please add a type hint or default value, or implement a Spec"
" literal."
)

if param.annotation is param.empty:
param_type = type(param.default)
Expand All @@ -1065,9 +1068,11 @@ def is_param_optional(parameter: inspect.Parameter) -> bool:
lit_type_params["default"] = param.default
spec[param.name] = lit_type_cstr(**lit_type_params)
else:
raise TypeError(f"Unsupported type '{param_type}' for parameter "
f"'{param.name}' of '{func.__name__}'. If possible "
"(e.g., this parameter is Optional), please implement a "
"spec literal instead of using inferencing.")
fn_name = getattr(func, "__name__", repr(func))
raise TypeError(
f"Unsupported type '{param_type}' for parameter '{param.name}' of"
f" '{fn_name}'. If possible (e.g., this parameter is Optional),"
" please implement a spec literal instead of using inferencing."
)

return spec
116 changes: 89 additions & 27 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
import collections
from collections.abc import Callable, Iterable, Mapping, Sequence
import functools
import glob
import math
import os
import random
import threading
import time
from typing import Any, Optional, TypedDict, Union
from typing import Any, Optional, TypedDict, Union, cast, get_type_hints

from absl import logging
from lit_nlp.api import components as lit_components
Expand Down Expand Up @@ -53,7 +51,12 @@
DatasetLoader = tuple[Callable[..., lit_dataset.Dataset], Optional[types.Spec]]
DatasetLoadersMap = dict[str, DatasetLoader]

ModelLoader = tuple[Callable[..., lit_model.Model], Optional[types.Spec]]
SingleModelLoader = Callable[..., lit_model.Model]
MultipleModelLoader = Callable[..., lit_model.ModelMap]
ModelLoader = tuple[
Union[SingleModelLoader, MultipleModelLoader],
Optional[types.Spec],
]
ModelLoadersMap = dict[str, ModelLoader]

_EMPTY_DATASET_KEY = '_union_empty'
Expand Down Expand Up @@ -435,10 +438,11 @@ def _create_dataset(
dataset_cls, dataset_init_spec = loader_info

if dataset_init_spec is not None:
initializer_name = getattr(dataset_cls, '__name__', repr(dataset_cls))
utils.validate_config_against_spec(
config,
dataset_init_spec,
f'{dataset_name} ({dataset_cls.__name__})',
f'{dataset_name} ({initializer_name})',
raise_for_unsupported=True,
)

Expand All @@ -450,49 +454,107 @@ def _create_dataset(
self._info = self._build_metadata()
return (self._info, new_name)

def _create_model(self,
data: types.JsonDict,
model_name: Optional[str] = None,
**unused_kw):
"""Create a model, updating and returning the metadata."""
def _create_model(
self, data: types.JsonDict, model_name: Optional[str] = None, **unused_kw
):
"""Create a model, updating and returning the metadata.
LIT supports two types of model loaders:
* Single-model loaders that return an instance of `lit_model.Model`; and
* Multiple-model loaders that return a `Mapping[str, lit_model.Model]`.
Multiple-model loaders are primarily used for LLM use cases, such as the
Prompt Debugging example, where LIT needs to access the generation,
tokenization, and salience computation features of a model separately, and
thus initializes one lit_model.Model wrapper for each of these purposes.
Note that the `Callable` associated with a given Multiple-model
`ModelLoader` must take `new_name` parameter as it is assumed that this
`Callable` will initialize multiple LIT Model wrappers for different
functions performed by a shared model, such as the generate, tokenize, and
salience functions of an LLM for prompt debugging use cases.
Single-model loaders are used in most other use cases, such as
classification and regression tasks where the prediction is more stable.
Args:
data: the JSON payload provided in the request.
model_name: the model intializer to use, a key of LitApp._model_loaders.
Returns:
A tuple containing the updated LitApp metadata and the name of the models
that were added.
Raises:
ValueError: If any of the following are missing: model_name, the config,
or a value for new_name in the config; if there is not a model loader
configured for the provided model_name; or if there is a name collision
with one of the models returned by a multiple-model loader.
"""
if model_name is None:
raise ValueError('No base model specified.')

if (loader_info := self._model_loaders.get(model_name)) is None:
raise ValueError(
f'No loader information (Cls + init_spec) found for {model_name}'
)

config: Optional[dict[str, Any]] = data.get('config')
if config is None:
raise ValueError('No config specified.')

new_name: Optional[str] = config.pop('new_name', None)
if new_name is None:
if not new_name:
raise ValueError('No name provided for the new model.')
elif new_name in self._models:
return (self._info, new_name) # Return the existing model

if (loader_info := self._model_loaders.get(model_name)) is None:
raise ValueError(
f'No loader information (Cls + init_spec) found for {model_name}'
)

model_cls, model_init_spec = loader_info
model_initializer, model_init_spec = loader_info

if model_init_spec is not None:
initializer_name = getattr(
model_initializer, '__name__', repr(model_initializer)
)
utils.validate_config_against_spec(
config,
model_init_spec,
f'{model_name} ({model_cls.__name__})',
f'{model_name} ({initializer_name})',
raise_for_unsupported=True,
)

new_model = model_cls(**config)
self._models[new_name] = caching.CachingModelWrapper(
new_model, new_name, **self._caching_model_wrapper_kw
)
return_type = get_type_hints(model_initializer)['return']

if Mapping in return_type.__mro__:
model_initializer = cast(MultipleModelLoader, model_initializer)
new_models = model_initializer(new_name=new_name, **config)
new_model_names: list[str] = list(new_models.keys())
model_name_collisions = [
model_name
for model_name in new_model_names
if model_name in self._models
]
if model_name_collisions:
raise ValueError(f'Model(s) already exist: {model_name_collisions}.')

for model_name, model_instance in new_models.items():
self._models[model_name] = caching.CachingModelWrapper(
model_instance, model_name, **self._caching_model_wrapper_kw
)
else:
if new_name in self._models:
return (self._info, new_name) # Return the existing model

new_model_names: list[str] = [new_name]
model_initializer = cast(SingleModelLoader, model_initializer)
new_model = model_initializer(**config)
self._models[new_name] = caching.CachingModelWrapper(
new_model, new_name, **self._caching_model_wrapper_kw
)

empty_dataset = lit_dataset.NoneDataset(self._models)
self._datasets[_EMPTY_DATASET_KEY] = lit_dataset.IndexedDataset(
base=self._run_annotators(empty_dataset), id_fn=caching.input_hash
)
self._info = self._build_metadata()
return (self._info, new_name)
return (self._info, new_model_names)

def _get_generated(
self,
Expand Down Expand Up @@ -824,7 +886,7 @@ def _handler(app: wsgi_app.App, request, environ):
if (
data
and 'inputs' in data.keys()
and len(data.get('inputs'))
and data.get('inputs')
and 'dataset_name' in kw
):
data['inputs'] = self._reconstitute_inputs(
Expand All @@ -849,7 +911,7 @@ def _handler(app: wsgi_app.App, request, environ):

def __init__(
self,
models: Mapping[str, lit_model.Model],
models: lit_model.ModelMap,
datasets: Mapping[str, lit_dataset.Dataset],
generators: Optional[Mapping[str, lit_components.Generator]] = None,
interpreters: Optional[Mapping[str, lit_components.Interpreter]] = None,
Expand Down
11 changes: 8 additions & 3 deletions lit_nlp/client/core/global_settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,19 @@ export class GlobalSettingsComponent extends MobxLitElement {

if (newInfo == null) {return;}

const [metadata, modelName] = newInfo;

// TODO(b/270268760): Adding a model via the UI is profoundly slow for
// LLM use cases due to adding modelNames to the AppState.currentModels,
// which fetches preds for the entire dataset at once. Doing this on
// demand would dramatically improve performance and allow adding
// modelNames to the AppState.currentModels
const [metadata, /* modelNames */] = newInfo;
if (loaderSpec != null) {
this.loadingCallConfig = initializeCallConfig(loaderSpec);
}
this.appState.metadata = metadata;
this.appState.currentModels.push(modelName);
this.resetLoadingCallConfig();
this.initializeLocalState();
// this.status = 'New model initialized and added auccessfully.';
};

const hideLoadingControls = this.appState.metadata.demoMode ||
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/client/services/api_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ export class ApiService extends LitService {
* with and satisfy the requirements of the `Model.init_spec()`.
*/
async createModel(model: string, config: CallConfig):
Promise<[LitMetadata, string]> {
Promise<[LitMetadata, string[]]> {
const loadMessage = 'Loading new model';
return this.queryServer(
'/create_model',
Expand Down
16 changes: 7 additions & 9 deletions lit_nlp/examples/prompt_debugging/keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,17 +514,15 @@ def output_spec(self) -> lit_types.Spec:


def initialize_model_group_for_salience(
name: str, *args, **kw
) -> dict[str, lit_model.Model]:
new_name: str, **kw
) -> lit_model.ModelMap:
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'."""
salience_name, tokenizer_name = pd_utils.generate_model_group_names(name)
generation_model = KerasGenerationModel(*args, **kw)
salience_model = KerasSalienceModel(model=generation_model.model, *args, **kw)
tokenizer_model = KerasTokenizerModel(
model=generation_model.model, *args, **kw
)
salience_name, tokenizer_name = pd_utils.generate_model_group_names(new_name)
generation_model = KerasGenerationModel(**kw)
salience_model = KerasSalienceModel(model=generation_model.model, **kw)
tokenizer_model = KerasTokenizerModel(model=generation_model.model, **kw)
return {
name: generation_model,
new_name: generation_model,
salience_name: salience_model,
tokenizer_name: tokenizer_model,
}
Loading

0 comments on commit ebeac49

Please sign in to comment.