Skip to content

Commit

Permalink
Fix @metric_scope for generator and async generator functions (#113)
Browse files Browse the repository at this point in the history
Co-authored-by: Alin RADU <alinra@amazon.com>
  • Loading branch information
acradu and Alin RADU authored Oct 2, 2024
1 parent 66e8b6b commit 1836dd8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 10 deletions.
55 changes: 45 additions & 10 deletions aws_embedded_metrics/metric_scope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,70 @@


def metric_scope(fn): # type: ignore
if inspect.isasyncgenfunction(fn):
@wraps(fn)
async def async_gen_wrapper(*args, **kwargs): # type: ignore
logger = create_metrics_logger()
if "metrics" in inspect.signature(fn).parameters:
kwargs["metrics"] = logger

try:
fn_gen = fn(*args, **kwargs)
while True:
result = await fn_gen.__anext__()
await logger.flush()
yield result
except Exception as ex:
await logger.flush()
if not isinstance(ex, StopIteration):
raise

return async_gen_wrapper

elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def gen_wrapper(*args, **kwargs): # type: ignore
logger = create_metrics_logger()
if "metrics" in inspect.signature(fn).parameters:
kwargs["metrics"] = logger

try:
fn_gen = fn(*args, **kwargs)
while True:
result = next(fn_gen)
asyncio.run(logger.flush())
yield result
except Exception as ex:
asyncio.run(logger.flush())
if not isinstance(ex, StopIteration):
raise

if asyncio.iscoroutinefunction(fn):
return gen_wrapper

elif asyncio.iscoroutinefunction(fn):
@wraps(fn)
async def wrapper(*args, **kwargs): # type: ignore
async def async_wrapper(*args, **kwargs): # type: ignore
logger = create_metrics_logger()
if "metrics" in inspect.signature(fn).parameters:
kwargs["metrics"] = logger

try:
return await fn(*args, **kwargs)
except Exception as e:
raise e
finally:
await logger.flush()

return wrapper
else:
return async_wrapper

else:
@wraps(fn)
def wrapper(*args, **kwargs): # type: ignore
logger = create_metrics_logger()
if "metrics" in inspect.signature(fn).parameters:
kwargs["metrics"] = logger

try:
return fn(*args, **kwargs)
except Exception as e:
raise e
finally:
loop = asyncio.get_event_loop()
loop.run_until_complete(logger.flush())
asyncio.run(logger.flush())

return wrapper
37 changes: 37 additions & 0 deletions tests/metric_scope/test_metric_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,43 @@ def my_handler(metrics):
actual_timestamp_second = int(round(logger.context.meta["Timestamp"] / 1000))
assert expected_timestamp_second == actual_timestamp_second


def test_sync_scope_iterates_generator(mock_logger):
expected_results = [1, 2]

@metric_scope
def my_handler():
yield from expected_results
raise Exception("test exception")

actual_results = []
with pytest.raises(Exception, match="test exception"):
for result in my_handler():
actual_results.append(result)

assert actual_results == expected_results
assert InvocationTracker.invocations == 3


@pytest.mark.asyncio
async def test_async_scope_iterates_async_generator(mock_logger):
expected_results = [1, 2]

@metric_scope
async def my_handler():
for item in expected_results:
yield item
await asyncio.sleep(1)
raise Exception("test exception")

actual_results = []
with pytest.raises(Exception, match="test exception"):
async for result in my_handler():
actual_results.append(result)

assert actual_results == expected_results
assert InvocationTracker.invocations == 3

# Test helpers


Expand Down

0 comments on commit 1836dd8

Please sign in to comment.