Skip to content

Commit

Permalink
[BUG] Remove token param vllm-project#10921 (vllm-project#11022)
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
  • Loading branch information
flaviabeo authored Dec 10, 2024
1 parent 9b9cef3 commit 250ee65
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union

Expand Down Expand Up @@ -41,6 +42,7 @@
from transformers import AutoConfig

MISTRAL_CONFIG_NAME = "params.json"
HF_TOKEN = os.getenv('HF_TOKEN', None)

logger = init_logger(__name__)

Expand Down Expand Up @@ -77,8 +79,8 @@ class ConfigFormat(str, enum.Enum):
MISTRAL = "mistral"


def file_or_path_exists(model: Union[str, Path], config_name, revision,
token) -> bool:
def file_or_path_exists(model: Union[str, Path], config_name: str,
revision: Optional[str]) -> bool:
if Path(model).exists():
return (Path(model) / config_name).is_file()

Expand All @@ -93,7 +95,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
return file_exists(model,
config_name,
revision=revision,
token=HF_TOKEN)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
Expand Down Expand Up @@ -161,7 +166,6 @@ def get_config(
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
token: Optional[str] = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
Expand All @@ -173,19 +177,20 @@ def get_config(

if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision, token=token):
model, HF_CONFIG_NAME, revision=revision):
config_format = ConfigFormat.HF
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision,
token=token):
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME,
revision=revision):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=HF_TOKEN)

raise ValueError(f"No supported config format found in {model}")

Expand All @@ -194,7 +199,7 @@ def get_config(
model,
revision=revision,
code_revision=code_revision,
token=token,
token=HF_TOKEN,
**kwargs,
)

Expand All @@ -206,7 +211,7 @@ def get_config(
model,
revision=revision,
code_revision=code_revision,
token=token,
token=HF_TOKEN,
**kwargs,
)
else:
Expand All @@ -216,7 +221,7 @@ def get_config(
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=token,
token=HF_TOKEN,
**kwargs,
)
except ValueError as e:
Expand All @@ -234,7 +239,7 @@ def get_config(
raise e

elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=token, **kwargs)
config = load_params_config(model, revision, token=HF_TOKEN, **kwargs)
else:
raise ValueError(f"Unsupported config format: {config_format}")

Expand All @@ -256,8 +261,7 @@ def get_config(

def get_hf_file_to_dict(file_name: str,
model: Union[str, Path],
revision: Optional[str] = 'main',
token: Optional[str] = None):
revision: Optional[str] = 'main'):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Expand All @@ -266,7 +270,6 @@ def get_hf_file_to_dict(file_name: str,
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
- token (str): The Hugging Face authentication token.
Returns:
- config_dict (dict): A dictionary containing
Expand All @@ -276,8 +279,7 @@ def get_hf_file_to_dict(file_name: str,

if file_or_path_exists(model=model,
config_name=file_name,
revision=revision,
token=token):
revision=revision):

if not file_path.is_file():
try:
Expand All @@ -296,9 +298,7 @@ def get_hf_file_to_dict(file_name: str,
return None


def get_pooling_config(model: str,
revision: Optional[str] = 'main',
token: Optional[str] = None):
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
"""
This function gets the pooling and normalize
config from the model - only applies to
Expand All @@ -315,8 +315,7 @@ def get_pooling_config(model: str,
"""

modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision,
token)
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)

if modules_dict is None:
return None
Expand All @@ -332,8 +331,7 @@ def get_pooling_config(model: str,
if pooling:

pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision,
token)
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None)

Expand Down Expand Up @@ -368,8 +366,8 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:


def get_sentence_transformer_tokenizer_config(model: str,
revision: Optional[str] = 'main',
token: Optional[str] = None):
revision: Optional[str] = 'main'
):
"""
Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model.
Expand All @@ -379,7 +377,6 @@ def get_sentence_transformer_tokenizer_config(model: str,
BERT model.
- revision (str, optional): The revision of the m
odel to use. Defaults to 'main'.
- token (str): A Hugging Face access token.
Returns:
- dict: A dictionary containing the configuration parameters
Expand All @@ -394,7 +391,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]:
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token)
encoder_dict = get_hf_file_to_dict(config_name, model, revision)
if encoder_dict:
break

Expand Down Expand Up @@ -474,16 +471,14 @@ def _reduce_config(config: VllmConfig):
exc_info=e)


def load_params_config(model: Union[str, Path],
revision: Optional[str],
token: Optional[str] = None,
def load_params_config(model: Union[str, Path], revision: Optional[str],
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
assert isinstance(config_dict, dict)

config_mapping = {
Expand Down

0 comments on commit 250ee65

Please sign in to comment.