diff --git a/.github/workflows/kfp-samples.yml b/.github/workflows/kfp-samples.yml index 98fb96f899a..06cd69db560 100644 --- a/.github/workflows/kfp-samples.yml +++ b/.github/workflows/kfp-samples.yml @@ -36,9 +36,16 @@ jobs: with: k8s_version: ${{ matrix.k8s_version }} + - name: Build and upload the sample Modelcar image to Kind + run: | + docker build -f samples/v2/modelcar_import/Dockerfile -t registry.domain.local/modelcar:test . + kind --name kfp load docker-image registry.domain.local/modelcar:test + - name: Forward API port run: ./.github/resources/scripts/forward-port.sh "kubeflow" "ml-pipeline" 8888 8888 - name: Run Samples Tests + env: + PULL_NUMBER: ${{ github.event.pull_request.number }} run: | ./backend/src/v2/test/sample-test.sh diff --git a/backend/src/v2/component/importer_launcher.go b/backend/src/v2/component/importer_launcher.go index 1207ed89fea..d5649b7959a 100644 --- a/backend/src/v2/component/importer_launcher.go +++ b/backend/src/v2/component/importer_launcher.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "strings" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata" @@ -227,10 +229,6 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact state := pb.Artifact_LIVE - provider, err := objectstore.ParseProviderFromPath(artifactUri) - if err != nil { - return nil, fmt.Errorf("No Provider scheme found in artifact Uri: %s", artifactUri) - } artifact = &pb.Artifact{ TypeId: &artifactTypeId, State: &state, @@ -248,6 +246,24 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact } } + if strings.HasPrefix(artifactUri, "oci://") { + artifactType, err := metadata.SchemaToArtifactType(schema) + if err != nil { + return nil, fmt.Errorf("converting schema to artifact type failed: %w", err) + } + + if *artifactType.Name != "system.Model" { + return nil, fmt.Errorf("the %s artifact type does not support OCI registries", *artifactType.Name) + } + + return artifact, nil + } + + provider, err := objectstore.ParseProviderFromPath(artifactUri) + if err != nil { + return nil, fmt.Errorf("no provider scheme found in artifact URI: %s", artifactUri) + } + // Assume all imported artifacts will rely on execution environment for store provider session info storeSessionInfo := objectstore.SessionInfo{ Provider: provider, diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 025440e1111..017f8339b73 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -127,12 +127,52 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co }, nil } +// stopWaitingArtifacts will create empty files to tell Modelcar sidecar containers to stop. Any errors encountered are +// logged since this is meant as a deferred function at the end of the launcher's execution. +func stopWaitingArtifacts(artifacts map[string]*pipelinespec.ArtifactList) { + for _, artifactList := range artifacts { + if len(artifactList.Artifacts) == 0 { + continue + } + + // Following the convention of downloadArtifacts in the launcher to only look at the first in the list. + inputArtifact := artifactList.Artifacts[0] + + // This should ideally verify that this is also a model input artifact, but this metadata doesn't seem to + // be set on inputArtifact. + if !strings.HasPrefix(inputArtifact.Uri, "oci://") { + continue + } + + localPath, err := LocalPathForURI(inputArtifact.Uri) + if err != nil { + continue + } + + glog.Infof("Stopping artifact %s", inputArtifact.Uri) + + // launcher-complete + launcherCompleteFile := strings.TrimSuffix(localPath, "/models") + "/launcher-complete" + _, err = os.Create(launcherCompleteFile) + if err != nil { + glog.Errorf( + "Failed to stop the artifact %s by creating %s: %v", inputArtifact.Uri, launcherCompleteFile, err, + ) + + continue + } + } +} + func (l *LauncherV2) Execute(ctx context.Context) (err error) { defer func() { if err != nil { err = fmt.Errorf("failed to execute component: %w", err) } }() + + defer stopWaitingArtifacts(l.executorInput.GetInputs().GetArtifacts()) + // publish execution regardless the task succeeds or not var execution *metadata.Execution var executorOutput *pipelinespec.ExecutorOutput @@ -401,6 +441,7 @@ func execute( if err := downloadArtifacts(ctx, executorInput, bucket, bucketConfig, namespace, k8sClient); err != nil { return nil, err } + if err := prepareOutputFolders(executorInput); err != nil { return nil, err } @@ -441,7 +482,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec } // Upload artifacts from local path to remote storages. - localDir, err := localPathForURI(outputArtifact.Uri) + localDir, err := LocalPathForURI(outputArtifact.Uri) if err != nil { glog.Warningf("Output Artifact %q does not have a recognized storage URI %q. Skipping uploading to remote storage.", name, outputArtifact.Uri) } else { @@ -477,6 +518,27 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec return outputArtifacts, nil } +// waitForModelcar assumes the Modelcar has already been validated by the init container on the launcher +// pod. This waits for the Modelcar as a sidecar container to be ready. +func waitForModelcar(artifactURI string, localPath string) error { + glog.Infof("Waiting for the Modelcar %s to be available", artifactURI) + + for { + _, err := os.Stat(localPath) + if err == nil { + glog.Infof("The Modelcar is now available at %s", localPath) + + return nil + } + + if !os.IsNotExist(err) { + return fmt.Errorf("failed to see if the artifact %s was ready at %s: %v", artifactURI, localPath, err) + } + + time.Sleep(500 * time.Millisecond) + } +} + func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient kubernetes.Interface) error { // Read input artifact metadata. nonDefaultBuckets, err := fetchNonDefaultBuckets(ctx, executorInput.GetInputs().GetArtifacts(), defaultBucketConfig, namespace, k8sClient) @@ -491,17 +553,31 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor if err != nil { return fmt.Errorf("failed to fetch non default buckets: %w", err) } + for name, artifactList := range executorInput.GetInputs().GetArtifacts() { // TODO(neuromage): Support concat-based placholders for arguments. if len(artifactList.Artifacts) == 0 { continue } inputArtifact := artifactList.Artifacts[0] - localPath, err := localPathForURI(inputArtifact.Uri) + + localPath, err := LocalPathForURI(inputArtifact.Uri) if err != nil { glog.Warningf("Input Artifact %q does not have a recognized storage URI %q. Skipping downloading to local path.", name, inputArtifact.Uri) + continue } + + // OCI artifacts are handled specially + if strings.HasPrefix(inputArtifact.Uri, "oci://") { + err := waitForModelcar(inputArtifact.Uri, localPath) + if err != nil { + return err + } + + continue + } + // Copy artifact to local storage. copyErr := func(err error) error { return fmt.Errorf("failed to download input artifact %q from remote storage URI %q: %w", name, inputArtifact.Uri, err) @@ -548,6 +624,12 @@ func fetchNonDefaultBuckets( } // TODO: Support multiple artifacts someday, probably through the v2 engine. artifact := artifactList.Artifacts[0] + + // OCI artifacts are handled specially + if strings.HasPrefix(artifact.Uri, "oci://") { + continue + } + // The artifact does not belong under the object store path for this run. Cases: // 1. Artifact is cached from a different run, so it may still be in the default bucket, but under a different run id subpath // 2. Artifact is imported from the same bucket, but from a different path (re-use the same session) @@ -598,7 +680,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma key := fmt.Sprintf(`{{$.inputs.artifacts['%s'].uri}}`, name) placeholders[key] = inputArtifact.Uri - localPath, err := localPathForURI(inputArtifact.Uri) + localPath, err := LocalPathForURI(inputArtifact.Uri) if err != nil { // Input Artifact does not have a recognized storage URI continue @@ -617,7 +699,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma outputArtifact := artifactList.Artifacts[0] placeholders[fmt.Sprintf(`{{$.outputs.artifacts['%s'].uri}}`, name)] = outputArtifact.Uri - localPath, err := localPathForURI(outputArtifact.Uri) + localPath, err := LocalPathForURI(outputArtifact.Uri) if err != nil { return nil, fmt.Errorf("resolve output artifact %q's local path: %w", name, err) } @@ -720,7 +802,7 @@ func getExecutorOutputFile(path string) (*pipelinespec.ExecutorOutput, error) { return executorOutput, nil } -func localPathForURI(uri string) (string, error) { +func LocalPathForURI(uri string) (string, error) { if strings.HasPrefix(uri, "gs://") { return "/gcs/" + strings.TrimPrefix(uri, "gs://"), nil } @@ -730,6 +812,9 @@ func localPathForURI(uri string) (string, error) { if strings.HasPrefix(uri, "s3://") { return "/s3/" + strings.TrimPrefix(uri, "s3://"), nil } + if strings.HasPrefix(uri, "oci://") { + return "/oci/" + strings.ReplaceAll(strings.TrimPrefix(uri, "oci://"), "/", "\\/") + "/models", nil + } return "", fmt.Errorf("failed to generate local path for URI %s: unsupported storage scheme", uri) } @@ -747,7 +832,7 @@ func prepareOutputFolders(executorInput *pipelinespec.ExecutorInput) error { } outputArtifact := artifactList.Artifacts[0] - localPath, err := localPathForURI(outputArtifact.Uri) + localPath, err := LocalPathForURI(outputArtifact.Uri) if err != nil { return fmt.Errorf("failed to generate local storage path for output artifact %q: %w", name, err) } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 8cd02d46508..46aad7eadaa 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -549,9 +550,117 @@ func initPodSpecPatch( Env: userEnvVar, }}, } + + addModelcarsToPodSpec(executorInput.GetInputs().GetArtifacts(), userEnvVar, podSpec) + return podSpec, nil } +// addModelcarsToPodSpec will patch the pod spec if there are any input artifacts in the Modelcar format. +// Much of this logic is based on KServe: +// https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L131 +func addModelcarsToPodSpec( + artifacts map[string]*pipelinespec.ArtifactList, + userEnvVar []k8score.EnvVar, + podSpec *k8score.PodSpec, +) { + for name, artifactList := range artifacts { + if len(artifactList.Artifacts) == 0 { + continue + } + + // Following the convention of downloadArtifacts in the launcher to only look at the first in the list. + inputArtifact := artifactList.Artifacts[0] + + // This should ideally verify that this is also a model input artifact, but this metadata doesn't seem to + // be set on inputArtifact. + if !strings.HasPrefix(inputArtifact.Uri, "oci://") { + continue + } + + localPath, err := component.LocalPathForURI(inputArtifact.Uri) + if err != nil { + continue + } + + // If there is at least one Modelcar image, then shareProcessNamespace must be enabled. + trueVal := true + podSpec.ShareProcessNamespace = &trueVal + + image := strings.TrimPrefix(inputArtifact.Uri, "oci://") + + podSpec.InitContainers = append( + podSpec.InitContainers, + k8score.Container{ + Name: "oci-prepull-" + name, + Image: image, + Command: []string{ + "sh", + "-c", + // Check that the expected models directory exists + // Taken from KServe: + // https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L732 + "echo 'Pre-fetching modelcar " + image + ": ' && [ -d /models ] && " + + "[ \"$$(ls -A /models)\" ] && echo 'OK ... Prefetched and valid (/models exists)' || " + + "(echo 'NOK ... Prefetched but modelcar is invalid (/models does not exist or is empty)' && " + + " exit 1)", + }, + Env: userEnvVar, + TerminationMessagePolicy: k8score.TerminationMessageFallbackToLogsOnError, + }, + ) + + volumeName := "oci-" + name + + podSpec.Volumes = append( + podSpec.Volumes, + k8score.Volume{ + Name: volumeName, + VolumeSource: k8score.VolumeSource{ + EmptyDir: &k8score.EmptyDirVolumeSource{}, + }, + }, + ) + + mountPath := strings.TrimSuffix(localPath, "/models") + + emptyDirVolumeMount := k8score.VolumeMount{ + Name: volumeName, + MountPath: mountPath, + SubPath: strings.TrimPrefix(mountPath, "/oci/"), + } + + podSpec.Containers[0].VolumeMounts = append(podSpec.Containers[0].VolumeMounts, emptyDirVolumeMount) + + podSpec.Containers = append( + podSpec.Containers, + k8score.Container{ + Name: "oci-" + name, + Image: image, + ImagePullPolicy: "IfNotPresent", + Env: userEnvVar, + VolumeMounts: []k8score.VolumeMount{emptyDirVolumeMount}, + Command: []string{ + "sh", + "-c", + // $$$$ gets escaped by YAML to $$, which is the current PID + // Mostly taken from KServe, but sleeps until the existence of a file that gets created by launcher + // on exit. This approach is taken instead of having the main container send a SIGHUP to the + // sleep process to avoid the need for the SYS_PTRACE capability which is not always available + // depending on the security context restrictions. + // https://github.com/kserve/kserve/blob/v0.14.1/pkg/webhook/admission/pod/storage_initializer_injector.go#L732 + fmt.Sprintf( + "ln -s /proc/$$$$/root/models \"%s\" && "+ + "echo \"Running...\" && until [ -f \"%s/launcher-complete\" ]; do sleep 1; done", + localPath, mountPath, + ), + }, + TerminationMessagePolicy: k8score.TerminationMessageFallbackToLogsOnError, + }, + ) + } +} + // Extends the PodSpec to include Kubernetes-specific executor config. func extendPodSpecPatch( podSpec *k8score.PodSpec, diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index c84896c41de..9696e0ba928 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -391,6 +391,54 @@ func Test_initPodSpecPatch_legacy_resources(t *testing.T) { assert.Equal(t, k8sres.MustParse("1"), res.Limits[k8score.ResourceName("nvidia.com/gpu")]) } +func Test_initPodSpecPatch_modelcar_input_artifact(t *testing.T) { + containerSpec := &pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec{ + Image: "python:3.9", + Args: []string{"--function_to_execute", "add"}, + Command: []string{"sh", "-ec", "python3 -m kfp.components.executor_main"}, + } + componentSpec := &pipelinespec.ComponentSpec{} + executorInput := &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + Artifacts: map[string]*pipelinespec.ArtifactList{ + "my-model": { + Artifacts: []*pipelinespec.RuntimeArtifact{ + { + Uri: "oci://registry.domain.local/my-model:latest", + }, + }, + }, + }, + }, + } + + podSpec, err := initPodSpecPatch( + containerSpec, componentSpec, executorInput, 27, "test", "0254beba-0be4-4065-8d97-7dc5e3adf300", + ) + assert.Nil(t, err) + + assert.Len(t, podSpec.InitContainers, 1) + assert.Equal(t, podSpec.InitContainers[0].Name, "oci-prepull-my-model") + assert.Equal(t, podSpec.InitContainers[0].Image, "registry.domain.local/my-model:latest") + + assert.Len(t, podSpec.Volumes, 1) + assert.Equal(t, podSpec.Volumes[0].Name, "oci-my-model") + assert.NotNil(t, podSpec.Volumes[0].EmptyDir) + + assert.Len(t, podSpec.Containers, 2) + assert.Len(t, podSpec.Containers[0].VolumeMounts, 1) + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].Name, "oci-my-model") + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].MountPath, "/oci/registry.domain.local\\/my-model:latest") + assert.Equal(t, podSpec.Containers[0].VolumeMounts[0].SubPath, "registry.domain.local\\/my-model:latest") + + assert.Equal(t, podSpec.Containers[1].Name, "oci-my-model") + assert.Equal(t, podSpec.Containers[1].Image, "registry.domain.local/my-model:latest") + assert.Len(t, podSpec.Containers[1].VolumeMounts, 1) + assert.Equal(t, podSpec.Containers[1].VolumeMounts[0].Name, "oci-my-model") + assert.Equal(t, podSpec.Containers[1].VolumeMounts[0].MountPath, "/oci/registry.domain.local\\/my-model:latest") + assert.Equal(t, podSpec.Containers[1].VolumeMounts[0].SubPath, "registry.domain.local\\/my-model:latest") +} + func Test_makeVolumeMountPatch(t *testing.T) { type args struct { pvcMount []*kubernetesplatform.PvcMount diff --git a/backend/src/v2/test/sample-test.sh b/backend/src/v2/test/sample-test.sh index b03b5ea9147..9af6e746686 100755 --- a/backend/src/v2/test/sample-test.sh +++ b/backend/src/v2/test/sample-test.sh @@ -23,6 +23,14 @@ python3 -m pip install -r ./requirements-sample-test.txt popd +if [[ -n "${PULL_NUMBER}" ]]; then + export KFP_PACKAGE_PATH="git+https://github.com/kubeflow/pipelines@refs/pull/${PULL_NUMBER}/merge#egg=kfp&subdirectory=sdk/python" +else + export KFP_PACKAGE_PATH='git+https://github.com/kubeflow/pipelines#egg=kfp&subdirectory=sdk/python' +fi + +python3 -m pip install $KFP_PACKAGE_PATH + # The -u flag makes python output unbuffered, so that we can see real time log. # Reference: https://stackoverflow.com/a/107717 python3 -u ./samples/v2/sample_test.py diff --git a/samples/v2/modelcar_import/Dockerfile b/samples/v2/modelcar_import/Dockerfile new file mode 100644 index 00000000000..c2166cd12cb --- /dev/null +++ b/samples/v2/modelcar_import/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.13-slim as base + +USER 0 + +RUN pip install huggingface-hub + +# Download a small model file from Hugging Face +RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='openai/whisper-tiny', local_dir='/models',allow_patterns=['*.safetensors', '*.json', '*.txt'], revision='169d4a4341b33bc18d8881c4b69c2e104e1cc0af')" + +# Final image containing only the essential model files +FROM alpine:3.19 + +RUN mkdir /models + +# Copy the model files from the base container +COPY --from=base /models /models + +USER 1001 diff --git a/samples/v2/modelcar_import/modelcar_import.py b/samples/v2/modelcar_import/modelcar_import.py new file mode 100755 index 00000000000..9fcd4baeace --- /dev/null +++ b/samples/v2/modelcar_import/modelcar_import.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2025 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pipeline that imports a model in the Modelcar format from an OCI registry.""" + +import os + +from kfp import compiler +from kfp import dsl +from kfp.dsl import component + +# In tests, we install a KFP package from the PR under test. Users should not +# normally need to specify `kfp_package_path` in their component definitions. +_KFP_PACKAGE_PATH = os.getenv("KFP_PACKAGE_PATH") + + +@dsl.component(kfp_package_path=_KFP_PACKAGE_PATH) +def get_model_files_list(model: dsl.Input[dsl.Model]) -> str: + import os + import os.path + + if not os.path.exists(model.path): + raise RuntimeError(f"The model does not exist at: {model.path}") + + expected_files = { + "added_tokens.json", + "config.json", + "generation_config.json", + "merges.txt", + "model.safetensors", + "normalizer.json", + "preprocessor_config.json", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "vocab.json", + } + + filesInPath = set(os.listdir(model.path)) + + if not filesInPath.issuperset(expected_files): + raise RuntimeError( + "The model does not have expected files: " + + ", ".join(sorted(expected_files.difference(filesInPath))) + ) + + return ", ".join(sorted(filesInPath)) + + +@dsl.pipeline(name="pipeline-with-modelcar-model") +def pipeline_modelcar_import( + model_uri: str = "oci://registry.domain.local/modelcar:test", +): + model_source_oci_task = dsl.importer( + artifact_uri=model_uri, artifact_class=dsl.Model + ) + + get_model_files_list(model=model_source_oci_task.output).set_caching_options(False) + + +if __name__ == "__main__": + compiler.Compiler().compile( + pipeline_func=pipeline_modelcar_import, package_path=__file__ + ".yaml" + ) diff --git a/samples/v2/sample_test.py b/samples/v2/sample_test.py index 9b357e1cf5e..a3466b448a9 100644 --- a/samples/v2/sample_test.py +++ b/samples/v2/sample_test.py @@ -30,6 +30,7 @@ import subdagio import two_step_pipeline_containerized import pipeline_with_placeholders +from modelcar_import import modelcar_import _MINUTE = 60 # seconds _DEFAULT_TIMEOUT = 5 * _MINUTE @@ -76,6 +77,7 @@ def test(self): TestCase( pipeline_func=subdagio.multiple_artifacts_namedtuple.crust), TestCase(pipeline_func=pipeline_with_placeholders.pipeline_with_placeholders), + TestCase(pipeline_func=modelcar_import.pipeline_modelcar_import), ] with ThreadPoolExecutor() as executor: diff --git a/sdk/python/kfp/dsl/types/artifact_types.py b/sdk/python/kfp/dsl/types/artifact_types.py index e91fe8081e1..54f9d17e24c 100644 --- a/sdk/python/kfp/dsl/types/artifact_types.py +++ b/sdk/python/kfp/dsl/types/artifact_types.py @@ -20,10 +20,12 @@ _GCS_LOCAL_MOUNT_PREFIX = '/gcs/' _MINIO_LOCAL_MOUNT_PREFIX = '/minio/' _S3_LOCAL_MOUNT_PREFIX = '/s3/' +_OCI_LOCAL_MOUNT_PREFIX = '/oci/' GCS_REMOTE_PREFIX = 'gs://' MINIO_REMOTE_PREFIX = 'minio://' S3_REMOTE_PREFIX = 's3://' +OCI_REMOTE_PREFIX = 'oci://' class Artifact: @@ -94,6 +96,10 @@ def _get_path(self) -> Optional[str]: ):] elif self.uri.startswith(S3_REMOTE_PREFIX): return _S3_LOCAL_MOUNT_PREFIX + self.uri[len(S3_REMOTE_PREFIX):] + + elif self.uri.startswith(OCI_REMOTE_PREFIX): + escaped_uri = self.uri[len(OCI_REMOTE_PREFIX):].replace('/', '\\/') + return _OCI_LOCAL_MOUNT_PREFIX + escaped_uri # uri == path for local execution return self.uri @@ -108,6 +114,14 @@ def convert_local_path_to_remote_path(path: str) -> str: return MINIO_REMOTE_PREFIX + path[len(_MINIO_LOCAL_MOUNT_PREFIX):] elif path.startswith(_S3_LOCAL_MOUNT_PREFIX): return S3_REMOTE_PREFIX + path[len(_S3_LOCAL_MOUNT_PREFIX):] + elif path.startswith(_OCI_LOCAL_MOUNT_PREFIX): + remotePath = OCI_REMOTE_PREFIX + path[len(_OCI_LOCAL_MOUNT_PREFIX + ):].replace('\\/', '/') + if remotePath.endswith("/models"): + return remotePath[:-len("/models")] + + return remotePath + return path @@ -128,6 +142,13 @@ def framework(self) -> str: def _get_framework(self) -> str: return self.metadata.get('framework', '') + @property + def path(self) -> str: + if self.uri.startswith("oci://"): + return self._get_path() + "/models" + + return self._get_path() + @framework.setter def framework(self, framework: str) -> None: self._set_framework(framework) diff --git a/sdk/python/kfp/dsl/types/artifact_types_test.py b/sdk/python/kfp/dsl/types/artifact_types_test.py index c34f4a6bba0..f21031d3a2a 100644 --- a/sdk/python/kfp/dsl/types/artifact_types_test.py +++ b/sdk/python/kfp/dsl/types/artifact_types_test.py @@ -138,6 +138,9 @@ class TestConvertLocalPathToRemotePath(parameterized.TestCase): ('/gcs/foo/bar', 'gs://foo/bar'), ('/minio/foo/bar', 'minio://foo/bar'), ('/s3/foo/bar', 's3://foo/bar'), + ('/oci/quay.io\\/org\\/repo:latest/models', + 'oci://quay.io/org/repo:latest'), + ('/oci/quay.io\\/org\\/repo:latest', 'oci://quay.io/org/repo:latest'), ('/tmp/kfp_outputs', '/tmp/kfp_outputs'), ('/some/random/path', '/some/random/path'), ]])