From 700dcd0f1c7e9468aa76048861e20238cc016ad4 Mon Sep 17 00:00:00 2001 From: waleedqk Date: Mon, 8 Jul 2024 13:40:14 -0400 Subject: [PATCH] add script that pulls the latest proto and generates the p2b files Signed-off-by: waleedqk --- caikit_tgis_backend/generation.proto | 542 +++++++++++++-------------- scripts/update_generation_proto.sh | 11 + 2 files changed, 282 insertions(+), 271 deletions(-) create mode 100755 scripts/update_generation_proto.sh diff --git a/caikit_tgis_backend/generation.proto b/caikit_tgis_backend/generation.proto index 82aae77..2bd07f5 100644 --- a/caikit_tgis_backend/generation.proto +++ b/caikit_tgis_backend/generation.proto @@ -2,274 +2,274 @@ Internal service interface for FMaaS completions */ - syntax = "proto3"; - package fmaas; - - - service GenerationService { - // Generates text given a text prompt, for one or more inputs - rpc Generate (BatchedGenerationRequest) returns (BatchedGenerationResponse) {} - // Generates text given a single input prompt, streaming the response - rpc GenerateStream (SingleGenerationRequest) returns (stream GenerationResponse) {} - // Tokenize text - rpc Tokenize (BatchedTokenizeRequest) returns (BatchedTokenizeResponse) {} - // Model info - rpc ModelInfo (ModelInfoRequest) returns (ModelInfoResponse) {} - } - - // ============================================================================================================ - // Generation API - - enum DecodingMethod { - GREEDY = 0; - SAMPLE = 1; - } - - message BatchedGenerationRequest { - string model_id = 1; - // Deprecated in favor of adapter_id - optional string prefix_id = 2; - optional string adapter_id = 4; - repeated GenerationRequest requests = 3; - - Parameters params = 10; - } - - message SingleGenerationRequest { - string model_id = 1; - // Deprecated in favor of adapter_id - optional string prefix_id = 2; - optional string adapter_id = 4; - GenerationRequest request = 3; - - Parameters params = 10; - } - - message BatchedGenerationResponse { - repeated GenerationResponse responses = 1; - } - - message GenerationRequest { - string text = 2; - } - - message GenerationResponse { - uint32 input_token_count = 6; - uint32 generated_token_count = 2; - string text = 4; - StopReason stop_reason = 7; - // The stop sequence encountered, iff stop_reason == STOP_SEQUENCE - string stop_sequence = 11; - // Random seed used, not applicable for greedy requests - uint64 seed = 10; - - // Individual generated tokens and associated details, if requested - repeated TokenInfo tokens = 8; - - // Input tokens and associated details, if requested - repeated TokenInfo input_tokens = 9; - } - - message Parameters { - // The high level decoding approach - DecodingMethod method = 1; - // Parameters related to sampling, applicable only when method == SAMPLING - SamplingParameters sampling = 2; - // Parameters controlling when generation should stop - StoppingCriteria stopping = 3; - // Flags to control what is returned in the response - ResponseOptions response = 4; - // Parameters for conditionally penalizing/boosting - // candidate tokens during decoding - DecodingParameters decoding = 5; - // Truncate to this many input tokens. Can be used to avoid requests - // failing due to input being longer than configured limits. - // Zero means don't truncate. - uint32 truncate_input_tokens = 6; - } - - message DecodingParameters { - message LengthPenalty { - // Start the decay after this number of tokens have been generated - uint32 start_index = 1; - // Factor of exponential decay - float decay_factor = 2; - } - - // Default (0.0) means no penalty (equivalent to 1.0) - // 1.2 is a recommended value - float repetition_penalty = 1; - - // Exponentially increases the score of the EOS token - // once start_index tokens have been generated - optional LengthPenalty length_penalty = 2; - - enum ResponseFormat { - // Plain text, no constraints - TEXT = 0; - // Valid json - JSON = 1; - } - - message StringChoices { - repeated string choices = 1; - } - - // Mutually-exclusive guided decoding options - oneof guided { - // Output will be in the specified format - ResponseFormat format = 3; - // Output will follow the provided JSON schema - string json_schema = 4; - // Output will follow the provided regex pattern - string regex = 5; - // Output will be exactly one of the specified choices - StringChoices choice = 6; - // Output will follow the provided context free grammar - string grammar = 7; - } - } - - - message SamplingParameters { - // Default (0.0) means disabled (equivalent to 1.0) - float temperature = 1; - // Default (0) means disabled - uint32 top_k = 2; - // Default (0) means disabled (equivalent to 1.0) - float top_p = 3; - // Default (0) means disabled (equivalent to 1.0) - float typical_p = 4; - - optional uint64 seed = 5; - } - - message StoppingCriteria { - // Default (0) is currently 20 - uint32 max_new_tokens = 1; - // Default (0) means no minimum - uint32 min_new_tokens = 2; - // Default (0) means no time limit - uint32 time_limit_millis = 3; - repeated string stop_sequences = 4; - // If not specified, default behavior depends on server setting - optional bool include_stop_sequence = 5; - - //more to come - } - - message ResponseOptions { - // Include input text - bool input_text = 1; - // Include list of individual generated tokens - // "Extra" token information is included based on the other flags below - bool generated_tokens = 2; - // Include list of input tokens - // "Extra" token information is included based on the other flags here, - // but only for decoder-only models - bool input_tokens = 3; - // Include logprob for each returned token - // Applicable only if generated_tokens == true and/or input_tokens == true - bool token_logprobs = 4; - // Include rank of each returned token - // Applicable only if generated_tokens == true and/or input_tokens == true - bool token_ranks = 5; - // Include top n candidate tokens at the position of each returned token - // The maximum value permitted is 5, but more may be returned if there is a tie - // for nth place. - // Applicable only if generated_tokens == true and/or input_tokens == true - uint32 top_n_tokens = 6; - } - - enum StopReason { - // Possibly more tokens to be streamed - NOT_FINISHED = 0; - // Maximum requested tokens reached - MAX_TOKENS = 1; - // End-of-sequence token encountered - EOS_TOKEN = 2; - // Request cancelled by client - CANCELLED = 3; - // Time limit reached - TIME_LIMIT = 4; - // Stop sequence encountered - STOP_SEQUENCE = 5; - // Total token limit reached - TOKEN_LIMIT = 6; - // Decoding error - ERROR = 7; - } - - message TokenInfo { - // uint32 id = 1; // TBD - string text = 2; - // The logprob (log of normalized probability), if requested - float logprob = 3; - // One-based rank relative to other tokens, if requested - uint32 rank = 4; - - message TopToken { - // uint32 id = 1; // TBD - string text = 2; - float logprob = 3; - } - - // Top N candidate tokens at this position, if requested - // May or may not include this token - repeated TopToken top_tokens = 5; - } - - - // ============================================================================================================ - // Tokenization API - - message BatchedTokenizeRequest { - string model_id = 1; - repeated TokenizeRequest requests = 2; - bool return_tokens = 3; - bool return_offsets = 4; - - // Zero means don't truncate. - uint32 truncate_input_tokens = 5; - } - - message BatchedTokenizeResponse { - repeated TokenizeResponse responses = 1; - } - - message TokenizeRequest { - string text = 1; - } - - message TokenizeResponse { - message Offset { - uint32 start = 1; - uint32 end = 2; - } - - uint32 token_count = 1; - - // if return_tokens = true - repeated string tokens = 2; - // if return_tokens = true - repeated Offset offsets = 3; - } - - - // ============================================================================================================ - // Model Info API - - message ModelInfoRequest { - string model_id = 1; - } - - message ModelInfoResponse { - enum ModelKind { - DECODER_ONLY = 0; - ENCODER_DECODER = 1; - } - - ModelKind model_kind = 1; - uint32 max_sequence_length = 2; - uint32 max_new_tokens = 3; - } +syntax = "proto3"; +package fmaas; + + +service GenerationService { + // Generates text given a text prompt, for one or more inputs + rpc Generate (BatchedGenerationRequest) returns (BatchedGenerationResponse) {} + // Generates text given a single input prompt, streaming the response + rpc GenerateStream (SingleGenerationRequest) returns (stream GenerationResponse) {} + // Tokenize text + rpc Tokenize (BatchedTokenizeRequest) returns (BatchedTokenizeResponse) {} + // Model info + rpc ModelInfo (ModelInfoRequest) returns (ModelInfoResponse) {} +} + +// ============================================================================================================ +// Generation API + +enum DecodingMethod { + GREEDY = 0; + SAMPLE = 1; +} + +message BatchedGenerationRequest { + string model_id = 1; + // Deprecated in favor of adapter_id + optional string prefix_id = 2; + optional string adapter_id = 4; + repeated GenerationRequest requests = 3; + + Parameters params = 10; +} + +message SingleGenerationRequest { + string model_id = 1; + // Deprecated in favor of adapter_id + optional string prefix_id = 2; + optional string adapter_id = 4; + GenerationRequest request = 3; + + Parameters params = 10; +} + +message BatchedGenerationResponse { + repeated GenerationResponse responses = 1; +} + +message GenerationRequest { + string text = 2; +} + +message GenerationResponse { + uint32 input_token_count = 6; + uint32 generated_token_count = 2; + string text = 4; + StopReason stop_reason = 7; + // The stop sequence encountered, iff stop_reason == STOP_SEQUENCE + string stop_sequence = 11; + // Random seed used, not applicable for greedy requests + uint64 seed = 10; + + // Individual generated tokens and associated details, if requested + repeated TokenInfo tokens = 8; + + // Input tokens and associated details, if requested + repeated TokenInfo input_tokens = 9; +} + +message Parameters { + // The high level decoding approach + DecodingMethod method = 1; + // Parameters related to sampling, applicable only when method == SAMPLING + SamplingParameters sampling = 2; + // Parameters controlling when generation should stop + StoppingCriteria stopping = 3; + // Flags to control what is returned in the response + ResponseOptions response = 4; + // Parameters for conditionally penalizing/boosting + // candidate tokens during decoding + DecodingParameters decoding = 5; + // Truncate to this many input tokens. Can be used to avoid requests + // failing due to input being longer than configured limits. + // Zero means don't truncate. + uint32 truncate_input_tokens = 6; +} + +message DecodingParameters { + message LengthPenalty { + // Start the decay after this number of tokens have been generated + uint32 start_index = 1; + // Factor of exponential decay + float decay_factor = 2; + } + + // Default (0.0) means no penalty (equivalent to 1.0) + // 1.2 is a recommended value + float repetition_penalty = 1; + + // Exponentially increases the score of the EOS token + // once start_index tokens have been generated + optional LengthPenalty length_penalty = 2; + + enum ResponseFormat { + // Plain text, no constraints + TEXT = 0; + // Valid json + JSON = 1; + } + + message StringChoices { + repeated string choices = 1; + } + + // Mutually-exclusive guided decoding options + oneof guided { + // Output will be in the specified format + ResponseFormat format = 3; + // Output will follow the provided JSON schema + string json_schema = 4; + // Output will follow the provided regex pattern + string regex = 5; + // Output will be exactly one of the specified choices + StringChoices choice = 6; + // Output will follow the provided context free grammar + string grammar = 7; + } +} + + +message SamplingParameters { + // Default (0.0) means disabled (equivalent to 1.0) + float temperature = 1; + // Default (0) means disabled + uint32 top_k = 2; + // Default (0) means disabled (equivalent to 1.0) + float top_p = 3; + // Default (0) means disabled (equivalent to 1.0) + float typical_p = 4; + + optional uint64 seed = 5; +} + +message StoppingCriteria { + // Default (0) is currently 20 + uint32 max_new_tokens = 1; + // Default (0) means no minimum + uint32 min_new_tokens = 2; + // Default (0) means no time limit + uint32 time_limit_millis = 3; + repeated string stop_sequences = 4; + // If not specified, default behavior depends on server setting + optional bool include_stop_sequence = 5; + + //more to come +} + +message ResponseOptions { + // Include input text + bool input_text = 1; + // Include list of individual generated tokens + // "Extra" token information is included based on the other flags below + bool generated_tokens = 2; + // Include list of input tokens + // "Extra" token information is included based on the other flags here, + // but only for decoder-only models + bool input_tokens = 3; + // Include logprob for each returned token + // Applicable only if generated_tokens == true and/or input_tokens == true + bool token_logprobs = 4; + // Include rank of each returned token + // Applicable only if generated_tokens == true and/or input_tokens == true + bool token_ranks = 5; + // Include top n candidate tokens at the position of each returned token + // The maximum value permitted is 5, but more may be returned if there is a tie + // for nth place. + // Applicable only if generated_tokens == true and/or input_tokens == true + uint32 top_n_tokens = 6; +} + +enum StopReason { + // Possibly more tokens to be streamed + NOT_FINISHED = 0; + // Maximum requested tokens reached + MAX_TOKENS = 1; + // End-of-sequence token encountered + EOS_TOKEN = 2; + // Request cancelled by client + CANCELLED = 3; + // Time limit reached + TIME_LIMIT = 4; + // Stop sequence encountered + STOP_SEQUENCE = 5; + // Total token limit reached + TOKEN_LIMIT = 6; + // Decoding error + ERROR = 7; +} + +message TokenInfo { + // uint32 id = 1; // TBD + string text = 2; + // The logprob (log of normalized probability), if requested + float logprob = 3; + // One-based rank relative to other tokens, if requested + uint32 rank = 4; + + message TopToken { + // uint32 id = 1; // TBD + string text = 2; + float logprob = 3; + } + + // Top N candidate tokens at this position, if requested + // May or may not include this token + repeated TopToken top_tokens = 5; +} + + +// ============================================================================================================ +// Tokenization API + +message BatchedTokenizeRequest { + string model_id = 1; + repeated TokenizeRequest requests = 2; + bool return_tokens = 3; + bool return_offsets = 4; + + // Zero means don't truncate. + uint32 truncate_input_tokens = 5; +} + +message BatchedTokenizeResponse { + repeated TokenizeResponse responses = 1; +} + +message TokenizeRequest { + string text = 1; +} + +message TokenizeResponse { + message Offset { + uint32 start = 1; + uint32 end = 2; + } + + uint32 token_count = 1; + + // if return_tokens = true + repeated string tokens = 2; + // if return_tokens = true + repeated Offset offsets = 3; +} + + +// ============================================================================================================ +// Model Info API + +message ModelInfoRequest { + string model_id = 1; +} + +message ModelInfoResponse { + enum ModelKind { + DECODER_ONLY = 0; + ENCODER_DECODER = 1; + } + + ModelKind model_kind = 1; + uint32 max_sequence_length = 2; + uint32 max_new_tokens = 3; +} diff --git a/scripts/update_generation_proto.sh b/scripts/update_generation_proto.sh new file mode 100755 index 0000000..3ce0e9c --- /dev/null +++ b/scripts/update_generation_proto.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# Run from the base +cd $(dirname ${BASH_SOURCE[0]})/.. + +TGIS_BACKEND_DIR="caikit_tgis_backend" + +# pull the latest generation.proto file from the vllm repo that needs to be followed +curl https://raw.githubusercontent.com/IBM/vllm/main/proto/generation.proto > $TGIS_BACKEND_DIR/generation.proto +# generate the p2b files based on the latest proto +./scripts/gen_tgis_protos.sh