Skip to content

Commit

Permalink
v1.14 Backport 3086 - StructuredDatasets inside dc/pydantic (#3088)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored Jan 25, 2025
1 parent 8a5edd3 commit e807e60
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ def to_literal(
) -> Literal:
"""
The current dance is because we are allowing users to call from an async function, this synchronous
to_literal function, and allowing this to_literal function, to then invoke yet another async functionl,
to_literal function, and allowing this to_literal function, to then invoke yet another async function,
namely an async transformer.
"""
from flytekit.core.promise import Promise
Expand Down
6 changes: 6 additions & 0 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ def _upload_chunk(self, final=False):
"""Only uploads the file at once from the buffer.
Not suitable for large files as the buffer will blow the memory for very large files.
Suitable for default values or local dataframes being uploaded all at once.
This function is called by fsspec.flush(). This will create a new file upload location.
"""
if final is False:
return False
self.buffer.seek(0)
data = self.buffer.read()

try:
# The inputs here are flipped a bit, it should be the filename is set to the filename and the filename root
# is something deterministic, like a hash. But since this is supposed to mimic open(), we can't hash.
# With the args currently below, the backend will create a random suffix for the filename.
# Since no hash is set on it, we will not be able to write to it again (which is totally fine).
res = self._remote.client.get_upload_signed_url(
self._remote.default_project,
self._remote.default_domain,
Expand Down
12 changes: 9 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flytekit import lazy_module
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, modify_literal_uris
from flytekit.deck.renderer import Renderable
from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator
from flytekit.loggers import developer_logger, logger
Expand Down Expand Up @@ -63,6 +63,7 @@ class (that is just a model, a Python class representation of the protobuf).
file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String()))

def _serialize(self) -> Dict[str, Optional[str]]:
# dataclass case
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, type(self), None
)
Expand All @@ -85,7 +86,7 @@ def _deserialize(cls, value) -> "StructuredDataset":
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
structured_dataset=literals.StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
Expand All @@ -98,6 +99,7 @@ def _deserialize(cls, value) -> "StructuredDataset":

@model_serializer
def serialize_structured_dataset(self) -> Dict[str, Optional[str]]:
# pydantic case
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, type(self), None
)
Expand All @@ -117,7 +119,7 @@ def deserialize_structured_dataset(self, info) -> StructuredDataset:
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
structured_dataset=literals.StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.file_format)
),
Expand Down Expand Up @@ -807,6 +809,10 @@ def encode(
# with a format of "" is used.
sd_model.metadata._structured_dataset_type.format = handler.supported_format
lit = Literal(scalar=Scalar(structured_dataset=sd_model))

# Because the handler.encode may have uploaded something, and because the sd may end up living inside a
# dataclass, we need to modify any uploaded flyte:// urls here.
modify_literal_uris(lit)
sd._literal_sd = sd_model
sd._already_uploaded = True
return lit
Expand Down
35 changes: 34 additions & 1 deletion tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from enum import Enum
from dataclasses_json import DataClassJsonMixin
import mock
from pathlib import Path
from mashumaro.mixins.json import DataClassJSONMixin
import os
import sys
Expand All @@ -17,6 +18,9 @@
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

pd = pytest.importorskip("pandas")


@pytest.fixture
def local_dummy_txt_file():
fd, path = tempfile.mkstemp(suffix=".txt")
Expand Down Expand Up @@ -1132,3 +1136,32 @@ class B():
res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B])

assert res.x.path == "s3://my-bucket/my-file"


@mock.patch("flytekit.remote.remote_fs.FlytePathResolver")
def test_modify_literal_uris_call(mock_resolver):
ctx = FlyteContextManager.current_context()

sd = StructuredDataset(dataframe=pd.DataFrame(
{"a": [1, 2], "b": [3, 4]}))

@dataclass
class DC1:
s: StructuredDataset

bm = DC1(s=sd)

def mock_resolve_remote_path(flyte_uri: str):
p = Path(flyte_uri)
if p.exists():
return "/my/replaced/val"
return ""

mock_resolver.resolve_remote_path.side_effect = mock_resolve_remote_path
mock_resolver.protocol = "/"

lt = TypeEngine.to_literal_type(DC1)
lit = TypeEngine.to_literal(ctx, bm, DC1, lt)

bm_revived = TypeEngine.to_python_value(ctx, lit, DC1)
assert bm_revived.s.literal.uri == "/my/replaced/val"
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import tempfile
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union
from unittest.mock import patch

import mock
import pytest
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand All @@ -15,12 +16,13 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.models.annotation import TypeAnnotation
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType, SimpleType
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset

pd = pytest.importorskip("pandas")


class Status(Enum):
PENDING = "pending"
Expand Down Expand Up @@ -992,3 +994,31 @@ class BM(BaseModel):
c: str = "Hello, Flyte"

assert TypeEngine.to_literal_type(BM).annotation == TypeAnnotation({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}})


@mock.patch("flytekit.remote.remote_fs.FlytePathResolver")
def test_modify_literal_uris_call(mock_resolver):
ctx = FlyteContextManager.current_context()

sd = StructuredDataset(dataframe=pd.DataFrame(
{"a": [1, 2], "b": [3, 4]}))

class BM(BaseModel):
s: StructuredDataset

bm = BM(s=sd)

def mock_resolve_remote_path(flyte_uri: str):
p = Path(flyte_uri)
if p.exists():
return "/my/replaced/val"
return ""

mock_resolver.resolve_remote_path.side_effect = mock_resolve_remote_path
mock_resolver.protocol = "/"

lt = TypeEngine.to_literal_type(BM)
lit = TypeEngine.to_literal(ctx, bm, BM, lt)

bm_revived = TypeEngine.to_python_value(ctx, lit, BM)
assert bm_revived.s.literal.uri == "/my/replaced/val"
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from collections import OrderedDict
from pathlib import Path

import mock
import google.cloud.bigquery
import pytest
from fsspec.utils import get_protocol
Expand Down Expand Up @@ -661,7 +661,6 @@ def wf_with_input() -> pd.DataFrame:
pd.testing.assert_frame_equal(wf_with_input(), input_val)



def test_read_sd_from_local_uri(local_tmp_pqt_file):

@task
Expand All @@ -677,9 +676,40 @@ def read_sd_from_local_uri(uri: str) -> pd.DataFrame:

return df


df = generate_pandas()

# Read sd from local uri
df_local = read_sd_from_local_uri(uri=local_tmp_pqt_file)
pd.testing.assert_frame_equal(df, df_local)


@mock.patch("flytekit.remote.remote_fs.FlytePathResolver")
@mock.patch("flytekit.types.structured.structured_dataset.StructuredDatasetTransformerEngine.get_encoder")
def test_modify_literal_uris_call(mock_get_encoder, mock_resolver):

ctx = FlyteContextManager.current_context()

sd = StructuredDataset(dataframe=pd.DataFrame(
{"a": [1, 2], "b": [3, 4]}), uri="bq://blah", file_format="bq")

def mock_resolve_remote_path(flyte_uri: str) -> typing.Optional[str]:
if flyte_uri == "bq://blah":
return "bq://blah/blah/blah"
return ""

mock_resolver.resolve_remote_path.side_effect = mock_resolve_remote_path
mock_resolver.protocol = "bq"

dummy_encoder = mock.MagicMock()
sd_model = literals.StructuredDataset(uri="bq://blah", metadata=StructuredDatasetMetadata(StructuredDatasetType(format="parquet")))
dummy_encoder.encode.return_value = sd_model

mock_get_encoder.return_value = dummy_encoder

sdte = StructuredDatasetTransformerEngine()
lt = LiteralType(
structured_dataset_type=StructuredDatasetType()
)

lit = sdte.encode(ctx, sd, df_type=pd.DataFrame, protocol="bq", format="parquet", structured_literal_type=lt)
assert lit.scalar.structured_dataset.uri == "bq://blah/blah/blah"

0 comments on commit e807e60

Please sign in to comment.