Skip to content

Commit

Permalink
Update OutputDict -> TypedDict for more docstrings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560595942
  • Loading branch information
chongkong authored and tfx-copybara committed Aug 28, 2023
1 parent 15b9e88 commit afc6bfb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
20 changes: 12 additions & 8 deletions docs/guide/custom_function_component.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ Writing your custom component in this style is very straightforward, as in the
following example.

```python
class MyOutput(TypedDict):
accuracy: float

@component
def MyValidationComponent(
model: InputArtifact[Model],
blessing: OutputArtifact[Model],
accuracy_threshold: Parameter[int] = 10,
) -> OutputDict(accuracy=float):
) -> MyOutput:
'''My simple custom model validation component.'''

accuracy = evaluate_model(model)
Expand Down Expand Up @@ -120,12 +123,8 @@ return value using annotations from the
described in the previous section). This argument can be optional or this
argument can be defined with a default value. If your component has simple
data type outputs (`int`, `float`, `str` or `bytes`), you can return these
outputs using an `OutputDict` instance. Apply the `OutputDict` type hint as
your component’s return value.

* For each **output**, add argument `<output_name>=<T>` to the `OutputDict`
constructor, where `<output_name>` is the output name and `<T>` is the
output type, such as: `int`, `float`, `str` or `bytes`.
outputs by using a `TypedDict` as a return type annotation, and returning an
appropriate dict object.

In the body of your function, input and output artifacts are passed as
`tfx.types.Artifact` objects; you can inspect its `.uri` to get its
Expand All @@ -137,16 +136,21 @@ output names and the values are the desired return values.
The completed function component can look like this:

```python
from typing import TypedDict
import tfx.v1 as tfx
from tfx.dsl.component.experimental.decorators import component

class MyOutput(TypedDict):
loss: float
accuracy: float

@component
def MyTrainerComponent(
training_data: tfx.dsl.components.InputArtifact[tfx.types.standard_artifacts.Examples],
model: tfx.dsl.components.OutputArtifact[tfx.types.standard_artifacts.Model],
dropout_hyperparameter: float,
num_iterations: tfx.dsl.components.Parameter[int] = 10
) -> tfx.v1.dsl.components.OutputDict(loss=float, accuracy=float):
) -> MyOutput:
'''My simple trainer component.'''

records = read_examples(training_data.uri)
Expand Down
20 changes: 11 additions & 9 deletions tfx/dsl/component/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,11 @@ def component(
optional argument.
The return value typehint should be either empty or `None`, in the case of a
component function that has no return values, or an instance of
`OutputDict(key_1=type_1, ...)`, where each key maps to a given type (each
type is a primitive value type, i.e. `int`, `float`, `str`, `bytes`, `bool`
`Dict` or `List`; or `Optional[T]`, where T is a primitive type value, in
which case `None` can be returned), to indicate that the return value is a
dictionary with specified keys and value types.
component function that has no return values, or a `TypedDict` of primitive
value types (`int`, `float`, `str`, `bytes`, `bool`, `dict` or `list`; or
`Optional[T]`, where T is a primitive type value, in which case `None` can be
returned), to indicate that the return value is a dictionary with specified
keys and value types.
Note that output artifacts should not be included in the return value
typehint; they should be included as `OutputArtifact` annotations in the
Expand All @@ -330,17 +329,20 @@ def component(
InputArtifact = tfx.dsl.components.InputArtifact
OutputArtifact = tfx.dsl.components.OutputArtifact
Parameter = tfx.dsl.components.Parameter
OutputDict = tfx.dsl.components.OutputDict
Examples = tfx.types.standard_artifacts.Examples
Model = tfx.types.standard_artifacts.Model
class MyOutput(TypedDict):
loss: float
accuracy: float
@component(component_annotation=tfx.dsl.standard_annotations.Train)
def MyTrainerComponent(
training_data: InputArtifact[Examples],
model: OutputArtifact[Model],
dropout_hyperparameter: float,
num_iterations: Parameter[int] = 10
) -> OutputDict(loss=float, accuracy=float):
) -> MyOutput:
'''My simple trainer component.'''
records = read_examples(training_data.uri)
Expand Down Expand Up @@ -370,7 +372,7 @@ def MyTrainerComponent(
model: OutputArtifact[standard_artifacts.Model],
dropout_hyperparameter: float,
num_iterations: Parameter[int] = 10
) -> OutputDict(loss=float, accuracy=float):
) -> Output:
'''My simple trainer component.'''
records = read_examples(training_data.uri)
Expand Down

0 comments on commit afc6bfb

Please sign in to comment.