Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI API via LitGPT CLI #1865

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 112 additions & 11 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from litgpt.utils import auto_download_checkpoint


_LITSERVE_AVAILABLE = RequirementCache("litserve")
if _LITSERVE_AVAILABLE:
from litserve import LitAPI, LitServer
from litserve import LitAPI, LitServer, OpenAISpec, OpenAIEmbeddingSpec
from litserve.specs.openai import ChatMessage
else:
LitAPI, LitServer = object, object

Expand All @@ -28,9 +28,9 @@ def __init__(
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1
devices: int = 1,
) -> None:

print ("BaseLitAPI : __init__")
if not _LITSERVE_AVAILABLE:
raise ImportError(str(_LITSERVE_AVAILABLE))

Expand All @@ -45,6 +45,7 @@ def __init__(
self.devices = devices

def setup(self, device: str) -> None:
print ("BaseLitAPI : setup")
if ":" in device:
accelerator, device = device.split(":")
device = f"[{int(device)}]"
Expand All @@ -69,6 +70,7 @@ def setup(self, device: str) -> None:

def decode_request(self, request: Dict[str, Any]) -> Any:
# Convert the request payload to your model input.
print ("BaseLitAPI : decode_request")
prompt = str(request["prompt"])
return prompt

Expand All @@ -83,7 +85,7 @@ def __init__(
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1
devices: int = 1,
):
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)

Expand Down Expand Up @@ -117,13 +119,16 @@ def __init__(
max_new_tokens: int = 50,
devices: int = 1
):
print ("StreamLitAPI : __init__")
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)

def setup(self, device: str):
print ("StreamLitAPI : setup")
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
print ("StreamLitAPI : predict")
yield from self.llm.generate(
inputs,
temperature=self.temperature,
Expand All @@ -134,9 +139,73 @@ def predict(self, inputs: torch.Tensor) -> Any:
)

def encode_response(self, output):
print ("StreamLitAPI : encode_response")
for out in output:
yield {"output": out}

class OpenAIAPI(LitAPI):
def __init__(
self,
checkpoint_dir: Path,
quantize: Optional[str] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1
):
print ("OpenAIAPI : init")
if not _LITSERVE_AVAILABLE:
raise ImportError(str(_LITSERVE_AVAILABLE))

super().__init__()
self.checkpoint_dir = checkpoint_dir
self.quantize = quantize
self.precision = precision
self.temperature = temperature
self.top_k = top_k
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.devices = devices

def setup(self, device: str):
print ("OpenAIAPI : setup")
if ":" in device:
accelerator, device = device.split(":")
device = f"[{int(device)}]"
else:
accelerator = device
device = 1

print("OpenAIAPI : Initializing model...")
self.llm = LLM.load(
model=self.checkpoint_dir,
distribute=None
)

self.llm.distribute(
devices=self.devices,
accelerator=accelerator,
quantize=self.quantize,
precision=self.precision,
generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None
)
print("OpenAIAPI : Model successfully initialized.")

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.

for chunk in inputs:
yield from self.llm.generate(
chunk["content"],
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
max_new_tokens=self.max_new_tokens,
stream=True
)


def run_server(
checkpoint_dir: Path,
Expand All @@ -151,6 +220,8 @@ def run_server(
port: int = 8000,
stream: bool = False,
access_token: Optional[str] = None,
spec: Optional[Literal["openaispec", "openaiembeddingspec"]] = None,

) -> None:
"""Serve a LitGPT model using LitServe.

Expand Down Expand Up @@ -189,11 +260,29 @@ def run_server(
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
access_token: Optional API token to access models with restrictions.
spec: Classname of the LitServe protocol spec to use. Options include: OpenAISpec, OpenAIEmbeddingSpec

Specs defined in:
https://github.com/Lightning-AI/LitServe/tree/main/src/litserve/specs
OpenAI in LitServe details:
https://lightning.ai/docs/litserve/features/open-ai-spec#enable-openai-api-in-litserve
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())

spec_impl = None
if not stream:

spec_impl = None

if (spec != None):
if (spec.lower() == "openaispec"):
raise Exception(f"OpenAISpec only supported when streaming == True")
elif (spec.lower() == "openaiembeddingspec"):
spec_impl = OpenAIEmbeddingSpec()
else:
raise Exception(f"LitServe protocol spec {spec} not known")

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
Expand All @@ -206,24 +295,36 @@ def run_server(
devices=devices
),
accelerator=accelerator,
devices=1 # We need to use the devives inside the `SimpleLitAPI` class
devices=1, # We need to use the devices inside the `SimpleLitAPI` class
spec=spec_impl
)

else:

spec_impl = None
if (spec != None):
if (spec.lower() == "openaispec"):
spec_impl = OpenAISpec()
elif (spec.lower() == "openaiembeddingspec"):
raise Exception(f"OpenAIEmbeddingSpec only supported when streaming == False")
else:
raise Exception(f"LitServe protocol spec {spec} not known")

server = LitServer(
StreamLitAPI(
OpenAIAPI(
checkpoint_dir=checkpoint_dir,
quantize=quantize,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
devices=devices # We need to use the devives inside the `StreamLitAPI` class
devices=devices # We need to use the devices inside the `StreamLitAPI` class
),
accelerator=accelerator,
devices=1,
stream=True
stream=True,
spec=spec_impl
)

server.run(port=port, generate_client_file=False)
server.run(port=port, generate_client_file=False)
Loading