From 5de804db79b060cb6ff9ddd6d1f387b50a948e08 Mon Sep 17 00:00:00 2001 From: Googler Date: Thu, 13 Aug 2020 13:46:19 -0400 Subject: [PATCH] Merged commit includes the following changes: 326478470 by iftenney: Fix unbatching issue and tidy up code for GPT-2 demo -- 326478303 by iftenney: Fix frontend error if no token input is available -- 326465488 by jwexler: Allow setting of hostname used by werkzeug -- 326434919 by jwexler: LM prediction module: reset masked token on example switch -- 326347597 by jwexler: Clean up README -- 326347174 by jwexler: Update README for LIT paper -- 326346943 by iftenney: Fix tokenization bug in click-to-mask mode for MLM -- 326345968 by iftenney: Safer max_length handling for GLUE classifier -- 326339730 by iftenney: Set default layout for LM demo -- 326338272 by iftenney: Internal change 326284480 by iftenney: Internal change 326278999 by iftenney: Internal change 326230051 by jwexler: Internal change PiperOrigin-RevId: 326478470 --- README.md | 89 ++++-- docs/development.md | 27 +- docs/faq.md | 35 ++- docs/index.md | 58 +--- docs/python_api.md | 50 ++- environment.yml | 4 +- lit_nlp/client/layout.ts | 4 +- .../client/modules/lm_prediction_module.ts | 44 +-- lit_nlp/client/tsconfig.json | 2 +- lit_nlp/examples/glue_demo.py | 6 +- lit_nlp/examples/models/glue_models.py | 13 +- lit_nlp/examples/models/pretrained_lms.py | 297 ++++++++---------- lit_nlp/examples/pretrained_lm_demo.py | 6 +- lit_nlp/lib/wsgi_serving.py | 8 +- lit_nlp/server_flags.py | 1 + requirements.txt | 13 - 16 files changed, 324 insertions(+), 333 deletions(-) delete mode 100644 requirements.txt diff --git a/README.md b/README.md index cc58c4ca..5b7c8b0e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Language Interpretability Tool (LIT) :fire: +# 🔥 Language Interpretability Tool (LIT) The Language Interpretability Tool (LIT) is a visual, interactive model-understanding tool for NLP models. @@ -29,31 +29,52 @@ Features include: multi-head models and multiple input features out of the box. * **Framework-agnostic** and compatible with TensorFlow, PyTorch, and more. -For a broader overview, check out [our paper](TBD) and the +For a broader overview, check out [our paper](https://arxiv.org/abs/2008.05122) and the [user guide](docs/user_guide.md). -## Getting Started +## Documentation + +* [User Guide](docs/user_guide.md) +* [Developer Guide](docs/development.md) +* [FAQ](docs/faq.md) + +## Download and Installation Download the repo and set up a Python environment: ```sh git clone https://github.com/PAIR-code/lit.git ~/lit + +# Set up Python environment cd ~/lit conda env create -f environment.yml conda activate lit-nlp +conda install cudnn cupti # optional, for GPU support +conda install -c pytorch pytorch # optional, for PyTorch + +# Build the frontend +cd ~/lit/lit_nlp/client +yarn && yarn build ``` -Build the frontend (output will be in `~/lit/client/build`). You only need to do -this once, unless you change the TypeScript or CSS files. +## Running LIT + +### Quick-start: sentiment classifier ```sh -cd ~/lit/lit_nlp/client -yarn # install deps -yarn build --watch +cd ~/lit +python -m lit_nlp.examples.quickstart_sst_demo --port=5432 ``` -And run a LIT server, such as those included in -../lit_nlp/examples: +This will fine-tune a [BERT-tiny](https://arxiv.org/abs/1908.08962) model on the +[Stanford Sentiment Treebank](https://nlp.stanford.edu/sentiment/treebank.html), +which should take less than 5 minutes on a GPU. After training completes, it'll +start a LIT server on the development set; navigate to http://localhost:5432 for +the UI. + +### Quick start: language modeling + +To explore predictions from a pretrained language model (BERT or GPT-2), run: ```sh cd ~/lit @@ -61,30 +82,58 @@ python -m lit_nlp.examples.pretrained_lm_demo --models=bert-base-uncased \ --port=5432 ``` -You can then access the LIT UI at http://localhost:5432. +And navigate to http://localhost:5432 for the UI. -## Full Documentation +### More Examples -[Click here for the full documentation site.](docs/index.md) +See ../lit_nlp/examples. Run similarly to the above: -To learn about the features of the tool as an end-user, check out the -[user guide](docs/user_guide.md). +```sh +cd ~/lit +python -m lit_nlp.examples. --port=5432 [optional --args] +``` + +## User Guide + +To learn about LIT's features, check out the [user guide](user_guide.md), or +watch this [short video](https://www.youtube.com/watch?v=j0OfBWFUqIE). + +## Adding your own models or data You can easily run LIT with your own model by creating a custom `demo.py` -launcher, similar to those in ../lit_nlp/examples. For a full -walkthrough, see -[adding models and data](docs/python_api.md#adding-models-and-data). +launcher, similar to those in ../lit_nlp/examples. The basic +steps are: + +* Write a data loader which follows the + [`Dataset` API](python_api.md#datasets) +* Write a model wrapper which follows the [`Model` API](python_api.md#models) +* Pass models, datasets, and any additional + [components](python_api.md#interpretation-components) to the LIT server + class + +For a full walkthrough, see +[adding models and data](python_api.md#adding-models-and-data). +## Extending LIT with new components LIT is easy to extend with new interpretability components, generators, and more, both on the frontend or the backend. See the -[developer guide](docs/development.md) to get started. +[developer guide](development.md) to get started. ## Citing LIT If you use LIT as part of your work, please cite: -TODO: add BibTeX here once we're on arXiv +``` +@misc{tenney2020language, + title={The Language Interpretability Tool: Extensible, Interactive Visualizations and Analysis for NLP Models}, + author={Ian Tenney and James Wexler and Jasmijn Bastings and Tolga Bolukbasi and Andy Coenen and Sebastian Gehrmann and Ellen Jiang and Mahima Pushkarna and Carey Radebaugh and Emily Reif and Ann Yuan}, + year={2020}, + eprint={2008.05122}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` ## Disclaimer diff --git a/docs/development.md b/docs/development.md index 0d412325..022c13a5 100644 --- a/docs/development.md +++ b/docs/development.md @@ -39,12 +39,12 @@ browser: models = {'foo': FooModel(...), 'bar': BarModel(...)} datasets = {'baz': BazDataset(...)} -server = lit.Server(models, datasets, port=4321) +server = lit_nlp.dev_server.Server(models, datasets, port=4321) server.serve() ``` -For more, see [adding models and data](python_api.md#adding-models-and-data) or the -examples in ../lit_nlp/examples. +For more, see [adding models and data](python_api.md#adding-models-and-data) or +the examples in ../lit_nlp/examples. [^1]: Naming is just a happy coincidence; the Language Interpretability Tool is not related to the lit-html or lit-element projects. @@ -63,10 +63,10 @@ might define the following spec: ```python # dataset.spec() { - "premise": lit.TextSegment(), - "hypothesis": lit.TextSegment(), - "label": lit.CategoryLabel(vocab=["entailment", "neutral", "contradiction"]), - "genre": lit.CategoryLabel(), + "premise": lit_types.TextSegment(), + "hypothesis": lit_types.TextSegment(), + "label": lit_types.CategoryLabel(vocab=["entailment", "neutral", "contradiction"]), + "genre": lit_types.CategoryLabel(), } ``` @@ -88,8 +88,8 @@ subset of the dataset fields: ```python # model.input_spec() { - "premise": lit.TextSegment(), - "hypothesis": lit.TextSegment(), + "premise": lit_types.TextSegment(), + "hypothesis": lit_types.TextSegment(), } ``` @@ -98,8 +98,9 @@ And the output spec: ```python # model.output_spec() { - "probas": lit.MulticlassPreds(parent="label", - vocab=["entailment", "neutral", "contradiction"]), + "probas": lit_types.MulticlassPreds( + parent="label", + vocab=["entailment", "neutral", "contradiction"]), } ``` @@ -126,7 +127,7 @@ defining multiple `TextSegment` fields as in the above example, while multi-headed models can simply define multiple output fields. Furthermore, new types can easily be added to support custom input modalities, output types, or to provide access to model internals. For a more detailed example, see the -[`lit.Model` documentation](python_api#models). +[`Model` documentation](python_api#models). The actual spec types, such as `MulticlassLabel`, are simple dataclasses (built using [`attr.s`](https://www.attrs.org/en/stable/). They are defined in Python, @@ -134,7 +135,7 @@ but are available in the [TypeScript client](client.md) as well. [`utils.find_spec_keys()`](../lit_nlp/lib/utils.py) (Python) and -[findSpecKeys()](../lit_nlp/client/lib/utils.ts) +[`findSpecKeys()`](../lit_nlp/client/lib/utils.ts) (TypeScript) are commonly used to interact with a full spec and identify fields of interest. These recognize subclasses: for example, `utils.find_spec_keys(spec, Scalar)` will also match any `RegressionScore` diff --git a/docs/faq.md b/docs/faq.md index ee443d58..4e428d11 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -2,7 +2,34 @@ -### Can LIT work with ``? +### Your implementation of `` is really cool - can I use it in ``? + +For backend components: yes! Models, datasets, and interpretation components +don't depend on the LIT serving code at all, and they're designed for standalone +use. You can treat them as any other Python class and use from Colab, regular +scripts, bulk inference pipelines, etc. For example, to compute LIME: + +```python +from lit_nlp.examples.datasets import glue +from lit_nlp.examples.models import glue_models +from lit_nlp.components import lime_explainer + +dataset = glue.SST2Data('validation') +model = glue_models.SST2Model("/path/to/saved/model") +lime = lime_explainer.LIME() +lime.run([dataset.examples[0]], model, dataset) +# will return {"tokens": ..., "salience": ...} for each example given +``` + +For the frontend, it's a little more difficult. In order to respond to and +interact with the shared UI state, there's a lot more "framework" code involved. +We're working on refactoring the LIT modules +(../lit_nlp/client/modules) to separate framework and API +code from the visualizations (e.g. +../lit_nlp/client/elements), which can then be re-used in +other environments. + +### Can LIT work with ``? Generally, yes! But you'll probably want to use `warm_start=1.0` (or pass `--warm_start=1.0` as a flag) to pre-compute predictions when the server loads, @@ -12,14 +39,14 @@ Also, beware of memory usage: since LIT keeps the models in memory to support new queries, only so many can fit on a single GPU. If you want to load more models than can fit in local memory, LIT has experimental support for remotely-hosted models on another LIT server (see -[`remote_model.py`](../language/lit/components/remote_model.py) -for more details), and you can also write a [`lit.Model`](python_api.md#models) +[`remote_model.py`](../lit_nlp/components/remote_model.py) +for more details), and you can also write a [`Model`](python_api.md#models) class to interface with your favorite serving framework. ### How many datapoints / examples can LIT handle? It depends on your model, and on your hardware. We've successfully tested with -10k examples (the entire MultiNLI `validation_matched` split), including +10k examples (the full MultiNLI `validation_matched` split), including embeddings from the model. But, a couple caveats: * LIT expects predictions to be available on the whole dataset when the UI diff --git a/docs/index.md b/docs/index.md index 9dbf6221..96f38c60 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,58 +4,6 @@ -## Getting Started - -Download the repo and set up a Python environment: - -```sh -git clone https://github.com/PAIR-code/lit.git ~/lit -cd ~/lit -conda env create -f environment.yml -conda activate lit-nlp -``` - -Build the frontend (output will be in `~/lit/client/build`). You only need to do -this once, unless you change the TypeScript or CSS files. - -```sh -cd ~/lit/client -yarn # install deps -yarn build --watch -``` - -And run a LIT server, such as those included in -../lit_nlp/examples: - -```sh -cd ~/lit -python -m lit_nlp.examples.pretrained_lm_demo --models=bert-base-uncased \ - --port=5432 -``` - -You can then access the LIT UI at http://localhost:4321. - -## User Guide - -To learn about LIT's features, check out the [user guide](user_guide.md). - -## Adding your own models or data - -You can easily run LIT with your own model by creating a custom `demo.py` -launcher, similar to those in ../lit_nlp/examples. The basic steps -are: - -* Write a data loader which follows the - [`lit.Dataset` API](python_api.md#datasets) -* Write a model wrapper which follows the - [`lit.Model` API](python_api.md#models) -* Pass models, datasets, and any additional - [components](python_api.md#interpretation-components) to the LIT server class - -For a full walkthrough, see [adding models and data](python_api.md#adding-models-and-data). - -## Extending LIT with new components - -LIT is easy to extend with new interpretability components, generators, and -more, both on the frontend or the backend. See the -[developer guide](development.md) to get started. +* [User Guide](user_guide.md) +* [Developer Guide](development.md) +* [FAQ](faq.md) diff --git a/docs/python_api.md b/docs/python_api.md index 43c23e2c..9c5d21dd 100644 --- a/docs/python_api.md +++ b/docs/python_api.md @@ -19,19 +19,19 @@ script that passes these to the LIT server. For example: ```py def main(_): - # MulitiNLIData implements lit.Dataset + # MulitiNLIData implements the Dataset API datasets = { 'mnli_matched': MultiNLIData('/path/to/dev_matched.tsv'), 'mnli_mismatched': MultiNLIData('/path/to/dev_mismatched.tsv'), } - # NLIModel implements lit.Model + # NLIModel implements the Model API models = { 'model_foo': NLIModel('/path/to/model/foo/files'), 'model_bar': NLIModel('/path/to/model/bar/files'), } - lit_demo = lit.Server(models, datasets, port=4321) + lit_demo = lit_nlp.dev_server.Server(models, datasets, port=4321) lit_demo.serve() if __name__ == '__main__': @@ -39,29 +39,29 @@ if __name__ == '__main__': ``` Conceptually, a dataset is just a list of examples and a model is just a -function that takes examples and returns predictions. The -[`lit.Dataset`](#datasets) and [`lit.Model`](#models) classes implement this, -and provide metadata (see the [type system](development.md#type-system)) to -describe themselves to other components. +function that takes examples and returns predictions. The [`Dataset`](#datasets) +and [`Model`](#models) classes implement this, and provide metadata (see the +[type system](development.md#type-system)) to describe themselves to other +components. For full examples, see ../lit_nlp/examples ## Datasets -Datasets ([`lit.Dataset`](../lit_nlp/api/dataset.py)) -are just a list of examples, with associated type information following LIT's +Datasets ([`Dataset`](../lit_nlp/api/dataset.py)) are +just a list of examples, with associated type information following LIT's [type system](development.md#type-system). * `spec()` should return a flat dict that describes the fields in each example * `self._examples` should be a list of flat dicts Implementations should subclass -[`lit.Dataset`](../lit_nlp/api/dataset.py). Usually -this is just a few lines of code - for example, the following is a complete -dataset loader for [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/): +[`Dataset`](../lit_nlp/api/dataset.py). Usually this +is just a few lines of code - for example, the following is a complete dataset +loader for [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/): ```py -class MultiNLIData(lit.Dataset): +class MultiNLIData(Dataset): """Loader for MultiNLI development set.""" NLI_LABELS = ['entailment', 'neutral', 'contradiction'] @@ -123,7 +123,7 @@ a `"text"` input via `Dataset.remap({"document": ## Models -Models ([`lit.Model`](../lit_nlp/api/model.py)) are +Models ([`Model`](../lit_nlp/api/model.py)) are functions which take inputs and produce outputs, with associated type information following LIT's [type system](development.md#type-system). The core API consists of three methods: @@ -137,12 +137,11 @@ API consists of three methods: matching `output_spec()`. Implementations should subclass -[`lit.Model`](../lit_nlp/api/model.py). An example -for [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) might look something -like: +[`Model`](../lit_nlp/api/model.py). An example for +[MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) might look something like: ```py -class NLIModel(lit.Model): +class NLIModel(Model): """Wrapper for a Natural Language Inference model.""" NLI_LABELS = ['entailment', 'neutral', 'contradiction'] @@ -177,14 +176,14 @@ Unlike the dataset example, this model implementation is incomplete - you'll need to customize `predict()` (or `predict_minibatch()`) accordingly with any pre- or post-processing needed, such as tokenization. -Note: The `lit.Model` base class implements simple batching, aided by the +Note: The `Model` base class implements simple batching, aided by the `max_minibatch_size()` function. This is purely for convenience, since most deep learning models will want this behavior. But if you don't need it, you can simply override the `predict()` function directly and handle large inputs accordingly. Note: there are a few additional methods in the model API - see -[`lit.Model`](../lit_nlp/api/model.py) for details. +[`Model`](../lit_nlp/api/model.py) for details. ### Adding more outputs @@ -242,9 +241,9 @@ aids like [UMAP](https://umap-learn.readthedocs.io/en/latest/), and counterfactual generator plug-ins. Most such components implement the -[`lit.Interpreter`](../lit_nlp/api/components.py) -API. Conceptually, this is any function that takes a set of datapoints and a -model, and produces some output.[^identity-component] For example, +[`Interpreter`](../lit_nlp/api/components.py) API. +Conceptually, this is any function that takes a set of datapoints and a model, +and produces some output.[^identity-component] For example, [local gradient-based salience](../lit_nlp/components/gradient_maps.py) processes the `TokenGradients` and `Tokens` returned by a model and produces a list of scores for each token. @@ -361,9 +360,8 @@ Conceptually, a generator is just an interpreter that returns new input examples. These may depend on the input only, as for techniques such as backtranslation, or can involve feedback from the model, such as for adversarial attacks. Currently, generators use a separate API, subclassing -[`lit.Generator`](../lit_nlp/api/components.py), but -in the near future this will be merged into the `Interpreter` API described -above. +[`Generator`](../lit_nlp/api/components.py), but in +the near future this will be merged into the `Interpreter` API described above. The core generator API is: diff --git a/environment.yml b/environment.yml index 7a40b165..184d12dc 100644 --- a/environment.yml +++ b/environment.yml @@ -20,10 +20,10 @@ dependencies: - scipy - pandas - scikit-learn - - tensorflow - - tensorflow-datasets - pip - pip: + - tensorflow + - tensorflow-datasets - lime - sacrebleu - umap-learn diff --git a/lit_nlp/client/layout.ts b/lit_nlp/client/layout.ts index 8cfda4ed..0dbcab21 100644 --- a/lit_nlp/client/layout.ts +++ b/lit_nlp/client/layout.ts @@ -90,7 +90,7 @@ export const LAYOUTS: LitComponentLayouts = { /** * For masked language models */ - 'mlm': { + 'lm': { components : { 'Main': [EmbeddingsModule, DataTableModule, DatapointEditorModule], 'Predictions': [ @@ -124,4 +124,4 @@ export const LAYOUTS: LitComponentLayouts = { } }, }; -// clang-format on \ No newline at end of file +// clang-format on diff --git a/lit_nlp/client/modules/lm_prediction_module.ts b/lit_nlp/client/modules/lm_prediction_module.ts index ff2c1169..9ce3378f 100644 --- a/lit_nlp/client/modules/lm_prediction_module.ts +++ b/lit_nlp/client/modules/lm_prediction_module.ts @@ -16,6 +16,7 @@ */ import '../elements/checkbox'; + // tslint:disable:no-new-decorators import {customElement, html, property} from 'lit-element'; import {classMap} from 'lit-html/directives/class-map'; @@ -24,7 +25,7 @@ import {computed, observable} from 'mobx'; import {app} from '../core/lit_app'; import {LitModule} from '../core/lit_module'; import {IndexedInput, ModelsMap, Spec, TopKResult} from '../lib/types'; -import {doesOutputSpecContain, findSpecKeys, flatten} from '../lib/utils'; +import {doesOutputSpecContain, findSpecKeys, flatten, isLitSubtype} from '../lib/utils'; import {styles} from './lm_prediction_module.css'; import {styles as sharedStyles} from './shared_styles.css'; @@ -68,17 +69,21 @@ export class LanguageModelPredictionModule extends LitModule { } @computed - private get outputTokenKey(): string { + private get outputTokensKey(): string { const spec = this.appState.getModelSpec(this.model); // This list is guaranteed to be non-empty due to checkModule() return spec.output[this.predKey].align as string; } @computed - private get inputTextKey(): string { + private get inputTokensKey(): string|null { const spec = this.appState.getModelSpec(this.model); - // TODO(lit-dev): ensure this is set in order to enable MLM mode. - return spec.output[this.outputTokenKey].parent!; + // Look for an input field matching the output tokens name. + if (spec.input.hasOwnProperty(this.outputTokensKey) && + isLitSubtype(spec.input[this.outputTokensKey], 'Tokens')) { + return this.outputTokensKey; + } + return null; } firstUpdated() { @@ -90,10 +95,10 @@ export class LanguageModelPredictionModule extends LitModule { } private async updateSelection(selectedInput: IndexedInput|null) { + this.selectedTokenIndex = null; if (selectedInput == null) { this.selectedInput = null; this.tokens = []; - this.selectedTokenIndex = null; this.lmResults = []; return; } @@ -106,7 +111,7 @@ export class LanguageModelPredictionModule extends LitModule { if (results === null) return; const predictions = results[0]; - this.tokens = predictions[this.outputTokenKey]; + this.tokens = predictions[this.outputTokensKey]; this.lmResults = predictions[this.predKey]; this.selectedInput = selectedInput; @@ -118,24 +123,23 @@ export class LanguageModelPredictionModule extends LitModule { // If there's nothing to show, enable click-to-mask by default. // TODO(lit-dev): infer this from something in the spec instead. - if (flatten(this.lmResults).length === 0) { + if (flatten(this.lmResults).length === 0 && this.inputTokensKey != null) { this.clickToMask = true; } } // TODO(lit-dev): unify this codepath with updateSelection()? - private async updateLmResults(index: number) { + private async updateLmResults(maskIndex: number) { if (this.selectedInput == null) return; if (this.clickToMask) { + if (this.inputTokensKey == null) return; const tokens = [...this.tokens]; - tokens[index] = this.maskToken; - // TODO(lit-dev): detokenize properly, or feed tokens directly to model? - const input = tokens.join(' '); + tokens[maskIndex] = this.maskToken; - // Use empty id to disable caching on backend. const inputData = Object.assign( - {}, this.selectedInput.data, {[this.inputTextKey]: input}); + {}, this.selectedInput.data, {[this.inputTokensKey]: tokens}); + // Use empty id to disable caching on backend. const inputs: IndexedInput[] = [{'data': inputData, 'id': '', 'meta': {}}]; @@ -148,7 +152,7 @@ export class LanguageModelPredictionModule extends LitModule { this.lmResults = lmResults[0][this.predKey]; this.maskApplied = true; } - this.selectedTokenIndex = index; + this.selectedTokenIndex = maskIndex; } updated() { @@ -177,10 +181,12 @@ export class LanguageModelPredictionModule extends LitModule { // clang-format off return html`
- { this.clickToMask = !this.clickToMask; }} - > + ${this.inputTokensKey ? html` + { this.clickToMask = !this.clickToMask; }} + > + ` : null}
`; // clang-format on diff --git a/lit_nlp/client/tsconfig.json b/lit_nlp/client/tsconfig.json index b2fc4a9c..f1092c50 100644 --- a/lit_nlp/client/tsconfig.json +++ b/lit_nlp/client/tsconfig.json @@ -18,7 +18,7 @@ "esModuleInterop": true, "experimentalDecorators": true, "importsNotUsedAsValues": "preserve", - "types": ["node", "jasmine", "resize-observer-browser"], + "types": ["node", "jasmine", "resize-observer-browser"] }, "include": ["./"], "compileOnSave": false diff --git a/lit_nlp/examples/glue_demo.py b/lit_nlp/examples/glue_demo.py index 7e0bd5bc..e5bd33aa 100644 --- a/lit_nlp/examples/glue_demo.py +++ b/lit_nlp/examples/glue_demo.py @@ -34,8 +34,10 @@ "/, and in standard transformers format, e.g. as " "saved by model.save_pretrained() and tokenizer.save_pretrained().") -flags.DEFINE_integer("max_examples", None, - "Maximum number of examples to load into LIT.") +flags.DEFINE_integer( + "max_examples", None, "Maximum number of examples to load into LIT. " + "Note: MNLI eval set is 10k examples, so will take a while to run and may " + "be slow on older machines. Set --max_examples=200 for a quick start.") def main(_): diff --git a/lit_nlp/examples/models/glue_models.py b/lit_nlp/examples/models/glue_models.py index 9de14bf1..09c9434a 100644 --- a/lit_nlp/examples/models/glue_models.py +++ b/lit_nlp/examples/models/glue_models.py @@ -79,9 +79,7 @@ def __init__(self, from_pt=True, ) - def _preprocess(self, - inputs: Iterable[JsonDict], - do_trim: bool = True) -> Dict[str, tf.Tensor]: + def _preprocess(self, inputs: Iterable[JsonDict]) -> Dict[str, tf.Tensor]: segments = [ (ex[self.config.text_a_name], ex[self.config.text_b_name] if self.config.text_b_name else None) @@ -91,14 +89,17 @@ def _preprocess(self, segments, return_tensors="tf", add_special_tokens=True, - max_length=None if do_trim else self.config.max_seq_length, + max_length=self.config.max_seq_length, pad_to_max_length=True) + # Trim everything to the actual max length, to remove extra padding. + max_tokens = tf.reduce_max( + tf.reduce_sum(encoded_input["attention_mask"], axis=1)) + encoded_input = {k: v[:, :max_tokens] for k, v in encoded_input.items()} return encoded_input def _make_dataset(self, inputs: Iterable[JsonDict]) -> tf.data.Dataset: """Make a tf.data.Dataset from inputs in LIT format.""" - # Set do_trim=False to pad everything to max_seq_length. - encoded_input = self._preprocess(inputs, do_trim=False) + encoded_input = self._preprocess(inputs) if self.is_regression: labels = tf.constant([ex[self.config.label_name] for ex in inputs], dtype=tf.float32) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 742d0366..647d0673 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -7,7 +7,7 @@ functions to predict a batch of examples and extract information such as hidden states and attention. """ -from typing import Any, Dict, List, Text, Tuple +from typing import Dict, List, Tuple from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types @@ -17,14 +17,16 @@ import tensorflow as tf import transformers -MAX_SEQ_LENGTH = 512 - class BertMLM(lit_model.Model): """BERT masked LM using Huggingface Transformers and TensorFlow 2.""" MASK_TOKEN = "[MASK]" + @property + def max_seq_length(self): + return self.model.config.max_position_embeddings + def __init__(self, model_name="bert-base-uncased", top_k=10): super().__init__() self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) @@ -34,14 +36,6 @@ def __init__(self, model_name="bert-base-uncased", top_k=10): model_name, output_hidden_states=True, output_attentions=True) self.top_k = top_k - def _batch_encode(self, text: List[str]): - """Encode a batch of strings for model input.""" - return self.tokenizer.batch_encode_plus( - text, - return_tensors="tf", - add_special_tokens=True, - pad_to_max_length=True) - # TODO(lit-dev): break this out as a helper function, write some tests, # and de-duplicate code with the other text generation functions. def _get_topk_tokens(self, @@ -84,9 +78,37 @@ def _postprocess(self, output: Dict[str, np.ndarray]): return output - def _predict_minibatch(self, texts: List[str]): - """Run the model on a batch of texts.""" - encoded_input = self._batch_encode(texts) + ## + # LIT API implementations + def max_minibatch_size(self, unused_config=None) -> int: + # The lit.Model base class handles batching automatically in the + # implementation of predict(), and uses this value as the batch size. + return 8 + + def predict_minibatch(self, inputs, config=None): + """Predict on a single minibatch of examples.""" + # If input has a 'tokens' field, use that. Otherwise tokenize the text. + tokenized_texts = [ + ex.get("tokens") or self.tokenizer.tokenize(ex["text"]) for ex in inputs + ] + # Process to ids, add special tokens, and compute segment ids and masks. + encoded_input = self.tokenizer.batch_encode_plus( + tokenized_texts, + is_pretokenized=True, + return_tensors="tf", + add_special_tokens=True, + max_length=self.max_seq_length, + pad_to_max_length=True) + # We have to set max_length explicitly above so that + # max_tokens <= model_max_length, in order to avoid indexing errors. But + # the combination of max_length= and pad_to_max_length=True means + # that if the max is < model_max_length, we end up with extra padding. + # Thee lines below strip this off. + # TODO(lit-dev): submit a PR to make this possible with tokenizer options? + max_tokens = tf.reduce_max( + tf.reduce_sum(encoded_input["attention_mask"], axis=1)) + encoded_input = {k: v[:, :max_tokens] for k, v in encoded_input.items()} + # logits is a single tensor # [batch_size, num_tokens, vocab_size] # embs is a list of num_layers + 1 tensors, each @@ -105,19 +127,11 @@ def _predict_minibatch(self, texts: List[str]): # Postprocess to remove padding and decode predictions. return map(self._postprocess, unbatched_outputs) - ## - # LIT API implementations - def max_minibatch_size(self, unused_config=None) -> int: - # The lit.Model base class handles batching automatically in the - # implementation of predict(), and uses this value as the batch size. - return 8 - - def predict_minibatch(self, inputs, config=None): - """Predict on a single minibatch of examples.""" - return self._predict_minibatch([ex["text"] for ex in inputs]) - def input_spec(self): - return {"text": lit_types.TextSegment()} + return { + "text": lit_types.TextSegment(), + "tokens": lit_types.Tokens(required=False), + } def output_spec(self): return { @@ -128,45 +142,6 @@ def output_spec(self): class GPT2LanguageModel(lit_model.Model): - """Wrapper for a GPT-2 language model.""" - - def __init__(self, *args, **kw): - # This loads the checkpoint into memory, so we"re ready for interactive use. - self._model = HFGPT2(*args, **kw) - - # LIT API implementations - def max_minibatch_size(self, unused_config=None) -> int: - # The lit.Model base class handles batching automatically in the - # implementation of predict(), and uses this value as the batch size. - return 6 - - def predict_minibatch(self, inputs, config=None): - """Predict on a single minibatch of examples.""" - examples = [self._model.convert_dict_input(d) for d in inputs] - payload = self._model.predict_examples(examples) - return payload - - def input_spec(self): - return { - "text": lit_types.TextSegment(), - } - - def output_spec(self): - spec = { - # the "parent" keyword tells LIT which field in the input spec we should - # compare this to when computing metrics. - "pred_tokens": lit_types.TokenTopKPreds(align="tokens"), - "tokens": lit_types.Tokens(parent="text"), # all tokens - } - # Add attention for each layer. - for i in range(self._model.get_num_layers()): - spec[f"layer_{i:d}_attention"] = lit_types.AttentionHeads( - align=("tokens", "tokens")) - spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings() - return spec - - -class HFGPT2(object): """Wrapper for a Huggingface Transformers GPT-2 model. This class loads a tokenizer and model using the Huggingface library and @@ -174,14 +149,17 @@ class HFGPT2(object): convert and clean tokens and to compute the top_k predictions from logits. """ + @property + def num_layers(self): + return self.model.config.n_layer + def __init__(self, model_name="gpt2", top_k=10): - """Constructor for HFGPT2 class. + """Constructor for GPT2LanguageModel. Args: model_name: Specify the GPT-2 size [distil, small, medium, large, xl]. top_k: How many predictions to prune. """ - super().__init__() # GPT2 is trained without pad_token, so pick arbitrary one and mask out. self.tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -190,57 +168,24 @@ def __init__(self, model_name="gpt2", top_k=10): model_name, output_hidden_states=True, output_attentions=True) self.top_k = top_k - def _tokenize(self, text: List[str]): - """Function to tokenize a batch of strings. - - Args: - text: A list of strings to analyze. - - Returns: - tok_input: Dictionary of input_ids, token_type_ids, and attention_mask; - Each is tf.Tensor of shape (batch_size, max_len_within_batch). - """ - tok_input = self.tokenizer.batch_encode_plus( - text, - # Specify TF over PyTorch. - return_tensors="tf", - # For sequence boundaries. - add_special_tokens=True, - # Otherwise interpreted as starting with space. - add_prefix_space=True, - # Pad up to the max length inside the batch. - pad_to_max_length=True) - return tok_input - - def _clean_bpe(self, tokens): - """Converts the special BPE tokens into readable format.""" + @staticmethod + def clean_bpe_token(tok): + if not tok.startswith("Ä "): + return "_" + tok + else: + return tok.replace("Ä ", "") - def clean_token(tok): - if not tok.startswith("Ä "): - return "_" + tok - else: - return tok.replace("Ä ", "") + def _detokenize(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return [self.clean_bpe_token(t) for t in tokens] - if isinstance(tokens, list): - return [clean_token(t) for t in tokens] - else: - return clean_token(tokens) - - def _detokenize(self, tokenized_text): - """Convert back from tokenized dict to List[str].""" - tokens = [] - for ids, mask in zip(tokenized_text["input_ids"], - tokenized_text["attention_mask"]): - # Filter out padding and remove BPE continuation token. - example = self._clean_bpe( - self.tokenizer.convert_ids_to_tokens( - [t for ix, t in enumerate(ids) if mask[ix] != 0])) - tokens.append(example) - return tokens - - def _pred(self, tokenized_text): + def _pred(self, encoded_inputs): """Predicts one batch of tokenized text. + Also performs some batch-level post-processing in TF. + Single-example postprocessing is done in _postprocess(), and operates on + numpy arrays. + Each prediction has the following returns: logits: tf.Tensor (batch_size, sequence_length, config.vocab_size). past: List[tf.Tensor] of length config.n_layers with each tensor shape @@ -252,68 +197,88 @@ def _pred(self, tokenized_text): Within this function, we combine each Tuple/List into a single Tensor. Args: - tokenized_text: Dictionary with output from self._tokenize + encoded_inputs: output of self.tokenizer.batch_encode_plus() Returns: payload: Dictionary with items described above, each as single Tensor. """ - logits, _, states, attentions = self.model(tokenized_text["input_ids"]) - # Convert representations for each layer from tuples to single Tensor. - payload = {} - for i in range(len(attentions)): - payload[f"layer_{i:d}_attention"] = attentions[i].numpy() - for i in range(len(states)): - payload[f"layer_{i:d}_avg_embedding"] = tf.math.reduce_mean( - states[i], axis=1).numpy() + logits, _, states, attentions = self.model(encoded_inputs["input_ids"]) - payload["pred_tokens"] = self._logits_to_topk_probs(logits, tokenized_text) - return payload - - def _logits_to_topk_probs(self, logits, tokenized_input): - """Softmaxes the logits and prunes to the top k (token, prob) tuples.""" model_probs = tf.nn.softmax(logits, axis=-1) top_k = tf.math.top_k(model_probs, k=self.top_k, sorted=True, name=None) - indices = top_k.indices - probs = top_k.values - format_top_k = [] - for index_batch, prob_batch, mask_batch in zip( - indices.numpy(), probs.numpy(), - tokenized_input["attention_mask"].numpy()): - # Initialize prediction for 0th token as N/A. - formatted_batch = [[("N/A", 1.)]] - # Add all other predictions for tokens. - for index, prob, mask in zip(index_batch, prob_batch, mask_batch): - if mask == 1: - formatted_batch.append([(i, "{:.3f}".format(p)) for i, p in zip( - self._clean_bpe(self.tokenizer.convert_ids_to_tokens(index)), - prob)]) - format_top_k.append(formatted_batch) - return format_top_k - - def _postproc_preds(self, pred): - """Postprocessing done on each batch element.""" - return pred - - def convert_dict_input(self, input_dict: Dict[Text, Any]) -> Dict[Text, Any]: - """Default implementation with generic keys.""" - return { - "text": input_dict["text"], - "guid": input_dict.get("guid", ""), + batched_outputs = { + "input_ids": encoded_inputs["input_ids"], + "ntok": tf.reduce_sum(encoded_inputs["attention_mask"], axis=1), + "top_k_indices": top_k.indices, + "top_k_probs": top_k.values, } - def get_num_layers(self): - return self.model.config.n_layer + # Convert representations for each layer from tuples to single Tensor. + for i in range(len(attentions)): + batched_outputs[f"layer_{i:d}_attention"] = attentions[i] + for i in range(len(states)): + batched_outputs[f"layer_{i:d}_avg_embedding"] = tf.math.reduce_mean( + states[i], axis=1) + + return batched_outputs + + def _postprocess(self, preds): + """Post-process single-example preds. Operates on numpy arrays.""" + ntok = preds.pop("ntok") + ids = preds.pop("input_ids")[:ntok] + preds["tokens"] = self._detokenize(ids) + + # Decode predicted top-k tokens. + # token_topk_preds will be a List[List[(word, prob)]] + # Initialize prediction for 0th token as N/A. + token_topk_preds = [[("N/A", 1.)]] + pred_ids = preds.pop("top_k_indices")[:ntok] # [num_tokens, k] + pred_probs = preds.pop("top_k_probs")[:ntok] # [num_tokens, k] + for token_pred_ids, token_pred_probs in zip(pred_ids, pred_probs): + token_pred_words = self._detokenize(token_pred_ids) + token_topk_preds.append(list(zip(token_pred_words, token_pred_probs))) + preds["pred_tokens"] = token_topk_preds + + return preds - def predict_examples(self, examples): - """Public Function for LITModel to call on examples.""" - # Text as sequence of BPE ID"s. - input_ids = self._tokenize([e["text"] for e in examples]) + ## + # LIT API implementations + def max_minibatch_size(self, unused_config=None) -> int: + # The lit.Model base class handles batching automatically in the + # implementation of predict(), and uses this value as the batch size. + return 6 + + def predict_minibatch(self, inputs, config=None): + """Predict on a single minibatch of examples.""" + # Preprocess inputs. + texts = [ex["text"] for ex in inputs] + encoded_inputs = self.tokenizer.batch_encode_plus( + texts, + return_tensors="tf", + add_special_tokens=True, + add_prefix_space=True, + pad_to_max_length=True) # Get the predictions. - preds = self._pred(input_ids) - # detokenize BPE tokens for interface. - detok = self._detokenize(input_ids) - preds["tokens"] = detok + batched_outputs = self._pred(encoded_inputs) + # Convert to numpy for post-processing. + detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + # Split up batched outputs, then post-process each example. + unbatched_outputs = utils.unbatch_preds(detached_outputs) + return map(self._postprocess, unbatched_outputs) - payload = [self._postproc_preds(p) for p in utils.unbatch_preds(preds)] - return payload + def input_spec(self): + return {"text": lit_types.TextSegment()} + def output_spec(self): + spec = { + # the "parent" keyword tells LIT which field in the input spec we should + # compare this to when computing metrics. + "pred_tokens": lit_types.TokenTopKPreds(align="tokens"), + "tokens": lit_types.Tokens(parent="text"), # all tokens + } + # Add attention and embeddings from each layer. + for i in range(self.num_layers): + spec[f"layer_{i:d}_attention"] = lit_types.AttentionHeads( + align=("tokens", "tokens")) + spec[f"layer_{i:d}_avg_embedding"] = lit_types.Embeddings() + return spec diff --git a/lit_nlp/examples/pretrained_lm_demo.py b/lit_nlp/examples/pretrained_lm_demo.py index 7ca07e69..9d096457 100644 --- a/lit_nlp/examples/pretrained_lm_demo.py +++ b/lit_nlp/examples/pretrained_lm_demo.py @@ -28,6 +28,8 @@ # NOTE: additional flags defined in server_flags.py +FLAGS = flags.FLAGS + flags.DEFINE_list( "models", ["bert-base-uncased"], "Models to load. Currently supports variants of BERT and GPT-2.") @@ -45,7 +47,9 @@ "If true, will load examples from the Billion Word Benchmark dataset. This may download a lot of data the first time you run it, so disable by default for the quick-start example." ) -FLAGS = flags.FLAGS +# Set default layout to one better suited to language models. +# You can also change this via URL param e.g. localhost:5432/?layout=default +FLAGS.set_default("default_layout", "lm") def main(_): diff --git a/lit_nlp/lib/wsgi_serving.py b/lit_nlp/lib/wsgi_serving.py index 272aa7ad..985edb27 100644 --- a/lit_nlp/lib/wsgi_serving.py +++ b/lit_nlp/lib/wsgi_serving.py @@ -26,17 +26,19 @@ class BasicDevServer(object): """Basic development server; not recommended for deployment.""" - def __init__(self, wsgi_app, port: int = 4321, **unused_kw): + def __init__(self, wsgi_app, port: int = 4321, host: Text = '127.0.0.1', + **unused_kw): self._port = port + self._host = host self._app = wsgi_app self.can_act_as_model_server = True def serve(self): logging.info(('\n\nStarting Server on port %d' - '\nYou can navigate to http://127.0.0.1:%d\n\n'), self._port, + '\nYou can navigate to %s:%d\n\n'), self._port, self._host, self._port) werkzeug_serving.run_simple( - '127.0.0.1', + self._host, self._port, self._app, use_debugger=False, diff --git a/lit_nlp/server_flags.py b/lit_nlp/server_flags.py index e7836f4e..1b9b0ff6 100644 --- a/lit_nlp/server_flags.py +++ b/lit_nlp/server_flags.py @@ -35,6 +35,7 @@ flags.DEFINE_integer('port', 5432, 'What port to serve on.') flags.DEFINE_string('server_type', 'default', 'Webserver to use; see dev_server.py') +flags.DEFINE_string('host', '127.0.0.1', 'What host address to serve on.') ## diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index acf1403c..00000000 --- a/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -absl-py -numpy -scipy -pandas -scikit-learn -tensorflow -tensorflow-datasets -lime -sacrebleu -umap-learn -transformers==2.11.0 -google-cloud-translate -