diff --git a/README.md b/README.md
index 530c6ff46..9c1f59f24 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers:
This works anywhere the IPython kernel runs (JupyterLab, Jupyter Notebook, Google Colab, Kaggle, VSCode, etc.).
* A native chat UI in JupyterLab that enables you to work with generative AI as a conversational assistant.
* Support for a wide range of generative model providers, including AI21, Anthropic, AWS, Cohere,
- Gemini, Hugging Face, NVIDIA, and OpenAI.
+ Gemini, Hugging Face, MistralAI, NVIDIA, and OpenAI.
* Local model support through GPT4All, enabling use of generative AI models on consumer grade machines
with ease and privacy.
diff --git a/docs/source/_static/fix-error-cell-selected.png b/docs/source/_static/fix-error-cell-selected.png
new file mode 100644
index 000000000..a8e2d0b82
Binary files /dev/null and b/docs/source/_static/fix-error-cell-selected.png differ
diff --git a/docs/source/_static/fix-no-error-cell-selected.png b/docs/source/_static/fix-no-error-cell-selected.png
new file mode 100644
index 000000000..1f29a77e5
Binary files /dev/null and b/docs/source/_static/fix-no-error-cell-selected.png differ
diff --git a/docs/source/_static/fix-response.png b/docs/source/_static/fix-response.png
new file mode 100644
index 000000000..0127af325
Binary files /dev/null and b/docs/source/_static/fix-response.png differ
diff --git a/docs/source/index.md b/docs/source/index.md
index a2c060f6a..d3d29164e 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -8,7 +8,7 @@ in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers:
This works anywhere the IPython kernel runs (JupyterLab, Jupyter Notebook, Google Colab, VSCode, etc.).
* A native chat UI in JupyterLab that enables you to work with generative AI as a conversational assistant.
* Support for a wide range of generative model providers and models
- (AI21, Anthropic, Cohere, Gemini, Hugging Face, OpenAI, SageMaker, NVIDIA, etc.).
+ (AI21, Anthropic, Cohere, Gemini, Hugging Face, MistralAI, OpenAI, SageMaker, NVIDIA, etc.).
` will export the chat history to `-YYYY-MM-DD-HH-mm.md` instead. You can export chat history as many times as you like in a single session. Each successive export will include the entire chat history up to that point in the session.
+### Fixing a code cell with an error
+
+The `/fix` command can be used to fix any code cell with an error output in a
+Jupyter notebook file. To start, type `/fix` into the chat input. Jupyter AI
+will then prompt you to select a cell with error output before sending the
+request.
+
+
+
+Then click on a code cell with error output. A blue bar should appear
+immediately to the left of the code cell.
+
+
+
+After this, the Send button to the right of the chat input will be enabled, and
+you can use your mouse or keyboard to send `/fix` to Jupyternaut. The code cell
+and its associated error output are included in the message automatically. When
+complete, Jupyternaut will reply with suggested code that should fix the error.
+You can use the action toolbar under each code block to quickly replace the
+contents of the failing cell.
+
+
+
+
### Additional chat commands
To clear the chat panel, use the `/clear` command. This does not reset the AI model; the model may still remember previous messages that you sent it, and it may use them to inform its responses.
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py
new file mode 100644
index 000000000..6d7f84006
--- /dev/null
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py
@@ -0,0 +1,32 @@
+from jupyter_ai_magics.providers import BaseProvider, EnvAuthStrategy
+from langchain_mistralai import ChatMistralAI, MistralAIEmbeddings
+
+from ..embedding_providers import BaseEmbeddingsProvider
+
+
+class MistralAIProvider(BaseProvider, ChatMistralAI):
+ id = "mistralai"
+ name = "MistralAI"
+ models = [
+ "open-mistral-7b",
+ "open-mixtral-8x7b",
+ "open-mixtral-8x22b",
+ "mistral-small-latest",
+ "mistral-medium-latest",
+ "mistral-large-latest",
+ "codestral-latest",
+ ]
+ model_id_key = "model"
+ auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY")
+ pypi_package_deps = ["langchain-mistralai"]
+
+
+class MistralAIEmbeddingsProvider(BaseEmbeddingsProvider, MistralAIEmbeddings):
+ id = "mistralai"
+ name = "MistralAI"
+ models = [
+ "mistral-embed",
+ ]
+ model_id_key = "model"
+ pypi_package_deps = ["langchain-mistralai"]
+ auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY")
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index bef7efacc..3ece92180 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -20,8 +20,6 @@
from jsonpath_ng import parse
from langchain.chat_models.base import BaseChatModel
-from langchain.llms.sagemaker_endpoint import LLMContentHandler
-from langchain.llms.utils import enforce_stop_tokens
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
@@ -44,6 +42,7 @@
SagemakerEndpoint,
Together,
)
+from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml
index 422dfd60b..941d17ba6 100644
--- a/packages/jupyter-ai-magics/pyproject.toml
+++ b/packages/jupyter-ai-magics/pyproject.toml
@@ -41,13 +41,14 @@ all = [
"huggingface_hub",
"ipywidgets",
"langchain_anthropic",
+ "langchain-mistralai",
"langchain_nvidia_ai_endpoints",
+ "langchain-google-genai",
"langchain-openai",
"pillow",
"boto3",
"qianfan",
"together",
- "langchain-google-genai",
]
[project.entry-points."jupyter_ai.model_providers"]
@@ -67,10 +68,12 @@ qianfan = "jupyter_ai_magics:QianfanProvider"
nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider"
together-ai = "jupyter_ai_magics:TogetherAIProvider"
gemini = "jupyter_ai_magics.partner_providers.gemini:GeminiProvider"
+mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIProvider"
[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
+mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
index b9046a6f8..a8fe9eb50 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
@@ -3,6 +3,7 @@
from .clear import ClearChatHandler
from .default import DefaultChatHandler
from .export import ExportChatHandler
+from .fix import FixChatHandler
from .generate import GenerateChatHandler
from .help import HelpChatHandler
from .learn import LearnChatHandler
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
index df288d409..3f936a142 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
@@ -43,19 +43,6 @@ def create_llm_chain(
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)
- def clear_memory(self):
- # clear chain memory
- if self.memory:
- self.memory.clear()
-
- # clear transcript for existing chat clients
- reply_message = ClearMessage()
- self.reply(reply_message)
-
- # clear transcript for new chat clients
- if self._chat_history:
- self._chat_history.clear()
-
async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"])
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
new file mode 100644
index 000000000..0f62e5681
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
@@ -0,0 +1,103 @@
+from typing import Dict, Type
+
+from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
+from jupyter_ai_magics.providers import BaseProvider
+from langchain.chains import LLMChain
+from langchain.prompts import PromptTemplate
+
+from .base import BaseChatHandler, SlashCommandRoutingType
+
+FIX_STRING_TEMPLATE = """
+You are Jupyternaut, a conversational assistant living in JupyterLab. Please fix
+the notebook cell described below.
+
+Additional instructions:
+
+{extra_instructions}
+
+Input cell:
+
+```
+{cell_content}
+```
+
+Output error:
+
+```
+{traceback}
+
+{error_name}: {error_value}
+```
+""".strip()
+
+FIX_PROMPT_TEMPLATE = PromptTemplate(
+ input_variables=[
+ "extra_instructions",
+ "cell_content",
+ "traceback",
+ "error_name",
+ "error_value",
+ ],
+ template=FIX_STRING_TEMPLATE,
+)
+
+
+class FixChatHandler(BaseChatHandler):
+ """
+ Accepts a `HumanChatMessage` that includes a cell with error output and
+ recommends a fix as a reply. If a cell with error output is not included,
+ this chat handler does nothing.
+
+ `/fix` also accepts additional instructions in natural language as an
+ arbitrary number of arguments, e.g.
+
+ ```
+ /fix use the numpy library to implement this function instead.
+ ```
+ """
+
+ id = "fix"
+ name = "Fix error cell"
+ help = "Fix an error cell selected in your notebook"
+ routing_type = SlashCommandRoutingType(slash_id="fix")
+ uses_llm = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def create_llm_chain(
+ self, provider: Type[BaseProvider], provider_params: Dict[str, str]
+ ):
+ unified_parameters = {
+ **provider_params,
+ **(self.get_model_parameters(provider, provider_params)),
+ }
+ llm = provider(**unified_parameters)
+
+ self.llm = llm
+ self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True)
+
+ async def process_message(self, message: HumanChatMessage):
+ if not (message.selection and message.selection.type == "cell-with-error"):
+ self.reply(
+ "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.",
+ message,
+ )
+ return
+
+ # hint type of selection
+ selection: CellWithErrorSelection = message.selection
+
+ # parse additional instructions specified after `/fix`
+ extra_instructions = message.body[4:].strip() or "None."
+
+ self.get_llm_chain()
+ response = await self.llm_chain.apredict(
+ extra_instructions=extra_instructions,
+ stop=["\nHuman:"],
+ cell_content=selection.source,
+ error_name=selection.error.name,
+ error_value=selection.error.value,
+ traceback="\n".join(selection.error.traceback),
+ )
+ self.reply(response, message)
diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py
index efb83b3db..848a6eed7 100644
--- a/packages/jupyter-ai/jupyter_ai/extension.py
+++ b/packages/jupyter-ai/jupyter_ai/extension.py
@@ -17,6 +17,7 @@
ClearChatHandler,
DefaultChatHandler,
ExportChatHandler,
+ FixChatHandler,
GenerateChatHandler,
HelpChatHandler,
LearnChatHandler,
@@ -264,6 +265,9 @@ def initialize_settings(self):
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
export_chat_handler = ExportChatHandler(**chat_handler_kwargs)
+
+ fix_chat_handler = FixChatHandler(**chat_handler_kwargs)
+
jai_chat_handlers = {
"default": default_chat_handler,
"/ask": ask_chat_handler,
@@ -271,6 +275,7 @@ def initialize_settings(self):
"/generate": generate_chat_handler,
"/learn": learn_chat_handler,
"/export": export_chat_handler,
+ "/fix": fix_chat_handler,
}
help_chat_handler = HelpChatHandler(
diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py
index cd3ebabde..339863c3f 100644
--- a/packages/jupyter-ai/jupyter_ai/handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/handlers.py
@@ -207,6 +207,7 @@ async def on_message(self, message):
id=chat_message_id,
time=time.time(),
body=chat_request.prompt,
+ selection=chat_request.selection,
client=self.chat_client,
)
diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py
index bdbbd81b6..84e2a524b 100644
--- a/packages/jupyter-ai/jupyter_ai/models.py
+++ b/packages/jupyter-ai/jupyter_ai/models.py
@@ -8,9 +8,29 @@
DEFAULT_CHUNK_OVERLAP = 100
+class CellError(BaseModel):
+ name: str
+ value: str
+ traceback: List[str]
+
+
+class CellWithErrorSelection(BaseModel):
+ type: Literal["cell-with-error"] = "cell-with-error"
+ source: str
+ error: CellError
+
+
+Selection = Union[CellWithErrorSelection]
+
+
# the type of message used to chat with the agent
class ChatRequest(BaseModel):
prompt: str
+ # TODO: This currently is only used when a user runs the /fix slash command.
+ # In the future, the frontend should set the text selection on this field in
+ # the `HumanChatMessage` it sends to JAI, instead of appending the text
+ # selection to `body` in the frontend.
+ selection: Optional[Selection]
class ChatUser(BaseModel):
@@ -55,6 +75,7 @@ class HumanChatMessage(BaseModel):
time: float
body: str
client: ChatClient
+ selection: Optional[Selection]
class ConnectionMessage(BaseModel):
diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx
index 1be099f77..2386fe873 100644
--- a/packages/jupyter-ai/src/components/chat-input.tsx
+++ b/packages/jupyter-ai/src/components/chat-input.tsx
@@ -9,11 +9,9 @@ import {
FormGroup,
FormControlLabel,
Checkbox,
- IconButton,
InputAdornment,
Typography
} from '@mui/material';
-import SendIcon from '@mui/icons-material/Send';
import {
Download,
FindInPage,
@@ -21,15 +19,18 @@ import {
MoreHoriz,
MenuBook,
School,
- HideSource
+ HideSource,
+ AutoFixNormal
} from '@mui/icons-material';
import { AiService } from '../handler';
+import { SendButton, SendButtonProps } from './chat-input/send-button';
+import { useActiveCellContext } from '../contexts/active-cell-context';
type ChatInputProps = {
value: string;
onChange: (newValue: string) => unknown;
- onSend: () => unknown;
+ onSend: (selection?: AiService.Selection) => unknown;
hasSelection: boolean;
includeSelection: boolean;
toggleIncludeSelection: () => unknown;
@@ -56,6 +57,7 @@ const DEFAULT_SLASH_COMMAND_ICONS: Record = {
ask: ,
clear: ,
export: ,
+ fix: ,
generate: ,
help: ,
learn: ,
@@ -101,6 +103,8 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
const [slashCommandOptions, setSlashCommandOptions] = useState<
SlashCommandOption[]
>([]);
+ const [currSlashCommand, setCurrSlashCommand] = useState(null);
+ const activeCell = useActiveCellContext();
/**
* Effect: fetch the list of available slash commands from the backend on
@@ -129,8 +133,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
/**
* Effect: Open the autocomplete when the user types a slash into an empty
- * chat input. Close the autocomplete and reset the last selected value when
- * the user clears the chat input.
+ * chat input. Close the autocomplete when the user clears the chat input.
*/
useEffect(() => {
if (props.value === '/') {
@@ -144,6 +147,35 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
}
}, [props.value]);
+ /**
+ * Effect: Set current slash command
+ */
+ useEffect(() => {
+ const matchedSlashCommand = props.value.match(/^\s*\/\w+/);
+ setCurrSlashCommand(matchedSlashCommand && matchedSlashCommand[0]);
+ }, [props.value]);
+
+ // TODO: unify the `onSend` implementation in `chat.tsx` and here once text
+ // selection is refactored.
+ function onSend() {
+ // case: /fix
+ if (currSlashCommand === '/fix') {
+ const cellWithError = activeCell.manager.getContent(true);
+ if (!cellWithError) {
+ return;
+ }
+
+ props.onSend({
+ ...cellWithError,
+ type: 'cell-with-error'
+ });
+ return;
+ }
+
+ // default case
+ props.onSend();
+ }
+
function handleKeyDown(event: React.KeyboardEvent) {
if (event.key !== 'Enter') {
return;
@@ -160,7 +192,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
((props.sendWithShiftEnter && event.shiftKey) ||
(!props.sendWithShiftEnter && !event.shiftKey))
) {
- props.onSend();
+ onSend();
event.stopPropagation();
event.preventDefault();
}
@@ -177,6 +209,15 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
);
+ const inputExists = !!props.value.trim();
+ const sendButtonProps: SendButtonProps = {
+ onSend,
+ sendWithShiftEnter: props.sendWithShiftEnter,
+ inputExists,
+ activeCellHasError: activeCell.hasError,
+ currSlashCommand
+ };
+
return (
-
-
-
+
)
}}
diff --git a/packages/jupyter-ai/src/components/chat-input/send-button.tsx b/packages/jupyter-ai/src/components/chat-input/send-button.tsx
new file mode 100644
index 000000000..6ab79a355
--- /dev/null
+++ b/packages/jupyter-ai/src/components/chat-input/send-button.tsx
@@ -0,0 +1,45 @@
+import React from 'react';
+import SendIcon from '@mui/icons-material/Send';
+
+import { TooltippedIconButton } from '../mui-extras/tooltipped-icon-button';
+
+export type SendButtonProps = {
+ onSend: () => unknown;
+ sendWithShiftEnter: boolean;
+ currSlashCommand: string | null;
+ inputExists: boolean;
+ activeCellHasError: boolean;
+};
+
+export function SendButton(props: SendButtonProps): JSX.Element {
+ const disabled =
+ props.currSlashCommand === '/fix'
+ ? !props.inputExists || !props.activeCellHasError
+ : !props.inputExists;
+
+ const defaultTooltip = props.sendWithShiftEnter
+ ? 'Send message (SHIFT+ENTER)'
+ : 'Send message (ENTER)';
+
+ const tooltip =
+ props.currSlashCommand === '/fix' && !props.activeCellHasError
+ ? '/fix requires a code cell with an error output selected'
+ : !props.inputExists
+ ? 'Message must not be empty'
+ : defaultTooltip;
+
+ return (
+ props.onSend()}
+ disabled={disabled}
+ tooltip={tooltip}
+ iconButtonProps={{
+ size: 'small',
+ color: 'primary',
+ title: defaultTooltip
+ }}
+ >
+
+
+ );
+}
diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx
index 8a2e4b658..0559387f2 100644
--- a/packages/jupyter-ai/src/components/chat-messages.tsx
+++ b/packages/jupyter-ai/src/components/chat-messages.tsx
@@ -137,20 +137,27 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element {
}
}}
>
- {props.messages.map((message, i) => (
- // extra div needed to ensure each bubble is on a new line
-
-
-
-
- ))}
+ {props.messages.map((message, i) => {
+ // render selection in HumanChatMessage, if any
+ const markdownStr =
+ message.type === 'human' && message.selection
+ ? message.body + '\n\n```\n' + message.selection.source + '\n```\n'
+ : message.body;
+
+ return (
+
+
+
+
+ );
+ })}
);
}
diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx
index 1ba8c717c..abf974054 100644
--- a/packages/jupyter-ai/src/components/chat.tsx
+++ b/packages/jupyter-ai/src/components/chat.tsx
@@ -42,7 +42,7 @@ function ChatBody({
const [includeSelection, setIncludeSelection] = useState(true);
const [replaceSelection, setReplaceSelection] = useState(false);
const [input, setInput] = useState('');
- const [selection, replaceSelectionFn] = useSelectionContext();
+ const [textSelection, replaceTextSelection] = useSelectionContext();
const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true);
/**
@@ -91,25 +91,26 @@ function ChatBody({
// no need to append to messageGroups imperatively here. all of that is
// handled by the listeners registered in the effect hooks above.
- const onSend = async () => {
+ // TODO: unify how text selection & cell selection are handled
+ const onSend = async (selection?: AiService.Selection) => {
setInput('');
const prompt =
input +
- (includeSelection && selection?.text
- ? '\n\n```\n' + selection.text + '\n```'
+ (includeSelection && textSelection?.text
+ ? '\n\n```\n' + textSelection.text + '\n```'
: '');
// send message to backend
- const messageId = await chatHandler.sendMessage({ prompt });
+ const messageId = await chatHandler.sendMessage({ prompt, selection });
// await reply from agent
// no need to append to messageGroups state variable, since that's already
// handled in the effect hooks.
const reply = await chatHandler.replyFor(messageId);
- if (replaceSelection && selection) {
- const { cellId, ...selectionProps } = selection;
- replaceSelectionFn({
+ if (replaceSelection && textSelection) {
+ const { cellId, ...selectionProps } = textSelection;
+ replaceTextSelection({
...selectionProps,
...(cellId && { cellId }),
text: reply.body
@@ -161,7 +162,7 @@ function ChatBody({
value={input}
onChange={setInput}
onSend={onSend}
- hasSelection={!!selection?.text}
+ hasSelection={!!textSelection?.text}
includeSelection={includeSelection}
toggleIncludeSelection={() =>
setIncludeSelection(includeSelection => !includeSelection)
diff --git a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx
index 9045fbcdd..08a9a3836 100644
--- a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx
+++ b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx
@@ -19,11 +19,11 @@ export type CodeToolbarProps = {
};
export function CodeToolbar(props: CodeToolbarProps): JSX.Element {
- const [activeCellExists, activeCellManager] = useActiveCellContext();
+ const activeCell = useActiveCellContext();
const sharedToolbarButtonProps = {
content: props.content,
- activeCellManager,
- activeCellExists
+ activeCellManager: activeCell.manager,
+ activeCellExists: activeCell.exists
};
return (
diff --git a/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx b/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx
index 7968e9e43..6f97c7462 100644
--- a/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx
+++ b/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx
@@ -1,5 +1,5 @@
import React from 'react';
-import { IconButton, TooltipProps } from '@mui/material';
+import { IconButton, IconButtonProps, TooltipProps } from '@mui/material';
import { ContrastingTooltip } from './contrasting-tooltip';
@@ -17,6 +17,10 @@ export type TooltippedIconButtonProps = {
*/
offset?: [number, number];
'aria-label'?: string;
+ /**
+ * Props passed directly to the MUI `IconButton` component.
+ */
+ iconButtonProps?: IconButtonProps;
};
/**
@@ -60,6 +64,7 @@ export function TooltippedIconButton(
*/}
{
+ this._pollActiveCell();
+ }, 200);
}
get activeCellChanged(): Signal {
return this._activeCellChanged;
}
+ get activeCellErrorChanged(): Signal {
+ return this._activeCellErrorChanged;
+ }
+
+ /**
+ * Returns an `ActiveCellContent` object that describes the current active
+ * cell. If no active cell exists, this method returns `null`.
+ *
+ * When called with `withError = true`, this method returns `null` if the
+ * active cell does not have an error output. Otherwise it returns an
+ * `ActiveCellContentWithError` object that describes both the active cell and
+ * the error output.
+ */
+ getContent(withError: false): CellContent | null;
+ getContent(withError: true): CellWithErrorContent | null;
+ getContent(withError = false): CellContent | CellWithErrorContent | null {
+ const sharedModel = this._activeCell?.model.sharedModel;
+ if (!sharedModel) {
+ return null;
+ }
+
+ // case where withError = false
+ if (!withError) {
+ return {
+ type: sharedModel.cell_type,
+ source: sharedModel.getSource()
+ };
+ }
+
+ // case where withError = true
+ const error = this._activeCellError;
+ if (error) {
+ return {
+ type: 'code',
+ source: sharedModel.getSource(),
+ error: {
+ name: error.ename,
+ value: error.evalue,
+ traceback: error.traceback
+ }
+ };
+ }
+
+ return null;
+ }
+
/**
* Inserts `content` in a new cell above the active cell.
*/
@@ -64,7 +128,7 @@ export class ActiveCellManager {
// create a new cell above the active cell and mark new cell as active
NotebookActions.insertAbove(notebook);
// emit activeCellChanged event to consumers
- this._updateActiveCell();
+ this._pollActiveCell();
// replace content of this new active cell
this.replace(content);
}
@@ -81,7 +145,7 @@ export class ActiveCellManager {
// create a new cell below the active cell and mark new cell as active
NotebookActions.insertBelow(notebook);
// emit activeCellChanged event to consumers
- this._updateActiveCell();
+ this._pollActiveCell();
// replace content of this new active cell
this.replace(content);
}
@@ -116,27 +180,94 @@ export class ActiveCellManager {
activeCell.editor?.model.sharedModel.setSource(content);
}
- protected _updateActiveCell(): void {
+ protected _pollActiveCell(): void {
const prevActiveCell = this._activeCell;
const currActiveCell = getActiveCell(this._mainAreaWidget);
- if (prevActiveCell === currActiveCell) {
- return;
+ // emit activeCellChanged when active cell changes
+ if (prevActiveCell !== currActiveCell) {
+ this._activeCell = currActiveCell;
+ this._activeCellChanged.emit(currActiveCell);
}
- this._activeCell = currActiveCell;
- this._activeCellChanged.emit(currActiveCell);
+ const currSharedModel = currActiveCell?.model.sharedModel;
+ const prevExecutionCount = this._activeCellExecutionCount;
+ const currExecutionCount: number | null =
+ currSharedModel && 'execution_count' in currSharedModel
+ ? currSharedModel?.execution_count
+ : null;
+ this._activeCellExecutionCount = currExecutionCount;
+
+ // emit activeCellErrorChanged when active cell changes or when the
+ // execution count changes
+ if (
+ prevActiveCell !== currActiveCell ||
+ prevExecutionCount !== currExecutionCount
+ ) {
+ const prevActiveCellError = this._activeCellError;
+ let currActiveCellError: CellError | null = null;
+ if (currSharedModel && 'outputs' in currSharedModel) {
+ currActiveCellError =
+ currSharedModel.outputs.find(
+ (output): output is CellError => output.output_type === 'error'
+ ) || null;
+ }
+
+ // for some reason, the `CellError` object is not referentially stable,
+ // meaning that this condition always evaluates to `true` and the
+ // `activeCellErrorChanged` signal is emitted every 200ms, even when the
+ // error output is unchanged. this is why we have to rely on
+ // `execution_count` to track changes to the error output.
+ if (prevActiveCellError !== currActiveCellError) {
+ this._activeCellError = currActiveCellError;
+ this._activeCellErrorChanged.emit(this._activeCellError);
+ }
+ }
}
protected _shell: JupyterFrontEnd.IShell;
protected _mainAreaWidget: Widget | null = null;
+
+ /**
+ * The active cell.
+ */
protected _activeCell: Cell | null = null;
+ /**
+ * The execution count of the active cell. This is the number shown on the
+ * left in square brackets after running a cell. Changes to this indicate that
+ * the error output may have changed.
+ */
+ protected _activeCellExecutionCount: number | null = null;
+ /**
+ * The `CellError` output within the active cell, if any.
+ */
+ protected _activeCellError: CellError | null = null;
+
protected _activeCellChanged = new Signal(this);
+ protected _activeCellErrorChanged = new Signal(this);
}
-const ActiveCellContext = React.createContext<
- [boolean, ActiveCellManager | null]
->([false, null]);
+type ActiveCellContextReturn = {
+ exists: boolean;
+ hasError: boolean;
+ manager: ActiveCellManager;
+};
+
+type ActiveCellContextValue = {
+ exists: boolean;
+ hasError: boolean;
+ manager: ActiveCellManager | null;
+};
+
+const defaultActiveCellContext: ActiveCellContextValue = {
+ exists: false,
+ hasError: false,
+ manager: null
+};
+
+const ActiveCellContext = React.createContext(
+ defaultActiveCellContext
+);
type ActiveCellContextProps = {
activeCellManager: ActiveCellManager;
@@ -146,17 +277,27 @@ type ActiveCellContextProps = {
export function ActiveCellContextProvider(
props: ActiveCellContextProps
): JSX.Element {
- const [activeCellExists, setActiveCellExists] = useState(false);
+ const [exists, setExists] = useState(false);
+ const [hasError, setHasError] = useState(false);
useEffect(() => {
- props.activeCellManager.activeCellChanged.connect((_, newActiveCell) => {
- setActiveCellExists(!!newActiveCell);
+ const manager = props.activeCellManager;
+
+ manager.activeCellChanged.connect((_, newActiveCell) => {
+ setExists(!!newActiveCell);
+ });
+ manager.activeCellErrorChanged.connect((_, newActiveCellError) => {
+ setHasError(!!newActiveCellError);
});
}, [props.activeCellManager]);
return (
{props.children}
@@ -164,16 +305,25 @@ export function ActiveCellContextProvider(
}
/**
- * Hook that returns the two-tuple `[activeCellExists, activeCellManager]`.
+ * Usage: `const activeCell = useActiveCellContext()`
+ *
+ * Returns an object `activeCell` with the following properties:
+ * - `activeCell.exists`: whether an active cell exists
+ * - `activeCell.hasError`: whether an active cell exists with an error output
+ * - `activeCell.manager`: the `ActiveCellManager` singleton
*/
-export function useActiveCellContext(): [boolean, ActiveCellManager] {
- const [activeCellExists, activeCellManager] = useContext(ActiveCellContext);
+export function useActiveCellContext(): ActiveCellContextReturn {
+ const { exists, hasError, manager } = useContext(ActiveCellContext);
- if (!activeCellManager) {
+ if (!manager) {
throw new Error(
'useActiveCellContext() cannot be called outside ActiveCellContextProvider.'
);
}
- return [activeCellExists, activeCellManager];
+ return {
+ exists,
+ hasError,
+ manager
+ };
}
diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts
index 4db2e85ec..5d06691fe 100644
--- a/packages/jupyter-ai/src/handler.ts
+++ b/packages/jupyter-ai/src/handler.ts
@@ -51,8 +51,23 @@ export namespace AiService {
serverSettings?: ServerConnection.ISettings;
}
+ export type CellError = {
+ name: string;
+ value: string;
+ traceback: string[];
+ };
+
+ export type CellWithErrorSelection = {
+ type: 'cell-with-error';
+ source: string;
+ error: CellError;
+ };
+
+ export type Selection = CellWithErrorSelection;
+
export type ChatRequest = {
prompt: string;
+ selection?: Selection;
};
export type Collaborator = {
@@ -88,6 +103,7 @@ export namespace AiService {
time: number;
body: string;
client: ChatClient;
+ selection?: Selection;
};
export type ConnectionMessage = {