diff --git a/lm_eval/models/gigachat_model.py b/lm_eval/models/gigachat_model.py index 091c13085d..e1bdb0a533 100644 --- a/lm_eval/models/gigachat_model.py +++ b/lm_eval/models/gigachat_model.py @@ -43,42 +43,44 @@ def _create_payload( **kwargs, ) -> dict: if generate: - max_tokens = gen_kwargs.pop("max_tokens", None) temperature = gen_kwargs.pop("temperature", None) - profanity_check = gen_kwargs.pop("profanity_check", True) + do_sample = gen_kwargs.pop("do_sample", None) - if ( - "do_sample" in gen_kwargs - ): # GigaChat API does not have do sample option. - do_sample = gen_kwargs.pop("do_sample") + if do_sample is not None: # GigaChat API does not have do sample option. if not do_sample: # Ensure greedy decoding if do_sample=False gen_kwargs["repetition_penalty"] = 1.0 gen_kwargs["top_p"] = 0.0 - elif temperature == 0: + elif temperature == 0.0: eval_logger.warning( "You cannot set do_sample=True and temperature=0. Automatically setting temperature=1." ) temperature = 1.0 if ( - temperature == 0 + temperature == 0.0 ): # Ensure greedy decoding by setting top_p=0 and repetition_penalty = 1 temperature = ( 1.0 # temperature cannot be set to zero. Use top_p instead ) - gen_kwargs["repetition_penalty"] = 1 - gen_kwargs["top_p"] = 0 + gen_kwargs["repetition_penalty"] = 1.0 + gen_kwargs["top_p"] = 0.0 + print( + { + "messages": messages, + "model": self.model, + "temperature": temperature, + **gen_kwargs, + } + ) return { "messages": messages, "model": self.model, - "max_tokens": max_tokens, "temperature": temperature, - "profanity_check": profanity_check, **gen_kwargs, } else: return None - @property # Don't use cached_property as we need to check that the acess_token has not expired. + @property # Don't use cached_property as we need to check that the access_token has not expired. def header(self) -> dict: """Override this property to return the headers for the API request.""" @@ -90,6 +92,11 @@ def header(self) -> dict: @property # Don't use cached_property as we need to check that the acess_token has not expired. def api_key(self): + self.key = os.environ.get( + "GIGACHAT_CREDENTIALS", None + ) # GigaChat access token. + if self.key: + return self.key # If access token is available, return access token. RqUID = os.environ.get( "GIGACHAT_RQUID", None ) # Unique identification request. Complies with uuid4 format. Value must match regular expression (([0-9a-fA-F-])36)