-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refact(ml_backend): separate servers
- Loading branch information
Showing
5 changed files
with
408 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from fastapi import FastAPI, HTTPException | ||
from pydantic import BaseModel | ||
from typing import List, Dict, Any, Optional | ||
import httpx | ||
import torch | ||
from PIL import Image | ||
import base64 | ||
import io | ||
from transformers import AutoProcessor, AutoModelForCausalLM | ||
import time | ||
from memos_ml_backends.schemas import ( | ||
ChatCompletionRequest, | ||
ChatCompletionResponse, | ||
ModelData, | ||
ModelsResponse, | ||
get_image_from_url, | ||
) | ||
|
||
MODEL_INFO = {"name": "florence2-base-ft", "max_model_len": 2048} | ||
|
||
# 检测可用的设备 | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
elif torch.backends.mps.is_available(): | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
torch_dtype = ( | ||
torch.float32 | ||
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) | ||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) | ||
else torch.float16 | ||
) | ||
print(f"Using device: {device}") | ||
|
||
# Load Florence-2 model | ||
florence_model = AutoModelForCausalLM.from_pretrained( | ||
"microsoft/Florence-2-base-ft", | ||
torch_dtype=torch_dtype, | ||
attn_implementation="sdpa", | ||
trust_remote_code=True, | ||
).to(device, torch_dtype) | ||
florence_processor = AutoProcessor.from_pretrained( | ||
"microsoft/Florence-2-base-ft", trust_remote_code=True | ||
) | ||
|
||
app = FastAPI() | ||
|
||
|
||
async def generate_florence_result(text_input, image_input, max_tokens): | ||
task_prompt = "<MORE_DETAILED_CAPTION>" | ||
prompt = task_prompt + "" | ||
|
||
inputs = florence_processor( | ||
text=prompt, images=image_input, return_tensors="pt" | ||
).to(device, torch_dtype) | ||
|
||
generated_ids = florence_model.generate( | ||
input_ids=inputs["input_ids"], | ||
pixel_values=inputs["pixel_values"], | ||
max_new_tokens=max_tokens or 1024, | ||
do_sample=False, | ||
num_beams=3, | ||
) | ||
|
||
generated_texts = florence_processor.batch_decode( | ||
generated_ids, skip_special_tokens=False | ||
) | ||
|
||
parsed_answer = florence_processor.post_process_generation( | ||
generated_texts[0], | ||
task=task_prompt, | ||
image_size=(image_input.width, image_input.height), | ||
) | ||
|
||
return parsed_answer.get(task_prompt, "") | ||
|
||
|
||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | ||
async def chat_completions(request: ChatCompletionRequest): | ||
try: | ||
last_message = request.messages[-1] | ||
text_input = last_message.get("content", "") | ||
image_input = None | ||
|
||
if isinstance(text_input, list): | ||
for content in text_input: | ||
if content.get("type") == "image_url": | ||
image_url = content["image_url"].get("url") | ||
image_input = await get_image_from_url(image_url) | ||
break | ||
text_input = " ".join( | ||
[ | ||
content["text"] | ||
for content in text_input | ||
if content.get("type") == "text" | ||
] | ||
) | ||
|
||
if image_input is None: | ||
raise ValueError("Image input is required") | ||
|
||
parsed_answer = await generate_florence_result( | ||
text_input, image_input, request.max_tokens | ||
) | ||
|
||
result = ChatCompletionResponse( | ||
id=str(int(time.time())), | ||
object="chat.completion", | ||
created=int(time.time()), | ||
model=request.model, | ||
choices=[ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": parsed_answer, | ||
}, | ||
"finish_reason": "stop", | ||
} | ||
], | ||
usage={ | ||
"prompt_tokens": 0, | ||
"total_tokens": 0, | ||
"completion_tokens": 0, | ||
}, | ||
) | ||
|
||
return result | ||
except Exception as e: | ||
print(f"Error generating chat completion: {str(e)}") | ||
raise HTTPException( | ||
status_code=500, detail=f"Error generating chat completion: {str(e)}" | ||
) | ||
|
||
|
||
@app.get("/v1/models", response_model=ModelsResponse) | ||
async def get_models(): | ||
model_data = ModelData( | ||
id=MODEL_INFO["name"], | ||
created=int(time.time()), | ||
max_model_len=MODEL_INFO["max_model_len"], | ||
permission=[ | ||
{ | ||
"id": f"modelperm-{MODEL_INFO['name']}", | ||
"object": "model_permission", | ||
"created": int(time.time()), | ||
"allow_create_engine": False, | ||
"allow_sampling": False, | ||
"allow_logprobs": False, | ||
"allow_search_indices": False, | ||
"allow_view": False, | ||
"allow_fine_tuning": False, | ||
"organization": "*", | ||
"group": None, | ||
"is_blocking": False, | ||
} | ||
], | ||
) | ||
|
||
return ModelsResponse(data=[model_data]) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
import uvicorn | ||
|
||
parser = argparse.ArgumentParser(description="Run the Florence-2 server") | ||
parser.add_argument( | ||
"--port", type=int, default=8000, help="Port to run the server on" | ||
) | ||
args = parser.parse_args() | ||
|
||
print("Using Florence-2 model") | ||
uvicorn.run(app, host="0.0.0.0", port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
from fastapi import FastAPI, HTTPException | ||
import torch | ||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | ||
from qwen_vl_utils import process_vision_info | ||
import time | ||
from memos_ml_backends.schemas import ( | ||
ChatCompletionRequest, | ||
ChatCompletionResponse, | ||
ModelData, | ||
ModelsResponse, | ||
get_image_from_url, | ||
) | ||
|
||
MODEL_INFO = {"name": "Qwen2-VL-2B-Instruct", "max_model_len": 32768} | ||
|
||
# 检测可用的设备 | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
elif torch.backends.mps.is_available(): | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
torch_dtype = ( | ||
torch.float32 | ||
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) | ||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) | ||
else torch.float16 | ||
) | ||
print(f"Using device: {device}") | ||
|
||
# Load Qwen2VL model | ||
qwen2vl_model = Qwen2VLForConditionalGeneration.from_pretrained( | ||
"Qwen/Qwen2-VL-2B-Instruct", | ||
torch_dtype=torch_dtype, | ||
device_map="auto", | ||
).to(device, torch_dtype) | ||
qwen2vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4") | ||
|
||
app = FastAPI() | ||
|
||
|
||
async def generate_qwen2vl_result(text_input, image_input, max_tokens): | ||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "image", "image": image_input}, | ||
{"type": "text", "text": text_input}, | ||
], | ||
} | ||
] | ||
|
||
text = qwen2vl_processor.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
image_inputs, video_inputs = process_vision_info(messages) | ||
|
||
inputs = qwen2vl_processor( | ||
text=[text], | ||
images=image_inputs, | ||
videos=video_inputs, | ||
padding=True, | ||
return_tensors="pt", | ||
) | ||
inputs = inputs.to(device) | ||
|
||
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) | ||
|
||
generated_ids_trimmed = [ | ||
out_ids[len(in_ids) :] | ||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | ||
] | ||
|
||
output_text = qwen2vl_processor.batch_decode( | ||
generated_ids_trimmed, | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False, | ||
) | ||
|
||
return output_text[0] if output_text else "" | ||
|
||
|
||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | ||
async def chat_completions(request: ChatCompletionRequest): | ||
try: | ||
last_message = request.messages[-1] | ||
text_input = last_message.get("content", "") | ||
image_input = None | ||
|
||
if isinstance(text_input, list): | ||
for content in text_input: | ||
if content.get("type") == "image_url": | ||
image_url = content["image_url"].get("url") | ||
image_input = await get_image_from_url(image_url) | ||
break | ||
text_input = " ".join( | ||
[ | ||
content["text"] | ||
for content in text_input | ||
if content.get("type") == "text" | ||
] | ||
) | ||
|
||
if image_input is None: | ||
raise ValueError("Image input is required") | ||
|
||
parsed_answer = await generate_qwen2vl_result( | ||
text_input, image_input, request.max_tokens | ||
) | ||
|
||
result = ChatCompletionResponse( | ||
id=str(int(time.time())), | ||
object="chat.completion", | ||
created=int(time.time()), | ||
model=request.model, | ||
choices=[ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": parsed_answer, | ||
}, | ||
"finish_reason": "stop", | ||
} | ||
], | ||
usage={ | ||
"prompt_tokens": 0, | ||
"total_tokens": 0, | ||
"completion_tokens": 0, | ||
}, | ||
) | ||
|
||
return result | ||
except Exception as e: | ||
print(f"Error generating chat completion: {str(e)}") | ||
raise HTTPException( | ||
status_code=500, detail=f"Error generating chat completion: {str(e)}" | ||
) | ||
|
||
|
||
# 添加新的 GET /v1/models 端点 | ||
@app.get("/v1/models", response_model=ModelsResponse) | ||
async def get_models(): | ||
model_data = ModelData( | ||
id=MODEL_INFO["name"], | ||
created=int(time.time()), | ||
max_model_len=MODEL_INFO["max_model_len"], | ||
permission=[ | ||
{ | ||
"id": f"modelperm-{MODEL_INFO['name']}", | ||
"object": "model_permission", | ||
"created": int(time.time()), | ||
"allow_create_engine": False, | ||
"allow_sampling": False, | ||
"allow_logprobs": False, | ||
"allow_search_indices": False, | ||
"allow_view": False, | ||
"allow_fine_tuning": False, | ||
"organization": "*", | ||
"group": None, | ||
"is_blocking": False, | ||
} | ||
], | ||
) | ||
|
||
return ModelsResponse(data=[model_data]) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
import uvicorn | ||
|
||
parser = argparse.ArgumentParser(description="Run the Qwen2VL server") | ||
parser.add_argument( | ||
"--port", type=int, default=8000, help="Port to run the server on" | ||
) | ||
args = parser.parse_args() | ||
|
||
print("Using Qwen2VL model") | ||
uvicorn.run(app, host="0.0.0.0", port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.