diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 651810b4aa..8906cc7e1d 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -10,10 +10,10 @@ from litgpt.utils import auto_download_checkpoint - _LITSERVE_AVAILABLE = RequirementCache("litserve") if _LITSERVE_AVAILABLE: - from litserve import LitAPI, LitServer + from litserve import LitAPI, LitServer, OpenAISpec, OpenAIEmbeddingSpec + from litserve.specs.openai import ChatMessage else: LitAPI, LitServer = object, object @@ -28,9 +28,9 @@ def __init__( top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, - devices: int = 1 + devices: int = 1, ) -> None: - + print ("BaseLitAPI : __init__") if not _LITSERVE_AVAILABLE: raise ImportError(str(_LITSERVE_AVAILABLE)) @@ -45,6 +45,7 @@ def __init__( self.devices = devices def setup(self, device: str) -> None: + print ("BaseLitAPI : setup") if ":" in device: accelerator, device = device.split(":") device = f"[{int(device)}]" @@ -69,6 +70,7 @@ def setup(self, device: str) -> None: def decode_request(self, request: Dict[str, Any]) -> Any: # Convert the request payload to your model input. + print ("BaseLitAPI : decode_request") prompt = str(request["prompt"]) return prompt @@ -83,7 +85,7 @@ def __init__( top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, - devices: int = 1 + devices: int = 1, ): super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) @@ -117,13 +119,16 @@ def __init__( max_new_tokens: int = 50, devices: int = 1 ): + print ("StreamLitAPI : __init__") super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) def setup(self, device: str): + print ("StreamLitAPI : setup") super().setup(device) def predict(self, inputs: torch.Tensor) -> Any: # Run the model on the input and return the output. + print ("StreamLitAPI : predict") yield from self.llm.generate( inputs, temperature=self.temperature, @@ -134,9 +139,73 @@ def predict(self, inputs: torch.Tensor) -> Any: ) def encode_response(self, output): + print ("StreamLitAPI : encode_response") for out in output: yield {"output": out} +class OpenAIAPI(LitAPI): + def __init__( + self, + checkpoint_dir: Path, + quantize: Optional[str] = None, + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 50, + top_p: float = 1.0, + max_new_tokens: int = 50, + devices: int = 1 + ): + print ("OpenAIAPI : init") + if not _LITSERVE_AVAILABLE: + raise ImportError(str(_LITSERVE_AVAILABLE)) + + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.quantize = quantize + self.precision = precision + self.temperature = temperature + self.top_k = top_k + self.max_new_tokens = max_new_tokens + self.top_p = top_p + self.devices = devices + + def setup(self, device: str): + print ("OpenAIAPI : setup") + if ":" in device: + accelerator, device = device.split(":") + device = f"[{int(device)}]" + else: + accelerator = device + device = 1 + + print("OpenAIAPI : Initializing model...") + self.llm = LLM.load( + model=self.checkpoint_dir, + distribute=None + ) + + self.llm.distribute( + devices=self.devices, + accelerator=accelerator, + quantize=self.quantize, + precision=self.precision, + generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None + ) + print("OpenAIAPI : Model successfully initialized.") + + def predict(self, inputs: torch.Tensor) -> Any: + # Run the model on the input and return the output. + + for chunk in inputs: + yield from self.llm.generate( + chunk["content"], + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + max_new_tokens=self.max_new_tokens, + stream=True + ) + def run_server( checkpoint_dir: Path, @@ -151,6 +220,8 @@ def run_server( port: int = 8000, stream: bool = False, access_token: Optional[str] = None, + spec: Optional[Literal["openaispec", "openaiembeddingspec"]] = None, + ) -> None: """Serve a LitGPT model using LitServe. @@ -189,11 +260,29 @@ def run_server( port: The network port number on which the model is configured to be served. stream: Whether to stream the responses. access_token: Optional API token to access models with restrictions. + spec: Classname of the LitServe protocol spec to use. Options include: OpenAISpec, OpenAIEmbeddingSpec + + Specs defined in: + https://github.com/Lightning-AI/LitServe/tree/main/src/litserve/specs + OpenAI in LitServe details: + https://lightning.ai/docs/litserve/features/open-ai-spec#enable-openai-api-in-litserve """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) + spec_impl = None if not stream: + + spec_impl = None + + if (spec != None): + if (spec.lower() == "openaispec"): + raise Exception(f"OpenAISpec only supported when streaming == True") + elif (spec.lower() == "openaiembeddingspec"): + spec_impl = OpenAIEmbeddingSpec() + else: + raise Exception(f"LitServe protocol spec {spec} not known") + server = LitServer( SimpleLitAPI( checkpoint_dir=checkpoint_dir, @@ -206,12 +295,23 @@ def run_server( devices=devices ), accelerator=accelerator, - devices=1 # We need to use the devives inside the `SimpleLitAPI` class + devices=1, # We need to use the devices inside the `SimpleLitAPI` class + spec=spec_impl ) - + else: + + spec_impl = None + if (spec != None): + if (spec.lower() == "openaispec"): + spec_impl = OpenAISpec() + elif (spec.lower() == "openaiembeddingspec"): + raise Exception(f"OpenAIEmbeddingSpec only supported when streaming == False") + else: + raise Exception(f"LitServe protocol spec {spec} not known") + server = LitServer( - StreamLitAPI( + OpenAIAPI( checkpoint_dir=checkpoint_dir, quantize=quantize, precision=precision, @@ -219,11 +319,12 @@ def run_server( top_k=top_k, top_p=top_p, max_new_tokens=max_new_tokens, - devices=devices # We need to use the devives inside the `StreamLitAPI` class + devices=devices # We need to use the devices inside the `StreamLitAPI` class ), accelerator=accelerator, devices=1, - stream=True + stream=True, + spec=spec_impl ) - server.run(port=port, generate_client_file=False) + server.run(port=port, generate_client_file=False) \ No newline at end of file