Skip to content

Commit

Permalink
Update dependencies (#8)
Browse files Browse the repository at this point in the history
* updated outdated packages (major versions)

* Updated package version

* Removed jupyter notebook results that were saved

* Made anthropic utilities work for tests when no api key is present
  • Loading branch information
CodexVeritas authored Dec 31, 2024
1 parent c3678ed commit a606f8a
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 1,889 deletions.
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@
"--profile",
"black"
],
"jupyter.debugJustMyCode": true,
"debugpy.debugJustMyCode": true,
}
12 changes: 12 additions & 0 deletions code_tests/unit_tests/test_ai_models/test_time_limited_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from unittest.mock import Mock

import pytest
Expand All @@ -12,6 +13,9 @@
from forecasting_tools.ai_models.basic_model_interfaces.time_limited_model import (
TimeLimitedModel,
)
from forecasting_tools.ai_models.model_archetypes.anthropic_text_model import (
AnthropicTextToTextModel,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,6 +51,14 @@ def test_ai_model_does_not_time_out_when_run_time_less_than_timeout_time(
if not issubclass(subclass, TimeLimitedModel):
raise ValueError(TIME_LIMITED_ERROR_MESSAGE)

if (
issubclass(subclass, AnthropicTextToTextModel)
and os.getenv("ANTHROPIC_API_KEY") is None
):
pytest.skip(
"Skipping test for AnthropicTextModel since API key is not set and is needed for token counting"
)

AiModelMockManager.mock_ai_model_direct_call_with_predefined_mock_value(
mocker, subclass
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from abc import ABC

from langchain_anthropic import ChatAnthropic
Expand All @@ -7,6 +8,7 @@
_get_anthropic_claude_token_cost,
)
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from pydantic import SecretStr

from forecasting_tools.ai_models.ai_utils.response_types import (
TextTokenCostResponse,
Expand All @@ -19,6 +21,12 @@


class AnthropicTextToTextModel(TraditionalOnlineLlm, ABC):
API_KEY_MISSING = True if os.getenv("ANTHROPIC_API_KEY") is None else False
ANTHROPIC_API_KEY = SecretStr(
os.getenv("ANTHROPIC_API_KEY") # type: ignore
if not API_KEY_MISSING
else "fake-api-key-so-tests-dont-fail-to-initialize"
)

async def invoke(self, prompt: str) -> str:
response: TextTokenCostResponse = (
Expand Down Expand Up @@ -46,6 +54,7 @@ async def _call_online_model_using_api(
timeout=None,
stop=None,
base_url=None,
api_key=self.ANTHROPIC_API_KEY,
)
messages = self._turn_model_input_into_messages(prompt)
answer_message = await anthropic_llm.ainvoke(messages)
Expand Down Expand Up @@ -91,16 +100,23 @@ def _get_mock_return_for_direct_call_to_model_using_cheap_input(
probable_output = "Hello! How can I assist you today? Feel free to ask any questions or let me know if you need help with anything."

model = cls()
prompt_tokens = model.input_to_tokens(cheap_input)
prompt_tokens = (
model.input_to_tokens(cheap_input)
if not cls.API_KEY_MISSING
else 13
)
anthropic_llm = ChatAnthropic(
model_name=model.MODEL_NAME,
timeout=None,
stop=None,
base_url=None,
api_key=cls.ANTHROPIC_API_KEY,
)
completion_tokens = (
anthropic_llm.get_num_tokens(probable_output)
if not cls.API_KEY_MISSING
else 26
)
completion_tokens = anthropic_llm.get_num_tokens(probable_output)
adjustment = 9 # Through manual experimentation, it was found that the number of tokens returned by the API is 9 off for the completion response
completion_tokens += adjustment
total_cost = model.calculate_cost_from_tokens(
prompt_tkns=prompt_tokens, completion_tkns=completion_tokens
)
Expand All @@ -126,16 +142,10 @@ def input_to_tokens(self, prompt: str) -> int:
timeout=None,
stop=None,
base_url=None,
api_key=self.ANTHROPIC_API_KEY,
)
messages = self._turn_model_input_into_messages(prompt)
tokens = llm.get_num_tokens_from_messages(messages)
adjustment = 0
for message in messages:
if isinstance(message, HumanMessage):
adjustment += 5 # Through manual experimentation, it was found that the number of tokens returned by the API is 5 off for user messages
if isinstance(message, SystemMessage):
adjustment -= 2 # Through manual experimentation, it was found that the number of tokens returned by the API is 2 off for system messages
tokens += adjustment
return tokens

def calculate_cost_from_tokens(
Expand Down
Loading

0 comments on commit a606f8a

Please sign in to comment.