diff --git a/.github/workflows/kfp-samples.yml b/.github/workflows/kfp-samples.yml index 98fb96f899a..06cd69db560 100644 --- a/.github/workflows/kfp-samples.yml +++ b/.github/workflows/kfp-samples.yml @@ -36,9 +36,16 @@ jobs: with: k8s_version: ${{ matrix.k8s_version }} + - name: Build and upload the sample Modelcar image to Kind + run: | + docker build -f samples/v2/modelcar_import/Dockerfile -t registry.domain.local/modelcar:test . + kind --name kfp load docker-image registry.domain.local/modelcar:test + - name: Forward API port run: ./.github/resources/scripts/forward-port.sh "kubeflow" "ml-pipeline" 8888 8888 - name: Run Samples Tests + env: + PULL_NUMBER: ${{ github.event.pull_request.number }} run: | ./backend/src/v2/test/sample-test.sh diff --git a/backend/src/v2/component/importer_launcher.go b/backend/src/v2/component/importer_launcher.go index 1207ed89fea..d5649b7959a 100644 --- a/backend/src/v2/component/importer_launcher.go +++ b/backend/src/v2/component/importer_launcher.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "strings" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata" @@ -227,10 +229,6 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact state := pb.Artifact_LIVE - provider, err := objectstore.ParseProviderFromPath(artifactUri) - if err != nil { - return nil, fmt.Errorf("No Provider scheme found in artifact Uri: %s", artifactUri) - } artifact = &pb.Artifact{ TypeId: &artifactTypeId, State: &state, @@ -248,6 +246,24 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact } } + if strings.HasPrefix(artifactUri, "oci://") { + artifactType, err := metadata.SchemaToArtifactType(schema) + if err != nil { + return nil, fmt.Errorf("converting schema to artifact type failed: %w", err) + } + + if *artifactType.Name != "system.Model" { + return nil, fmt.Errorf("the %s artifact type does not support OCI registries", *artifactType.Name) + } + + return artifact, nil + } + + provider, err := objectstore.ParseProviderFromPath(artifactUri) + if err != nil { + return nil, fmt.Errorf("no provider scheme found in artifact URI: %s", artifactUri) + } + // Assume all imported artifacts will rely on execution environment for store provider session info storeSessionInfo := objectstore.SessionInfo{ Provider: provider, diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 025440e1111..fcc186f001c 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -22,8 +22,10 @@ import ( "os" "os/exec" "path/filepath" + "regexp" "strconv" "strings" + "syscall" "time" "github.com/golang/glog" @@ -43,6 +45,12 @@ import ( "k8s.io/client-go/rest" ) +var findPIDRegex *regexp.Regexp + +func init() { + findPIDRegex = regexp.MustCompile(`\/proc\/(\d+)\/.+`) +} + type LauncherV2Options struct { Namespace, PodName, @@ -398,9 +406,15 @@ func execute( namespace string, k8sClient kubernetes.Interface, ) (*pipelinespec.ExecutorOutput, error) { - if err := downloadArtifacts(ctx, executorInput, bucket, bucketConfig, namespace, k8sClient); err != nil { + cleanUpFuncs, err := downloadArtifacts(ctx, executorInput, bucket, bucketConfig, namespace, k8sClient) + for _, cleanUpFunc := range cleanUpFuncs { + defer cleanUpFunc() + } + + if err != nil { return nil, err } + if err := prepareOutputFolders(executorInput); err != nil { return nil, err } @@ -441,7 +455,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec } // Upload artifacts from local path to remote storages. - localDir, err := localPathForURI(outputArtifact.Uri) + localDir, err := LocalPathForURI(outputArtifact.Uri) if err != nil { glog.Warningf("Output Artifact %q does not have a recognized storage URI %q. Skipping uploading to remote storage.", name, outputArtifact.Uri) } else { @@ -477,7 +491,72 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec return outputArtifacts, nil } -func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient kubernetes.Interface) error { +// waitForModelcar assumes the Modelcar has already been validated by the init container on the launcher +// pod. This waits for the Modelcar as a sidecar container to be ready. It returns a function that will stop +// the Modelcar container that is to be run after the user code has completed. +func waitForModelcar(artifactURI string, localPath string) func() { + glog.Infof("Waiting for the Modelcar %s to be available", artifactURI) + + for { + _, err := os.Stat(localPath) + if err != nil { + time.Sleep(500 * time.Millisecond) + + continue + } + + targetPath, err := os.Readlink(localPath) + if err != nil { + glog.Infof( + "Expected the Modelcar local path to be a symlink, will not stop Modelcar %s", artifactURI, + ) + + return nil + } + + matches := findPIDRegex.FindStringSubmatch(targetPath) + if len(matches) != 2 { + glog.Infof( + "Expected the Modelcar symlink (%s) to start with /proc/$pid, will not stop Modelcar %s", + targetPath, artifactURI, + ) + + return nil + } + + pidStr := matches[1] + + pid, err := strconv.Atoi(pidStr) + if err != nil { + glog.Infof( + "Expected the Modelcar symlink (%s) target to start with /proc/$pid, will not stop Modelcar %s", + targetPath, artifactURI, + ) + + return nil + } + + return func() { + glog.Infof("Stopping the Modelcar: %s", artifactURI) + + process, err := os.FindProcess(pid) + if err != nil { + // If the process stopped already, nothing to do + return + } + + err = process.Signal(syscall.SIGHUP) + if err != nil { + glog.Error("Error stopping the Modelcar %s due to: %v", artifactURI, err) + + return + } + } + } +} + +// downloadArtifacts returns a slice of functions to call for artifact clean up after user code has completed. +func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient kubernetes.Interface) ([]func(), error) { // Read input artifact metadata. nonDefaultBuckets, err := fetchNonDefaultBuckets(ctx, executorInput.GetInputs().GetArtifacts(), defaultBucketConfig, namespace, k8sClient) closeNonDefaultBuckets := func(buckets map[string]*blob.Bucket) { @@ -489,19 +568,35 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor } defer closeNonDefaultBuckets(nonDefaultBuckets) if err != nil { - return fmt.Errorf("failed to fetch non default buckets: %w", err) + return nil, fmt.Errorf("failed to fetch non default buckets: %w", err) } + + cleanUpFuncs := []func(){} + for name, artifactList := range executorInput.GetInputs().GetArtifacts() { // TODO(neuromage): Support concat-based placholders for arguments. if len(artifactList.Artifacts) == 0 { continue } inputArtifact := artifactList.Artifacts[0] - localPath, err := localPathForURI(inputArtifact.Uri) + + localPath, err := LocalPathForURI(inputArtifact.Uri) if err != nil { glog.Warningf("Input Artifact %q does not have a recognized storage URI %q. Skipping downloading to local path.", name, inputArtifact.Uri) + continue } + + // OCI artifacts are handled specially + if strings.HasPrefix(inputArtifact.Uri, "oci://") { + cleanUpFunc := waitForModelcar(inputArtifact.Uri, localPath) + if cleanUpFunc != nil { + cleanUpFuncs = append(cleanUpFuncs, cleanUpFunc) + } + + continue + } + // Copy artifact to local storage. copyErr := func(err error) error { return fmt.Errorf("failed to download input artifact %q from remote storage URI %q: %w", name, inputArtifact.Uri, err) @@ -513,25 +608,25 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor if !strings.HasPrefix(inputArtifact.Uri, defaultBucketConfig.PrefixedBucket()) { nonDefaultBucketConfig, err := objectstore.ParseBucketConfigForArtifactURI(inputArtifact.Uri) if err != nil { - return fmt.Errorf("failed to parse bucketConfig for output artifact %q with uri %q: %w", name, inputArtifact.GetUri(), err) + return cleanUpFuncs, fmt.Errorf("failed to parse bucketConfig for output artifact %q with uri %q: %w", name, inputArtifact.GetUri(), err) } nonDefaultBucket, ok := nonDefaultBuckets[nonDefaultBucketConfig.PrefixedBucket()] if !ok { - return fmt.Errorf("failed to get bucket when downloading input artifact %s with bucket key %s: %w", name, nonDefaultBucketConfig.PrefixedBucket(), err) + return cleanUpFuncs, fmt.Errorf("failed to get bucket when downloading input artifact %s with bucket key %s: %w", name, nonDefaultBucketConfig.PrefixedBucket(), err) } bucket = nonDefaultBucket bucketConfig = nonDefaultBucketConfig } blobKey, err := bucketConfig.KeyFromURI(inputArtifact.Uri) if err != nil { - return copyErr(err) + return cleanUpFuncs, copyErr(err) } if err := objectstore.DownloadBlob(ctx, bucket, localPath, blobKey); err != nil { - return copyErr(err) + return cleanUpFuncs, copyErr(err) } } - return nil + return cleanUpFuncs, nil } func fetchNonDefaultBuckets( @@ -548,6 +643,12 @@ func fetchNonDefaultBuckets( } // TODO: Support multiple artifacts someday, probably through the v2 engine. artifact := artifactList.Artifacts[0] + + // OCI artifacts are handled specially + if strings.HasPrefix(artifact.Uri, "oci://") { + continue + } + // The artifact does not belong under the object store path for this run. Cases: // 1. Artifact is cached from a different run, so it may still be in the default bucket, but under a different run id subpath // 2. Artifact is imported from the same bucket, but from a different path (re-use the same session) @@ -598,7 +699,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma key := fmt.Sprintf(`{{$.inputs.artifacts['%s'].uri}}`, name) placeholders[key] = inputArtifact.Uri - localPath, err := localPathForURI(inputArtifact.Uri) + localPath, err := LocalPathForURI(inputArtifact.Uri) if err != nil { // Input Artifact does not have a recognized storage URI continue @@ -617,7 +718,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma outputArtifact := artifactList.Artifacts[0] placeholders[fmt.Sprintf(`{{$.outputs.artifacts['%s'].uri}}`, name)] = outputArtifact.Uri - localPath, err := localPathForURI(outputArtifact.Uri) + localPath, err := LocalPathForURI(outputArtifact.Uri) if err != nil { return nil, fmt.Errorf("resolve output artifact %q's local path: %w", name, err) } @@ -720,7 +821,7 @@ func getExecutorOutputFile(path string) (*pipelinespec.ExecutorOutput, error) { return executorOutput, nil } -func localPathForURI(uri string) (string, error) { +func LocalPathForURI(uri string) (string, error) { if strings.HasPrefix(uri, "gs://") { return "/gcs/" + strings.TrimPrefix(uri, "gs://"), nil } @@ -730,6 +831,9 @@ func localPathForURI(uri string) (string, error) { if strings.HasPrefix(uri, "s3://") { return "/s3/" + strings.TrimPrefix(uri, "s3://"), nil } + if strings.HasPrefix(uri, "oci://") { + return "/oci/" + strings.ReplaceAll(strings.TrimPrefix(uri, "oci://"), "/", "\\/") + "/models", nil + } return "", fmt.Errorf("failed to generate local path for URI %s: unsupported storage scheme", uri) } @@ -747,7 +851,7 @@ func prepareOutputFolders(executorInput *pipelinespec.ExecutorInput) error { } outputArtifact := artifactList.Artifacts[0] - localPath, err := localPathForURI(outputArtifact.Uri) + localPath, err := LocalPathForURI(outputArtifact.Uri) if err != nil { return fmt.Errorf("failed to generate local storage path for output artifact %q: %w", name, err) } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index e104fc1a9c9..ea37fcd502d 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -544,9 +545,132 @@ func initPodSpecPatch( Env: userEnvVar, }}, } + + addModelcarsToPodSpec(executorInput.GetInputs().GetArtifacts(), userEnvVar, podSpec) + return podSpec, nil } +// addModelcarsToPodSpec will patch the pod spec if there are any input artifacts in the Modelcar format. +// Much of this logic is based on KServe: +// https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L131 +func addModelcarsToPodSpec( + artifacts map[string]*pipelinespec.ArtifactList, + userEnvVar []k8score.EnvVar, + podSpec *k8score.PodSpec, +) { + mainContainer := &podSpec.Containers[0] + modelcarAdded := false + + for name, artifactList := range artifacts { + if len(artifactList.Artifacts) == 0 { + continue + } + + // Following the convention of downloadArtifacts in the launcher to only look at the first in the list. + inputArtifact := artifactList.Artifacts[0] + + // This should ideally verify that this is also a model input artifact, but this metadata doesn't seem to + // be set on inputArtifact. + if !strings.HasPrefix(inputArtifact.Uri, "oci://") { + continue + } + + localPath, err := component.LocalPathForURI(inputArtifact.Uri) + if err != nil { + continue + } + + if !modelcarAdded { + modelcarAdded = true + + // If there is at least one Modelcar image, then shareProcessNamespace must be enabled. + trueVal := true + podSpec.ShareProcessNamespace = &trueVal + + if mainContainer.SecurityContext == nil { + mainContainer.SecurityContext = &k8score.SecurityContext{} + } + + if mainContainer.SecurityContext.Capabilities == nil { + mainContainer.SecurityContext.Capabilities = &k8score.Capabilities{} + } + + // The SYS_PTRACE capability is required for the main container to access the Modelcar container's + // filesystem. The SIGHUP capability is required for the main container to stop the Modelcar container + // when launcher is exiting. + mainContainer.SecurityContext.Capabilities.Add = append( + mainContainer.SecurityContext.Capabilities.Add, "SYS_PTRACE", "SIGHUP", + ) + } + + image := strings.TrimPrefix(inputArtifact.Uri, "oci://") + + podSpec.InitContainers = append( + podSpec.InitContainers, + k8score.Container{ + Name: "oci-prepull-" + name, + Image: image, + Command: []string{ + "sh", + "-c", + // Check that the expected models directory exists + // Taken from KServe: + // https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L732 + "echo 'Pre-fetching modelcar " + image + ": ' && [ -d /models ] && " + + "[ \"$$(ls -A /models)\" ] && echo 'OK ... Prefetched and valid (/models exists)' || " + + "(echo 'NOK ... Prefetched but modelcar is invalid (/models does not exist or is empty)' && " + + " exit 1)", + }, + Env: userEnvVar, + TerminationMessagePolicy: k8score.TerminationMessageFallbackToLogsOnError, + }, + ) + + volumeName := "oci-" + name + + podSpec.Volumes = append( + podSpec.Volumes, + k8score.Volume{ + Name: volumeName, + VolumeSource: k8score.VolumeSource{ + EmptyDir: &k8score.EmptyDirVolumeSource{}, + }, + }, + ) + + mountPath := strings.TrimSuffix(localPath, "/models") + + emptyDirVolumeMount := k8score.VolumeMount{ + Name: volumeName, + MountPath: mountPath, + SubPath: strings.TrimPrefix(mountPath, "/oci/"), + } + + mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, emptyDirVolumeMount) + + podSpec.Containers = append( + podSpec.Containers, + k8score.Container{ + Name: "oci-" + name, + Image: image, + ImagePullPolicy: "IfNotPresent", + Env: userEnvVar, + VolumeMounts: []k8score.VolumeMount{emptyDirVolumeMount}, + Command: []string{ + "sh", + "-c", + // $$$$ gets escaped by YAML to $$, which is the current PID + // Taken from KServe: + // https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L732 + fmt.Sprintf("ln -s /proc/$$$$/root/models \"%s\" && sleep infinity", localPath), + }, + TerminationMessagePolicy: k8score.TerminationMessageFallbackToLogsOnError, + }, + ) + } +} + // Extends the PodSpec to include Kubernetes-specific executor config. func extendPodSpecPatch( podSpec *k8score.PodSpec, diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index c84896c41de..4fb05089ac5 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -391,6 +391,47 @@ func Test_initPodSpecPatch_legacy_resources(t *testing.T) { assert.Equal(t, k8sres.MustParse("1"), res.Limits[k8score.ResourceName("nvidia.com/gpu")]) } +func Test_initPodSpecPatch_modelcar_input_artifact(t *testing.T) { + containerSpec := &pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec{ + Image: "python:3.9", + Args: []string{"--function_to_execute", "add"}, + Command: []string{"sh", "-ec", "python3 -m kfp.components.executor_main"}, + } + componentSpec := &pipelinespec.ComponentSpec{} + executorInput := &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + Artifacts: map[string]*pipelinespec.ArtifactList{ + "my-model": { + Artifacts: []*pipelinespec.RuntimeArtifact{ + { + Uri: "oci://registry.domain.local/my-model:latest", + }, + }, + }, + }, + }, + } + + podSpec, err := initPodSpecPatch( + containerSpec, componentSpec, executorInput, 27, "test", "0254beba-0be4-4065-8d97-7dc5e3adf300", + ) + assert.Nil(t, err) + assert.Len(t, podSpec.Containers, 1) + assert.Len(t, podSpec.InitContainers, 1) + assert.Equal(t, podSpec.InitContainers[0].Name, "oci-my-model") + assert.Len(t, podSpec.InitContainers[0].Command, 3) + expectedCopyCmd := "mkdir -p '/oci/registry.domain.local\\/my-model:latest' && " + + "cp -R /models/* '/oci/registry.domain.local\\/my-model:latest'" + assert.Equal(t, podSpec.InitContainers[0].Command[2], expectedCopyCmd) + assert.Len(t, podSpec.Volumes, 1) + assert.Equal(t, podSpec.Volumes[0].Name, "oci-my-model") + assert.NotNil(t, podSpec.Volumes[0].EmptyDir) + assert.Len(t, podSpec.Containers[0].VolumeMounts, 1) + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].Name, "oci-my-model") + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].MountPath, "/oci/registry.domain.local\\/my-model:latest") + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].SubPath, "registry.domain.local\\/my-model:latest") +} + func Test_makeVolumeMountPatch(t *testing.T) { type args struct { pvcMount []*kubernetesplatform.PvcMount diff --git a/backend/src/v2/test/sample-test.sh b/backend/src/v2/test/sample-test.sh index b03b5ea9147..9af6e746686 100755 --- a/backend/src/v2/test/sample-test.sh +++ b/backend/src/v2/test/sample-test.sh @@ -23,6 +23,14 @@ python3 -m pip install -r ./requirements-sample-test.txt popd +if [[ -n "${PULL_NUMBER}" ]]; then + export KFP_PACKAGE_PATH="git+https://github.com/kubeflow/pipelines@refs/pull/${PULL_NUMBER}/merge#egg=kfp&subdirectory=sdk/python" +else + export KFP_PACKAGE_PATH='git+https://github.com/kubeflow/pipelines#egg=kfp&subdirectory=sdk/python' +fi + +python3 -m pip install $KFP_PACKAGE_PATH + # The -u flag makes python output unbuffered, so that we can see real time log. # Reference: https://stackoverflow.com/a/107717 python3 -u ./samples/v2/sample_test.py diff --git a/samples/v2/modelcar_import/Dockerfile b/samples/v2/modelcar_import/Dockerfile new file mode 100644 index 00000000000..c2166cd12cb --- /dev/null +++ b/samples/v2/modelcar_import/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.13-slim as base + +USER 0 + +RUN pip install huggingface-hub + +# Download a small model file from Hugging Face +RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='openai/whisper-tiny', local_dir='/models',allow_patterns=['*.safetensors', '*.json', '*.txt'], revision='169d4a4341b33bc18d8881c4b69c2e104e1cc0af')" + +# Final image containing only the essential model files +FROM alpine:3.19 + +RUN mkdir /models + +# Copy the model files from the base container +COPY --from=base /models /models + +USER 1001 diff --git a/samples/v2/modelcar_import/modelcar_import.py b/samples/v2/modelcar_import/modelcar_import.py new file mode 100755 index 00000000000..63a1bfdef20 --- /dev/null +++ b/samples/v2/modelcar_import/modelcar_import.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2025 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pipeline that imports a model in the Modelcar format from an OCI registry.""" + +import os + +from kfp import compiler +from kfp import dsl +from kfp.dsl import component + +# In tests, we install a KFP package from the PR under test. Users should not +# normally need to specify `kfp_package_path` in their component definitions. +_KFP_PACKAGE_PATH = os.getenv("KFP_PACKAGE_PATH") + + +@dsl.component(kfp_package_path=_KFP_PACKAGE_PATH) +def get_model_files_list(model: dsl.Input[dsl.Model]) -> str: + import os + import os.path + + if not os.path.exists(model.path): + raise RuntimeError(f"The model does not exist at: {model.path}") + + expected_files = { + "added_tokens.json", + "config.json", + "generation_config.json", + "merges.txt", + "model.safetensors", + "normalizer.json", + "preprocessor_config.json", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "vocab.json", + } + + filesInPath = set(os.listdir(model.path)) + + if not filesInPath.issuperset(expected_files): + raise RuntimeError( + "The model does not have expected files: " + + ", ".join(sorted(expected_files.difference(filesInPath))) + ) + + return ", ".join(sorted(filesInPath)) + + +@dsl.pipeline(name="pipeline-with-modelcar-model") +def pipeline_modelcar_import( + model_uri: str = "oci://registry.domain.local/modelcar:test", +): + model_source_oci_task = dsl.importer( + artifact_uri=model_uri, artifact_class=dsl.Model + ).set_caching_options(False) + + get_model_files_list(model=model_source_oci_task.output).set_caching_options(False) + + +if __name__ == "__main__": + compiler.Compiler().compile( + pipeline_func=pipeline_modelcar_import, package_path=__file__ + ".yaml" + ) diff --git a/samples/v2/sample_test.py b/samples/v2/sample_test.py index 2af7c4fba7d..38920f33870 100644 --- a/samples/v2/sample_test.py +++ b/samples/v2/sample_test.py @@ -29,6 +29,7 @@ import producer_consumer_param import subdagio import two_step_pipeline_containerized +from modelcar_import import modelcar_import _MINUTE = 60 # seconds _DEFAULT_TIMEOUT = 5 * _MINUTE @@ -74,6 +75,7 @@ def test(self): TestCase(pipeline_func=subdagio.artifact.crust), TestCase( pipeline_func=subdagio.multiple_artifacts_namedtuple.crust), + TestCase(pipeline_func=modelcar_import.pipeline_modelcar_import), ] with ThreadPoolExecutor() as executor: diff --git a/sdk/python/kfp/dsl/types/artifact_types.py b/sdk/python/kfp/dsl/types/artifact_types.py index e91fe8081e1..1100db84025 100644 --- a/sdk/python/kfp/dsl/types/artifact_types.py +++ b/sdk/python/kfp/dsl/types/artifact_types.py @@ -20,10 +20,12 @@ _GCS_LOCAL_MOUNT_PREFIX = '/gcs/' _MINIO_LOCAL_MOUNT_PREFIX = '/minio/' _S3_LOCAL_MOUNT_PREFIX = '/s3/' +_OCI_LOCAL_MOUNT_PREFIX = '/oci/' GCS_REMOTE_PREFIX = 'gs://' MINIO_REMOTE_PREFIX = 'minio://' S3_REMOTE_PREFIX = 's3://' +OCI_REMOTE_PREFIX = 'oci://' class Artifact: @@ -94,6 +96,10 @@ def _get_path(self) -> Optional[str]: ):] elif self.uri.startswith(S3_REMOTE_PREFIX): return _S3_LOCAL_MOUNT_PREFIX + self.uri[len(S3_REMOTE_PREFIX):] + + elif self.uri.startswith(OCI_REMOTE_PREFIX): + escaped_uri = self.uri[len(OCI_REMOTE_PREFIX):].replace('/', '\\/') + return _OCI_LOCAL_MOUNT_PREFIX + escaped_uri # uri == path for local execution return self.uri @@ -108,6 +114,9 @@ def convert_local_path_to_remote_path(path: str) -> str: return MINIO_REMOTE_PREFIX + path[len(_MINIO_LOCAL_MOUNT_PREFIX):] elif path.startswith(_S3_LOCAL_MOUNT_PREFIX): return S3_REMOTE_PREFIX + path[len(_S3_LOCAL_MOUNT_PREFIX):] + elif path.startswith(_OCI_LOCAL_MOUNT_PREFIX): + return OCI_REMOTE_PREFIX + path[len(_OCI_LOCAL_MOUNT_PREFIX):].replace( + '\\/', '/') return path diff --git a/sdk/python/kfp/dsl/types/artifact_types_test.py b/sdk/python/kfp/dsl/types/artifact_types_test.py index c34f4a6bba0..cb6bc1e58c6 100644 --- a/sdk/python/kfp/dsl/types/artifact_types_test.py +++ b/sdk/python/kfp/dsl/types/artifact_types_test.py @@ -138,6 +138,8 @@ class TestConvertLocalPathToRemotePath(parameterized.TestCase): ('/gcs/foo/bar', 'gs://foo/bar'), ('/minio/foo/bar', 'minio://foo/bar'), ('/s3/foo/bar', 's3://foo/bar'), + ('/oci/quay.io\\/org\\/repo:latest/models', + 'oci://quay.io/org/repo:latest'), ('/tmp/kfp_outputs', '/tmp/kfp_outputs'), ('/some/random/path', '/some/random/path'), ]])