@@ -132,6 +146,5 @@ export const RemovableTag: FC = ({
if (Wrapper) {
return {nakedTag};
}
-
return nakedTag;
};
diff --git a/weave/integrations/langchain_nvidia_ai_endpoints/__init__.py b/weave/integrations/langchain_nvidia_ai_endpoints/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py b/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py
new file mode 100644
index 000000000000..0d376db51d91
--- /dev/null
+++ b/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py
@@ -0,0 +1,221 @@
+from __future__ import annotations
+
+import importlib
+import time
+from typing import Any, Callable
+
+import_failed = False
+
+try:
+ from langchain_core.messages import AIMessageChunk, convert_to_openai_messages
+ from langchain_core.outputs import ChatGenerationChunk, ChatResult
+except ImportError:
+ import_failed = True
+
+import weave
+from weave.trace.autopatch import IntegrationSettings, OpSettings
+from weave.trace.op import Op, ProcessedInputs
+from weave.trace.op_extensions.accumulator import add_accumulator
+from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
+
+_lc_nvidia_patcher: MultiPatcher | None = None
+
+
+# NVIDIA-specific accumulator for parsing the objects of streaming interactions
+def nvidia_accumulator(acc: Any | None, value: Any) -> Any:
+ if acc is None:
+ acc = ChatGenerationChunk(message=AIMessageChunk(content=""))
+ acc = acc + value
+
+ # Need to do this since the __add__ impl for the streaming response is wrong
+ # We will get the actual usage in the final chunk so this will be eventually consistent
+ acc.message.usage_metadata = value.message.usage_metadata
+
+ return acc
+
+
+# Post processor to transform output into OpenAI's ChatCompletion format -- need to handle stream and non-stream outputs
+def postprocess_output_to_openai_format(output: Any) -> dict:
+ """
+ Need to post process the output reported to weave to send it on openai format so that Weave front end renders
+ chat view. This only affects what is sent to weave.
+ """
+ from openai.types.chat import ChatCompletion
+
+ if isinstance(output, ChatResult): # its ChatResult
+ message = output.llm_output
+ enhanced_usage = message.get("token_usage", {})
+ enhanced_usage["output_tokens"] = message.get("token_usage").get(
+ "completion_tokens", 0
+ )
+ enhanced_usage["input_tokens"] = message.get("token_usage").get(
+ "prompt_tokens", 0
+ )
+
+ returnable = ChatCompletion(
+ id="None",
+ choices=[
+ {
+ "index": 0,
+ "message": {
+ "content": message.get("content", ""),
+ "role": message.get("role", ""),
+ "tool_calls": message.get("tool_calls", []),
+ },
+ "logprobs": None,
+ "finish_reason": message.get("finish_reason", ""),
+ }
+ ],
+ created=int(time.time()),
+ model=message.get("model_name", ""),
+ object="chat.completion",
+ tool_calls=message.get("tool_calls", []),
+ system_fingerprint=None,
+ usage=enhanced_usage,
+ )
+
+ return returnable.model_dump(exclude_unset=True, exclude_none=True)
+
+ elif isinstance(output, ChatGenerationChunk): # its ChatGenerationChunk
+ orig_message = output.message
+ openai_message = convert_to_openai_messages(output.message)
+ enhanced_usage = getattr(orig_message, "usage_metadata", {})
+ enhanced_usage["completion_tokens"] = orig_message.usage_metadata.get(
+ "output_tokens", 0
+ )
+ enhanced_usage["prompt_tokens"] = orig_message.usage_metadata.get(
+ "input_tokens", 0
+ )
+
+ returnable = ChatCompletion(
+ id="None",
+ choices=[
+ {
+ "index": 0,
+ "message": {
+ "content": orig_message.content,
+ "role": getattr(orig_message, "role", "assistant"),
+ "tool_calls": openai_message.get("tool_calls", []),
+ },
+ "logprobs": None,
+ "finish_reason": getattr(orig_message, "response_metadata", {}).get(
+ "finish_reason", None
+ ),
+ }
+ ],
+ created=int(time.time()),
+ model=getattr(orig_message, "response_metadata", {}).get(
+ "model_name", None
+ ),
+ tool_calls=openai_message.get("tool_calls", []),
+ object="chat.completion",
+ system_fingerprint=None,
+ usage=enhanced_usage,
+ )
+
+ return returnable.model_dump(exclude_unset=True, exclude_none=True)
+ return output
+
+
+def postprocess_inputs_to_openai_format(
+ func: Op, args: tuple, kwargs: dict
+) -> ProcessedInputs:
+ """
+ Need to process the input reported to weave to send it on openai format so that Weave front end renders
+ chat view. This only affects what is sent to weave.
+ """
+ original_args = args
+ original_kwargs = kwargs
+
+ chat_nvidia_obj = args[0]
+ messages_array = args[1]
+ messages_array = convert_to_openai_messages(messages_array)
+ n = len(messages_array)
+
+ stream = False
+ if "stream" in func.name:
+ stream = True
+
+ weave_report = {
+ "model": chat_nvidia_obj.model,
+ "messages": messages_array,
+ "max_tokens": chat_nvidia_obj.max_tokens,
+ "temperature": chat_nvidia_obj.temperature,
+ "top_p": chat_nvidia_obj.top_p,
+ "object": "ChatNVIDIA._generate",
+ "n": n,
+ "stream": stream,
+ }
+
+ return ProcessedInputs(
+ original_args=original_args,
+ original_kwargs=original_kwargs,
+ args=original_args,
+ kwargs=original_kwargs,
+ inputs=weave_report,
+ )
+
+
+def should_use_accumulator(inputs: dict) -> bool:
+ return isinstance(inputs, dict) and bool(inputs.get("stream"))
+
+
+def nvidia_ai_endpoints_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
+ def wrapper(fn: Callable) -> Callable:
+ op_kwargs = settings.model_dump()
+ op = weave.op(fn, **op_kwargs)
+ op._set_on_input_handler(postprocess_inputs_to_openai_format)
+ return add_accumulator(
+ op,
+ make_accumulator=lambda inputs: nvidia_accumulator,
+ should_accumulate=should_use_accumulator,
+ on_finish_post_processor=postprocess_output_to_openai_format,
+ )
+
+ return wrapper
+
+
+def get_nvidia_ai_patcher(
+ settings: IntegrationSettings | None = None,
+) -> MultiPatcher | NoOpPatcher:
+ if settings is None:
+ settings = IntegrationSettings()
+
+ if not settings.enabled:
+ return NoOpPatcher()
+
+ global _lc_nvidia_patcher
+ if _lc_nvidia_patcher is not None:
+ return _lc_nvidia_patcher
+
+ base = settings.op_settings
+
+ generate_settings: OpSettings = base.model_copy(
+ update={
+ "name": base.name or "langchain_nvidia_ai_endpoints.ChatNVIDIA._generate",
+ }
+ )
+ stream_settings: OpSettings = base.model_copy(
+ update={
+ "name": base.name or "langchain_nvidia_ai_endpoints.ChatNVIDIA._stream",
+ }
+ )
+
+ _lc_nvidia_patcher = MultiPatcher(
+ [
+ # Patch invoke method
+ SymbolPatcher(
+ lambda: importlib.import_module("langchain_nvidia_ai_endpoints"),
+ "ChatNVIDIA._generate",
+ nvidia_ai_endpoints_wrapper(generate_settings),
+ ),
+ # Patch stream method
+ SymbolPatcher(
+ lambda: importlib.import_module("langchain_nvidia_ai_endpoints"),
+ "ChatNVIDIA._stream",
+ nvidia_ai_endpoints_wrapper(stream_settings),
+ ),
+ ]
+ )
+
+ return _lc_nvidia_patcher
diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py
index c1c47d375127..bc77752957c2 100644
--- a/weave/trace/autopatch.py
+++ b/weave/trace/autopatch.py
@@ -46,6 +46,7 @@ class AutopatchSettings(BaseModel):
notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings)
openai: IntegrationSettings = Field(default_factory=IntegrationSettings)
vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings)
+ chatnvidia: IntegrationSettings = Field(default_factory=IntegrationSettings)
@validate_call
@@ -60,6 +61,9 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import get_instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
+ from weave.integrations.langchain_nvidia_ai_endpoints.langchain_nv_ai_endpoints import (
+ get_nvidia_ai_patcher,
+ )
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import get_mistral_patcher
@@ -82,6 +86,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_google_genai_patcher(settings.google_ai_studio).attempt_patch()
get_notdiamond_patcher(settings.notdiamond).attempt_patch()
get_vertexai_patcher(settings.vertexai).attempt_patch()
+ get_nvidia_ai_patcher(settings.chatnvidia).attempt_patch()
llamaindex_patcher.attempt_patch()
langchain_patcher.attempt_patch()
@@ -98,6 +103,9 @@ def reset_autopatch() -> None:
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import get_instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
+ from weave.integrations.langchain_nvidia_ai_endpoints.langchain_nv_ai_endpoints import (
+ get_nvidia_ai_patcher,
+ )
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import get_mistral_patcher
@@ -117,6 +125,7 @@ def reset_autopatch() -> None:
get_google_genai_patcher().undo_patch()
get_notdiamond_patcher().undo_patch()
get_vertexai_patcher().undo_patch()
+ get_nvidia_ai_patcher().undo_patch()
llamaindex_patcher.undo_patch()
langchain_patcher.undo_patch()
diff --git a/weave/trace/sanitize.py b/weave/trace/sanitize.py
index 5125dbfb09dc..24e1fea5dd95 100644
--- a/weave/trace/sanitize.py
+++ b/weave/trace/sanitize.py
@@ -1,6 +1,11 @@
+# always use lowercase keys for the redact keys
REDACT_KEYS = (
"api_key",
"auth_headers",
- "Authorization",
+ "authorization",
)
REDACTED_VALUE = "REDACTED"
+
+
+def should_redact(key: str) -> bool:
+ return key.lower() in REDACT_KEYS
diff --git a/weave/trace/serialize.py b/weave/trace/serialize.py
index 5e0c6006aa50..ae8afa0e12e8 100644
--- a/weave/trace/serialize.py
+++ b/weave/trace/serialize.py
@@ -8,7 +8,7 @@
from weave.trace import custom_objs
from weave.trace.object_record import ObjectRecord
from weave.trace.refs import ObjectRef, TableRef, parse_uri
-from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE
+from weave.trace.sanitize import REDACTED_VALUE, should_redact
from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import (
BUILTIN_OBJECT_REGISTRY,
)
@@ -148,7 +148,7 @@ def dictify(
elif isinstance(obj, dict):
dict_result = {}
for k, v in obj.items():
- if k in REDACT_KEYS:
+ if should_redact(k):
dict_result[k] = REDACTED_VALUE
else:
dict_result[k] = dictify(v, maxdepth, depth + 1, seen)
@@ -160,7 +160,7 @@ def dictify(
if isinstance(as_dict, dict):
to_dict_result = {}
for k, v in as_dict.items():
- if k in REDACT_KEYS:
+ if should_redact(k):
to_dict_result[k] = REDACTED_VALUE
elif maxdepth == 0 or depth < maxdepth:
to_dict_result[k] = dictify(v, maxdepth, depth + 1)
@@ -187,7 +187,7 @@ def dictify(
for attr in dir(obj):
if attr.startswith("_"):
continue
- if attr in REDACT_KEYS:
+ if should_redact(attr):
result[attr] = REDACTED_VALUE
continue
try:
diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py
index 1d5d54b9b23c..87a445383b57 100644
--- a/weave/trace/weave_client.py
+++ b/weave/trace/weave_client.py
@@ -38,7 +38,7 @@
parse_op_uri,
parse_uri,
)
-from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE
+from weave.trace.sanitize import REDACTED_VALUE, should_redact
from weave.trace.serialize import from_json, isinstance_namedtuple, to_json
from weave.trace.serializer import get_serializer_for_obj
from weave.trace.settings import client_parallelism
@@ -1648,7 +1648,7 @@ def redact_sensitive_keys(obj: Any) -> Any:
if isinstance(obj, dict):
dict_res = {}
for k, v in obj.items():
- if k in REDACT_KEYS:
+ if should_redact(k):
dict_res[k] = REDACTED_VALUE
else:
dict_res[k] = redact_sensitive_keys(v)
diff --git a/weave/trace_server/async_batch_processor.py b/weave/trace_server/async_batch_processor.py
index 03071607aee0..a8a183d94bfe 100644
--- a/weave/trace_server/async_batch_processor.py
+++ b/weave/trace_server/async_batch_processor.py
@@ -6,7 +6,6 @@
from typing import Callable, Generic, TypeVar
from weave.trace.context.tests_context import get_raise_on_captured_errors
-from weave.trace_server import requests
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -61,14 +60,10 @@ def _process_batches(self) -> None:
if current_batch:
try:
self.processor_fn(current_batch)
- except requests.HTTPError as e:
- if e.response.status_code == 413:
- # 413: payload too large, don't raise just log
- if get_raise_on_captured_errors():
- raise
- logger.exception(f"Error processing batch: {e}")
- else:
- raise e
+ except Exception as e:
+ if get_raise_on_captured_errors():
+ raise
+ logger.exception(f"Error processing batch: {e}")
if self.stop_event.is_set() and self.queue.empty():
break