Skip to content

Commit

Permalink
Fix pydantic basemodel default input (#3013) (#3084)
Browse files Browse the repository at this point in the history
* Fix pydantic default input



* add pydantic integration test



* Use duck typing by Thomas's advice




* lint



---------

Signed-off-by: Future-Outlier <eric901201@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
Future-Outlier and thomasjpfan authored Jan 24, 2025
1 parent 29d024c commit 8a5edd3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
11 changes: 9 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,15 @@ def to_click_option(
If no custom logic exists, fall back to json.dumps.
"""
with FlyteContextManager.with_context(flyte_ctx.new_builder()):
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if hasattr(default_val, "model_dump_json"):
# pydantic v2
default_val = default_val.model_dump_json()
elif hasattr(default_val, "json"):
# pydantic v1
default_val = default_val.json()
else:
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if literal_var.type.metadata:
description_extra = f": {json.dumps(literal_var.type.metadata)}"

Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def test_remote_run():
# run twice to make sure it will register a new version of the workflow.
run("default_lp.py", "my_wf")

def test_pydantic_default_input_with_map_task():
execution_id = run("pydantic_wf.py", "wf")
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
print("Execution Error:", execution.error)
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"


def test_generic_idl_flytetypes():
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true"
Expand Down
20 changes: 20 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pydantic import BaseModel

from flytekit import map_task
from typing import List
from flytekit import task, workflow


class MyBaseModel(BaseModel):
my_floats: List[float] = [1.0, 2.0, 5.0, 10.0]

@task
def print_float(my_float: float):
print(f"my_float: {my_float}")

@workflow
def wf(bm: MyBaseModel = MyBaseModel()):
map_task(print_float)(my_float=bm.my_floats)

if __name__ == "__main__":
wf()

0 comments on commit 8a5edd3

Please sign in to comment.