diff --git a/ahk/_async/transport.py b/ahk/_async/transport.py index e0f14c8..71c5c48 100644 --- a/ahk/_async/transport.py +++ b/ahk/_async/transport.py @@ -59,6 +59,10 @@ else: from typing import TypeAlias, TypeGuard +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self T_AsyncFuture = TypeVar('T_AsyncFuture') # unasync: remove T_SyncFuture = TypeVar('T_SyncFuture') @@ -110,6 +114,9 @@ def async_assert_send_nonblocking_type_correct( class Communicable(Protocol): runargs: List[str] + async def start(self, atexit_cleanup: bool = True) -> None: ... + def astart(self, *args: Any, **kwargs: Any) -> None: ... # unasync: remove + def communicate(self, input_bytes: Optional[bytes], timeout: Optional[int] = None) -> Tuple[bytes, bytes]: ... async def acommunicate( # unasync: remove @@ -119,6 +126,8 @@ async def acommunicate( # unasync: remove @property def returncode(self) -> Optional[int]: ... + def kill(self) -> None: ... + class AsyncAHKProcess: def __init__(self, runargs: List[str]): @@ -130,9 +139,12 @@ def returncode(self) -> Optional[int]: assert self._proc is not None return self._proc.returncode - async def start(self) -> None: + def astart(self, *args: Any, **kwargs: Any) -> None: ... # unasync: remove + + async def start(self, atexit_cleanup: bool = True) -> None: self._proc = await async_create_process(self.runargs) - atexit.register(kill, self._proc) + if atexit_cleanup: + atexit.register(kill, self._proc) return None async def adrain_stdin(self) -> None: # unasync: remove @@ -183,6 +195,17 @@ def communicate(self, input_bytes: Optional[bytes] = None, timeout: Optional[int assert isinstance(self._proc, subprocess.Popen) return self._proc.communicate(input=input_bytes, timeout=timeout) + async def __aenter__(self) -> Self: + await self.start(atexit_cleanup=False) + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + try: + self.kill() + except Exception: + pass + return False + async def async_create_process(runargs: List[str]) -> asyncio.subprocess.Process: # unasync: remove return await asyncio.subprocess.create_subprocess_exec( @@ -635,7 +658,8 @@ async def start(self) -> None: assert self._proc is None, 'cannot start a process twice' with warnings.catch_warnings(record=True) as caught_warnings: async with self.lock: - self._proc = await self._create_process() + self._proc = self._create_process() + await self._proc.start() if caught_warnings: for warning in caught_warnings: warnings.warn(warning.message, warning.category, stacklevel=2) @@ -659,9 +683,7 @@ def lock(self) -> Any: return self._a_execution_lock # unasync: remove return self._execution_lock - async def _create_process( - self, template: Optional[jinja2.Template] = None, **template_kwargs: Any - ) -> AsyncAHKProcess: + def _create_process(self, template: Optional[jinja2.Template] = None, **template_kwargs: Any) -> AsyncAHKProcess: if template is None: if template_kwargs: raise ValueError('template kwargs were specified, but no template was provided') @@ -684,15 +706,13 @@ async def _create_process( atexit.register(try_remove, tempscript.name) runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', daemon_script] proc = AsyncAHKProcess(runargs=runargs) - await proc.start() return proc async def _send_nonblocking( self, request: RequestMessage, engine: Optional[AsyncAHK[Any]] = None ) -> Union[None, Tuple[int, int], int, str, bool, AsyncWindow, List[AsyncWindow], List[AsyncControl]]: msg = request.format() - proc = await self._create_process() - try: + async with self._create_process() as proc: proc.write(msg) await proc.adrain_stdin() tom = await proc.readline() @@ -715,11 +735,6 @@ async def _send_nonblocking( part = await proc.readline() content_buffer.write(part) content = content_buffer.getvalue()[:-1] - finally: - try: - proc.kill() - except: # noqa - pass response = ResponseMessage.from_bytes(content, engine=engine) return response.unpack() # type: ignore @@ -781,11 +796,17 @@ async def _async_run_nonblocking( # unasync: remove loop = asyncio.get_running_loop() async def f() -> str: - stdout, stderr = await proc.acommunicate(script_bytes, timeout) + try: + await proc.start(atexit_cleanup=False) + stdout, stderr = await proc.acommunicate(script_bytes, timeout) + finally: + try: + proc.kill() + except Exception: + pass if proc.returncode != 0: assert proc.returncode is not None raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) - return stdout.decode('utf-8') task = loop.create_task(f()) @@ -797,15 +818,23 @@ def _sync_run_nonblocking( script_bytes: Optional[bytes], timeout: Optional[int] = None, ) -> FutureResult[str]: - pool = ThreadPoolExecutor(max_workers=1) + raise RuntimeError('This method can only be called from the sync API') # unasync: remove def f() -> str: - stdout, stderr = proc.communicate(script_bytes, timeout) + try: + proc.astart(atexit_cleanup=False) + stdout, stderr = proc.communicate(script_bytes, timeout) + finally: + try: + proc.kill() + except Exception: + pass if proc.returncode != 0: assert proc.returncode is not None raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) return stdout.decode('utf-8') + pool = ThreadPoolExecutor(max_workers=1) fut = pool.submit(f) pool.shutdown(wait=False) return FutureResult(fut) @@ -830,13 +859,13 @@ async def run_script( script_bytes = bytes(script_text_or_path, 'utf-8') runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', '*'] proc = AsyncAHKProcess(runargs) - await proc.start() if blocking: - stdout, stderr = await proc.acommunicate(script_bytes, timeout=timeout) - if proc.returncode != 0: - assert proc.returncode is not None - raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) - return stdout.decode('utf-8') + async with proc: + stdout, stderr = await proc.acommunicate(script_bytes, timeout=timeout) + if proc.returncode != 0: + assert proc.returncode is not None + raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) + return stdout.decode('utf-8') else: return await self._async_run_nonblocking(proc, script_bytes, timeout=timeout) diff --git a/ahk/_sync/transport.py b/ahk/_sync/transport.py index ed369c5..74931b0 100644 --- a/ahk/_sync/transport.py +++ b/ahk/_sync/transport.py @@ -59,6 +59,10 @@ else: from typing import TypeAlias, TypeGuard +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self T_SyncFuture = TypeVar('T_SyncFuture') @@ -102,12 +106,16 @@ def async_assert_send_nonblocking_type_correct( class Communicable(Protocol): runargs: List[str] + def start(self, atexit_cleanup: bool = True) -> None: ... + def communicate(self, input_bytes: Optional[bytes], timeout: Optional[int] = None) -> Tuple[bytes, bytes]: ... @property def returncode(self) -> Optional[int]: ... + def kill(self) -> None: ... + class SyncAHKProcess: def __init__(self, runargs: List[str]): @@ -119,9 +127,11 @@ def returncode(self) -> Optional[int]: assert self._proc is not None return self._proc.returncode - def start(self) -> None: + + def start(self, atexit_cleanup: bool = True) -> None: self._proc = sync_create_process(self.runargs) - atexit.register(kill, self._proc) + if atexit_cleanup: + atexit.register(kill, self._proc) return None @@ -160,6 +170,17 @@ def communicate(self, input_bytes: Optional[bytes] = None, timeout: Optional[int assert isinstance(self._proc, subprocess.Popen) return self._proc.communicate(input=input_bytes, timeout=timeout) + def __enter__(self) -> Self: + self.start(atexit_cleanup=False) + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + try: + self.kill() + except Exception: + pass + return False + @@ -601,6 +622,7 @@ def start(self) -> None: with warnings.catch_warnings(record=True) as caught_warnings: with self.lock: self._proc = self._create_process() + self._proc.start() if caught_warnings: for warning in caught_warnings: warnings.warn(warning.message, warning.category, stacklevel=2) @@ -623,9 +645,7 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A def lock(self) -> Any: return self._execution_lock - def _create_process( - self, template: Optional[jinja2.Template] = None, **template_kwargs: Any - ) -> SyncAHKProcess: + def _create_process(self, template: Optional[jinja2.Template] = None, **template_kwargs: Any) -> SyncAHKProcess: if template is None: if template_kwargs: raise ValueError('template kwargs were specified, but no template was provided') @@ -648,15 +668,13 @@ def _create_process( atexit.register(try_remove, tempscript.name) runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', daemon_script] proc = SyncAHKProcess(runargs=runargs) - proc.start() return proc def _send_nonblocking( self, request: RequestMessage, engine: Optional[AHK[Any]] = None ) -> Union[None, Tuple[int, int], int, str, bool, Window, List[Window], List[Control]]: msg = request.format() - proc = self._create_process() - try: + with self._create_process() as proc: proc.write(msg) proc.drain_stdin() tom = proc.readline() @@ -679,11 +697,6 @@ def _send_nonblocking( part = proc.readline() content_buffer.write(part) content = content_buffer.getvalue()[:-1] - finally: - try: - proc.kill() - except: # noqa - pass response = ResponseMessage.from_bytes(content, engine=engine) return response.unpack() # type: ignore @@ -738,15 +751,22 @@ def _sync_run_nonblocking( script_bytes: Optional[bytes], timeout: Optional[int] = None, ) -> FutureResult[str]: - pool = ThreadPoolExecutor(max_workers=1) def f() -> str: - stdout, stderr = proc.communicate(script_bytes, timeout) + try: + proc.start(atexit_cleanup=False) + stdout, stderr = proc.communicate(script_bytes, timeout) + finally: + try: + proc.kill() + except Exception: + pass if proc.returncode != 0: assert proc.returncode is not None raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) return stdout.decode('utf-8') + pool = ThreadPoolExecutor(max_workers=1) fut = pool.submit(f) pool.shutdown(wait=False) return FutureResult(fut) @@ -771,13 +791,13 @@ def run_script( script_bytes = bytes(script_text_or_path, 'utf-8') runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', '*'] proc = SyncAHKProcess(runargs) - proc.start() if blocking: - stdout, stderr = proc.communicate(script_bytes, timeout=timeout) - if proc.returncode != 0: - assert proc.returncode is not None - raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) - return stdout.decode('utf-8') + with proc: + stdout, stderr = proc.communicate(script_bytes, timeout=timeout) + if proc.returncode != 0: + assert proc.returncode is not None + raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr) + return stdout.decode('utf-8') else: return self._sync_run_nonblocking(proc, script_bytes, timeout=timeout) diff --git a/buildunasync.py b/buildunasync.py index d53b399..40323af 100644 --- a/buildunasync.py +++ b/buildunasync.py @@ -19,6 +19,7 @@ 'AsyncFutureResult': 'FutureResult', '_async_run_nonblocking': '_sync_run_nonblocking', 'acommunicate': 'communicate', + 'astart': 'start', # "__aenter__": "__aenter__", }, ),