diff --git a/arg_services/__init__.py b/arg_services/__init__.py index 2d53cd3..c8dd973 100644 --- a/arg_services/__init__.py +++ b/arg_services/__init__.py @@ -8,7 +8,7 @@ from grpc_reflection.v1alpha import reflection -def handle_except(ex: Exception, ctx: grpc.ServicerContext) -> None: +def handle_except(ex: Exception, ctx: t.Optional[grpc.ServicerContext]) -> None: """Handler that can be called when handling an exception. It will pass the traceback to the gRPC client and abort the context. @@ -19,13 +19,17 @@ def handle_except(ex: Exception, ctx: grpc.ServicerContext) -> None: """ msg = "".join(traceback.TracebackException.from_exception(ex).format()) - ctx.abort(grpc.StatusCode.UNKNOWN, msg) + + if ctx is None: + raise ex + else: + ctx.abort(grpc.StatusCode.UNKNOWN, msg) def require_any( attrs: t.Collection[str], obj: object, - ctx: grpc.ServicerContext, + ctx: t.Optional[grpc.ServicerContext], parent: str = "request", ) -> None: """Verify that any of the required arguments are supplied by the client. @@ -45,16 +49,18 @@ def require_any( attr_result = [attr_result] if not any(attr_result): - ctx.abort( - grpc.StatusCode.INVALID_ARGUMENT, - f"The message '{parent}' requires the following attributes: {attrs}.", - ) + msg = f"The message '{parent}' requires the following attributes: {attrs}." + + if ctx is None: + raise ValueError(msg) + else: + ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, msg) def require_all( attrs: t.Collection[str], obj: object, - ctx: grpc.ServicerContext, + ctx: t.Optional[grpc.ServicerContext] = None, parent: str = "request", ) -> None: """Verify that all required arguments are supplied by the client. @@ -74,17 +80,19 @@ def require_all( attr_result = [attr_result] if not all(attr_result): - ctx.abort( - grpc.StatusCode.INVALID_ARGUMENT, - f"The message '{parent}' requires the following attributes: {attrs}.", - ) + msg = f"The message '{parent}' requires the following attributes: {attrs}." + + if ctx is None: + raise ValueError(msg) + else: + ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, msg) def require_all_repeated( key: str, attrs: t.Collection[str], obj: object, - ctx: grpc.ServicerContext, + ctx: t.Optional[grpc.ServicerContext] = None, ) -> None: """Verify that all required arguments are supplied by the client. @@ -106,7 +114,7 @@ def require_any_repeated( key: str, attrs: t.Collection[str], obj: object, - ctx: grpc.ServicerContext, + ctx: t.Optional[grpc.ServicerContext] = None, ) -> None: """Verify that any required arguments are supplied by the client. @@ -127,7 +135,7 @@ def require_any_repeated( def forbid_all( attrs: t.Collection[str], obj: object, - ctx: grpc.ServicerContext, + ctx: t.Optional[grpc.ServicerContext] = None, parent: str = "request", ) -> None: """Verify that no illegal combination of arguments is provided by the client. @@ -145,13 +153,12 @@ def forbid_all( attr_result = [attr_result] if all(attr_result): - ctx.abort( - grpc.StatusCode.INVALID_ARGUMENT, - ( - f"The message '{parent}' is not allowed to allowed to have the" - f" following parameter combination: {attrs}." - ), - ) + error = f"The message '{parent}' is not allowed to allowed to have the following parameter combination: {attrs}." + + if ctx is None: + raise ValueError(error) + else: + ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, error) def full_service_name(pkg, service: str) -> str: