Skip to content

Commit

Permalink
Merge pull request #7 from leolivier:leolivier/issue3
Browse files Browse the repository at this point in the history
Leolivier/issue3
  • Loading branch information
leolivier authored Sep 10, 2024
2 parents b2fea56 + 593d634 commit 7e1af28
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ CONTEXT_LANGUAGE=French
OPENAI_API_KEY=[your API key goes here]
# for Claude models, set:
ANTHROPIC_API_KEY=[your API key goes here]
# for Grok models, set:
XAI_API_KEY=[your API key goes here]
# for Gemini models, set:
GEMINI_API_KEY=[your API key goes here]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies=[
"langcodes>=3.4.0",
"openai>=1.12.0",
"anthropic>=0.34.1",
"xai-sdk>=0.3.0",
"google-generativeai>=0.7.2",
]
classifiers = [
Expand Down
21 changes: 21 additions & 0 deletions src/auto_po_lyglot/clients/grok_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import xai_sdk
import asyncio
from .client_base import TranspoClient
import logging

logger = logging.getLogger(__name__)


class GrokClient(TranspoClient):
def __init__(self, params, target_language=None):
params.model = params.model or "" # default model given by Grok itself if not provided
super().__init__(params, target_language)
self.client = xai_sdk.Client(api_key=params.xai_api_key) if hasattr(params, 'xai_api_key') else xai_sdk.Client()

async def async_get_translation(self, system_prompt, user_prompt):
conversation = self.client.chat.create_conversation()
response = await conversation.add_response_no_stream(f'{system_prompt}\n{user_prompt}\n')
return response.message

def get_translation(self, system_prompt, user_prompt):
return asyncio.run(self.async_get_translation(system_prompt, user_prompt))
4 changes: 3 additions & 1 deletion src/auto_po_lyglot/getenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse_args(self, additional_args=None):
help='Le type of LLM you want to use. Can be openai, ollama, claude or claude_cached. '
'For openai or claude[_cached], you need to set the api key in the environment. '
'Supersedes LLM_CLIENT in .env. Default is ollama',
choices=['openai', 'ollama', 'claude', 'claude_cached', 'gemini'])
choices=['openai', 'ollama', 'claude', 'claude_cached', 'gemini', 'grok'])
parser.add_argument('-m', '--model',
type=str,
help='the name of the model to use. Supersedes LLM_MODEL in .env. If not provided at all, '
Expand Down Expand Up @@ -184,6 +184,8 @@ def get_client(self):
from .clients.claude_client import CachedClaudeClient as LLMClient
case 'gemini':
from .clients.gemini_client import GeminiClient as LLMClient
case 'grok':
from .clients.grok_client import GrokClient as LLMClient
case _:
raise Exception(
f"LLM_CLIENT must be one of 'ollama', 'openai', 'claude' or 'claude_cached', not '{self.llm_client}'"
Expand Down
2 changes: 1 addition & 1 deletion src/auto_po_lyglot/po_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def main():
sleep(1.0) # Sleep for 1 second to avoid rate limiting
nb_translations += 1
except Exception as e:
logger.info(f"Error: {e}")
logger.error(f"Error: {e}")
# Save the new .po file even if there was an error to not lose what was translated
po.save(output_file)
percent_translated = round(nb_translations / len(po) * 100, 2)
Expand Down

0 comments on commit 7e1af28

Please sign in to comment.