Skip to content

Commit

Permalink
Fix: workflow().with_resources(...) properly copies default source …
Browse files Browse the repository at this point in the history
…and dependency imports (#393)

# The problem

`with_resources` function omitted copying of the imports causing
problems when this parameter was used

# This PR's solution

properly copy all parameters

# Checklist

_Check that this PR satisfies the following items:_

- [x] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [x] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [x] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).
  • Loading branch information
SebastianMorawiec authored Apr 23, 2024
1 parent bb12529 commit e851657
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

* `sdk.workflow(fn, resources=...)` will no longer show type errors from linters.
* CLI log dumping now correctly saves stdout and stderr logs
* `workflow().with_resources(...)` properly copies default source and dependency imports

💅 *Improvements*

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ dependencies = [
# Capture stdout/stderr
"wurlitzer>=3.0",
# For dremio client
"pyarrow>=10.0",
# pyarrow 16.0 crashed ray workers for unknown reason. Crashes were not
# reproducable on mac - so carefull with taking that restriction away.
"pyarrow>=10.0,<16.0",
"pandas>=1.4",
]

Expand Down
2 changes: 2 additions & 0 deletions src/orquestra/sdk/_client/_base/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def with_resources(
data_aggregation=self._data_aggregation,
workflow_args=self._workflow_args,
workflow_kwargs=self._workflow_kwargs,
default_source_import=self.default_source_import,
default_dependency_imports=self.default_dependency_imports,
)


Expand Down
17 changes: 13 additions & 4 deletions tests/sdk/test_consistent_return_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ def test_consistent_returns_for_single_value(
m = re.match(
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
)
assert m is not None

assert (
m is not None
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
run_id_ray = m.group("run_id").strip()
assert "Workflow Submitted!" in run_ce.stdout.decode()

Expand Down Expand Up @@ -520,7 +523,9 @@ def test_consistent_returns_for_multiple_values(
m = re.match(
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
)
assert m is not None
assert (
m is not None
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
run_id_ray = m.group("run_id").strip()
assert "Workflow Submitted!" in run_ce.stdout.decode()

Expand Down Expand Up @@ -591,7 +596,9 @@ def test_consistent_downloads_for_single_value(
m = re.match(
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
)
assert m is not None
assert (
m is not None
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
run_id_ray = m.group("run_id").strip()
assert mock_ce_run_single in run_ce.stdout.decode()

Expand Down Expand Up @@ -669,7 +676,9 @@ def test_consistent_downloads_for_multiple_values(
m = re.match(
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
)
assert m is not None
assert (
m is not None
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
run_id_ray = m.group("run_id").strip()
assert mock_ce_run_multiple in run_ce.stdout.decode()

Expand Down
21 changes: 21 additions & 0 deletions tests/sdk/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,24 @@ def my_workflow():
pass

assert my_workflow._default_dependency_imports == expected_imports


def test_with_resources_copies_imports():
@sdk.workflow(
default_dependency_imports=[sdk.PythonImports("abc")],
default_source_import=sdk.GitImport(repo_url="abc", git_ref="xyz"),
)
def my_workflow():
pass

initial_workflow = my_workflow()
modified_workflow = initial_workflow.with_resources(cpu="xyz")

assert (
initial_workflow.default_dependency_imports
== modified_workflow.default_dependency_imports
)
assert (
initial_workflow.default_source_import
== modified_workflow.default_source_import
)

0 comments on commit e851657

Please sign in to comment.