diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 8998cf39de3..15aa701cb0f 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -117,7 +117,7 @@ def create_streaming_error_response( return json_str -def load_chat_template_for_openai_api(chat_template_arg): +def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): global chat_template_name print(f"Use chat template: {chat_template_arg}") @@ -127,27 +127,38 @@ def load_chat_template_for_openai_api(chat_template_arg): f"Chat template {chat_template_arg} is not a built-in template name " "or a valid chat template file path." ) - with open(chat_template_arg, "r") as filep: - template = json.load(filep) - try: - sep_style = SeparatorStyle[template["sep_style"]] - except KeyError: - raise ValueError( - f"Unknown separator style: {template['sep_style']}" - ) from None - register_conv_template( - Conversation( - name=template["name"], - system_template=template["system"] + "\n{system_message}", - system_message=template.get("system_message", ""), - roles=(template["user"], template["assistant"]), - sep_style=sep_style, - sep=template.get("sep", "\n"), - stop_str=template["stop_str"], - ), - override=True, + if chat_template_arg.endswith(".jinja"): + with open(chat_template_arg, "r") as f: + chat_template = "".join(f.readlines()).strip("\n") + tokenizer_manager.tokenizer.chat_template = chat_template.replace( + "\\n", "\n" ) - chat_template_name = template["name"] + chat_template_name = None + else: + assert chat_template_arg.endswith( + ".json" + ), "unrecognized format of chat template file" + with open(chat_template_arg, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError( + f"Unknown separator style: {template['sep_style']}" + ) from None + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + chat_template_name = template["name"] else: chat_template_name = chat_template_arg diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8f735ac0c74..973f9c8e120 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -288,6 +288,8 @@ def launch_server( # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) @@ -375,11 +377,6 @@ def _set_envs_and_config(server_args: ServerArgs): # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. maybe_set_triton_cache_manager() - # Set global chat template - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) - # Check flashinfer version if not server_args.disable_flashinfer: assert_pkg_version(