Skip to content

Commit

Permalink
change create model function
Browse files Browse the repository at this point in the history
  • Loading branch information
lrbmike committed Aug 10, 2024
1 parent be89027 commit f46de6f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 83 deletions.
38 changes: 32 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Project base on [ScrapeGraphAI](https://github.com/ScrapeGraphAI/Scrapegraph-ai.
pip install -r requirements.txt
# Browser driver install
playwright install
# If prompted "ImportError: burr package is not installed. Please install it with 'pip install scrapegraphai[burr]'"
pip install scrapegraphai[burr]
```

## Environment
Expand Down Expand Up @@ -44,8 +46,7 @@ curl -X POST https://your-domain/crawl/scraper_graph \
"url": "https://techcrunch.com/category/artificial-intelligence/",
"model_provider": "google_genai",
"model_name": "gemini-1.5-flash-latest",
"temperature": 0,
"model_instance": true
"temperature": 0
}'

```
Expand All @@ -65,7 +66,7 @@ curl -X POST https://your-domain/crawl/search_graph \

### OpenAI Model

You need to set `API_KEY` or `API_BASE_URL` in `.env` file first. If you set `API_BASE_URL` , it will be configured into the Gemini model.
You need to set `API_KEY` or `API_BASE_URL` in `.env` file first. If you set `API_BASE_URL` , it will be configured into the OpenAI model.

#### scraper graph

Expand All @@ -77,8 +78,7 @@ curl -X POST https://your-domain/crawl/scraper_graph \
"url": "https://techcrunch.com/category/artificial-intelligence/",
"model_provider": "openai",
"model_name": "gpt-4o-mini",
"temperature": 0,
"model_instance": false
"temperature": 0
}'
```

Expand All @@ -95,7 +95,34 @@ curl -X POST https://your-domain/crawl/search_graph \
}'
```

### Ollama

#### scraper graph

```shell
curl -X POST https://your-domain/crawl/scraper_graph \
-H "Content-Type: application/json" \
-d '{
"prompt": "List me all the articles with their title、description、link、published",
"url": "https://next.ithome.com/",
"model_provider": "ollama",
"model_name": "ollama/llama3.1",
"temperature": 0
}'
```

#### search graph

```shell
curl -X POST https://your-domain/crawl/search_graph \
-H "Content-Type: application/json" \
-d '{
"prompt": "List me all the traditional recipes from Chioggia",
"model_provider": "ollama",
"model_name": "ollama/llama3.1",
"temperature": 0
}'
```

## Docker

Expand All @@ -106,4 +133,3 @@ Or you can publish to [Render](https://render.com/)
## Known issues
> The current support for models is not perfect, and there are quite a few such problems in [Scrapegraph-ai](https://github.com/ScrapeGraphAI/Scrapegraph-ai/issues).
- When using the Search Graph method, you can't initialize the model using the `model_instance`
87 changes: 16 additions & 71 deletions app/modules/scrapegraphai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from scrapegraphai.graphs import SmartScraperGraph, SearchGraph
from scrapegraphai.helpers import models_tokens
from langchain.chat_models import init_chat_model
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -23,13 +21,11 @@ def __init__(
self,
model_provider: str,
model_name: str,
temperature: float = 0,
model_instance: bool = False
temperature: float = 0
):
self.model_provider = model_provider
self.model_name = model_name
self.temperature = temperature
self.model_instance = model_instance

self.model_tokens = 128000
self.graph_config = self.create_model_instance_config()
Expand Down Expand Up @@ -62,37 +58,6 @@ async def search(
result = search_graph.run()
return result

def create_llm(
self
):
# special treatment of Gemini models
if self.model_provider == "google_genai":
if models_tokens["gemini"][self.model_name]:
self.model_tokens = models_tokens["gemini"][self.model_name]

# use api endpoint
if os.getenv("GOOGLE_API_ENDPOINT"):

return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["GOOGLE_API_KEY"], transport="rest",
client_options={"api_endpoint": os.environ['GOOGLE_API_ENDPOINT']}
)

return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["GOOGLE_API_KEY"]
)

else:
if models_tokens[self.model_provider][self.model_name]:
self.model_tokens = models_tokens[self.model_provider][self.model_name]

return init_chat_model(
self.model_name, model_provider=self.model_provider, temperature=self.temperature,
api_key=os.environ["API_KEY"], base_url=os.environ["API_BASE_URL"]
)

def create_model_instance_config(
self
):
Expand All @@ -101,43 +66,23 @@ def create_model_instance_config(
"verbose": True,
}

if self.model_instance:
llm = self.create_llm()

if self.model_provider == "google_genai":
graph_config["llm"] = {
"model_instance": llm,
"model_tokens": self.model_tokens,
"model": self.model_name,
"api_key": os.environ["GOOGLE_API_KEY"]
}
if os.getenv("GOOGLE_API_ENDPOINT"):
graph_config["llm"]["transport"] = "rest"
graph_config["llm"]["client_options"] = {"api_endpoint": os.environ['GOOGLE_API_ENDPOINT']}

else:
graph_config["llm"] = {
"model": self.model_name,
"api_key": os.getenv("API_KEY")
}
if os.getenv("API_BASE_URL"):
graph_config["llm"]["base_url"] = os.getenv("API_BASE_URL")

return graph_config


# special treatment of Gemini models
if self.model_provider == "google_genai":
# use api endpoint
if os.getenv("GOOGLE_API_ENDPOINT"):
graph_config["llm"] = {
"model": self.model_name,
"api_key": os.environ["GOOGLE_API_KEY"],
"transport": "rest",
"client_options": {"api_endpoint": os.environ['GOOGLE_API_ENDPOINT']}
}
else:
graph_config["llm"] = {
"model": self.model_name,
"api_key": os.environ["GOOGLE_API_KEY"]
}
# special treatment of local models
elif self.model_provider == "ollama":
graph_config["llm"] = {
"model": self.model_name,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434",
}
else:
graph_config["llm"] = {
"model": self.model_name,
"api_key": os.getenv("API_KEY"),
"base_url": os.getenv("API_BASE_URL")
}

return graph_config
10 changes: 4 additions & 6 deletions app/routers/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ async def scraper_graph(
url: str = Body(embed=True),
model_provider: str = Body(embed=True),
model_name: str = Body(embed=True),
temperature: Optional[float] = Body(embed=False),
model_instance: Optional[bool] = Body(embed=False),
temperature: Optional[float] = Body(embed=False)
):
engine = ScrapeGraphAiEngine(model_provider=model_provider, model_name=model_name,
temperature=temperature, model_instance=model_instance)
temperature=temperature)

return await engine.crawl(prompt=prompt, source=url)

Expand All @@ -27,10 +26,9 @@ async def search_graph(
prompt: str = Body(embed=True),
model_provider: str = Body(embed=True),
model_name: str = Body(embed=True),
temperature: float = Body(embed=False),
model_instance: bool = Body(embed=False),
temperature: float = Body(embed=False)
):
engine = ScrapeGraphAiEngine(model_provider=model_provider, model_name=model_name,
temperature=temperature, model_instance=model_instance)
temperature=temperature)

return await engine.search(prompt=prompt)

0 comments on commit f46de6f

Please sign in to comment.