Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

Commit

Permalink
add more in-line docs
Browse files Browse the repository at this point in the history
  • Loading branch information
justinthelaw committed Aug 2, 2024
1 parent c92e56c commit 0acedbe
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions packages/vllm/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def clamp(n: float | int, smallest: float | int, largest: float | int):


class RandomAsyncIterator:
"""Manages multiple async iterables and allows iterating over them in a random order."""
"""Manages multiple async iterables and iterates over them in a random order."""

def __init__(self, async_iterables):
# Convert each iterable into an async iterator
Expand All @@ -57,6 +57,8 @@ async def __anext__(self):
except StopAsyncIteration:
# If the selected iterator is exhausted, remove it from the list
del self.async_iterators[random_index]

# Continue to the next iterator
if self.async_iterators:
return await self.__anext__()

Expand All @@ -80,6 +82,8 @@ def remove_iterator(self, async_iterable):


def get_backend_configs():
"""Get vLLM generation parameters from Confz environment variables"""

# Manually load "STOP_TOKENS", since Confz can't handle complex types (i.e., lists)
stop_tokens: str | None = os.getenv("LAI_STOP_TOKENS")
processed_stop_tokens = json.loads(stop_tokens) if stop_tokens else []
Expand Down Expand Up @@ -113,6 +117,7 @@ def get_backend_configs():


def get_config_from_request(request: ChatCompletionRequest | CompletionRequest):
"""Get vLLM generation parameters from the request object"""
return GenerationConfig(
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
Expand Down Expand Up @@ -151,11 +156,11 @@ def __init__(self):
self.backend_config = get_backend_configs()
self.model = self.backend_config.model_path
self.engine_args = AsyncEngineArgs(
# Taken from LFAI SDK config pattern
# Taken from LFAI SDK environment variables
model=self.backend_config.model_path,
max_model_len=self.backend_config.max_context_length,
max_seq_len_to_capture=self.backend_config.max_context_length,
# Taken from Confz
# Taken from Confz environment variables
engine_use_ray=AppConfig().backend_options.engine_use_ray,
worker_use_ray=AppConfig().backend_options.worker_use_ray,
gpu_memory_utilization=AppConfig().backend_options.gpu_memory_utilization,
Expand Down Expand Up @@ -209,6 +214,7 @@ async def iterate_outputs(self):
request_output.outputs[0].token_ids
)

# Add the result to the queue for this request
await self.delta_queue_by_id[request_id].put(text_delta)
await asyncio.sleep(0)

Expand Down Expand Up @@ -254,6 +260,7 @@ async def generate(
request_id = random_uuid()
self.done_by_id[request_id] = False

# Create new request in the queue
await self.generate_session(request_id, prompt, config)

logging.info(f"Begin reading the output for request {request_id}")
Expand Down

0 comments on commit 0acedbe

Please sign in to comment.