Skip to content

Commit

Permalink
Merge branch 'main' into research/ppl_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques authored Oct 24, 2023
2 parents fe4b267 + 67327cf commit c55a05e
Show file tree
Hide file tree
Showing 10 changed files with 574 additions and 718 deletions.
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]:

elif model_path.startswith("zoo:"):
zoo_model = Model(model_path)
deployment_path = zoo_model.deployment_directory_path
deployment_path = zoo_model.deployment.path
return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME)
elif model_path.startswith("hf:"):
from huggingface_hub import snapshot_download
Expand Down
3 changes: 3 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def process_engine_outputs(
input_tokens=input_tokens,
)

if "session_ids" in kwargs:
outputs["session_ids"] = kwargs["session_ids"]

if self._debug:
debug_params = dict(
kv_cache_state=kv_cache_state,
Expand Down
10 changes: 5 additions & 5 deletions src/deepsparse/transformers/utils/token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Optional

import numpy

Expand All @@ -29,7 +29,7 @@ class TokenGenerator:
def __init__(
self,
logits_shape: int,
tokens: List[int] = [],
tokens: Optional[List[int]] = None,
deterministic: bool = True,
sampling_temperature: float = 1.0,
top_k: int = 0,
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.tokens = tokens
self.tokens = [] if tokens is None else tokens

self._initialize_token_frequencies()

Expand Down Expand Up @@ -168,5 +168,5 @@ def _update_frequencies(self, token: numpy.ndarray):

def _initialize_token_frequencies(self):
unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True)
for token, frequnecies in zip(unique_tokens, frequencies):
self.token_frequencies[token] += frequnecies
for token, freq in zip(unique_tokens, frequencies):
self.token_frequencies[token] += freq
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cadence: "nightly"
model_path: "zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none"
torch_model_name: "salesforce/codegen-350m-mono"
task: ["text-generation"]#, "chat"]
prompt: "\ndef Fibonacci(n):\n # Check if input is 0 then it will\n # print incorrect input"
has_bos_token: False
precision: 0.0001
internal_kv_cache: [True, False]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cadence: "nightly"
model_path: "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none"
torch_model_name: "facebook/opt-1.3b"
task: ["text-generation"]
prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio"
has_bos_token: True
precision: 0.0001
internal_kv_cache: [True, False]
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
import logging
import os
from typing import Any, Dict, List, Tuple, Union

import numpy
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer

import pytest


class TorchGroundTruthSource:
"""
Expand All @@ -36,7 +41,6 @@ def __init__(self, num_tokens_to_generate: int, model_name: str):
self.tokenizer = self._create_tokenizer(model_name)

self.num_tokens_to_generate = num_tokens_to_generate
self.model_name = model_name

def tokenize(self, prompt: str):
return self.tokenizer(prompt, return_tensors="pt")
Expand Down Expand Up @@ -82,3 +86,61 @@ def _create_tokenizer(model_name):
tokenizer.pad_token = tokenizer.eos_token

return tokenizer


def parse_params(configs_directory: str) -> List[Dict[str, Any]]:
# parses the config file provided
assert os.path.isdir(
configs_directory
), f"Config_directory {configs_directory} is not a directory"

config_dicts = []
for file in os.listdir(configs_directory):
if file.endswith(".yaml"):
config_path = os.path.join(configs_directory, file)
# reads the yaml file
with open(config_path, "r") as f:
config = yaml.safe_load(f)

cadence = os.environ.get("CADENCE", "commit")
expected_cadence = config["cadence"]

if not isinstance(expected_cadence, list):
expected_cadence = [expected_cadence]
if cadence in expected_cadence:
config_dicts.append(config)
else:
logging.info(
f"Skipping testing model: {config['model_path']} "
f"for cadence: {config['cadence']}"
)
else:
raise FileNotFoundError(
f"Could not find a yaml file in {configs_directory}"
)
return config_dicts


def validate_internal_kv_cache(
internal_kv_cache, available_kv_cache_types: Union[str, List[str]]
) -> bool:
if internal_kv_cache and True not in available_kv_cache_types:
pytest.skip(
"The tests for running the pipeline with "
"internal kv cache management are disabled."
)
if not internal_kv_cache and False not in available_kv_cache_types:
pytest.skip(
"The tests for running the pipeline with "
"external kv cache management are disabled."
)
return internal_kv_cache


def validate_task(task: str, available_tasks: Union[str, List[str]]) -> bool:
if task not in available_tasks:
pytest.skip(
f"The tests for running the pipeline with task: {task} are disabled. "
f"The available tasks, as specified in the config are: {available_tasks}"
)
return task
Loading

0 comments on commit c55a05e

Please sign in to comment.