Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Output Rails Streaming #966

Merged
merged 22 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5712766
feat: add streaming support for output rails
Pouyanpi Jan 27, 2025
b546e7a
test: add unit tests for BufferStrategy in llmrails
Pouyanpi Jan 27, 2025
404a2ec
feat: add example config
Pouyanpi Jan 27, 2025
5f9fad8
test: add self-check output streaming tests
Pouyanpi Jan 27, 2025
bcf429c
feat: add get_action_name_from_flow_id function
Pouyanpi Jan 29, 2025
0b84ce7
feat: update OutputRailsStreamingConfig with new fields
Pouyanpi Jan 31, 2025
88a0fde
refactor: refactor code and also integrate output_mapping decorator
Pouyanpi Jan 31, 2025
0b25a38
refactor(logging): change log level from warning to info
Pouyanpi Jan 31, 2025
f0e66c2
feat: add context message handling and action params
Pouyanpi Feb 2, 2025
987e09d
refactor: move buffer module to rails.llm directory
Pouyanpi Feb 3, 2025
586515f
refactor: extract output mapping logic to separate module
Pouyanpi Feb 3, 2025
96e16fd
fix _get_action_details_from_flow
Pouyanpi Feb 3, 2025
17a9c8f
test: add tests for output mapping functions
Pouyanpi Feb 3, 2025
95f5570
feat: yield tokens instead of chunks in streaming output
Pouyanpi Feb 3, 2025
01805e4
fix: apply review suggestions by mike
Pouyanpi Feb 3, 2025
6374198
refactor: update streaming config and handling
Pouyanpi Feb 4, 2025
edc834d
feat: add state parameter to LLMRails stream method
Pouyanpi Feb 4, 2025
c2fb77f
fix: finalize name changes and defautl values
Pouyanpi Feb 4, 2025
06ffbc2
fix supported prefixes
Pouyanpi Feb 4, 2025
4346671
fix: remove debug print stmts
Pouyanpi Feb 4, 2025
9a8ae0d
feat: handle ABORT SSE in streaming output
Pouyanpi Feb 5, 2025
56a3560
chore: delete abc_streaming config
Pouyanpi Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions nemoguardrails/actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def action(
name (Optional[str]): The name to associate with the action.
execute_async: Whether the function should be executed in async mode.
output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result.
It should accept the return value (e.g. the first element of a tuple) and return True if the output
should be considered blocked.
It accepts the return value (e.g. the first element of a tuple) and return True if the output
is not safe.

Returns:
callable: The decorated function or class.
"""
Expand Down
12 changes: 11 additions & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,16 @@ async def generate_bot_message(

streaming_handler = streaming_handler_var.get()

# when we have 'output rails streaming' enabled
# we must disable (skip) the output rails which gets executed on $bot_message
# as it is executed separately in llmrails.py
# of course, it does not work when passed as context in `run_output_rails_in_streaming`
# streaming_handler is set when stream_async method is used

if streaming_handler and len(self.config.rails.output.flows) > 0:
# if streaming_handler and self.config.rails.output.streaming.enabled:
context_updates["skip_output_rails"] = True

if bot_intent in self.config.bot_messages:
# Choose a message randomly from self.config.bot_messages[bot_message]
# However, in test mode, we always choose the first one, to keep it predictable.
Expand All @@ -779,7 +789,7 @@ async def generate_bot_message(
context_updates["skip_output_rails"] = True

# Check if the output is supposed to be the content of a context variable
elif bot_intent[0] == "$" and bot_intent[1:] in context:
elif bot_intent and bot_intent[0] == "$" and bot_intent[1:] in context:
bot_utterance = context[bot_intent[1:]]

else:
Expand Down
51 changes: 51 additions & 0 deletions nemoguardrails/actions/output_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Tuple


def default_output_mapping(result: Any) -> bool:
"""A fallback output mapping if an action does not provide one.

- For a boolean result: assume True means allowed (so block if False).
- For a numeric result: use 0.5 as a threshold (block if the value is less).
- Otherwise, assume the result is allowed.
"""
if isinstance(result, bool):
return not result
elif isinstance(result, (int, float)):
return result < 0.5
else:
return False


def is_output_blocked(result: Any, action_func: Any) -> bool:
"""Determines if an action result is not allowed using its attached mapping.

Args:
result: The value returned by the action.
action_func: The action function (whose metadata contains the mapping).

Returns:
True if the mapping indicates that the output should be blocked, False otherwise.
"""
mapping = getattr(action_func, "action_meta", {}).get("output_mapping")
if mapping is None:
mapping = default_output_mapping

if not isinstance(result, Tuple):
result = (result,)

return mapping(result[0])
31 changes: 22 additions & 9 deletions nemoguardrails/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, cast
Expand Down Expand Up @@ -82,17 +83,29 @@ async def _run_chat_v1_0(
if not server_url:
# If we have streaming from a locally loaded config, we initialize the handler.
if streaming and not server_url and rails_app.main_llm_supports_streaming:
streaming_handler = StreamingHandler(enable_print=True)
else:
streaming_handler = None
bot_message_list = []
async for chunk in rails_app.stream_async(messages=history):
if '{"event": "ABORT"' in chunk:
dict_chunk = json.loads(chunk)
console.print(
"\n\n[red]"
+ f"ABORT streaming. {dict_chunk['data']}"
+ "[/]"
)
break

bot_message = await rails_app.generate_async(
messages=history, streaming_handler=streaming_handler
)
console.print("[green]" + f"{chunk}" + "[/]", end="")
bot_message_list.append(chunk)

if not streaming or not rails_app.main_llm_supports_streaming:
# We print bot messages in green.
console.print("[green]" + f"{bot_message['content']}" + "[/]")
bot_message_text = "".join(bot_message_list)
bot_message = {"role": "assistant", "content": bot_message_text}

else:
bot_message = await rails_app.generate_async(messages=history)

if not streaming or not rails_app.main_llm_supports_streaming:
# We print bot messages in green.
console.print("[green]" + f"{bot_message['content']}" + "[/]")
else:
data = {
"config_id": config_id,
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/logging/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def on_llm_end(
llm_call_info.completion_tokens = token_usage.get("completion_tokens", 0)

if not token_stats_found:
log.warning(
log.info(
"Token stats in LLM call info cannot be computed for current model!"
)

Expand Down
108 changes: 108 additions & 0 deletions nemoguardrails/rails/llm/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Tuple

from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig


class BufferStrategy(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: OutputRailsStreamingConfig) -> "BufferStrategy":
pass

# The abstract method is not async to ensure the return type
# matches the async generator in the concrete implementation.
@abstractmethod
def __call__(
self, streaming_handler
) -> AsyncGenerator[Tuple[List[str], str], None]:
pass

@abstractmethod
def generate_chunk_str(self, *args, **kwargs) -> str:
pass


class RollingBuffer(BufferStrategy):
"""A minimal buffer strategy that buffers chunks and yields them when the buffer is full.

Args:
buffer_context_size (int): The number of tokens carried over from the previous chunk to provide context for continuity in processing.
buffer_chunk_size (int): The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.
"""

def __init__(self, buffer_context_size: int = 5, buffer_chunk_size: int = 10):
self.buffer_context_size = buffer_context_size
self.buffer_chunk_size = buffer_chunk_size
self.last_index = 0

@classmethod
def from_config(cls, config: OutputRailsStreamingConfig):
return cls(
buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size
)

async def __call__(
self, streaming_handler
) -> AsyncGenerator[Tuple[List[str], str], None]:
buffer = []
index = 0

async for chunk in streaming_handler:
buffer.append(chunk)
index += 1

if len(buffer) >= self.buffer_chunk_size:
yield (
# we apply output rails on the buffer
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
# generate_chunk_str is what gets printed in the console or yield to user
# to avoid repeating the already streamed/printed chunk
self.generate_chunk_str(
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
index,
),
)
buffer = buffer[-self.buffer_context_size :]

# Yield any remaining buffer if it's not empty
if buffer:
yield (
buffer,
self.generate_chunk_str(
buffer[-self.buffer_chunk_size - self.buffer_context_size :], index
),
)

def generate_chunk_str(self, buffer, current_index) -> str:
if current_index <= self.last_index:
return ""

new_chunks = buffer[self.last_index - current_index :]
self.last_index = current_index
# TODO: something causes duplicate whitespaces between tokens, figure out why,
# If using `return "".join(new_chunks)` works, then the issue might be elsewhere in the code where the chunks are being generated or processed.
# Ensure that the chunks themselves do not contain extra spaces.
# WAR: return "".join(new_chunks)
return "".join(new_chunks)


def get_buffer_strategy(config: OutputRailsStreamingConfig) -> BufferStrategy:
# TODO: use a factory function or class
# currently we only have RollingBuffer, in future we use a registry
return RollingBuffer.from_config(config)
41 changes: 35 additions & 6 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,27 @@ class InputRails(BaseModel):
)


class OutputRailsStreamingConfig(BaseModel):
"""Configuration for managing streaming output of LLM tokens."""

enabled: bool = Field(
default=False, description="Enables streaming mode when True."
)
chunk_size: int = Field(
default=200,
description="The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.",
)
context_size: int = Field(
default=50,
description="The number of tokens carried over from the previous chunk to provide context for continuity in processing.",
)
stream_first: bool = Field(
default=True,
description="If True, token chunks are streamed immediately before output rails are applied.",
)
model_config = ConfigDict(extra="allow")


class OutputRails(BaseModel):
"""Configuration of output rails."""

Expand All @@ -312,6 +333,11 @@ class OutputRails(BaseModel):
description="The names of all the flows that implement output rails.",
)

streaming: Optional[OutputRailsStreamingConfig] = Field(
default_factory=OutputRailsStreamingConfig,
description="Configuration for streaming output rails.",
)


class RetrievalRails(BaseModel):
"""Configuration of retrieval rails."""
Expand Down Expand Up @@ -1201,12 +1227,15 @@ def parse_object(cls, obj):

@property
def streaming_supported(self):
"""Whether the current config supports streaming or not.

Currently, we don't support streaming if there are output rails.
"""
if len(self.rails.output.flows) > 0:
return False
"""Whether the current config supports streaming or not."""

# if len(self.rails.output.flows) > 0:
# # if we have output rails streaming enabled
# # we keep it in case it was needed when we have
# # support per rails
# if self.rails.output.streaming.enabled:
# return True
# return False

return True

Expand Down
Loading