Skip to content

Commit

Permalink
[server] Fix tests + allow server to parse generation kwargs (#1311)
Browse files Browse the repository at this point in the history
* update predict to call the pipeline directly

* fix tests
  • Loading branch information
dsikka authored Oct 11, 2023
1 parent 130e27d commit 2eb9d3c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
18 changes: 16 additions & 2 deletions src/deepsparse/server/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,25 @@ def _add_inference_endpoints(

if hasattr(pipeline.input_schema, "from_files"):
routes_and_fns.append(
(route, partial(Server.predict_from_files, ProxyPipeline(pipeline)))
(
route + "/from_files",
partial(
Server.predict_from_files,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)
else:
routes_and_fns.append(
(route, partial(Server.predict, ProxyPipeline(pipeline)))
(
route,
partial(
Server.predict,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)

self._update_routes(
Expand Down
20 changes: 14 additions & 6 deletions src/deepsparse/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,27 @@ async def predict(
system_logging_config: SystemLoggingConfig,
raw_request: Request,
):
request = proxy_pipeline.pipeline.input_schema(**await raw_request.json())
pipeline_outputs = proxy_pipeline.pipeline(request)
pipeline_outputs = proxy_pipeline.pipeline(**await raw_request.json())
server_logger = proxy_pipeline.pipeline.logger
if server_logger:
log_system_information(
server_logger=server_logger, system_logging_config=system_logging_config
)
pipeline_outputs = prep_outputs_for_serialization(pipeline_outputs)
return pipeline_outputs
return prep_outputs_for_serialization(pipeline_outputs)

@staticmethod
def predict_from_files(proxy_pipeline: ProxyPipeline, request: List[UploadFile]):
def predict_from_files(
proxy_pipeline: ProxyPipeline,
system_logging_config: SystemLoggingConfig,
request: List[UploadFile],
):
request = proxy_pipeline.pipeline.input_schema.from_files(
(file.file for file in request), from_server=True
)
return Server.predict(request)
pipeline_outputs = proxy_pipeline.pipeline(request)
server_logger = proxy_pipeline.pipeline.logger
if server_logger:
log_system_information(
server_logger=server_logger, system_logging_config=system_logging_config
)
return prep_outputs_for_serialization(pipeline_outputs)
15 changes: 10 additions & 5 deletions tests/server/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class StrSchema(BaseModel):
value: str


def parse(v: StrSchema) -> int:
return int(v.value)
def parse(value) -> int:
return int(value)


class TestStatusEndpoints:
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_add_model_endpoint(
):
mock_pipeline = Mock(
side_effect=parse,
input_schema=StrSchema,
input_schema=str,
output_schema=int,
logger=MultiLogger([]),
)
Expand Down Expand Up @@ -146,6 +146,7 @@ def test_add_model_endpoint_with_from_files(self, server, app):
assert app.routes[-1].path == "/v2/models/predict/parse_int/infer/from_files"
assert app.routes[-1].endpoint.func.__annotations__ == {
"proxy_pipeline": ProxyPipeline,
"system_logging_config": SystemLoggingConfig,
"request": List[UploadFile],
}
assert app.routes[-1].response_model is int
Expand All @@ -159,9 +160,12 @@ def test_sagemaker_only_adds_one_endpoint(self, sagemaker_server, app):
pipeline=Mock(input_schema=FromFilesSchema, output_schema=int),
)
assert len(app.routes) == num_routes + 1
assert app.routes[-1].path == "/invocations/predict/parse_int/infer"
num_routes = len(app.routes)

assert app.routes[-1].path == "/invocations/predict/parse_int/infer/from_files"
assert app.routes[-1].endpoint.func.__annotations__ == {
"proxy_pipeline": ProxyPipeline,
"system_logging_config": SystemLoggingConfig,
"request": List[UploadFile],
}

Expand All @@ -174,7 +178,8 @@ def test_sagemaker_only_adds_one_endpoint(self, sagemaker_server, app):
assert app.routes[-1].path == "/invocations/predict/parse_int/infer"
assert app.routes[-1].endpoint.func.__annotations__ == {
"proxy_pipeline": ProxyPipeline,
"request": List[UploadFile],
"system_logging_config": SystemLoggingConfig,
"raw_request": Request,
}

def test_add_endpoint_with_no_route_specified(self, server, app):
Expand Down

0 comments on commit 2eb9d3c

Please sign in to comment.