From 7f04c4f07f2698f823953902bfb79fc7cb6e1584 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 26 Mar 2024 16:54:00 -0700 Subject: [PATCH] Deduplicate common environments. (#30681) We deduplicate both on proto construction (as before, but fixed) and again after more environments have been resolved. --- sdks/python/apache_beam/pipeline.py | 31 +-------- sdks/python/apache_beam/runners/common.py | 67 +++++++++++++++++++ .../python/apache_beam/runners/common_test.py | 59 ++++++++++++++++ .../runners/dataflow/dataflow_runner.py | 2 + .../portability/fn_api_runner/fn_runner.py | 4 +- 5 files changed, 133 insertions(+), 30 deletions(-) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 53044982a066..11bc74d27eca 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -86,6 +86,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import PipelineRunner +from apache_beam.runners import common from apache_beam.runners import create_runner from apache_beam.transforms import ParDo from apache_beam.transforms import ptransform @@ -967,35 +968,7 @@ def merge_compatible_environments(proto): Mutates proto as contexts may have references to proto.components. """ - env_map = {} - canonical_env = {} - files_by_hash = {} - for env_id, env in proto.components.environments.items(): - # First deduplicate any file dependencies by their hash. - for dep in env.dependencies: - if dep.type_urn == common_urns.artifact_types.FILE.urn: - file_payload = beam_runner_api_pb2.ArtifactFilePayload.FromString( - dep.type_payload) - if file_payload.sha256: - if file_payload.sha256 in files_by_hash: - file_payload.path = files_by_hash[file_payload.sha256] - dep.type_payload = file_payload.SerializeToString() - else: - files_by_hash[file_payload.sha256] = file_payload.path - # Next check if we've ever seen this environment before. - normalized = env.SerializeToString(deterministic=True) - if normalized in canonical_env: - env_map[env_id] = canonical_env[normalized] - else: - canonical_env[normalized] = env_id - for old_env, new_env in env_map.items(): - for transform in proto.components.transforms.values(): - if transform.environment_id == old_env: - transform.environment_id = new_env - for windowing_strategy in proto.components.windowing_strategies.values(): - if windowing_strategy.environment_id == old_env: - windowing_strategy.environment_id = new_env - del proto.components.environments[old_env] + common.merge_common_environments(proto, inplace=True) @staticmethod def from_runner_api( diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 1cd0a3044663..630ed7910c8d 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -24,6 +24,8 @@ # pytype: skip-file +import collections +import copy import logging import sys import threading @@ -43,6 +45,7 @@ from apache_beam.internal import util from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.portability import common_urns +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import TaggedOutput from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider from apache_beam.runners.sdf_utils import RestrictionTrackerView @@ -52,6 +55,7 @@ from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator from apache_beam.transforms import DoFn from apache_beam.transforms import core +from apache_beam.transforms import environments from apache_beam.transforms import userstate from apache_beam.transforms.core import RestrictionProvider from apache_beam.transforms.core import WatermarkEstimatorProvider @@ -1941,3 +1945,66 @@ def validate_transform(transform_id): for t in pipeline_proto.root_transform_ids: validate_transform(t) + + +def merge_common_environments(pipeline_proto, inplace=False): + def dep_key(dep): + if dep.type_urn == common_urns.artifact_types.FILE.urn: + payload = beam_runner_api_pb2.ArtifactFilePayload.FromString( + dep.type_payload) + if payload.sha256: + type_info = 'sha256', payload.sha256 + else: + type_info = 'path', payload.path + elif dep.type_urn == common_urns.artifact_types.URL.urn: + payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString( + dep.type_payload) + if payload.sha256: + type_info = 'sha256', payload.sha256 + else: + type_info = 'url', payload.url + else: + type_info = dep.type_urn, dep.type_payload + return type_info, dep.role_urn, dep.role_payload + + def base_env_key(env): + return ( + env.urn, + env.payload, + tuple(sorted(env.capabilities)), + tuple(sorted(env.resource_hints.items())), + tuple(sorted(dep_key(dep) for dep in env.dependencies))) + + def env_key(env): + return tuple( + sorted( + base_env_key(e) + for e in environments.expand_anyof_environments(env))) + + cannonical_enviornments = collections.defaultdict(list) + for env_id, env in pipeline_proto.components.environments.items(): + cannonical_enviornments[env_key(env)].append(env_id) + + if len(cannonical_enviornments) == len( + pipeline_proto.components.environments): + # All environments are already sufficiently distinct. + return pipeline_proto + + environment_remappings = { + e: es[0] + for es in cannonical_enviornments.values() for e in es + } + + if not inplace: + pipeline_proto = copy.copy(pipeline_proto) + + for t in pipeline_proto.components.transforms.values(): + if t.environment_id: + t.environment_id = environment_remappings[t.environment_id] + for w in pipeline_proto.components.windowing_strategies.values(): + if w.environment_id: + w.environment_id = environment_remappings[w.environment_id] + for e in set(pipeline_proto.components.environments.keys()) - set( + environment_remappings.values()): + del pipeline_proto.components.environments[e] + return pipeline_proto diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py index 00645948c3ed..ca2cd2539a8c 100644 --- a/sdks/python/apache_beam/runners/common_test.py +++ b/sdks/python/apache_beam/runners/common_test.py @@ -26,8 +26,11 @@ from apache_beam.io.restriction_trackers import OffsetRestrictionTracker from apache_beam.io.watermark_estimators import ManualWatermarkEstimator from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners.common import DoFnSignature from apache_beam.runners.common import PerWindowInvoker +from apache_beam.runners.common import merge_common_environments +from apache_beam.runners.portability.expansion_service_test import FibTransform from apache_beam.runners.sdf_utils import SplitResultPrimary from apache_beam.runners.sdf_utils import SplitResultResidual from apache_beam.testing.test_pipeline import TestPipeline @@ -584,5 +587,61 @@ def test_window_observing_split_on_window_boundary_round_down_on_last_window( self.assertEqual(stop_index, 2) +class UtilitiesTest(unittest.TestCase): + def test_equal_environments_merged(self): + pipeline_proto = merge_common_environments( + beam_runner_api_pb2.Pipeline( + components=beam_runner_api_pb2.Components( + environments={ + 'a1': beam_runner_api_pb2.Environment(urn='A'), + 'a2': beam_runner_api_pb2.Environment(urn='A'), + 'b1': beam_runner_api_pb2.Environment( + urn='B', payload=b'x'), + 'b2': beam_runner_api_pb2.Environment( + urn='B', payload=b'x'), + 'b3': beam_runner_api_pb2.Environment( + urn='B', payload=b'y'), + }, + transforms={ + 't1': beam_runner_api_pb2.PTransform( + unique_name='t1', environment_id='a1'), + 't2': beam_runner_api_pb2.PTransform( + unique_name='t2', environment_id='a2'), + }, + windowing_strategies={ + 'w1': beam_runner_api_pb2.WindowingStrategy( + environment_id='b1'), + 'w2': beam_runner_api_pb2.WindowingStrategy( + environment_id='b2'), + }))) + self.assertEqual(len(pipeline_proto.components.environments), 3) + self.assertTrue(('a1' in pipeline_proto.components.environments) + ^ ('a2' in pipeline_proto.components.environments)) + self.assertTrue(('b1' in pipeline_proto.components.environments) + ^ ('b2' in pipeline_proto.components.environments)) + self.assertEqual( + len( + set( + t.environment_id + for t in pipeline_proto.components.transforms.values())), + 1) + self.assertEqual( + len( + set( + w.environment_id for w in + pipeline_proto.components.windowing_strategies.values())), + 1) + + def test_external_merged(self): + p = beam.Pipeline() + # This transform recursively creates several external environments. + _ = p | FibTransform(4) + pipeline_proto = p.to_runner_api() + # All our external environments are equal and consolidated. + # We also have a placeholder "default" environment that has not been + # resolved do anything concrete yet. + self.assertEqual(len(pipeline_proto.components.environments), 2) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index db6a5235ac92..e428551ef028 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -43,6 +43,7 @@ from apache_beam.options.pipeline_options import WorkerOptions from apache_beam.portability import common_urns from apache_beam.runners.common import group_by_key_input_visitor +from apache_beam.runners.common import merge_common_environments from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner @@ -419,6 +420,7 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None): self.proto_pipeline.components.environments[env_id].CopyFrom( environments.resolve_anyof_environment( env, common_urns.environments.DOCKER.urn)) + self.proto_pipeline = merge_common_environments(self.proto_pipeline) # Optimize the pipeline if it not streaming and the pre_optimize # experiment is set. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index b3dd124216be..07569fe328d8 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -62,6 +62,7 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import runner from apache_beam.runners.common import group_by_key_input_visitor +from apache_beam.runners.common import merge_common_environments from apache_beam.runners.common import validate_pipeline_graph from apache_beam.runners.portability import portable_metrics from apache_beam.runners.portability.fn_api_runner import execution @@ -221,7 +222,8 @@ def run_via_runner_api(self, pipeline_proto, options): ] if direct_options.direct_embed_docker_python: pipeline_proto = self.embed_default_docker_image(pipeline_proto) - pipeline_proto = self.resolve_any_environments(pipeline_proto) + pipeline_proto = merge_common_environments( + self.resolve_any_environments(pipeline_proto)) stage_context, stages = self.create_stages(pipeline_proto) return self.run_stages(stage_context, stages)