Skip to content

Commit

Permalink
[deepsparse.server] integrate pipeline logging into deepsparse server (
Browse files Browse the repository at this point in the history
…#641)

* [deepsparse.server] utility functions for instantiating loggers

* default logger info log

* allow config to be dict type for server config support

* [deepsparse.server] integrate pipeline logging into deepsparse server

* update config test

* tmp_path fixture use

* review response

* suppress logging on existing tests, add test for defualt logging server

* fix tests after rebase

* prometheus integration fixes

* update ServerConfig.loggers to only allow 'default' as str arg, make None mean no loggers

* review responses

* update test name

* fix prometheus test

* review fix
  • Loading branch information
bfineran authored Sep 14, 2022
1 parent a650b9a commit 16fed62
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 29 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,python,react,pycharm,emacs,vim,visualstudio,visualstudiocode
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,python,react,pycharm,emacs,vim,visualstudio,visualstudiocode

# logging
**/*.prom

### Emacs ###
# -*- mode: gitignore; -*-
*~
Expand Down Expand Up @@ -766,4 +769,4 @@ MigrationBackup/
licenses/
engine-version.txt
src/deepsparse/generated_version.py
.idea/*
.idea/*
19 changes: 14 additions & 5 deletions src/deepsparse/loggers/prometheus_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PrometheusLogger(BaseLogger):

def __init__(
self,
port: int = 8000,
port: int = 6100,
text_log_save_dir: str = os.getcwd(),
text_log_save_freq: int = 10,
text_log_file_name: Optional[str] = None,
Expand Down Expand Up @@ -123,8 +123,16 @@ def log_latency(
self._update_call_count(pipeline_name)
self._export_metrics_to_textfile()

def log_data(self, inputs: Any, outputs: Any):
raise NotImplementedError()
def log_data(self, pipeline_name: str, inputs: Any, outputs: Any):
"""
:param pipeline_name: The name of the inference pipeline from which the
logger consumes the inference information to be monitored
:param inputs: the data received and consumed by the inference
pipeline
:param outputs: the data returned by the inference pipeline
"""
pass

def _log_latency(self, metric_name: str, value_to_log: float, pipeline_name: str):
"""
Expand Down Expand Up @@ -158,12 +166,13 @@ def _setup_metrics(
# Histograms track the size and number of events in buckets
for field_name, field_data in inference_timing.__fields__.items():
field_description = field_data.field_info.description
metric_name = f"{pipeline_name}:{field_name}".strip().replace(" ", "-")
self.metrics[pipeline_name][field_name] = Histogram(
f"{field_name}_{pipeline_name}", field_description, registry=REGISTRY
metric_name, field_description, registry=REGISTRY
)
_LOGGER.info(
f"Prometheus client: set the metrics to track pipeline: {pipeline_name}. "
f"Tracked metrics: {[metric for metric in self.metrics]}"
f"Added metrics: {[metric for metric in self.metrics[pipeline_name]]}"
)

def _export_metrics_to_textfile(self):
Expand Down
13 changes: 12 additions & 1 deletion src/deepsparse/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ def config(config_path: str, host: str, port: int, log_level: str):
default="info",
help="Sets the logging level.",
)
@click.option(
"--no-loggers",
is_flag=True,
default=False,
help=(
"Set to not use any inference logging integration. Defaults to using "
"a default integration such as Prometheus."
),
)
def task(
task: str,
model_path: str,
Expand All @@ -165,6 +174,7 @@ def task(
host: str,
port: int,
log_level: str,
no_loggers: bool,
):
"""
Run the server using configuration with CLI options,
Expand All @@ -176,12 +186,13 @@ def task(
endpoints=[
EndpointConfig(
task=task,
name=f"{task} inference model",
name=f"{task}",
route="/predict",
model=model_path,
batch_size=batch_size,
)
],
loggers=None if no_loggers else "default",
)

with TemporaryDirectory() as tmp_dir:
Expand Down
10 changes: 10 additions & 0 deletions src/deepsparse/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ class ServerConfig(BaseModel):

endpoints: List[EndpointConfig] = Field(description="The models to serve.")

loggers: Union[Dict[str, Dict[str, Any]], str, None] = Field(
default="default",
description=(
"Optional dictionary of logger integration names to initialization kwargs."
" Set to 'default' for default logger based on deployment. Set to None for"
" no loggers. Default is 'default'. Example: "
"{'prometheus': {'port': 8001}}."
),
)


def _unpack_bucketing(
task: str, bucketing: Optional[Union[SequenceLengthsConfig, ImageSizesConfig]]
Expand Down
55 changes: 46 additions & 9 deletions src/deepsparse/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import uvicorn
from deepsparse.engine import Context
from deepsparse.loggers import ManagerLogger
from deepsparse.pipeline import Pipeline
from deepsparse.server.config import (
INTEGRATION_LOCAL,
Expand All @@ -31,6 +32,7 @@
EndpointConfig,
ServerConfig,
)
from deepsparse.server.helpers import default_logger_manager, logger_manager_from_config
from fastapi import FastAPI, UploadFile
from starlette.responses import RedirectResponse

Expand Down Expand Up @@ -107,6 +109,7 @@ def _build_app(server_config: ServerConfig) -> FastAPI:
_LOGGER.info(f"Built ThreadPoolExecutor with {executor._max_workers} workers")

app = FastAPI()
pipeline_logger = _initialize_loggers(server_config)

@app.get("/", include_in_schema=False)
def _home():
Expand All @@ -125,7 +128,7 @@ def _health():

@app.post("/endpoints", tags=["endpoints"], response_model=bool)
def _add_endpoint_endpoint(cfg: EndpointConfig):
_add_endpoint(app, server_config, cfg, executor, context)
_add_endpoint(app, server_config, cfg, executor, context, pipeline_logger)
# force regeneration of the docs
app.openapi_schema = None
return True
Expand All @@ -147,7 +150,9 @@ def _delete_endpoint(cfg: EndpointConfig):

# create pipelines & endpoints
for endpoint_config in server_config.endpoints:
_add_endpoint(app, server_config, endpoint_config, executor, context)
_add_endpoint(
app, server_config, endpoint_config, executor, context, pipeline_logger
)

_LOGGER.info(f"Added endpoints: {[route.path for route in app.routes]}")

Expand Down Expand Up @@ -191,6 +196,7 @@ def _add_endpoint(
endpoint_config: EndpointConfig,
executor: ThreadPoolExecutor,
context: Context,
pipeline_logger: ManagerLogger,
):
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = executor
Expand All @@ -199,42 +205,56 @@ def _add_endpoint(
pipeline = Pipeline.from_config(pipeline_config, context)

_LOGGER.info(f"Adding endpoints for '{endpoint_config.name}'")
_add_pipeline_endpoint(app, endpoint_config, pipeline, server_config.integration)
_add_pipeline_endpoint(
app, endpoint_config, pipeline, pipeline_logger, server_config.integration
)


def _add_pipeline_endpoint(
app: FastAPI,
endpoint_config: EndpointConfig,
pipeline: Pipeline,
pipeline_logger: ManagerLogger,
integration: str = INTEGRATION_LOCAL,
):
input_schema = pipeline.input_schema
output_schema = pipeline.output_schema

def _predict_from_schema(request: pipeline.input_schema):
return pipeline(request)
pipeline_name = pipeline.alias or pipeline.task

def _predict(request: pipeline.input_schema):
(
pipeline_outputs,
pipeline_inputs,
engine_inputs,
inference_timing,
) = pipeline.run_with_monitoring(request)
pipeline_logger.log_latency(pipeline_name, inference_timing)
pipeline_logger.log_data(
pipeline_name, (pipeline_inputs, engine_inputs), pipeline_outputs
)
return pipeline_outputs

def _predict_from_files(request: List[UploadFile]):
request = pipeline.input_schema.from_files(
(file.file for file in request), from_server=True
)
return pipeline(request)
return _predict(request)

routes_and_fns = []
if integration == INTEGRATION_LOCAL:
route = endpoint_config.route or "/predict"
if not route.startswith("/"):
route = "/" + route

routes_and_fns.append((route, _predict_from_schema))
routes_and_fns.append((route, _predict))
if hasattr(input_schema, "from_files"):
routes_and_fns.append((route + "/from_files", _predict_from_files))
elif integration == INTEGRATION_SAGEMAKER:
route = "/invocations"
if hasattr(input_schema, "from_files"):
routes_and_fns.append((route, _predict_from_files))
else:
routes_and_fns.append((route, _predict_from_schema))
routes_and_fns.append((route, _predict))

for route, endpoint_fn in routes_and_fns:
app.add_api_route(
Expand All @@ -245,3 +265,20 @@ def _predict_from_files(request: List[UploadFile]):
tags=["predict"],
)
_LOGGER.info(f"Added '{route}' endpoint")


def _initialize_loggers(server_config: ServerConfig) -> ManagerLogger:
loggers_config = server_config.loggers
if loggers_config is None:
return ManagerLogger([])
if isinstance(loggers_config, str):
if not loggers_config == "default":
raise ValueError(
f"given string {loggers_config} for ServerConfig.loggers only "
"supported string is 'default', other configs should be specified "
"with a dict literal of logging integration to their initialization "
"kwargs"
)
else:
return default_logger_manager()
return logger_manager_from_config(server_config.loggers)
2 changes: 1 addition & 1 deletion tests/deepsparse/loggers/test_prometheus_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,5 @@ def _check_logs(logger, pipelines, no_iterations, timings, port):
@staticmethod
def _check_correct_count(lines, timings, pipeline, no_iterations):
for name, value in dict(timings).items():
searched_line = f"{name}_{pipeline.name}_count {float(no_iterations)}"
searched_line = f"{pipeline.name}:{name}_count {float(no_iterations)}"
assert any([searched_line in line for line in lines])
50 changes: 43 additions & 7 deletions tests/server/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_add_multiple_endpoints_with_no_route():
EndpointConfig(task="", model="", route=None),
EndpointConfig(task="", model="", route=None),
],
loggers=None,
)
)

Expand All @@ -51,6 +52,7 @@ def test_add_multiple_endpoints_with_same_route():
EndpointConfig(task="", model="", route="asdf"),
EndpointConfig(task="", model="", route="asdf"),
],
loggers=None,
)
)

Expand All @@ -63,7 +65,13 @@ def test_invalid_integration():
),
):
_build_app(
ServerConfig(num_cores=1, num_workers=1, integration="asdf", endpoints=[])
ServerConfig(
num_cores=1,
num_workers=1,
integration="asdf",
endpoints=[],
loggers=None,
)
)


Expand All @@ -72,12 +80,24 @@ def test_pytorch_num_threads():

orig_num_threads = torch.get_num_threads()
_build_app(
ServerConfig(num_cores=1, num_workers=1, pytorch_num_threads=None, endpoints=[])
ServerConfig(
num_cores=1,
num_workers=1,
pytorch_num_threads=None,
endpoints=[],
loggers=None,
)
)
assert torch.get_num_threads() == orig_num_threads

_build_app(
ServerConfig(num_cores=1, num_workers=1, pytorch_num_threads=1, endpoints=[])
ServerConfig(
num_cores=1,
num_workers=1,
pytorch_num_threads=1,
endpoints=[],
loggers=None,
)
)
assert torch.get_num_threads() == 1

Expand All @@ -88,7 +108,11 @@ def test_thread_pinning_none():
os.environ.pop("NM_BIND_THREADS_TO_SOCKETS", None)
_build_app(
ServerConfig(
num_cores=1, num_workers=1, engine_thread_pinning="none", endpoints=[]
num_cores=1,
num_workers=1,
engine_thread_pinning="none",
endpoints=[],
loggers=None,
)
)
assert os.environ["NM_BIND_THREADS_TO_CORES"] == "0"
Expand All @@ -101,7 +125,11 @@ def test_thread_pinning_numa():
os.environ.pop("NM_BIND_THREADS_TO_SOCKETS", None)
_build_app(
ServerConfig(
num_cores=1, num_workers=1, engine_thread_pinning="numa", endpoints=[]
num_cores=1,
num_workers=1,
engine_thread_pinning="numa",
endpoints=[],
loggers=None,
)
)
assert os.environ["NM_BIND_THREADS_TO_CORES"] == "0"
Expand All @@ -114,7 +142,11 @@ def test_thread_pinning_cores():
os.environ.pop("NM_BIND_THREADS_TO_SOCKETS", None)
_build_app(
ServerConfig(
num_cores=1, num_workers=1, engine_thread_pinning="core", endpoints=[]
num_cores=1,
num_workers=1,
engine_thread_pinning="core",
endpoints=[],
loggers=None,
)
)
assert os.environ["NM_BIND_THREADS_TO_CORES"] == "1"
Expand All @@ -125,6 +157,10 @@ def test_invalid_thread_pinning():
with pytest.raises(ValueError, match='Expected one of {"core","numa","none"}.'):
_build_app(
ServerConfig(
num_cores=1, num_workers=1, engine_thread_pinning="asdf", endpoints=[]
num_cores=1,
num_workers=1,
engine_thread_pinning="asdf",
endpoints=[],
loggers=None,
)
)
1 change: 1 addition & 0 deletions tests/server/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_yaml_load_config(tmp_path):
bucketing=SequenceLengthsConfig(sequence_lengths=[5, 6, 7]),
),
],
loggers="path/to/logging/config.yaml",
)

path = tmp_path / "config.yaml"
Expand Down
Loading

0 comments on commit 16fed62

Please sign in to comment.