diff --git a/tfx/dsl/input_resolution/ops/ops_utils.py b/tfx/dsl/input_resolution/ops/ops_utils.py index dc76666c1a..d70caf12ae 100644 --- a/tfx/dsl/input_resolution/ops/ops_utils.py +++ b/tfx/dsl/input_resolution/ops/ops_utils.py @@ -34,7 +34,6 @@ MODEL_EXPORT_KEY = 'model_export' MODEL_INFRA_BLESSING_KEY = 'model_infra_blessing' MODEL_PUSH_KEY = 'model_push' -TRANSFORMED_EXAMPLES_KEY = 'transformed_examples' # Taken from tfx.tflex.dsl.types.standard_artifacts. We don't use the existing # constants due to Copybara. diff --git a/tfx/dsl/input_resolution/ops/span_driven_evaluator_inputs_op.py b/tfx/dsl/input_resolution/ops/span_driven_evaluator_inputs_op.py index e80c3a15d5..20fd43dcf7 100644 --- a/tfx/dsl/input_resolution/ops/span_driven_evaluator_inputs_op.py +++ b/tfx/dsl/input_resolution/ops/span_driven_evaluator_inputs_op.py @@ -70,11 +70,6 @@ class SpanDrivenEvaluatorInputs( # >= start_span_number will be considered. start_span_number = resolver_op.Property(type=int, default=0) - # Whether to return the materialized Examples produced by the Transform - # component. Should only be used if a Model was trained on materialized - # transformed Examples produced by a Transform. Defaults to False. - use_transformed_examples = resolver_op.Property(type=bool, default=False) - def _get_model_to_evaluate( self, trained_examples_by_model: Dict[types.Artifact, List[types.Artifact]], @@ -93,7 +88,7 @@ def _get_model_to_evaluate( # The first Model was trained on spans less than max_span. return model - # No eligible Model was found, so a SkipSignal is raised. + # No elegible Model was found, so a SkipSignal is raised. raise exceptions.SkipSignal() def apply(self, input_dict: typing_utils.ArtifactMultiDict): @@ -151,7 +146,7 @@ def apply(self, input_dict: typing_utils.ArtifactMultiDict): trained_examples_by_model = {} for model in input_dict[ops_utils.MODEL_KEY]: trained_examples_by_model[model] = training_range_op.training_range( - self.context.store, model, self.use_transformed_examples + self.context.store, model ) # Sort the Models by latest created, with ties broken by id. diff --git a/tfx/dsl/input_resolution/ops/training_range_op.py b/tfx/dsl/input_resolution/ops/training_range_op.py index 10dbb0f732..ad526b2ffc 100644 --- a/tfx/dsl/input_resolution/ops/training_range_op.py +++ b/tfx/dsl/input_resolution/ops/training_range_op.py @@ -22,7 +22,6 @@ from tfx.orchestration.portable.mlmd import event_lib from tfx.types import artifact_utils -from ml_metadata.google.tools.mlmd_resolver import metadata_resolver from ml_metadata.proto import metadata_store_pb2 @@ -41,9 +40,7 @@ def _validate_input_list( return input_list[0] -def training_range( - store: Any, model: types.Artifact, use_transformed_examples: bool = False -) -> List[types.Artifact]: +def training_range(store: Any, model: types.Artifact) -> List[types.Artifact]: """ContainsTrainingRange implementation, for shared use across ResolverOps. Returns the Examples artifact the Model was trained on. @@ -53,10 +50,6 @@ def training_range( Args: store: The MetadataStore. model: The Model artifact whose trained Examples to return. - use_transformed_examples: Whether to return the materialized Examples - produced by the Transform component. Should only be used if a Model was - trained on materialized transformed Examples produced by a Transform. - Defaults to False. Returns: List of Examples artifacts if found, else empty list. We intentionally don't @@ -67,68 +60,57 @@ def training_range( # Event 1 Event 2 # Examples ------> Execution ------> Model # - # - # or, in the case where a Transform component materializes Examples: - # - # Event 1 Event 2 Event 3 - # Examples ------> Execution ------> Examples -----> Execution ------> Model - # - # # For a single Model, there may be many parent Examples it was trained on. # TODO(kshivvy): Support querying multiple Model ids at once, to reduce the # number of round trip MLMD queries. This will be useful for resolving inputs # of a span driven evaluator. - # Get all upstream Examples artifacts associated with the Model. - mlmd_resolver = metadata_resolver.MetadataResolver(store) - upstream_examples_dict = mlmd_resolver.get_upstream_artifacts_by_artifact_ids( - artifact_ids=[model.id], - # In MLMD, artifacts are 2 hops away. Because we are considering - # Example -> (transformd) Examples -> Model, we set max_num_hops to 4. - max_num_hops=4, - filter_query=f'type="{ops_utils.EXAMPLES_TYPE_NAME}"', - ) - if not upstream_examples_dict: - return [] - upstream_examples = upstream_examples_dict[model.id] - if not upstream_examples: - return [] - - # Get the sets of artifact IDs for Examples produced by Transform and by - # ExampleGen. - all_examples_ids = {a.id for a in upstream_examples} - transformed_examples_ids = set() - for event in store.get_events_by_artifact_ids(all_examples_ids): - if event_lib.is_valid_output_event( - event, expected_output_key=ops_utils.TRANSFORMED_EXAMPLES_KEY - ): - transformed_examples_ids.add(event.artifact_id) - # We intentionally do set subtraction instead of filtering by the output_key - # "examples", in case the Examples artifact is produced by a custom - # component. - examples_ids = all_examples_ids - transformed_examples_ids - - mlmd_artifacts = [] - for artifact in upstream_examples: + # Get all Executions associated with creating the Model. + execution_ids = set() + for event in store.get_events_by_artifact_ids([model.id]): + if event_lib.is_valid_output_event(event): + execution_ids.add(event.execution_id) + + # Get all artifact ids associated with an INPUT Event in each Execution. + # These ids correspond to parent artifacts of the Model. + parent_artifact_ids = set() + for event in store.get_events_by_execution_ids(execution_ids): + if event_lib.is_valid_input_event(event): + parent_artifact_ids.add(event.artifact_id) + + # Get the type ids of the parent artifacts and only keep ones marked as LIVE. + type_ids = set() + live_artifacts = [] + for artifact in store.get_artifacts_by_id(parent_artifact_ids): # Only consider Examples artifacts that are marked LIVE. This excludes # garbage collected artifacts (which are marked as DELETED). if artifact.state != metadata_store_pb2.Artifact.State.LIVE: continue - elif use_transformed_examples and artifact.id in transformed_examples_ids: - mlmd_artifacts.append(artifact) - elif not use_transformed_examples and artifact.id in examples_ids: - mlmd_artifacts.append(artifact) - if not mlmd_artifacts: + type_ids.add(artifact.type_id) + live_artifacts.append(artifact) + + # Find the ArtifactType associated with Examples. + for artifact_type in store.get_artifact_types_by_id(type_ids): + if artifact_type.name == ops_utils.EXAMPLES_TYPE_NAME: + examples_type = artifact_type + break + else: return [] - # Find the ArtifactType associated with the artifacts. - artifact_type = store.get_artifact_types_by_id([mlmd_artifacts[0].type_id])[0] + mlmd_examples = [] + for artifact in live_artifacts: + if ( + artifact.type_id == examples_type.id + and artifact.state == metadata_store_pb2.Artifact.State.LIVE + ): + mlmd_examples.append(artifact) - # Return the sorted, serialized Examples. - artifacts = artifact_utils.deserialize_artifacts( - artifact_type, mlmd_artifacts - ) + if not mlmd_examples: + return [] + + # Return the sorted Examples. + artifacts = artifact_utils.deserialize_artifacts(examples_type, mlmd_examples) return sorted( artifacts, key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id) ) @@ -142,11 +124,6 @@ class TrainingRange( ): """TrainingRange operator.""" - # Whether to return the materialized Examples produced by the Transform - # component. Should only be used if a Model was trained on materialized - # transformed Examples produced by a Transform. Defaults to False. - use_transformed_examples = resolver_op.Property(type=bool, default=False) - def apply( self, input_list: Sequence[types.Artifact] ) -> Sequence[types.Artifact]: @@ -156,9 +133,7 @@ def apply( model = _validate_input_list(input_list) - examples = training_range( - self.context.store, model, self.use_transformed_examples - ) + examples = training_range(self.context.store, model) if not examples: return [] diff --git a/tfx/dsl/input_resolution/ops/training_range_op_test.py b/tfx/dsl/input_resolution/ops/training_range_op_test.py index 3fd4e4433a..51fd27596d 100644 --- a/tfx/dsl/input_resolution/ops/training_range_op_test.py +++ b/tfx/dsl/input_resolution/ops/training_range_op_test.py @@ -81,7 +81,7 @@ def testTrainingRangeOp_MultipleModels(self): actual_2 = self._training_range([model_2]) self.assertArtifactListEqual(actual_2, self.examples[5:]) - def testTrainingRangeOp_TrainOnTransformedExamples( + def testTrainingRangeOp_TrainOnTransformedExamples_ReturnsTransformedExamples( self, ): transformed_examples = self._build_examples(10) @@ -103,11 +103,7 @@ def testTrainingRangeOp_TrainOnTransformedExamples( self.train_on_examples( self.model, transformed_examples, self.transform_graph ) - - actual = self._training_range([self.model], use_transformed_examples=False) - self.assertArtifactListEqual(actual, self.examples) - - actual = self._training_range([self.model], use_transformed_examples=True) + actual = self._training_range([self.model]) self.assertArtifactListEqual(actual, transformed_examples) def testTrainingRangeOp_SameSpanMultipleVersions_AllVersionsReturned(self): @@ -167,9 +163,9 @@ def testTrainingRangeOp_BulkInferrerProducesExamples(self): # The BulkInferrer takes in the same Examples used to Trainer the Model, # and outputs 5 new examples to be used downstream. This creates additional # Examples artifacts in MLMD linked to the Model, but they should NOT be - # returned as the Examples that the Model was trained on. + # returend as the Examples that the Model was trained on. self.put_execution( - 'BulkInferrer', + 'TFTrainer', inputs={ 'examples': self.unwrap_tfx_artifacts(self.examples), 'model': self.unwrap_tfx_artifacts([self.model]),