diff --git a/examples/flask/client.py b/examples/flask/client.py index 70a290aed3..ce8ea8aa26 100644 --- a/examples/flask/client.py +++ b/examples/flask/client.py @@ -42,22 +42,17 @@ """ import argparse -import os +import logging import time from typing import Any, Callable, List import numpy import requests -from deepsparse.utils import ( - arrays_to_bytes, - bytes_to_arrays, - generate_random_inputs, - log_init, -) +from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, generate_random_inputs -_LOGGER = log_init(os.path.basename(__file__)) +_LOGGER = logging.getLogger(__name__) class EngineFlaskClient: diff --git a/examples/flask/server.py b/examples/flask/server.py index 8c3917ebfa..3a470cfc78 100644 --- a/examples/flask/server.py +++ b/examples/flask/server.py @@ -48,16 +48,16 @@ """ import argparse -import os +import logging import flask from flask_cors import CORS from deepsparse import Scheduler, compile_model -from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, log_init +from deepsparse.utils import arrays_to_bytes, bytes_to_arrays -_LOGGER = log_init(os.path.basename(__file__)) +_LOGGER = logging.getLogger(__name__) def engine_flask_server( diff --git a/src/deepsparse/tasks.py b/src/deepsparse/tasks.py index 0946fe8c5b..6ffaad7ec3 100644 --- a/src/deepsparse/tasks.py +++ b/src/deepsparse/tasks.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Classes and implementations for supported tasks in the DeepSparse pipeline and system +""" + from collections import namedtuple from typing import List @@ -20,23 +24,50 @@ class AliasedTask: + """ + A task that can have multiple aliases to match to. + For example, question_answering which can alias to qa as well + + :param name: the name of the task such as question_answering or text_classification + :param aliases: the aliases the task can go by in addition to the name such as + qa, glue, sentiment_analysis, etc + """ + def __init__(self, name: str, aliases: List[str]): self._name = name self._aliases = aliases @property def name(self) -> str: + """ + :return: the name of the task such as question_answering + """ return self._name @property def aliases(self) -> List[str]: + """ + :return: the aliases the task can go by such as qa, glue, sentiment_analysis + """ return self._aliases def matches(self, task: str) -> bool: + """ + :param task: the name of the task to check whether the given instance matches. + Checks the current name as well as any aliases. + Everything is compared at lower case and "-" are replaced with "_". + :return: True if task does match the current instance, False otherwise + """ + task = task.lower().replace("-", "_") + return task == self.name or task in self.aliases class SupportedTasks: + """ + The supported tasks in the DeepSparse pipeline and system + """ + nlp = namedtuple( "nlp", ["question_answering", "text_classification", "token_classification"] )( @@ -49,6 +80,11 @@ class SupportedTasks: @classmethod def is_nlp(cls, task: str) -> bool: + """ + :param task: the name of the task to check whether it is an nlp task + such as question_answering + :return: True if it is an nlp task, False otherwise + """ return ( cls.nlp.question_answering.matches(task) or cls.nlp.text_classification.matches(task)