Skip to content

Commit

Permalink
Adding TensorFlow support to the Machine Learning overview page (apac…
Browse files Browse the repository at this point in the history
…he#22949)

Co-authored-by: tvalentyn <tvalentyn@users.noreply.github.com>
  • Loading branch information
rszper and tvalentyn authored Sep 2, 2022
1 parent 8c57b21 commit 31561e2
Showing 1 changed file with 48 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ For more information, see the [`BatchElements` transform documentation](https://

### Shared helper class

Using the `Shared` class within RunInference implementation allows us to load the model only once per process and share it with all DoFn instances created in that process. This feature reduces memory consumption and model loading time. For more information, see the
Using the `Shared` class within the RunInference implementation makes it possible to load the model only once per process and share it with all DoFn instances created in that process. This feature reduces memory consumption and model loading time. For more information, see the
[`Shared` class documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/utils/shared.py#L20).

### Multi-model pipelines

The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out ensembles that are comprised of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out ensembles made up of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.

## Modify a pipeline to use an ML model

Expand Down Expand Up @@ -165,7 +165,49 @@ For detailed instructions explaining how to build and run a pipeline that uses M

## Beam Java SDK support

RunInference API is available to Beam Java SDK 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines). Please see [here](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java) for the Java wrapper transform to use and please see [here](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java) for some example pipelines.
The RunInference API is available with the Beam Java SDK versions 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines). For information about the Java wrapper transform, see [RunInference.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java). For example pipelines, see [RunInferenceTransformTest.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java).

## TensorFlow support

To use TensorFlow with the RunInference API, you need to do the following:

* Use `tfx_bsl` version 1.10.0 or later.
* Create a model handler using `tfx_bsl.public.beam.run_inference.CreateModelHandler()`.
* Use the model handler with the [`apache_beam.ml.inference.base.RunInference`](/releases/pydoc/current/apache_beam.ml.inference.base.html) transform.

A sample pipeline might look like the following example:

```
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from tensorflow_serving.apis import prediction_log_pb2
from tfx_bsl.public.proto import model_spec_pb2
from tfx_bsl.public.tfxio import TFExampleRecord
from tfx_bsl.public.beam.run_inference import CreateModelHandler
pipeline = beam.Pipeline()
tfexample_beam_record = TFExampleRecord(file_pattern='/path/to/examples')
saved_model_spec = model_spec_pb2.SavedModelSpec(model_path='/path/to/model')
inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)
model_handler = CreateModelHandler(inference_spec_type)
with pipeline as p:
_ = (p | tfexample_beam_record.RawRecordBeamSource()
| RunInference(model_handler)
| beam.Map(print)
)
```

The model handler that is created with `CreateModelHander()` is always unkeyed. To make a keyed model handler, wrap the unkeyed model handler in the keyed model handler, which would then take the `tfx-bsl` model handler as a parameter. For example:

```
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.base import KeyedModelHandler
RunInference(KeyedModelHandler(tf_handler))
```

If you are unsure if your data is keyed, you can also use `MaybeKeyedModelHandler`.

For more information, see [`KeyedModelHander`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler).

## Troubleshooting

Expand All @@ -177,7 +219,7 @@ In some cases, the `PredictionResults` output might not include the correct pred

The RunInference API currently expects outputs to be an `Iterable[Any]`. Example return types are `Iterable[Tensor]` or `Iterable[Dict[str, Tensor]]`. When RunInference zips the inputs with the predictions, the predictions iterate over the dictionary keys instead of the batch elements. The result is that the key name is preserved but the prediction tensors are discarded. For more information, see the [Pytorch RunInference PredictionResult is a Dict](https://github.com/apache/beam/issues/22240) issue in the Apache Beam GitHub project.

To work with the current RunInference implementation, you can create a wrapper class that overrides the `model(input)` call. In PyTorch, for example, your wrapper would override the `forward()` function and return an output with the appropriate format of `List[Dict[str, torch.Tensor]]`. For more information, see our [HuggingFace language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py#L49).
To work with the current RunInference implementation, you can create a wrapper class that overrides the `model(input)` call. In PyTorch, for example, your wrapper would override the `forward()` function and return an output with the appropriate format of `List[Dict[str, torch.Tensor]]`. For more information, see the [HuggingFace language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py#L49).

### Unable to batch tensor elements

Expand All @@ -204,5 +246,7 @@ Disable batching by overriding the `batch_elements_kwargs` function in your Mode

* [RunInference transforms](/documentation/transforms/python/elementwise/runinference)
* [RunInference API pipeline examples](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference)
* [RunInference public codelab](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_basic.ipynb)
* [RunInference notebooks](https://github.com/apache/beam/tree/master/examples/notebooks/beam-ml)

{{< button-pydoc path="apache_beam.ml.inference" class="RunInference" >}}

0 comments on commit 31561e2

Please sign in to comment.