diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 694d11d664..1edaf63ef3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -62,16 +62,21 @@ __all__ = ["TextGenerationPipeline"] +# Based off of https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig # noqa E501 class GenerationDefaults: - num_return_sequences = 1 + # Parameters that control the length of the output max_length = None max_new_tokens = 100 - output_scores = False - top_k = 0 - top_p = 0.0 - repetition_penalty = 0.0 + # Parameters that control the generation strategy used do_sample = False + # Parameters for manipulation of the model output logits temperature = 1.0 + top_k = 50 + top_p = 1.0 + repetition_penalty = 1.0 + # Parameters that define the outputs + num_return_sequences = 1 + output_scores = False class FinishReason(Enum):