Skip to content

Commit

Permalink
LLM interface & vertexai: add response_schema support, add location p…
Browse files Browse the repository at this point in the history
…arameter and fix some bugs (#3268)

* Add response_schema to vertexai and fix multiple bugs
* Update timesketch.conf
  • Loading branch information
itsmvd authored Jan 24, 2025
1 parent b542cee commit d1690f5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
7 changes: 4 additions & 3 deletions data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ LLM_PROVIDER_CONFIGS = {
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file.
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
#
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
Expand All @@ -372,7 +374,7 @@ LLM_PROVIDER_CONFIGS = {
'project_id': '',
},
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# pip install google-generativeai
# pip3 install google-generativeai
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
Expand All @@ -384,4 +386,3 @@ DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv'
PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q'
EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q'
LLM_PROVIDER = ''

6 changes: 5 additions & 1 deletion timesketch/lib/llms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 0.1
DEFAULT_TOP_K = 0
DEFAULT_TOP_K = 1
DEFAULT_MAX_OUTPUT_TOKENS = 2048
DEFAULT_STREAM = False
DEFAULT_LOCATION = None


class LLMProvider:
Expand All @@ -37,6 +38,7 @@ def __init__(
top_k: int = DEFAULT_TOP_K,
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS,
stream: bool = DEFAULT_STREAM,
location: Optional[str] = DEFAULT_LOCATION,
):
"""Initialize the LLM provider.
Expand All @@ -46,6 +48,7 @@ def __init__(
top_k: The top_k to use for the response.
max_output_tokens: The maximum number of output tokens to generate.
stream: Whether to stream the response.
location: The cloud location/region to use for the provider.
Attributes:
config: The configuration for the LLM provider.
Expand All @@ -59,6 +62,7 @@ def __init__(
config["top_k"] = top_k
config["max_output_tokens"] = max_output_tokens
config["stream"] = stream
config["location"] = location

# Load the LLM provider config from the Flask app config
config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get(
Expand Down
45 changes: 36 additions & 9 deletions timesketch/lib/llms/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google Inc. All rights reserved.
# Copyright 2025 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,9 @@
# limitations under the License.
"""Vertex AI LLM provider."""

import json
from typing import Optional

from timesketch.lib.llms import interface
from timesketch.lib.llms import manager

Expand All @@ -21,6 +24,7 @@
try:
from google.cloud import aiplatform
from vertexai.preview.generative_models import GenerativeModel
from vertexai.preview.generative_models import GenerationConfig
except ImportError:
has_required_deps = False

Expand All @@ -30,31 +34,54 @@ class VertexAI(interface.LLMProvider):

NAME = "vertexai"

def generate(self, prompt: str) -> str:
def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:
"""
Generate text using the Vertex AI service.
Args:
prompt: The prompt to use for the generation.
temperature: The temperature to use for the generation.
stream: Whether to stream the generation or not.
response_schema: An optional JSON schema to define the expected
response format.
Returns:
The generated text as a string.
The generated text as a string (or parsed data if
response_schema is provided).
"""
aiplatform.init(
project=self.config.get("project_id"),
location=self.config.get("location"),
)
model = GenerativeModel(self.config.get("model"))

if response_schema:
generation_config = GenerationConfig(
temperature=self.config.get("temperature"),
top_k=self.config.get("top_k"),
top_p=self.config.get("top_p"),
response_mime_type="application/json",
response_schema=response_schema,
)
else:
generation_config = GenerationConfig(
temperature=self.config.get("temperature"),
top_k=self.config.get("top_k"),
top_p=self.config.get("top_p"),
)

response = model.generate_content(
prompt,
generation_config={
"max_output_tokens": self.config.get("max_output_tokens"),
"temperature": self.config.get("temperature"),
},
generation_config=generation_config,
stream=self.config.get("stream"),
)

if response_schema:
try:
return json.loads(response.text)
except Exception as error:
raise ValueError(
f"Error JSON parsing text: {response.text}: {error}"
) from error

return response.text


Expand Down

0 comments on commit d1690f5

Please sign in to comment.