Skip to content

Commit

Permalink
Refactor Python SDK (#568)
Browse files Browse the repository at this point in the history
* add some comments

* remove unused import; add license to dsl_bridge

* move_convert_k8s_obj_to_dic from compiler to k8s_helper

* move unit test
  • Loading branch information
gaoning777 authored and k8s-ci-robot committed Dec 20, 2018
1 parent 549a366 commit 85c6413
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 87 deletions.
1 change: 0 additions & 1 deletion sdk/python/kfp/compiler/_component_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import logging
from collections import OrderedDict
from pathlib import PurePath, Path
from .. import dsl
from ..components._components import _create_task_factory_from_component_spec

class GCSHelper(object):
Expand Down
49 changes: 49 additions & 0 deletions sdk/python/kfp/compiler/_k8s_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,52 @@ def run_job(self, yaml_spec, timeout=600):
# print(self._read_pod_log(pod_name, yaml_spec))
self._delete_k8s_job(pod_name, yaml_spec)
return succ

@staticmethod
def convert_k8s_obj_to_json(k8s_obj):
"""
Builds a JSON K8s object.
If obj is None, return None.
If obj is str, int, long, float, bool, return directly.
If obj is datetime.datetime, datetime.date
convert to string in iso8601 format.
If obj is list, sanitize each element in the list.
If obj is dict, return the dict.
If obj is swagger model, return the properties dict.
Args:
obj: The data to serialize.
Returns: The serialized form of data.
"""

from six import text_type, integer_types, iteritems
PRIMITIVE_TYPES = (float, bool, bytes, text_type) + integer_types
from datetime import date, datetime
if k8s_obj is None:
return None
elif isinstance(k8s_obj, PRIMITIVE_TYPES):
return k8s_obj
elif isinstance(k8s_obj, list):
return [K8sHelper.convert_k8s_obj_to_json(sub_obj)
for sub_obj in k8s_obj]
elif isinstance(k8s_obj, tuple):
return tuple(K8sHelper.convert_k8s_obj_to_json(sub_obj)
for sub_obj in obj)
elif isinstance(k8s_obj, (datetime, date)):
return k8s_obj.isoformat()

if isinstance(k8s_obj, dict):
obj_dict = k8s_obj
else:
# Convert model obj to dict except
# attributes `swagger_types`, `attribute_map`
# and attributes which value is not None.
# Convert attribute name to json key in
# model definition for request.
obj_dict = {k8s_obj.attribute_map[attr]: getattr(k8s_obj, attr)
for attr, _ in iteritems(k8s_obj.swagger_types)
if getattr(k8s_obj, attr) is not None}

return {key: K8sHelper.convert_k8s_obj_to_json(val)
for key, val in iteritems(obj_dict)}
98 changes: 33 additions & 65 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@


from collections import defaultdict
import copy
import inspect
import re
import string
import tarfile
import tempfile
import yaml

from .. import dsl

from ._k8s_helper import K8sHelper

class Compiler(object):
"""DSL Compiler.
Expand All @@ -42,9 +39,17 @@ def my_pipeline(a: dsl.PipelineParam, b: dsl.PipelineParam):
"""

def _sanitize_name(self, name):
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-') #from _make_kubernetes_name
"""From _make_kubernetes_name
_sanitize_name cleans and converts the names in the workflow.
"""
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-')

def _param_full_name(self, param):
def _pipelineparam_full_name(self, param):
"""_pipelineparam_full_name
Args:
param(PipelineParam): pipeline parameter
"""
if param.op_name:
return param.op_name + '-' + param.name
return self._sanitize_name(param.name)
Expand Down Expand Up @@ -79,12 +84,12 @@ def _op_to_template(self, op):
for i, _ in enumerate(processed_args):
if op.argument_inputs:
for param in op.argument_inputs:
full_name = self._param_full_name(param)
full_name = self._pipelineparam_full_name(param)
processed_args[i] = re.sub(str(param), '{{inputs.parameters.%s}}' % full_name,
processed_args[i])
input_parameters = []
for param in op.inputs:
one_parameter = {'name': self._param_full_name(param)}
one_parameter = {'name': self._pipelineparam_full_name(param)}
if param.value:
one_parameter['value'] = str(param.value)
input_parameters.append(one_parameter)
Expand All @@ -94,7 +99,7 @@ def _op_to_template(self, op):
output_parameters = []
for param in op.outputs.values():
output_parameters.append({
'name': self._param_full_name(param),
'name': self._pipelineparam_full_name(param),
'valueFrom': {'path': op.file_outputs[param.name]}
})
output_parameters.sort(key=lambda x: x['name'])
Expand Down Expand Up @@ -140,9 +145,9 @@ def _op_to_template(self, op):
template['nodeSelector'] = op.node_selector

if op.env_variables:
template['container']['env'] = list(map(self._convert_k8s_obj_to_dic, op.env_variables))
template['container']['env'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.env_variables))
if op.volume_mounts:
template['container']['volumeMounts'] = list(map(self._convert_k8s_obj_to_dic, op.volume_mounts))
template['container']['volumeMounts'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.volume_mounts))

if op.pod_annotations or op.pod_labels:
template['metadata'] = {}
Expand Down Expand Up @@ -222,7 +227,7 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups):
if param.value:
continue

full_name = self._param_full_name(param)
full_name = self._pipelineparam_full_name(param)
if param.op_name:
upstream_op = pipeline.ops[param.op_name]
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
Expand Down Expand Up @@ -297,10 +302,16 @@ def _get_dependencies(self, pipeline, root_group, op_groups):
dependencies[downstream_groups[0]].add(upstream_groups[0])
return dependencies

def _resolve_value_or_reference(self, value_or_reference, inputs):
def _resolve_value_or_reference(self, value_or_reference, potential_references):
"""_resolve_value_or_reference resolves values and PipelineParams, which could be task parameters or input parameters.
Args:
value_or_reference: value or reference to be resolved. It could be basic python types or PipelineParam
potential_references(dict{str->str}): a dictionary of parameter names to task names
"""
if isinstance(value_or_reference, dsl.PipelineParam):
parameter_name = self._param_full_name(value_or_reference)
task_names = [task_name for param_name, task_name in inputs if param_name == parameter_name]
parameter_name = self._pipelineparam_full_name(value_or_reference)
task_names = [task_name for param_name, task_name in potential_references if param_name == parameter_name]
if task_names:
task_name = task_names[0]
return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name)
Expand Down Expand Up @@ -381,7 +392,6 @@ def _group_to_template(self, group, inputs, outputs, dependencies):
template['dag'] = {'tasks': tasks}
return template


def _create_templates(self, pipeline):
"""Create all groups and ops templates in the pipeline."""

Expand Down Expand Up @@ -411,30 +421,36 @@ def _create_volumes(self, pipeline):
#TODO: check for duplicity based on the serialized volumes instead of just name.
if v.name not in volume_name_set:
volume_name_set.add(v.name)
volumes.append(self._convert_k8s_obj_to_dic(v))
volumes.append(K8sHelper.convert_k8s_obj_to_json(v))
volumes.sort(key=lambda x: x['name'])
return volumes

def _create_pipeline_workflow(self, args, pipeline):
"""Create workflow for the pipeline."""

# Input Parameters
input_params = []
for arg in args:
param = {'name': arg.name}
if arg.value is not None:
param['value'] = str(arg.value)
input_params.append(param)

# Templates
templates = self._create_templates(pipeline)
templates.sort(key=lambda x: x['name'])

# Exit Handler
exit_handler = None
if pipeline.groups[0].groups:
first_group = pipeline.groups[0].groups[0]
if first_group.type == 'exit_handler':
exit_handler = first_group.exit_op

# Volumes
volumes = self._create_volumes(pipeline)

# The whole pipeline workflow
workflow = {
'apiVersion': 'argoproj.io/v1alpha1',
'kind': 'Workflow',
Expand Down Expand Up @@ -503,54 +519,6 @@ def _compile(self, pipeline_func):
workflow = self._create_pipeline_workflow(args_list_with_defaults, p)
return workflow

def _convert_k8s_obj_to_dic(self, obj):
"""
Builds a JSON K8s object.
If obj is None, return None.
If obj is str, int, long, float, bool, return directly.
If obj is datetime.datetime, datetime.date
convert to string in iso8601 format.
If obj is list, sanitize each element in the list.
If obj is dict, return the dict.
If obj is swagger model, return the properties dict.
Args:
obj: The data to serialize.
Returns: The serialized form of data.
"""

from six import text_type, integer_types, iteritems
PRIMITIVE_TYPES = (float, bool, bytes, text_type) + integer_types
from datetime import date, datetime
if obj is None:
return None
elif isinstance(obj, PRIMITIVE_TYPES):
return obj
elif isinstance(obj, list):
return [self._convert_k8s_obj_to_dic(sub_obj)
for sub_obj in obj]
elif isinstance(obj, tuple):
return tuple(self._convert_k8s_obj_to_dic(sub_obj)
for sub_obj in obj)
elif isinstance(obj, (datetime, date)):
return obj.isoformat()

if isinstance(obj, dict):
obj_dict = obj
else:
# Convert model obj to dict except
# attributes `swagger_types`, `attribute_map`
# and attributes which value is not None.
# Convert attribute name to json key in
# model definition for request.
obj_dict = {obj.attribute_map[attr]: getattr(obj, attr)
for attr, _ in iteritems(obj.swagger_types)
if getattr(obj, attr) is not None}

return {key: self._convert_k8s_obj_to_dic(val)
for key, val in iteritems(obj_dict)}

def compile(self, pipeline_func, package_path):
"""Compile the given pipeline function into workflow yaml.
Expand Down
1 change: 0 additions & 1 deletion sdk/python/kfp/components/_component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pathlib import Path
import requests
import warnings
from . import _components as comp

class ComponentStore:
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sys
from collections import OrderedDict
from ._yaml_utils import load_yaml, dump_yaml
from ._yaml_utils import load_yaml
from ._structures import ComponentSpec


Expand Down
14 changes: 14 additions & 0 deletions sdk/python/kfp/components/_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Mapping, Sequence
import types
from inspect import Parameter, Signature
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import copy
from collections import OrderedDict
from typing import Union, List, Sequence, Mapping, Tuple
from typing import Union, List, Mapping, Tuple


class InputOrOutputSpec:
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/_ops_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class OpsGroup(object):
def __init__(self, group_type: str, name: str=None):
"""Create a new instance of OpsGroup.
Args:
group_type: usually one of 'exit_handler', 'condition', and 'loop'.
group_type: one of 'pipeline', 'exit_handler', 'condition', and 'loop'.
"""
self.type = group_type
self.ops = list()
Expand Down
17 changes: 0 additions & 17 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import kfp.compiler as compiler
import kfp.dsl as dsl
import os
Expand All @@ -23,7 +22,6 @@
import tempfile
import unittest
import yaml
import datetime

class TestCompiler(unittest.TestCase):

Expand Down Expand Up @@ -142,21 +140,6 @@ def test_basic_workflow(self):
shutil.rmtree(tmpdir)
# print(tmpdir)

def test_convert_k8s_obj_to_dic_accepts_dict(self):
now = datetime.datetime.now()
converted = compiler.Compiler()._convert_k8s_obj_to_dic({
"ENV": "test",
"number": 3,
"list": [1,2,3],
"time": now
})
self.assertEqual(converted, {
"ENV": "test",
"number": 3,
"list": [1,2,3],
"time": now.isoformat()
})

def test_composing_workflow(self):
"""Test compiling a simple workflow, and a bigger one composed from the simple one."""

Expand Down
34 changes: 34 additions & 0 deletions sdk/python/tests/compiler/k8s_helper_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kfp.compiler._k8s_helper import K8sHelper
from datetime import datetime
import unittest


class TestCompiler(unittest.TestCase):
def test_convert_k8s_obj_to_dic_accepts_dict(self):
now = datetime.now()
converted = K8sHelper.convert_k8s_obj_to_json({
"ENV": "test",
"number": 3,
"list": [1,2,3],
"time": now
})
self.assertEqual(converted, {
"ENV": "test",
"number": 3,
"list": [1,2,3],
"time": now.isoformat()
})
Loading

0 comments on commit 85c6413

Please sign in to comment.