Skip to content

Commit

Permalink
Minor fixes for the new deepsparse.server integration (#276)
Browse files Browse the repository at this point in the history
* Minor fixes for the new deepsparse.server integration

* run make style
  • Loading branch information
markurtz authored Mar 2, 2022
1 parent 16c00dd commit 0972a84
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
11 changes: 3 additions & 8 deletions examples/flask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,17 @@
"""

import argparse
import os
import logging
import time
from typing import Any, Callable, List

import numpy
import requests

from deepsparse.utils import (
arrays_to_bytes,
bytes_to_arrays,
generate_random_inputs,
log_init,
)
from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, generate_random_inputs


_LOGGER = log_init(os.path.basename(__file__))
_LOGGER = logging.getLogger(__name__)


class EngineFlaskClient:
Expand Down
6 changes: 3 additions & 3 deletions examples/flask/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@
"""

import argparse
import os
import logging

import flask
from flask_cors import CORS

from deepsparse import Scheduler, compile_model
from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, log_init
from deepsparse.utils import arrays_to_bytes, bytes_to_arrays


_LOGGER = log_init(os.path.basename(__file__))
_LOGGER = logging.getLogger(__name__)


def engine_flask_server(
Expand Down
36 changes: 36 additions & 0 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Classes and implementations for supported tasks in the DeepSparse pipeline and system
"""

from collections import namedtuple
from typing import List

Expand All @@ -20,23 +24,50 @@


class AliasedTask:
"""
A task that can have multiple aliases to match to.
For example, question_answering which can alias to qa as well
:param name: the name of the task such as question_answering or text_classification
:param aliases: the aliases the task can go by in addition to the name such as
qa, glue, sentiment_analysis, etc
"""

def __init__(self, name: str, aliases: List[str]):
self._name = name
self._aliases = aliases

@property
def name(self) -> str:
"""
:return: the name of the task such as question_answering
"""
return self._name

@property
def aliases(self) -> List[str]:
"""
:return: the aliases the task can go by such as qa, glue, sentiment_analysis
"""
return self._aliases

def matches(self, task: str) -> bool:
"""
:param task: the name of the task to check whether the given instance matches.
Checks the current name as well as any aliases.
Everything is compared at lower case and "-" are replaced with "_".
:return: True if task does match the current instance, False otherwise
"""
task = task.lower().replace("-", "_")

return task == self.name or task in self.aliases


class SupportedTasks:
"""
The supported tasks in the DeepSparse pipeline and system
"""

nlp = namedtuple(
"nlp", ["question_answering", "text_classification", "token_classification"]
)(
Expand All @@ -49,6 +80,11 @@ class SupportedTasks:

@classmethod
def is_nlp(cls, task: str) -> bool:
"""
:param task: the name of the task to check whether it is an nlp task
such as question_answering
:return: True if it is an nlp task, False otherwise
"""
return (
cls.nlp.question_answering.matches(task)
or cls.nlp.text_classification.matches(task)
Expand Down

0 comments on commit 0972a84

Please sign in to comment.