Skip to content

Commit

Permalink
Automated rollback of commit a590dbb
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560598539
  • Loading branch information
tfx-copybara committed Aug 28, 2023
1 parent afc6bfb commit 6f6a06a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 81 deletions.
1 change: 0 additions & 1 deletion tfx/dsl/input_resolution/ops/ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
MODEL_EXPORT_KEY = 'model_export'
MODEL_INFRA_BLESSING_KEY = 'model_infra_blessing'
MODEL_PUSH_KEY = 'model_push'
TRANSFORMED_EXAMPLES_KEY = 'transformed_examples'

# Taken from tfx.tflex.dsl.types.standard_artifacts. We don't use the existing
# constants due to Copybara.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ class SpanDrivenEvaluatorInputs(
# >= start_span_number will be considered.
start_span_number = resolver_op.Property(type=int, default=0)

# Whether to return the materialized Examples produced by the Transform
# component. Should only be used if a Model was trained on materialized
# transformed Examples produced by a Transform. Defaults to False.
use_transformed_examples = resolver_op.Property(type=bool, default=False)

def _get_model_to_evaluate(
self,
trained_examples_by_model: Dict[types.Artifact, List[types.Artifact]],
Expand All @@ -93,7 +88,7 @@ def _get_model_to_evaluate(
# The first Model was trained on spans less than max_span.
return model

# No eligible Model was found, so a SkipSignal is raised.
# No elegible Model was found, so a SkipSignal is raised.
raise exceptions.SkipSignal()

def apply(self, input_dict: typing_utils.ArtifactMultiDict):
Expand Down Expand Up @@ -151,7 +146,7 @@ def apply(self, input_dict: typing_utils.ArtifactMultiDict):
trained_examples_by_model = {}
for model in input_dict[ops_utils.MODEL_KEY]:
trained_examples_by_model[model] = training_range_op.training_range(
self.context.store, model, self.use_transformed_examples
self.context.store, model
)

# Sort the Models by latest created, with ties broken by id.
Expand Down
105 changes: 40 additions & 65 deletions tfx/dsl/input_resolution/ops/training_range_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.types import artifact_utils

from ml_metadata.google.tools.mlmd_resolver import metadata_resolver
from ml_metadata.proto import metadata_store_pb2


Expand All @@ -41,9 +40,7 @@ def _validate_input_list(
return input_list[0]


def training_range(
store: Any, model: types.Artifact, use_transformed_examples: bool = False
) -> List[types.Artifact]:
def training_range(store: Any, model: types.Artifact) -> List[types.Artifact]:
"""ContainsTrainingRange implementation, for shared use across ResolverOps.
Returns the Examples artifact the Model was trained on.
Expand All @@ -53,10 +50,6 @@ def training_range(
Args:
store: The MetadataStore.
model: The Model artifact whose trained Examples to return.
use_transformed_examples: Whether to return the materialized Examples
produced by the Transform component. Should only be used if a Model was
trained on materialized transformed Examples produced by a Transform.
Defaults to False.
Returns:
List of Examples artifacts if found, else empty list. We intentionally don't
Expand All @@ -67,68 +60,57 @@ def training_range(
# Event 1 Event 2
# Examples ------> Execution ------> Model
#
#
# or, in the case where a Transform component materializes Examples:
#
# Event 1 Event 2 Event 3
# Examples ------> Execution ------> Examples -----> Execution ------> Model
#
#
# For a single Model, there may be many parent Examples it was trained on.

# TODO(kshivvy): Support querying multiple Model ids at once, to reduce the
# number of round trip MLMD queries. This will be useful for resolving inputs
# of a span driven evaluator.

# Get all upstream Examples artifacts associated with the Model.
mlmd_resolver = metadata_resolver.MetadataResolver(store)
upstream_examples_dict = mlmd_resolver.get_upstream_artifacts_by_artifact_ids(
artifact_ids=[model.id],
# In MLMD, artifacts are 2 hops away. Because we are considering
# Example -> (transformd) Examples -> Model, we set max_num_hops to 4.
max_num_hops=4,
filter_query=f'type="{ops_utils.EXAMPLES_TYPE_NAME}"',
)
if not upstream_examples_dict:
return []
upstream_examples = upstream_examples_dict[model.id]
if not upstream_examples:
return []

# Get the sets of artifact IDs for Examples produced by Transform and by
# ExampleGen.
all_examples_ids = {a.id for a in upstream_examples}
transformed_examples_ids = set()
for event in store.get_events_by_artifact_ids(all_examples_ids):
if event_lib.is_valid_output_event(
event, expected_output_key=ops_utils.TRANSFORMED_EXAMPLES_KEY
):
transformed_examples_ids.add(event.artifact_id)
# We intentionally do set subtraction instead of filtering by the output_key
# "examples", in case the Examples artifact is produced by a custom
# component.
examples_ids = all_examples_ids - transformed_examples_ids

mlmd_artifacts = []
for artifact in upstream_examples:
# Get all Executions associated with creating the Model.
execution_ids = set()
for event in store.get_events_by_artifact_ids([model.id]):
if event_lib.is_valid_output_event(event):
execution_ids.add(event.execution_id)

# Get all artifact ids associated with an INPUT Event in each Execution.
# These ids correspond to parent artifacts of the Model.
parent_artifact_ids = set()
for event in store.get_events_by_execution_ids(execution_ids):
if event_lib.is_valid_input_event(event):
parent_artifact_ids.add(event.artifact_id)

# Get the type ids of the parent artifacts and only keep ones marked as LIVE.
type_ids = set()
live_artifacts = []
for artifact in store.get_artifacts_by_id(parent_artifact_ids):
# Only consider Examples artifacts that are marked LIVE. This excludes
# garbage collected artifacts (which are marked as DELETED).
if artifact.state != metadata_store_pb2.Artifact.State.LIVE:
continue
elif use_transformed_examples and artifact.id in transformed_examples_ids:
mlmd_artifacts.append(artifact)
elif not use_transformed_examples and artifact.id in examples_ids:
mlmd_artifacts.append(artifact)
if not mlmd_artifacts:
type_ids.add(artifact.type_id)
live_artifacts.append(artifact)

# Find the ArtifactType associated with Examples.
for artifact_type in store.get_artifact_types_by_id(type_ids):
if artifact_type.name == ops_utils.EXAMPLES_TYPE_NAME:
examples_type = artifact_type
break
else:
return []

# Find the ArtifactType associated with the artifacts.
artifact_type = store.get_artifact_types_by_id([mlmd_artifacts[0].type_id])[0]
mlmd_examples = []
for artifact in live_artifacts:
if (
artifact.type_id == examples_type.id
and artifact.state == metadata_store_pb2.Artifact.State.LIVE
):
mlmd_examples.append(artifact)

# Return the sorted, serialized Examples.
artifacts = artifact_utils.deserialize_artifacts(
artifact_type, mlmd_artifacts
)
if not mlmd_examples:
return []

# Return the sorted Examples.
artifacts = artifact_utils.deserialize_artifacts(examples_type, mlmd_examples)
return sorted(
artifacts, key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id)
)
Expand All @@ -142,11 +124,6 @@ class TrainingRange(
):
"""TrainingRange operator."""

# Whether to return the materialized Examples produced by the Transform
# component. Should only be used if a Model was trained on materialized
# transformed Examples produced by a Transform. Defaults to False.
use_transformed_examples = resolver_op.Property(type=bool, default=False)

def apply(
self, input_list: Sequence[types.Artifact]
) -> Sequence[types.Artifact]:
Expand All @@ -156,9 +133,7 @@ def apply(

model = _validate_input_list(input_list)

examples = training_range(
self.context.store, model, self.use_transformed_examples
)
examples = training_range(self.context.store, model)

if not examples:
return []
Expand Down
12 changes: 4 additions & 8 deletions tfx/dsl/input_resolution/ops/training_range_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def testTrainingRangeOp_MultipleModels(self):
actual_2 = self._training_range([model_2])
self.assertArtifactListEqual(actual_2, self.examples[5:])

def testTrainingRangeOp_TrainOnTransformedExamples(
def testTrainingRangeOp_TrainOnTransformedExamples_ReturnsTransformedExamples(
self,
):
transformed_examples = self._build_examples(10)
Expand All @@ -103,11 +103,7 @@ def testTrainingRangeOp_TrainOnTransformedExamples(
self.train_on_examples(
self.model, transformed_examples, self.transform_graph
)

actual = self._training_range([self.model], use_transformed_examples=False)
self.assertArtifactListEqual(actual, self.examples)

actual = self._training_range([self.model], use_transformed_examples=True)
actual = self._training_range([self.model])
self.assertArtifactListEqual(actual, transformed_examples)

def testTrainingRangeOp_SameSpanMultipleVersions_AllVersionsReturned(self):
Expand Down Expand Up @@ -167,9 +163,9 @@ def testTrainingRangeOp_BulkInferrerProducesExamples(self):
# The BulkInferrer takes in the same Examples used to Trainer the Model,
# and outputs 5 new examples to be used downstream. This creates additional
# Examples artifacts in MLMD linked to the Model, but they should NOT be
# returned as the Examples that the Model was trained on.
# returend as the Examples that the Model was trained on.
self.put_execution(
'BulkInferrer',
'TFTrainer',
inputs={
'examples': self.unwrap_tfx_artifacts(self.examples),
'model': self.unwrap_tfx_artifacts([self.model]),
Expand Down

0 comments on commit 6f6a06a

Please sign in to comment.