-
-
Notifications
You must be signed in to change notification settings - Fork 362
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
521 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from jupyter_ai_magics.providers import BaseProvider, EnvAuthStrategy | ||
from langchain_mistralai import ChatMistralAI, MistralAIEmbeddings | ||
|
||
from ..embedding_providers import BaseEmbeddingsProvider | ||
|
||
|
||
class MistralAIProvider(BaseProvider, ChatMistralAI): | ||
id = "mistralai" | ||
name = "MistralAI" | ||
models = [ | ||
"open-mistral-7b", | ||
"open-mixtral-8x7b", | ||
"open-mixtral-8x22b", | ||
"mistral-small-latest", | ||
"mistral-medium-latest", | ||
"mistral-large-latest", | ||
"codestral-latest", | ||
] | ||
model_id_key = "model" | ||
auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY") | ||
pypi_package_deps = ["langchain-mistralai"] | ||
|
||
|
||
class MistralAIEmbeddingsProvider(BaseEmbeddingsProvider, MistralAIEmbeddings): | ||
id = "mistralai" | ||
name = "MistralAI" | ||
models = [ | ||
"mistral-embed", | ||
] | ||
model_id_key = "model" | ||
pypi_package_deps = ["langchain-mistralai"] | ||
auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import Dict, Type | ||
|
||
from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage | ||
from jupyter_ai_magics.providers import BaseProvider | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
|
||
from .base import BaseChatHandler, SlashCommandRoutingType | ||
|
||
FIX_STRING_TEMPLATE = """ | ||
You are Jupyternaut, a conversational assistant living in JupyterLab. Please fix | ||
the notebook cell described below. | ||
Additional instructions: | ||
{extra_instructions} | ||
Input cell: | ||
``` | ||
{cell_content} | ||
``` | ||
Output error: | ||
``` | ||
{traceback} | ||
{error_name}: {error_value} | ||
``` | ||
""".strip() | ||
|
||
FIX_PROMPT_TEMPLATE = PromptTemplate( | ||
input_variables=[ | ||
"extra_instructions", | ||
"cell_content", | ||
"traceback", | ||
"error_name", | ||
"error_value", | ||
], | ||
template=FIX_STRING_TEMPLATE, | ||
) | ||
|
||
|
||
class FixChatHandler(BaseChatHandler): | ||
""" | ||
Accepts a `HumanChatMessage` that includes a cell with error output and | ||
recommends a fix as a reply. If a cell with error output is not included, | ||
this chat handler does nothing. | ||
`/fix` also accepts additional instructions in natural language as an | ||
arbitrary number of arguments, e.g. | ||
``` | ||
/fix use the numpy library to implement this function instead. | ||
``` | ||
""" | ||
|
||
id = "fix" | ||
name = "Fix error cell" | ||
help = "Fix an error cell selected in your notebook" | ||
routing_type = SlashCommandRoutingType(slash_id="fix") | ||
uses_llm = True | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def create_llm_chain( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
unified_parameters = { | ||
**provider_params, | ||
**(self.get_model_parameters(provider, provider_params)), | ||
} | ||
llm = provider(**unified_parameters) | ||
|
||
self.llm = llm | ||
self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True) | ||
|
||
async def process_message(self, message: HumanChatMessage): | ||
if not (message.selection and message.selection.type == "cell-with-error"): | ||
self.reply( | ||
"`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", | ||
message, | ||
) | ||
return | ||
|
||
# hint type of selection | ||
selection: CellWithErrorSelection = message.selection | ||
|
||
# parse additional instructions specified after `/fix` | ||
extra_instructions = message.body[4:].strip() or "None." | ||
|
||
self.get_llm_chain() | ||
response = await self.llm_chain.apredict( | ||
extra_instructions=extra_instructions, | ||
stop=["\nHuman:"], | ||
cell_content=selection.source, | ||
error_name=selection.error.name, | ||
error_value=selection.error.value, | ||
traceback="\n".join(selection.error.traceback), | ||
) | ||
self.reply(response, message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.