Skip to content

Commit

Permalink
add minimal test case
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-treebeard committed Oct 24, 2023
1 parent f5f883c commit b08c3d6
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 104 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ temperatures, etc. See [Langchain's API ref](https://api.python.langchain.com/en
Note! This is a confusing use case -- change it to something relevant to your work.
```yaml
task: |
task: |
Create a hello world notebook 'x.ipynb', use nbmake's NotebookRun class to test it from a Python application
steps:
- Create a hello world notebook using nbformat
Expand Down Expand Up @@ -132,3 +132,6 @@ using Phoenix.
```
1. In another termianl, run nbwrite with the following var set: `export NBWRITE_PHOENIX_TRACE=1`
1. Check the phoenix traces in the dashboard (default http://127.0.0.1:6060/)


## TODO make phoenix optional, fix empty packages bug
136 changes: 70 additions & 66 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ nbwrite = "nbwrite.cli:cli"
[tool.poetry.dependencies]
pytest = ">=6.1.0"
python = ">=3.8.1,<3.12"
openai = {extras = ["embeddings"], version = "^0.28.1"}
openai = {optional = true, extras = ["embeddings"], version = "^0.28.1"}
pathlib = "^1.0.1"
click = "^8.1.7"
langchain = {optional = false, version = "^0.0.312", extras = ["openai"]}
langchain = {version = "^0.0.312", extras = ["openai"]}
nbformat = "^5.9.2"
chromadb = "^0.4.14"
pysqlite3-binary = {version = "^0.5.2", markers = "sys_platform == 'linux'"}
arize-phoenix = {extras = ["experimental"], version = "^0.0.49"}
chromadb = {optional = true, version = "^0.4.14"}
pysqlite3-binary = {optional = true, version = "^0.5.2", markers = "sys_platform == 'linux'"}
arize-phoenix = {optional = true, extras = ["experimental"], version = "^0.0.49"}
python-dotenv = "^1.0.0"
pyyaml = "^6.0.1"
rich = "^13.6.0"

[tool.poetry.extras]
tracing = ["arize-phoenix"]
rag = ["chromadb", "pysqlite3-binary", "openai"]

[tool.poetry.dev-dependencies]
pytest = "^7.1.0"
pre-commit = "^2.8.2"
Expand Down
8 changes: 4 additions & 4 deletions src/nbwrite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GenerationConfig(BaseModel):

class Config(BaseModel):
task: str
steps: List[str]
packages: List[str]
out: str
generation: GenerationConfig
steps: List[str] = []
packages: List[str] = []
out: str = "nbwrite-out"
generation: GenerationConfig = GenerationConfig()
14 changes: 6 additions & 8 deletions src/nbwrite/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib.util
import os
import platform
from pathlib import Path
from typing import Any, Dict, List
Expand All @@ -10,13 +9,6 @@
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

if platform.system() == "Linux":
# https://docs.trychroma.com/troubleshooting#sqlite
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")


def get_embeddings():
return OpenAIEmbeddings(disallowed_special=())
Expand All @@ -27,6 +19,12 @@ def create_index(
retriever_kwargs: Dict[str, Any],
text_splitter_kwargs: Dict[str, Any],
):
if platform.system() == "Linux":
# https://docs.trychroma.com/troubleshooting#sqlite
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, **text_splitter_kwargs
Expand Down
57 changes: 37 additions & 20 deletions src/nbwrite/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from langchain.schema import format_document
from langchain.schema.messages import SystemMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel
from langchain.schema.runnable import RunnableLambda, RunnableParallel
from nbformat.v4 import (
new_code_cell,
new_markdown_cell,
Expand All @@ -42,13 +42,19 @@ def gen(

if os.getenv("NBWRITE_PHOENIX_TRACE"):
click.echo("Enabling Phoenix Trace")
from phoenix.trace.langchain import (
LangChainInstrumentor,
OpenInferenceTracer,
)

tracer = OpenInferenceTracer()
LangChainInstrumentor(tracer).instrument()
try:
from phoenix.trace.langchain import (
LangChainInstrumentor,
OpenInferenceTracer,
)

tracer = OpenInferenceTracer()
LangChainInstrumentor(tracer).instrument()
except ModuleNotFoundError:
click.echo(
"In order to use Phoenix Tracing you must `pip install 'nbwrite[tracing]'"
)
exit(1)

prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -57,22 +63,26 @@ def gen(
]
)

retriever = create_index(
config.packages,
config.generation.retriever_kwargs,
config.generation.text_splitter_kwargs,
)

def _combine_documents(
docs, document_prompt=PromptTemplate.from_template(template="{page_content}"), document_separator="\n\n" # type: ignore
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
if len(config.packages) > 0:
try:
retriever = create_index(
config.packages,
config.generation.retriever_kwargs,
config.generation.text_splitter_kwargs,
)
context_chain = itemgetter("task") | retriever | _combine_documents
except ModuleNotFoundError:
click.echo(
"In order to use `packages`, you must `pip install 'nbwrite[rag]'`"
)
exit(1)
else:
context_chain = RunnableLambda(lambda _: "none")

llm = get_llm(**config.generation.llm_kwargs) | StrOutputParser()
chain = (
{
"context": itemgetter("task") | retriever | _combine_documents,
"context": context_chain,
"task": itemgetter("task"),
"steps": itemgetter("steps"),
"packages": itemgetter("packages"),
Expand Down Expand Up @@ -114,3 +124,10 @@ def _combine_documents(
click.echo(f"Wrote notebook to {filename}")
except Exception as e:
logger.error(f"Error writing notebook (generation {generation}): {e}")


def _combine_documents(
docs, document_prompt=PromptTemplate.from_template(template="{page_content}"), document_separator="\n\n" # type: ignore
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
2 changes: 2 additions & 0 deletions tests/resources/nbwrite-in/minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
task: |
Plot the iris dataset using pandas
30 changes: 30 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,33 @@ def test_cli(tmpdir: local):
assert "Hello, world!" in "".join(
output.text for output in cell.outputs
)


def test_cli_minimal(tmpdir: local):

if os.getenv("NBWRITE_DEBUG_MODE"):
outdir = "test-debug-out"
[pp.unlink() for pp in Path(outdir).glob("*.ipynb")]
else:
outdir = str(tmpdir)
runner = CliRunner()
args = [
"tests/resources/nbwrite-in/minimal.yaml",
"--out",
outdir,
]

shell_fmt = " \\\n ".join(["nbwrite", *args])
logger.warn(f"Running\n{shell_fmt}")

with patch("nbwrite.writer.get_llm") as mock_get_llm:
mock_get_llm.return_value = FakeListLLM(
responses=["Code:\n```python\nprint('Hello, world!')\n```\n"]
)
result = runner.invoke(cli, args)

assert result.exit_code == 0

logger.warn(f"Checking outputs in {outdir}")
outputs = list(Path(outdir).glob("*.ipynb"))
assert len(outputs) == 2

0 comments on commit b08c3d6

Please sign in to comment.