From 85c6413a2e13da4b8f198aeac1abc2f3a74fe789 Mon Sep 17 00:00:00 2001 From: Ning Date: Thu, 20 Dec 2018 09:51:09 -0800 Subject: [PATCH] Refactor Python SDK (#568) * add some comments * remove unused import; add license to dsl_bridge * move_convert_k8s_obj_to_dic from compiler to k8s_helper * move unit test --- sdk/python/kfp/compiler/_component_builder.py | 1 - sdk/python/kfp/compiler/_k8s_helper.py | 49 ++++++++++ sdk/python/kfp/compiler/compiler.py | 98 +++++++------------ sdk/python/kfp/components/_component_store.py | 1 - sdk/python/kfp/components/_components.py | 2 +- sdk/python/kfp/components/_dynamic.py | 14 +++ sdk/python/kfp/components/_structures.py | 2 +- sdk/python/kfp/dsl/_ops_group.py | 2 +- sdk/python/tests/compiler/compiler_tests.py | 17 ---- sdk/python/tests/compiler/k8s_helper_tests.py | 34 +++++++ sdk/python/tests/compiler/main.py | 2 + 11 files changed, 135 insertions(+), 87 deletions(-) create mode 100644 sdk/python/tests/compiler/k8s_helper_tests.py diff --git a/sdk/python/kfp/compiler/_component_builder.py b/sdk/python/kfp/compiler/_component_builder.py index c9d00d18a15..b911596b1f2 100644 --- a/sdk/python/kfp/compiler/_component_builder.py +++ b/sdk/python/kfp/compiler/_component_builder.py @@ -22,7 +22,6 @@ import logging from collections import OrderedDict from pathlib import PurePath, Path -from .. import dsl from ..components._components import _create_task_factory_from_component_spec class GCSHelper(object): diff --git a/sdk/python/kfp/compiler/_k8s_helper.py b/sdk/python/kfp/compiler/_k8s_helper.py index 429ec25c66a..000ebb55703 100644 --- a/sdk/python/kfp/compiler/_k8s_helper.py +++ b/sdk/python/kfp/compiler/_k8s_helper.py @@ -118,3 +118,52 @@ def run_job(self, yaml_spec, timeout=600): # print(self._read_pod_log(pod_name, yaml_spec)) self._delete_k8s_job(pod_name, yaml_spec) return succ + + @staticmethod + def convert_k8s_obj_to_json(k8s_obj): + """ + Builds a JSON K8s object. + + If obj is None, return None. + If obj is str, int, long, float, bool, return directly. + If obj is datetime.datetime, datetime.date + convert to string in iso8601 format. + If obj is list, sanitize each element in the list. + If obj is dict, return the dict. + If obj is swagger model, return the properties dict. + + Args: + obj: The data to serialize. + Returns: The serialized form of data. + """ + + from six import text_type, integer_types, iteritems + PRIMITIVE_TYPES = (float, bool, bytes, text_type) + integer_types + from datetime import date, datetime + if k8s_obj is None: + return None + elif isinstance(k8s_obj, PRIMITIVE_TYPES): + return k8s_obj + elif isinstance(k8s_obj, list): + return [K8sHelper.convert_k8s_obj_to_json(sub_obj) + for sub_obj in k8s_obj] + elif isinstance(k8s_obj, tuple): + return tuple(K8sHelper.convert_k8s_obj_to_json(sub_obj) + for sub_obj in obj) + elif isinstance(k8s_obj, (datetime, date)): + return k8s_obj.isoformat() + + if isinstance(k8s_obj, dict): + obj_dict = k8s_obj + else: + # Convert model obj to dict except + # attributes `swagger_types`, `attribute_map` + # and attributes which value is not None. + # Convert attribute name to json key in + # model definition for request. + obj_dict = {k8s_obj.attribute_map[attr]: getattr(k8s_obj, attr) + for attr, _ in iteritems(k8s_obj.swagger_types) + if getattr(k8s_obj, attr) is not None} + + return {key: K8sHelper.convert_k8s_obj_to_json(val) + for key, val in iteritems(obj_dict)} \ No newline at end of file diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index 99dab6de04a..70d085dcdfa 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -14,16 +14,13 @@ from collections import defaultdict -import copy import inspect import re -import string import tarfile -import tempfile import yaml from .. import dsl - +from ._k8s_helper import K8sHelper class Compiler(object): """DSL Compiler. @@ -42,9 +39,17 @@ def my_pipeline(a: dsl.PipelineParam, b: dsl.PipelineParam): """ def _sanitize_name(self, name): - return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-') #from _make_kubernetes_name + """From _make_kubernetes_name + _sanitize_name cleans and converts the names in the workflow. + """ + return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-') - def _param_full_name(self, param): + def _pipelineparam_full_name(self, param): + """_pipelineparam_full_name + + Args: + param(PipelineParam): pipeline parameter + """ if param.op_name: return param.op_name + '-' + param.name return self._sanitize_name(param.name) @@ -79,12 +84,12 @@ def _op_to_template(self, op): for i, _ in enumerate(processed_args): if op.argument_inputs: for param in op.argument_inputs: - full_name = self._param_full_name(param) + full_name = self._pipelineparam_full_name(param) processed_args[i] = re.sub(str(param), '{{inputs.parameters.%s}}' % full_name, processed_args[i]) input_parameters = [] for param in op.inputs: - one_parameter = {'name': self._param_full_name(param)} + one_parameter = {'name': self._pipelineparam_full_name(param)} if param.value: one_parameter['value'] = str(param.value) input_parameters.append(one_parameter) @@ -94,7 +99,7 @@ def _op_to_template(self, op): output_parameters = [] for param in op.outputs.values(): output_parameters.append({ - 'name': self._param_full_name(param), + 'name': self._pipelineparam_full_name(param), 'valueFrom': {'path': op.file_outputs[param.name]} }) output_parameters.sort(key=lambda x: x['name']) @@ -140,9 +145,9 @@ def _op_to_template(self, op): template['nodeSelector'] = op.node_selector if op.env_variables: - template['container']['env'] = list(map(self._convert_k8s_obj_to_dic, op.env_variables)) + template['container']['env'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.env_variables)) if op.volume_mounts: - template['container']['volumeMounts'] = list(map(self._convert_k8s_obj_to_dic, op.volume_mounts)) + template['container']['volumeMounts'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.volume_mounts)) if op.pod_annotations or op.pod_labels: template['metadata'] = {} @@ -222,7 +227,7 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups): if param.value: continue - full_name = self._param_full_name(param) + full_name = self._pipelineparam_full_name(param) if param.op_name: upstream_op = pipeline.ops[param.op_name] upstream_groups, downstream_groups = self._get_uncommon_ancestors( @@ -297,10 +302,16 @@ def _get_dependencies(self, pipeline, root_group, op_groups): dependencies[downstream_groups[0]].add(upstream_groups[0]) return dependencies - def _resolve_value_or_reference(self, value_or_reference, inputs): + def _resolve_value_or_reference(self, value_or_reference, potential_references): + """_resolve_value_or_reference resolves values and PipelineParams, which could be task parameters or input parameters. + + Args: + value_or_reference: value or reference to be resolved. It could be basic python types or PipelineParam + potential_references(dict{str->str}): a dictionary of parameter names to task names + """ if isinstance(value_or_reference, dsl.PipelineParam): - parameter_name = self._param_full_name(value_or_reference) - task_names = [task_name for param_name, task_name in inputs if param_name == parameter_name] + parameter_name = self._pipelineparam_full_name(value_or_reference) + task_names = [task_name for param_name, task_name in potential_references if param_name == parameter_name] if task_names: task_name = task_names[0] return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name) @@ -381,7 +392,6 @@ def _group_to_template(self, group, inputs, outputs, dependencies): template['dag'] = {'tasks': tasks} return template - def _create_templates(self, pipeline): """Create all groups and ops templates in the pipeline.""" @@ -411,13 +421,14 @@ def _create_volumes(self, pipeline): #TODO: check for duplicity based on the serialized volumes instead of just name. if v.name not in volume_name_set: volume_name_set.add(v.name) - volumes.append(self._convert_k8s_obj_to_dic(v)) + volumes.append(K8sHelper.convert_k8s_obj_to_json(v)) volumes.sort(key=lambda x: x['name']) return volumes def _create_pipeline_workflow(self, args, pipeline): """Create workflow for the pipeline.""" + # Input Parameters input_params = [] for arg in args: param = {'name': arg.name} @@ -425,16 +436,21 @@ def _create_pipeline_workflow(self, args, pipeline): param['value'] = str(arg.value) input_params.append(param) + # Templates templates = self._create_templates(pipeline) templates.sort(key=lambda x: x['name']) + # Exit Handler exit_handler = None if pipeline.groups[0].groups: first_group = pipeline.groups[0].groups[0] if first_group.type == 'exit_handler': exit_handler = first_group.exit_op + # Volumes volumes = self._create_volumes(pipeline) + + # The whole pipeline workflow workflow = { 'apiVersion': 'argoproj.io/v1alpha1', 'kind': 'Workflow', @@ -503,54 +519,6 @@ def _compile(self, pipeline_func): workflow = self._create_pipeline_workflow(args_list_with_defaults, p) return workflow - def _convert_k8s_obj_to_dic(self, obj): - """ - Builds a JSON K8s object. - - If obj is None, return None. - If obj is str, int, long, float, bool, return directly. - If obj is datetime.datetime, datetime.date - convert to string in iso8601 format. - If obj is list, sanitize each element in the list. - If obj is dict, return the dict. - If obj is swagger model, return the properties dict. - - Args: - obj: The data to serialize. - Returns: The serialized form of data. - """ - - from six import text_type, integer_types, iteritems - PRIMITIVE_TYPES = (float, bool, bytes, text_type) + integer_types - from datetime import date, datetime - if obj is None: - return None - elif isinstance(obj, PRIMITIVE_TYPES): - return obj - elif isinstance(obj, list): - return [self._convert_k8s_obj_to_dic(sub_obj) - for sub_obj in obj] - elif isinstance(obj, tuple): - return tuple(self._convert_k8s_obj_to_dic(sub_obj) - for sub_obj in obj) - elif isinstance(obj, (datetime, date)): - return obj.isoformat() - - if isinstance(obj, dict): - obj_dict = obj - else: - # Convert model obj to dict except - # attributes `swagger_types`, `attribute_map` - # and attributes which value is not None. - # Convert attribute name to json key in - # model definition for request. - obj_dict = {obj.attribute_map[attr]: getattr(obj, attr) - for attr, _ in iteritems(obj.swagger_types) - if getattr(obj, attr) is not None} - - return {key: self._convert_k8s_obj_to_dic(val) - for key, val in iteritems(obj_dict)} - def compile(self, pipeline_func, package_path): """Compile the given pipeline function into workflow yaml. diff --git a/sdk/python/kfp/components/_component_store.py b/sdk/python/kfp/components/_component_store.py index 0c354f3db66..cff2187287c 100644 --- a/sdk/python/kfp/components/_component_store.py +++ b/sdk/python/kfp/components/_component_store.py @@ -4,7 +4,6 @@ from pathlib import Path import requests -import warnings from . import _components as comp class ComponentStore: diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index 2aeafd2c27b..ff0d2dbb423 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -21,7 +21,7 @@ import sys from collections import OrderedDict -from ._yaml_utils import load_yaml, dump_yaml +from ._yaml_utils import load_yaml from ._structures import ComponentSpec diff --git a/sdk/python/kfp/components/_dynamic.py b/sdk/python/kfp/components/_dynamic.py index 6e8275a9e69..f35a693d19d 100644 --- a/sdk/python/kfp/components/_dynamic.py +++ b/sdk/python/kfp/components/_dynamic.py @@ -1,3 +1,17 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any, Callable, Mapping, Sequence import types from inspect import Parameter, Signature diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index 16c1853209d..c23f2efdf6b 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -30,7 +30,7 @@ import copy from collections import OrderedDict -from typing import Union, List, Sequence, Mapping, Tuple +from typing import Union, List, Mapping, Tuple class InputOrOutputSpec: diff --git a/sdk/python/kfp/dsl/_ops_group.py b/sdk/python/kfp/dsl/_ops_group.py index 1e277b398dd..321fe4792af 100644 --- a/sdk/python/kfp/dsl/_ops_group.py +++ b/sdk/python/kfp/dsl/_ops_group.py @@ -28,7 +28,7 @@ class OpsGroup(object): def __init__(self, group_type: str, name: str=None): """Create a new instance of OpsGroup. Args: - group_type: usually one of 'exit_handler', 'condition', and 'loop'. + group_type: one of 'pipeline', 'exit_handler', 'condition', and 'loop'. """ self.type = group_type self.ops = list() diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 6363d6537a9..b0faafd5750 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import kfp.compiler as compiler import kfp.dsl as dsl import os @@ -23,7 +22,6 @@ import tempfile import unittest import yaml -import datetime class TestCompiler(unittest.TestCase): @@ -142,21 +140,6 @@ def test_basic_workflow(self): shutil.rmtree(tmpdir) # print(tmpdir) - def test_convert_k8s_obj_to_dic_accepts_dict(self): - now = datetime.datetime.now() - converted = compiler.Compiler()._convert_k8s_obj_to_dic({ - "ENV": "test", - "number": 3, - "list": [1,2,3], - "time": now - }) - self.assertEqual(converted, { - "ENV": "test", - "number": 3, - "list": [1,2,3], - "time": now.isoformat() - }) - def test_composing_workflow(self): """Test compiling a simple workflow, and a bigger one composed from the simple one.""" diff --git a/sdk/python/tests/compiler/k8s_helper_tests.py b/sdk/python/tests/compiler/k8s_helper_tests.py new file mode 100644 index 00000000000..5eaebc03a80 --- /dev/null +++ b/sdk/python/tests/compiler/k8s_helper_tests.py @@ -0,0 +1,34 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kfp.compiler._k8s_helper import K8sHelper +from datetime import datetime +import unittest + + +class TestCompiler(unittest.TestCase): + def test_convert_k8s_obj_to_dic_accepts_dict(self): + now = datetime.now() + converted = K8sHelper.convert_k8s_obj_to_json({ + "ENV": "test", + "number": 3, + "list": [1,2,3], + "time": now + }) + self.assertEqual(converted, { + "ENV": "test", + "number": 3, + "list": [1,2,3], + "time": now.isoformat() + }) \ No newline at end of file diff --git a/sdk/python/tests/compiler/main.py b/sdk/python/tests/compiler/main.py index 710cedc00bc..80a15acbedb 100644 --- a/sdk/python/tests/compiler/main.py +++ b/sdk/python/tests/compiler/main.py @@ -18,12 +18,14 @@ import compiler_tests import component_builder_test +import k8s_helper_tests if __name__ == '__main__': suite = unittest.TestSuite() suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(compiler_tests)) suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(component_builder_test)) + suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(k8s_helper_tests)) runner = unittest.TextTestRunner() if not runner.run(suite).wasSuccessful(): sys.exit(1)