Skip to content

Commit

Permalink
update scrapegraphai to v1.12.2.
Browse files Browse the repository at this point in the history
use langchain init_chat_model function to create llm.
rewrite create_model_instance_config function.
  • Loading branch information
lrbmike committed Aug 9, 2024
1 parent 34ae59a commit 10c0e8e
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 49 deletions.
7 changes: 4 additions & 3 deletions .env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
GOOGLE_API_KEY=xxxx
OPENAI_API_KEY=xxxx
OPENAI_API_BASE=xxxx
GOOGLE_API_KEY=
GOOGLE_API_ENDPOINT=
API_KEY=
API_BASE_URL=
4 changes: 1 addition & 3 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import FastAPI
# 路由

from app.routers import crawl

import sys
Expand All @@ -8,10 +8,8 @@
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

# 防止相对路径导入出错
sys.path.append(os.path.join(os.path.dirname(__file__)))

app = FastAPI()

# 将其余单独模块进行整合
app.include_router(crawl.router)
61 changes: 29 additions & 32 deletions app/modules/scrapegraphai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from scrapegraphai.graphs import SmartScraperGraph, SearchGraph
from scrapegraphai.helpers import models_tokens
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chat_models import init_chat_model
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -15,18 +14,19 @@ async def run_blocking_code_in_thread(blocking_func, *args):


class ScrapeGraphAiEngine:

"""
Your can find the model_provider by langchain website:
https://python.langchain.com/v0.2/docs/how_to/chat_models_universal_init/#inferring-model-provider
"""
def __init__(
self,
llm_name: str,
model_provider: str,
model_name: str,
embeddings_name: str,
temperature: float = 0,
model_instance: bool = False
):
self.llm_name = llm_name
self.model_provider = model_provider
self.model_name = model_name
self.embeddings_name = embeddings_name
self.temperature = temperature
self.model_instance = model_instance

Expand Down Expand Up @@ -62,46 +62,46 @@ async def search(
def create_llm(
self
):
if self.llm_name == "Gemini":
# special treatment of Gemini models
if self.model_provider == "google_genai":
if models_tokens["gemini"][self.model_name]:
self.model_tokens = models_tokens["gemini"][self.model_name]

return ChatGoogleGenerativeAI(model=self.model_name, temperature=self.temperature,
google_api_key=os.getenv('GOOGLE_API_KEY'))
# use api endpoint
if os.getenv("GOOGLE_API_ENDPOINT"):

elif self.llm_name == "OpenAI":
if models_tokens["openai"][self.model_name]:
self.model_tokens = models_tokens["openai"][self.model_name]
return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["GOOGLE_API_KEY"], transport="rest",
client_options={"api_endpoint": os.environ['GOOGLE_API_ENDPOINT']}
)

return ChatOpenAI(model=self.model_name, temperature=self.temperature,
api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"))
return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["GOOGLE_API_KEY"]
)

def create_embeddings(
self
):
if self.llm_name == "Gemini":
return GoogleGenerativeAIEmbeddings(model=self.embeddings_name,
google_api_key=os.environ['GOOGLE_API_KEY'])
else:
if models_tokens[self.model_provider][self.model_name]:
self.model_tokens = models_tokens[self.model_provider][self.model_name]

elif self.llm_name == "OpenAI":
return OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"))
return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["API_KEY"], base_url=os.environ["API_BASE_URL"]
)

def create_model_instance_config(
self
):

if self.model_instance:
llm = self.create_llm()
embeddings = self.create_embeddings()

graph_config = {
"llm": {
"model_instance": llm,
"model_tokens": self.model_tokens,
},
"embeddings": {
"model_instance": embeddings,
},
"verbose": True,
}

Expand All @@ -112,11 +112,8 @@ def create_model_instance_config(
graph_config = {
"llm": {
"model": self.model_name,
"api_key": os.getenv("OPENAI_API_KEY"),
"base_url": os.getenv("OPENAI_API_BASE")
},
"embeddings": {
"model": self.embeddings_name,
"api_key": os.getenv("API_KEY"),
"base_url": os.getenv("API_BASE_URL")
},
"verbose": True,
}
Expand Down
16 changes: 6 additions & 10 deletions app/routers/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,26 @@
async def scraper_graph(
prompt: str = Body(embed=True),
url: str = Body(embed=True),
llm_name: str = Body(embed=True),
model_provider: str = Body(embed=True),
model_name: str = Body(embed=True),
embeddings_name: str = Body(embed=True),
temperature: Optional[float] = Body(embed=False),
model_instance: Optional[bool] = Body(embed=False),
):
engine = ScrapeGraphAiEngine(llm_name=llm_name, model_name=model_name,
embeddings_name=embeddings_name, temperature=temperature,
model_instance=model_instance)
engine = ScrapeGraphAiEngine(model_provider=model_provider, model_name=model_name,
temperature=temperature, model_instance=model_instance)

return await engine.crawl(prompt=prompt, source=url)


@router.post("/search_graph")
async def search_graph(
prompt: str = Body(embed=True),
llm_name: str = Body(embed=True),
model_provider: str = Body(embed=True),
model_name: str = Body(embed=True),
embeddings_name: str = Body(embed=True),
temperature: float = Body(embed=False),
model_instance: bool = Body(embed=False),
):
engine = ScrapeGraphAiEngine(llm_name=llm_name, model_name=model_name,
embeddings_name=embeddings_name, temperature=temperature,
model_instance=model_instance)
engine = ScrapeGraphAiEngine(model_provider=model_provider, model_name=model_name,
temperature=temperature, model_instance=model_instance)

return await engine.search(prompt=prompt)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
scrapegraphai==1.11.3
scrapegraphai==1.12.2
fastapi==0.111.1
python-dotenv==1.0.1
nest-asyncio==1.6.0

0 comments on commit 10c0e8e

Please sign in to comment.