From 7955439bbc4f7c9117117d2360bbbca39d027c97 Mon Sep 17 00:00:00 2001 From: Olivier LEVILLAIN Date: Tue, 10 Sep 2024 19:16:32 +0200 Subject: [PATCH 1/2] Implement a client for Grok Fixes #3 NOT TESTED YET, CANNOT GET GROK API KEY --- .env.example | 2 ++ pyproject.toml | 1 + src/auto_po_lyglot/clients/grok_client.py | 21 +++++++++++++++++++++ src/auto_po_lyglot/getenv.py | 4 +++- 4 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 src/auto_po_lyglot/clients/grok_client.py diff --git a/.env.example b/.env.example index 4099bfc..aac94d6 100644 --- a/.env.example +++ b/.env.example @@ -30,6 +30,8 @@ CONTEXT_LANGUAGE=French OPENAI_API_KEY=[your API key goes here] # for Claude models, set: ANTHROPIC_API_KEY=[your API key goes here] +# for Grok models, set: +XAI_API_KEY=[your API key goes here] # for Gemini models, set: GEMINI_API_KEY=[your API key goes here] diff --git a/pyproject.toml b/pyproject.toml index 11c900d..099156b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies=[ "langcodes>=3.4.0", "openai>=1.12.0", "anthropic>=0.34.1", + "xai-sdk>=0.3.0", "google-generativeai>=0.7.2", ] classifiers = [ diff --git a/src/auto_po_lyglot/clients/grok_client.py b/src/auto_po_lyglot/clients/grok_client.py new file mode 100644 index 0000000..5cea623 --- /dev/null +++ b/src/auto_po_lyglot/clients/grok_client.py @@ -0,0 +1,21 @@ +import xai_sdk +import asyncio +from .client_base import TranspoClient +import logging + +logger = logging.getLogger(__name__) + + +class GrokClient(TranspoClient): + def __init__(self, params, target_language=None): + params.model = params.model or "" # default model given by Grok itself if not provided + super().__init__(params, target_language) + self.client = xai_sdk.Client(api_key=params.xai_api_key) if hasattr(params, 'xai_api_key') else xai_sdk.Client() + + async def async_get_translation(self, system_prompt, user_prompt): + conversation = self.client.chat.create_conversation() + response = await conversation.add_response_no_stream(f'{system_prompt}\n{user_prompt}\n') + return response.message + + def get_translation(self, system_prompt, user_prompt): + return asyncio.run(self.async_get_translation(system_prompt, user_prompt)) diff --git a/src/auto_po_lyglot/getenv.py b/src/auto_po_lyglot/getenv.py index 8e3a137..2bedfef 100755 --- a/src/auto_po_lyglot/getenv.py +++ b/src/auto_po_lyglot/getenv.py @@ -60,7 +60,7 @@ def parse_args(self, additional_args=None): help='Le type of LLM you want to use. Can be openai, ollama, claude or claude_cached. ' 'For openai or claude[_cached], you need to set the api key in the environment. ' 'Supersedes LLM_CLIENT in .env. Default is ollama', - choices=['openai', 'ollama', 'claude', 'claude_cached', 'gemini']) + choices=['openai', 'ollama', 'claude', 'claude_cached', 'gemini', 'grok']) parser.add_argument('-m', '--model', type=str, help='the name of the model to use. Supersedes LLM_MODEL in .env. If not provided at all, ' @@ -184,6 +184,8 @@ def get_client(self): from .clients.claude_client import CachedClaudeClient as LLMClient case 'gemini': from .clients.gemini_client import GeminiClient as LLMClient + case 'grok': + from .clients.grok_client import GrokClient as LLMClient case _: raise Exception( f"LLM_CLIENT must be one of 'ollama', 'openai', 'claude' or 'claude_cached', not '{self.llm_client}'" From 593d634be044767cf30480424546f4b968f2682f Mon Sep 17 00:00:00 2001 From: Olivier LEVILLAIN Date: Tue, 10 Sep 2024 19:17:19 +0200 Subject: [PATCH 2/2] log errors as errors, not infos --- src/auto_po_lyglot/po_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auto_po_lyglot/po_main.py b/src/auto_po_lyglot/po_main.py index 156487c..eda8819 100755 --- a/src/auto_po_lyglot/po_main.py +++ b/src/auto_po_lyglot/po_main.py @@ -118,7 +118,7 @@ def main(): sleep(1.0) # Sleep for 1 second to avoid rate limiting nb_translations += 1 except Exception as e: - logger.info(f"Error: {e}") + logger.error(f"Error: {e}") # Save the new .po file even if there was an error to not lose what was translated po.save(output_file) percent_translated = round(nb_translations / len(po) * 100, 2)