Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix ShellOperation running on Windows (#70)
Browse files Browse the repository at this point in the history
* Flush after write

* Fix issue with Windows commands (#71)

* Adds windows test

* Updates Windows tests to mark runs as failed with failed tests

* Add write trigger command

* Static analysis

* Fix asserts

* Try again

* Add changelog and strip newline

* Do not strip new line

* Fix tests?

* Try again

* Attempt

* Again

* Confident?

* Ok I think so

* I think this should do it...

---------

Co-authored-by: Alexander Streed <desertaxle@users.noreply.github.com>
  • Loading branch information
ahuang11 and desertaxle authored Feb 17, 2023
1 parent 761cf1f commit cac3c3a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 51 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/windows_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ jobs:
env:
PREFECT_ORION_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///./orion-tests.db"
run: |
coverage run --branch -m pytest tests -vv
coverage report
pytest tests -vv
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Released on February 17, 2023.
### Changed
- Change the behavior of the `ShellOperation` `stream_output` parameter. Setting it to `False` will now only turn off the logging and not send `stdout` and `stderr` to `DEVNULL`. The previous behavior can be achieved by manually setting `stdout`/`stderr` to `DEVNULL` through the `open_kwargs` arguments. - [#67](https://github.com/PrefectHQ/prefect-shell/issues/67)

### Fixed
- Using `ShellOperation` on Windows - [#70](https://github.com/PrefectHQ/prefect-shell/issues/70)

## 0.1.4

Released on February 2nd, 2023.
Expand Down
69 changes: 41 additions & 28 deletions prefect_shell/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import subprocess
import sys
import tempfile
from contextlib import AsyncExitStack
from typing import Any, Dict, List, Optional, Union
from contextlib import AsyncExitStack, contextmanager
from typing import Any, Dict, Generator, List, Optional, Union

import anyio
from anyio.abc import Process
Expand Down Expand Up @@ -264,39 +264,52 @@ class ShellOperation(JobBlock):
default_factory=AsyncExitStack,
)

def _compile_kwargs(self, **open_kwargs: Dict[str, Any]) -> Dict[str, Any]:
@contextmanager
def _prep_trigger_command(self) -> Generator[str, None, None]:
"""
Helper method to compile the kwargs for `open_process` so it's not repeated
across the run and trigger methods.
Write the commands to a temporary file, handling all the details of
creating the file and cleaning it up afterwards. Then, return the command
to run the temporary file.
"""
extension = self.extension or (".ps1" if sys.platform == "win32" else ".sh")
temp_file = self._exit_stack.enter_context(
tempfile.NamedTemporaryFile(
try:
extension = self.extension or (".ps1" if sys.platform == "win32" else ".sh")
temp_file = tempfile.NamedTemporaryFile(
prefix="prefect-",
suffix=extension,
delete=False,
)
)

joined_commands = os.linesep.join(self.commands)
self.logger.debug(
f"Writing the following commands to "
f"{temp_file.name!r}:{os.linesep}{joined_commands}"
)
temp_file.write(joined_commands.encode())
temp_file.flush()

if self.shell is None and sys.platform == "win32" or extension == ".ps1":
shell = "powershell"
elif self.shell is None:
shell = "bash"
else:
shell = self.shell.lower()

if shell == "powershell":
# if powershell, set exit code to that of command
temp_file.write("\r\nExit $LastExitCode".encode())
joined_commands = os.linesep.join(self.commands)
self.logger.debug(
f"Writing the following commands to "
f"{temp_file.name!r}:{os.linesep}{joined_commands}"
)
temp_file.write(joined_commands.encode())

if self.shell is None and sys.platform == "win32" or extension == ".ps1":
shell = "powershell"
elif self.shell is None:
shell = "bash"
else:
shell = self.shell.lower()

if shell == "powershell":
# if powershell, set exit code to that of command
temp_file.write("\r\nExit $LastExitCode".encode())
temp_file.close()

trigger_command = [shell, temp_file.name]
yield trigger_command
finally:
if os.path.exists(temp_file.name):
os.remove(temp_file.name)

trigger_command = [shell, temp_file.name]
def _compile_kwargs(self, **open_kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to compile the kwargs for `open_process` so it's not repeated
across the run and trigger methods.
"""
trigger_command = self._exit_stack.enter_context(self._prep_trigger_command())
input_env = os.environ.copy()
input_env.update(self.env)
input_open_kwargs = dict(
Expand Down
48 changes: 27 additions & 21 deletions tests/test_commands_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@ def test_shell_run_command_error_windows(prefect_task_runs_caplog):
def test_flow():
return shell_run_command(command="throw", return_all=True, shell="powershell")

with pytest.raises(RuntimeError, match="Exception:"):
with pytest.raises(RuntimeError, match="Exception"):
test_flow()

assert len(prefect_task_runs_caplog.records) == 7


def test_shell_run_command_windows(prefect_task_runs_caplog):
prefect_task_runs_caplog.set_level(logging.INFO)
echo_msg = "_THIS_ IS WORKING!!!!"
echo_msg = "WORKING"

@flow
def test_flow():
msg = shell_run_command(
command=f"echo {echo_msg}", return_all=True, shell="powershell"
)
return " ".join(word.replace("\r", "") for word in msg)
return msg

print(prefect_task_runs_caplog.text)

assert test_flow() == echo_msg
assert " ".join(test_flow()) == echo_msg
for record in prefect_task_runs_caplog.records:
if echo_msg in record.msg:
if "WORKING" in record.msg:
break # it's in the records
else:
raise AssertionError


def test_shell_run_command_stream_level_windows(prefect_task_runs_caplog):
prefect_task_runs_caplog.set_level(logging.WARNING)
echo_msg = "_THIS_ IS WORKING!!!!"
echo_msg = "WORKING"

@flow
def test_flow():
Expand All @@ -58,13 +58,13 @@ def test_flow():
return_all=True,
shell="powershell",
)
return " ".join(word.replace("\r", "") for word in msg)
return msg

print(prefect_task_runs_caplog.text)

assert test_flow() == echo_msg
assert " ".join(test_flow()) == echo_msg
for record in prefect_task_runs_caplog.records:
if echo_msg in record.msg:
if "WORKING" in record.msg:
break # it's in the records
else:
raise AssertionError
Expand All @@ -77,21 +77,23 @@ def test_flow():
command="Get-Location",
helper_command="cd $env:USERPROFILE",
shell="powershell",
return_all=True,
)

assert test_flow() == os.path.expandvars("$USERPROFILE")
assert os.path.expandvars("$USERPROFILE") in test_flow()


def test_shell_run_command_cwd():
@flow
def test_flow():
return shell_run_command(
command="Get-Location",
command="echo 'work!'; Get-Location",
shell="powershell",
cwd=Path.home(),
return_all=True,
)

assert test_flow() == os.fspath(Path.home())
assert os.fspath(Path.home()) in test_flow()


def test_shell_run_command_return_all():
Expand Down Expand Up @@ -137,8 +139,8 @@ def test_flow():
)

result = test_flow()
assert result[0] == os.environ["USERPROFILE"]
assert result[1] == "test value"
assert os.environ["USERPROFILE"] in " ".join(result)
assert "test value" in result


def test_shell_run_command_ensure_suffix_ps1():
Expand Down Expand Up @@ -214,6 +216,10 @@ async def execute(self, op, method):
await proc.wait_for_completion()
return await proc.fetch_result()

def test_echo(self):
op = ShellOperation(commands=["echo Hello"])
assert op.run() == ["Hello"]

@pytest.mark.parametrize("method", ["run", "trigger"])
async def test_error(self, method):
op = ShellOperation(commands=["throw"])
Expand All @@ -222,12 +228,12 @@ async def test_error(self, method):

@pytest.mark.parametrize("method", ["run", "trigger"])
async def test_output(self, prefect_task_runs_caplog, method):
op = ShellOperation(commands=["echo 'testing\nthe output'", "echo good"])
assert await self.execute(op, method) == ["testing", "the output", "good"]
op = ShellOperation(commands=["echo 'testing'"])
assert await self.execute(op, method) == ["testing"]
records = prefect_task_runs_caplog.records
assert len(records) == 3
assert "triggered with 2 commands running" in records[0].message
assert "stream output:\ntesting\nthe output\ngood" in records[1].message
assert "triggered with 1 commands running" in records[0].message
assert "testing" in records[1].message
assert "completed with return code 0" in records[2].message

@pytest.mark.parametrize("method", ["run", "trigger"])
Expand All @@ -240,12 +246,12 @@ async def test_updated_env(self, method):
op = ShellOperation(
commands=["echo $env:TEST_VAR"], env={"TEST_VAR": "test value"}
)
assert await self.execute(op, method) == ["test_value"]
assert await self.execute(op, method) == ["test value"]

@pytest.mark.parametrize("method", ["run", "trigger"])
async def test_cwd(self, method):
op = ShellOperation(commands=["Get-Location"], working_dir=Path.home())
assert await self.execute(op, method) == [os.fspath(Path.home())]
assert os.fspath(Path.home()) in (await self.execute(op, method))

async def test_context_manager(self):
async with ShellOperation(commands=["echo 'testing'"]) as op:
Expand All @@ -257,4 +263,4 @@ def test_async_context_manager(self):
with ShellOperation(commands=["echo 'testing'"]) as op:
proc = op.trigger()
proc.wait_for_completion()
proc.fetch_result() == ["testing"]
proc.fetch_result() == ["testing", ""]

0 comments on commit cac3c3a

Please sign in to comment.