From 8a5edd39288bc8d431e2f334897e01a641877c6a Mon Sep 17 00:00:00 2001 From: "Han-Ru Chen (Future-Outlier)" Date: Fri, 24 Jan 2025 13:39:06 -0800 Subject: [PATCH] Fix pydantic basemodel default input (#3013) (#3084) * Fix pydantic default input * add pydantic integration test * Use duck typing by Thomas's advice * lint --------- Signed-off-by: Future-Outlier Co-authored-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/run.py | 11 ++++++++-- .../integration/remote/test_remote.py | 8 ++++++++ .../remote/workflows/basic/pydantic_wf.py | 20 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 5c768792c1..81f05c9139 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -475,8 +475,15 @@ def to_click_option( If no custom logic exists, fall back to json.dumps. """ with FlyteContextManager.with_context(flyte_ctx.new_builder()): - encoder = JSONEncoder(python_type) - default_val = encoder.encode(default_val) + if hasattr(default_val, "model_dump_json"): + # pydantic v2 + default_val = default_val.model_dump_json() + elif hasattr(default_val, "json"): + # pydantic v1 + default_val = default_val.json() + else: + encoder = JSONEncoder(python_type) + default_val = encoder.encode(default_val) if literal_var.type.metadata: description_extra = f": {json.dumps(literal_var.type.metadata)}" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 94946237ef..1cc373f6ff 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -109,6 +109,14 @@ def test_remote_run(): # run twice to make sure it will register a new version of the workflow. run("default_lp.py", "my_wf") +def test_pydantic_default_input_with_map_task(): + execution_id = run("pydantic_wf.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + def test_generic_idl_flytetypes(): os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true" diff --git a/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py new file mode 100644 index 0000000000..d5e9c32170 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +from flytekit import map_task +from typing import List +from flytekit import task, workflow + + +class MyBaseModel(BaseModel): + my_floats: List[float] = [1.0, 2.0, 5.0, 10.0] + +@task +def print_float(my_float: float): + print(f"my_float: {my_float}") + +@workflow +def wf(bm: MyBaseModel = MyBaseModel()): + map_task(print_float)(my_float=bm.my_floats) + +if __name__ == "__main__": + wf()