diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 98f521b0fbb..1c9d6b6be10 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -43,6 +43,7 @@ from kfp.dsl import PipelineTaskFinalStatus from kfp.dsl import tasks_group from kfp.dsl import yaml_component +from kfp.dsl.pipeline_config import PipelineConfig from kfp.dsl.types import type_utils from kfp.pipeline_spec import pipeline_spec_pb2 import yaml @@ -3937,6 +3938,71 @@ def outer(): foo_platform_set_bar_feature(task, 12) +class TestPipelineSemaphoreMutex(unittest.TestCase): + + def test_pipeline_with_semaphore(self): + """Test that pipeline config correctly sets the semaphore key.""" + config = PipelineConfig() + config.set_semaphore_key('semaphore') + + @dsl.pipeline(pipeline_config=config) + def my_pipeline(): + task = comp() + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_docs = list(yaml.safe_load_all(f)) + + pipeline_spec = None + for doc in pipeline_docs: + if 'platforms' in doc: + pipeline_spec = doc + break + + # # Debug statement to print the contents of the pipeline_spec + # print("Pipeline spec:", pipeline_spec) + + if pipeline_spec: + kubernetes_spec = pipeline_spec['platforms']['kubernetes'][ + 'pipelineConfig'] + assert kubernetes_spec['semaphoreKey'] == 'semaphore' + + def test_pipeline_with_mutex(self): + """Test that pipeline config correctly sets the mutex name.""" + config = PipelineConfig() + config.set_mutex_name('mutex') + + @dsl.pipeline(pipeline_config=config) + def my_pipeline(): + task = comp() + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_docs = list(yaml.safe_load_all(f)) + + pipeline_spec = None + for doc in pipeline_docs: + if 'platforms' in doc: + pipeline_spec = doc + break + + # # Debug statement to print the contents of the pipeline_spec + # print("Pipeline spec:", pipeline_spec) + + if pipeline_spec: + kubernetes_spec = pipeline_spec['platforms']['kubernetes'][ + 'pipelineConfig'] + assert kubernetes_spec['mutexName'] == 'mutex' + + class ExtractInputOutputDescription(unittest.TestCase): def test_no_descriptions(self): diff --git a/sdk/python/kfp/dsl/pipeline_config.py b/sdk/python/kfp/dsl/pipeline_config.py index 8a730548d8b..b1d2f86a15f 100644 --- a/sdk/python/kfp/dsl/pipeline_config.py +++ b/sdk/python/kfp/dsl/pipeline_config.py @@ -24,8 +24,16 @@ def __init__(self): def set_semaphore_key(self, semaphore_key: str): """Set the name of the semaphore to control pipeline concurrency. + The semaphore is configured via a ConfigMap. By default, the ConfigMap is + named "semaphore-config", but this name can be specified through the APIServer + deployment manifests using an environment variable named SEMAPHORE_CONFIGMAP_NAME. + If the environment variable is not specified, the default name "semaphore-config" + is used. The semaphore key is provided through the pipeline configuration. + If a pipeline has a semaphore, the backend maps the semaphore to the ConfigMap + using the key provided by the user. + Args: - semaphore_key (str): Name of the semaphore. + semaphore_key (str): The key used to map to the ConfigMap. """ self.semaphore_key = semaphore_key.strip() diff --git a/test/presubmit-tests-sdk.sh b/test/presubmit-tests-sdk.sh index ae9411f78d6..3484a9023cb 100755 --- a/test/presubmit-tests-sdk.sh +++ b/test/presubmit-tests-sdk.sh @@ -21,6 +21,7 @@ source venv/bin/activate python3 -m pip install --upgrade pip python3 -m pip install setuptools +python3 -m pip install grpcio grpcio-tools python3 -m pip install coveralls==1.9.2 python3 -m pip install $(grep 'absl-py==' sdk/python/requirements-dev.txt) python3 -m pip install $(grep 'docker==' sdk/python/requirements-dev.txt) @@ -28,8 +29,10 @@ python3 -m pip install $(grep 'pytest==' sdk/python/requirements-dev.txt) python3 -m pip install $(grep 'pytest-xdist==' sdk/python/requirements-dev.txt) python3 -m pip install $(grep 'pytest-cov==' sdk/python/requirements-dev.txt) python3 -m pip install --upgrade protobuf +pushd api && make clean-python python && popd -python3 -m pip install sdk/python +python3 -m pip install -e api/v2alpha1/python +python3 -m pip install -e sdk/python pytest sdk/python/kfp --cov=kfp