diff --git a/src/deepsparse/transformers/pipelines.py b/src/deepsparse/transformers/pipelines.py index 467ecb0ac9..15187ca2d1 100644 --- a/src/deepsparse/transformers/pipelines.py +++ b/src/deepsparse/transformers/pipelines.py @@ -50,9 +50,10 @@ SquadFeatures, squad_convert_examples_to_features, ) + from transformers.file_utils import ExplicitEnum from transformers.models.auto import AutoConfig, AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer - from transformers.tokenization_utils_base import PaddingStrategy + from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy from transformers.utils import logging transformers_import_error = None @@ -61,10 +62,12 @@ SquadExample = object SquadFeatures = object squad_convert_examples_to_features = None + ExplicitEnum = object AutoConfig = object AutoTokenizer = object PreTrainedTokenizer = object PaddingStrategy = object + TruncationStrategy = object logging = None transformers_import_error = transformers_import_err @@ -72,6 +75,8 @@ __all__ = [ "ArgumentHandler", "Pipeline", + "TextClassificationPipeline", + "TokenClassificationPipeline", "QuestionAnsweringPipeline", "pipeline", "overwrite_transformer_onnx_model_inputs", @@ -245,14 +250,15 @@ def _parse_and_tokenize( inputs, add_special_tokens=add_special_tokens, return_tensors=self._framework, - padding=padding, + padding=PaddingStrategy.MAX_LENGTH.value, + truncation=TruncationStrategy.LONGEST_FIRST.value, ) return inputs def __call__(self, *args, **kwargs): inputs = self._parse_and_tokenize(*args, **kwargs) - self._forward(inputs) + return self._forward(inputs) def _forward(self, inputs): if not all(name in inputs for name in self.input_names): @@ -277,6 +283,35 @@ def _forward(self, inputs): # return predictions.numpy() +class TokenClassificationArgumentHandler(ArgumentHandler): + """ + Handles arguments for token classification. + """ + + def __call__(self, inputs: Union[str, List[str]], **kwargs): + + if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: + inputs = list(inputs) + batch_size = len(inputs) + elif isinstance(inputs, str): + inputs = [inputs] + batch_size = 1 + else: + raise ValueError("At least one input is required.") + + offset_mapping = kwargs.get("offset_mapping") + if offset_mapping: + if isinstance(offset_mapping, list) and isinstance( + offset_mapping[0], tuple + ): + offset_mapping = [offset_mapping] + if len(offset_mapping) != batch_size: + raise ValueError( + "offset_mapping should have the same batch size as the input" + ) + return inputs, offset_mapping + + class QuestionAnsweringArgumentHandler(ArgumentHandler): """ QuestionAnsweringPipeline requires the user to provide multiple arguments @@ -345,6 +380,380 @@ def __call__(self, *args, **kwargs): return inputs +class TextClassificationPipeline(Pipeline): + """ + Text classification pipeline using any `ModelForSequenceClassification`. + + This text classification pipeline can currently be loaded from `pipeline()` + using the following task identifier: `"text-classification"`. + + The models that this pipeline can use are models that have been fine-tuned on + a text classification task. + + :param return_all_scores: set True to return all model scores. Default False + """ + + def __init__(self, return_all_scores: bool = False, **kwargs): + super().__init__(**kwargs) + + self.return_all_scores = return_all_scores + + def __call__(self, *args, **kwargs): + """ + Classify the text(s) given as inputs. + + :param args: One or several texts (or one list of prompts) to classify + :param args: kwargs for inner call function + :return: A list or a list of list of dicts: Each result comes as list of dicts + with the following keys: + - `label` -- The label predicted. + - `score` -- The corresponding probability. + If ``self.return_all_scores=True``, one dictionary is returned per label + """ + outputs = super().__call__(*args, **kwargs) + + if isinstance(outputs, list) and outputs: + outputs = outputs[0] + + if self.config.num_labels == 1: + scores = 1.0 / (1.0 + np.exp(-outputs)) + else: + scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True) + if self.return_all_scores: + return [ + [ + {"label": self.config.id2label[i], "score": score.item()} + for i, score in enumerate(item) + ] + for item in scores + ] + else: + return [ + { + "label": self.config.id2label[item.argmax()], + "score": item.max().item(), + } + for item in scores + ] + + +class AggregationStrategy(ExplicitEnum): + """ + All the valid aggregation strategies for TokenClassificationPipeline + """ + + NONE = "none" + SIMPLE = "simple" + FIRST = "first" + AVERAGE = "average" + MAX = "max" + + +class TokenClassificationPipeline(Pipeline): + """ + Named Entity Recognition pipeline using any `ModelForTokenClassification`. + + This token classification pipeline can currently be loaded from `pipeline()` + using the following task identifier: `"token-classification"`. + + The models that this pipeline can use are models that have been fine-tuned on + a token classification task. + + :param args_parser: argument parser to use default is + TokenClassificationArgumentHandler + :param aggregation_strategy: AggregationStrategy Enum object to determine + the pipeline aggregation strategy. Default is AggregationStrategy.NONE + :param ignore_labels: list of labels to ignore. Default is `["O"]` + """ + + default_input_names = "sequences" + + def __init__( + self, + args_parser: ArgumentHandler = None, + aggregation_strategy: AggregationStrategy = AggregationStrategy.NONE, + ignore_labels: List[str] = False, + **kwargs, + ): + super().__init__( + args_parser=args_parser or TokenClassificationArgumentHandler(), + **kwargs, + ) + + self.ignore_labels = ignore_labels or ["O"] + + if isinstance(aggregation_strategy, str): + aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] + + if ( + aggregation_strategy + in { + AggregationStrategy.FIRST, + AggregationStrategy.MAX, + AggregationStrategy.AVERAGE, + } + and not self.tokenizer.is_fast + ): + raise ValueError( + "Slow tokenizers cannot handle subwords. Please set the " + '`aggregation_strategy` option to `"simple"` or use a fast tokenizer.' + ) + + self.aggregation_strategy = aggregation_strategy + + def __call__(self, inputs: Union[str, List[str]], **kwargs): + """ + Classify each token of the text(s) given as inputs. + + + :param inputs: One or several texts (or one list of texts) for token + classification + :return: A list or a list of list of :obj:`dict`: Each result comes as a list + of dictionaries (one for each token in the corresponding input, or each + entity if this pipeline was instantiated with an aggregation_strategy) + with the following keys: + - `word` -- The token/word classified. + - `score` -- The corresponding probability for `entity`. + - `entity` -- The entity predicted for that token/word (it is named + `entity_group` when `aggregation_strategy` is not `"none"`. + - `index` -- The index of the corresponding token in the sentence. + - `start` -- index of the start of the corresponding entity in the sentence + Only exists if the offsets are available within the tokenizer + - `end` -- The index of the end of the corresponding entity in the sentence. + Only exists if the offsets are available within the tokenizer + """ + + _inputs, offset_mappings = self._args_parser(inputs, **kwargs) + + answers = [] + + for i, sentence in enumerate(_inputs): + + tokens = self.tokenizer( + sentence, + return_tensors=self._framework, + truncation=TruncationStrategy.LONGEST_FIRST.value, + padding=PaddingStrategy.MAX_LENGTH.value, + return_special_tokens_mask=True, + return_offsets_mapping=self.tokenizer.is_fast, + ) + if self.tokenizer.is_fast: + offset_mapping = tokens.pop("offset_mapping")[0] + elif offset_mappings: + offset_mapping = offset_mappings[i] + else: + offset_mapping = None + + special_tokens_mask = tokens.pop("special_tokens_mask")[0] + + # Forward + entities = self._forward(tokens)[0][0] + input_ids = tokens["input_ids"][0] + + scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) + pre_entities = self.gather_pre_entities( + sentence, input_ids, scores, offset_mapping, special_tokens_mask + ) + grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy) + # Filter anything that is in self.ignore_labels + entities = [ + entity + for entity in grouped_entities + if entity.get("entity", None) not in self.ignore_labels + and entity.get("entity_group", None) not in self.ignore_labels + ] + answers.append(entities) + + if len(answers) == 1: + return answers[0] + return answers + + def gather_pre_entities( + self, + sentence: str, + input_ids: np.ndarray, + scores: np.ndarray, + offset_mapping: Optional[List[Tuple[int, int]]], + special_tokens_mask: np.ndarray, + ) -> List[dict]: + pre_entities = [] + for idx, token_scores in enumerate(scores): + # Filter special_tokens, they should only occur + # at the sentence boundaries since we're not encoding pairs of + # sentences so we don't have to keep track of those. + if special_tokens_mask[idx]: + continue + + word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) + if offset_mapping is not None: + start_ind, end_ind = offset_mapping[idx] + word_ref = sentence[start_ind:end_ind] + is_subword = len(word_ref) != len(word) + + if int(input_ids[idx]) == self.tokenizer.unk_token_id: + word = word_ref + is_subword = False + else: + start_ind = None + end_ind = None + is_subword = False + + pre_entity = { + "word": word, + "scores": token_scores, + "start": start_ind, + "end": end_ind, + "index": idx, + "is_subword": is_subword, + } + pre_entities.append(pre_entity) + return pre_entities + + def aggregate( + self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy + ) -> List[dict]: + if aggregation_strategy in { + AggregationStrategy.NONE, + AggregationStrategy.SIMPLE, + }: + entities = [] + for pre_entity in pre_entities: + entity_idx = pre_entity["scores"].argmax() + score = pre_entity["scores"][entity_idx] + entity = { + "entity": self.config.id2label[entity_idx], + "score": score, + "index": pre_entity["index"], + "word": pre_entity["word"], + "start": pre_entity["start"], + "end": pre_entity["end"], + } + entities.append(entity) + else: + entities = self.aggregate_words(pre_entities, aggregation_strategy) + + if aggregation_strategy == AggregationStrategy.NONE: + return entities + return self.group_entities(entities) + + def aggregate_word( + self, entities: List[dict], aggregation_strategy: AggregationStrategy + ) -> dict: + word = self.tokenizer.convert_tokens_to_string( + [entity["word"] for entity in entities] + ) + if aggregation_strategy == AggregationStrategy.FIRST: + scores = entities[0]["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.config.id2label[idx] + elif aggregation_strategy == AggregationStrategy.MAX: + max_entity = max(entities, key=lambda entity: entity["scores"].max()) + scores = max_entity["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.config.id2label[idx] + elif aggregation_strategy == AggregationStrategy.AVERAGE: + scores = np.stack([entity["scores"] for entity in entities]) + average_scores = np.nanmean(scores, axis=0) + entity_idx = average_scores.argmax() + entity = self.config.id2label[entity_idx] + score = average_scores[entity_idx] + else: + raise ValueError("Invalid aggregation_strategy") + new_entity = { + "entity": entity, + "score": score, + "word": word, + "start": entities[0]["start"], + "end": entities[-1]["end"], + } + return new_entity + + def aggregate_words( + self, entities: List[dict], aggregation_strategy: AggregationStrategy + ) -> List[dict]: + assert aggregation_strategy not in { + AggregationStrategy.NONE, + AggregationStrategy.SIMPLE, + }, "NONE and SIMPLE strategies are invalid" + + word_entities = [] + word_group = None + for entity in entities: + if word_group is None: + word_group = [entity] + elif entity["is_subword"]: + word_group.append(entity) + else: + word_entities.append( + self.aggregate_word(word_group, aggregation_strategy) + ) + word_group = [entity] + # Last item + word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) + return word_entities + + def group_sub_entities(self, entities: List[dict]) -> dict: + # Get the first entity in the entity group + entity = entities[0]["entity"].split("-")[-1] + scores = np.nanmean([entity["score"] for entity in entities]) + tokens = [entity["word"] for entity in entities] + + entity_group = { + "entity_group": entity, + "score": np.mean(scores), + "word": self.tokenizer.convert_tokens_to_string(tokens), + "start": entities[0]["start"], + "end": entities[-1]["end"], + } + return entity_group + + def get_tag(self, entity_name: str) -> Tuple[str, str]: + if entity_name.startswith("B-"): + bi = "B" + tag = entity_name[2:] + elif entity_name.startswith("I-"): + bi = "I" + tag = entity_name[2:] + else: + # It's not in B-, I- format + bi = "B" + tag = entity_name + return bi, tag + + def group_entities(self, entities: List[dict]) -> List[dict]: + + entity_groups = [] + entity_group_disagg = [] + + for entity in entities: + if not entity_group_disagg: + entity_group_disagg.append(entity) + continue + + # If the current entity is similar and adjacent to the previous entity, + # append it to the disaggregated entity group + # The split is meant to account for the "B" and "I" prefixes + # Shouldn't merge if both entities are B-type + bi, tag = self.get_tag(entity["entity"]) + last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"]) + + if tag == last_tag and bi != "B": + # Modify subword type to be previous_type + entity_group_disagg.append(entity) + else: + # If the current entity is different from the previous entity + # aggregate the disaggregated entity group + entity_groups.append(self.group_sub_entities(entity_group_disagg)) + entity_group_disagg = [entity] + if entity_group_disagg: + # it's the last entity, add it to the entity groups + entity_groups.append(self.group_sub_entities(entity_group_disagg)) + + return entity_groups + + class QuestionAnsweringPipeline(Pipeline): """ Question Answering pipeline using any `ModelForQuestionAnswering` @@ -806,6 +1215,10 @@ class TaskInfo: # Register all the supported tasks here SUPPORTED_TASKS = { + "ner": TaskInfo( + pipeline_constructor=TokenClassificationPipeline, + default_model_name="bert-base-uncased", + ), "question-answering": TaskInfo( pipeline_constructor=QuestionAnsweringPipeline, default_model_name="bert-base-uncased", @@ -816,7 +1229,19 @@ class TaskInfo: "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/" "pruned-aggressive_98" ), - ) + ), + "sentiment-analysis": TaskInfo( + pipeline_constructor=TextClassificationPipeline, + default_model_name="bert-base-uncased", + ), + "text-classification": TaskInfo( + pipeline_constructor=TextClassificationPipeline, + default_model_name="bert-base-uncased", + ), + "token-classification": TaskInfo( + pipeline_constructor=TokenClassificationPipeline, + default_model_name="bert-base-uncased", + ), } DEEPSPARSE_ENGINE = "deepsparse" @@ -834,7 +1259,6 @@ def pipeline( tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, max_length: int = 128, num_cores: Optional[int] = None, - num_sockets: Optional[int] = None, **kwargs, ) -> Pipeline: """ @@ -847,12 +1271,11 @@ def pipeline( :param engine_type: inference engine name to use. Supported options are 'deepsparse' and 'onnxruntime' :param config: huggingface model config, if none provided, default will be used + which will be from the model name or sparsezoo stub if given for model path :param tokenizer: huggingface tokenizer, if none provided, default will be used :param max_length: maximum sequence length of model inputs. default is 128 :param num_cores: number of CPU cores to run engine with. Default is the maximum available - :param num_sockets: number of CPU sockets to run engine with. Default is the maximum - available :param kwargs: additional key word arguments for task specific pipeline constructor :return: Pipeline object for the given taks and model """ @@ -878,23 +1301,31 @@ def pipeline( config = config or model_name # create model + zoo_model_path = None if model_path.startswith("zoo:"): + zoo_model_path = model_path model_path = _download_zoo_model(model_path) - model, input_names = _create_model( - model_path, engine_type, num_cores, num_sockets, max_length - ) + model, input_names = _create_model(model_path, engine_type, num_cores, max_length) # Instantiate tokenizer if needed if isinstance(tokenizer, (str, tuple)): if isinstance(tokenizer, tuple): # For tuple we have (tokenizer name, {kwargs}) + tokenizer_kwargs = tokenizer[1] + tokenizer_kwargs["model_max_length"] = max_length tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1]) else: - tokenizer = AutoTokenizer.from_pretrained(tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer, model_max_length=max_length + ) # Instantiate config if needed if config is not None and isinstance(config, str): - config = AutoConfig.from_pretrained(config) + if zoo_model_path: + zoo_config = _download_zoo_config(zoo_model_path) + if zoo_config: + config = zoo_config + config = AutoConfig.from_pretrained(config, finetuning_task=task) return task_info.pipeline_constructor( model=model, @@ -962,11 +1393,19 @@ def _download_zoo_model(model_path: str) -> str: return model.onnx_file.downloaded_path() +def _download_zoo_config(model_path: str) -> Optional[str]: + model = Zoo.load_model_from_stub(model_path) + config_file = None + for framework_file in model.framework_files: + if framework_file.display_name == "config.json": + config_file = framework_file + return config_file.downloaded_path() if config_file else None + + def _create_model( model_path: str, engine_type: str, num_cores: Optional[int], - num_sockets: Optional[int], max_length: int = 128, ) -> Tuple[Union[Engine, "onnxruntime.InferenceSession"], List[str]]: onnx_path, input_names, _ = overwrite_transformer_onnx_model_inputs( @@ -974,9 +1413,7 @@ def _create_model( ) if engine_type == DEEPSPARSE_ENGINE: - model = compile_model( - onnx_path, batch_size=1, num_cores=num_cores, num_sockets=num_sockets - ) + model = compile_model(onnx_path, batch_size=1, num_cores=num_cores) elif engine_type == ORT_ENGINE: _validate_ort_import() sess_options = onnxruntime.SessionOptions()