Skip to content

Commit

Permalink
feat(llm): extend node from llm
Browse files Browse the repository at this point in the history
  • Loading branch information
MorvanZhou committed Jul 10, 2024
1 parent c0e0be6 commit 71ae7c1
Show file tree
Hide file tree
Showing 24 changed files with 404 additions and 59 deletions.
2 changes: 2 additions & 0 deletions src/retk/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
manager,
statistic,
notice,
ai,
)
from .routes.utils import on_shutdown, on_startup

Expand Down Expand Up @@ -59,6 +60,7 @@ async def lifespan(app: FastAPI):
manager,
statistic,
notice,
ai,
]:
app.include_router(r.router)

Expand Down
Empty file.
35 changes: 35 additions & 0 deletions src/retk/controllers/ai/knowledge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from retk import core
from retk.controllers import schemas
from retk.controllers.node.node_ops import get_node_data
from retk.controllers.utils import maybe_raise_json_exception
from retk.models.tps import AuthedUser


async def get_extended_nodes(
au: AuthedUser,
) -> schemas.ai.GetExtendedNodesResponse:
docs = await core.ai.llm.knowledge.extended.get_extended_nodes(uid=au.u.id)
return schemas.ai.GetExtendedNodesResponse(
requestId=au.request_id,
nodes=[schemas.ai.GetExtendedNodesResponse.Node(
id=str(doc["_id"]),
sourceNid=doc["sourceNid"],
sourceTitle=doc["sourceMd"].split("\n", 1)[0].strip(),
md=doc["extendMd"],
) for doc in docs]
)


async def accept_extended_node(
au: AuthedUser,
eid: str,
) -> schemas.node.NodeResponse:
n, code = await core.ai.llm.knowledge.extended.accept_extended_node(
au=au,
eid=eid,
)
maybe_raise_json_exception(au=au, code=code)
return schemas.node.NodeResponse(
requestId=au.request_id,
node=get_node_data(n),
)
8 changes: 4 additions & 4 deletions src/retk/controllers/node/node_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from retk.utils import contain_only_http_link, get_title_description_from_link, datetime2str


def __get_node_data(n: Node) -> schemas.node.NodeData:
def get_node_data(n: Node) -> schemas.node.NodeData:
from_nodes: List[schemas.node.NodeData.LinkedNode] = []
to_nodes: List[schemas.node.NodeData.LinkedNode] = []
for nodes, n_nodes in zip(
Expand Down Expand Up @@ -62,7 +62,7 @@ async def post_node(
)
return schemas.node.NodeResponse(
requestId=au.request_id,
node=__get_node_data(n),
node=get_node_data(n),
)


Expand Down Expand Up @@ -102,7 +102,7 @@ async def get_node(

return schemas.node.NodeResponse(
requestId=au.request_id,
node=__get_node_data(n),
node=get_node_data(n),
)


Expand All @@ -120,7 +120,7 @@ async def update_md(

return schemas.node.NodeResponse(
requestId=au.request_id,
node=__get_node_data(n),
node=get_node_data(n),
)


Expand Down
1 change: 1 addition & 0 deletions src/retk/controllers/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
manager,
statistic,
notice,
ai,
)


Expand Down
14 changes: 14 additions & 0 deletions src/retk/controllers/schemas/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from pydantic import BaseModel


class GetExtendedNodesResponse(BaseModel):
class Node(BaseModel):
id: str
sourceNid: str
sourceTitle: str
md: str

requestId: str
nodes: List[Node]
1 change: 1 addition & 0 deletions src/retk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
self_hosted,
notice,
analysis,
ai,
)
3 changes: 2 additions & 1 deletion src/retk/core/ai/llm/knowledge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Tuple

from retk import const
from .db_ops import extend_on_node_update, extend_on_node_post, LLM_SERVICES
from . import extended
from .extending import extend_on_node_update, extend_on_node_post, LLM_SERVICES
from ..api.base import BaseLLMService, MessagesType

system_summary_prompt = (Path(__file__).parent / "system_summary.md").read_text(encoding="utf-8")
Expand Down
45 changes: 45 additions & 0 deletions src/retk/core/ai/llm/knowledge/extended.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import List, Tuple

from bson import ObjectId

from retk import core
from retk.config import is_local_db
from retk.const import CodeEnum
from retk.models.client import client
from retk.models.tps import AuthedUser
from retk.models.tps.llm import ExtendedNode
from retk.models.tps.node import Node
from retk.utils import get_at_node_md_link


async def get_extended_nodes(
uid: str,
) -> List[ExtendedNode]:
docs = await client.coll.llm_extended_node.find({"uid": uid}).to_list(None)
return docs


async def accept_extended_node(
au: AuthedUser,
eid: str,
) -> Tuple[Node, CodeEnum]:
if not is_local_db():
doc = await client.coll.llm_extended_node.find_one_and_delete(
{"_id": ObjectId(eid)},
)
else:
doc = await client.coll.llm_extended_node.find_one(
{"_id": ObjectId(eid)},
)
await client.coll.llm_extended_node.delete_one(
{"_id": ObjectId(eid)},
)
title = doc["sourceMd"].split("\n", 1)[0].strip()
at_node = get_at_node_md_link(title, doc["sourceNid"])
md = doc["extendMd"] + "\n\n" + at_node
n, code = await core.node.post(
au=au,
md=md,
from_nid=doc["sourceNid"],
)
return n, code
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta

from bson import ObjectId
from bson.tz_util import utc

from retk.models.client import client
from retk.models.tps.llm import NodeExtendQueue
Expand All @@ -26,42 +27,45 @@ async def extend_on_node_post(data: Node):
_id=ObjectId(),
uid=data["uid"],
nid=data["id"],
modifiedAt=int(data["modifiedAt"].replace(tzinfo=utc).timestamp()),
summaryService="tencent",
summaryModel=api.TencentModelEnum.HUNYUAN_LITE.value,
extendService="ali",
extendModel=api.AliyunModelEnum.QWEN_PLUS.value,
extendService="tencent",
extendModel=api.TencentModelEnum.HUNYUAN_LITE.value,
)

# sort by _id desc
docs = await client.coll.llm_extend_node_queue.find(
filter={"uid": data["uid"]}
).sort("_id", -1).to_list(None)
).sort("modifiedAt", -1).to_list(None)
has_q = False
for doc in docs:
if doc["nid"] == data["id"]:
has_q = True
q["_id"] = doc["_id"]
# renew the creating time
await client.coll.llm_extend_knowledge_queue.update_one(
await client.coll.llm_extend_node_queue.update_one(
filter={"_id": doc["_id"]},
update={"_id": q["_id"]},
update={"$set": {"modifiedAt": q["modifiedAt"]}},
)
break

max_keep = 5
if not has_q:
# this is a new node in queue
if len(docs) >= max_keep:
# remove the oldest and only keep the latest 5
await client.coll.llm_extend_node_queue.delete_many(
{"_id": {"$in": [doc["_id"] for doc in docs[max_keep:]]}}
)

await client.coll.llm_extend_node_queue.insert_one(q)


async def extend_on_node_update(old_data: Node, new_data: Node):
async def extend_on_node_update(
old_data: Node,
new_data: Node,
cooling_time: int = 60,
):
# filter out frequent updates
if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=60):
if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=cooling_time):
return

await extend_on_node_post(new_data)
4 changes: 2 additions & 2 deletions src/retk/core/files/importing/async_tasks/obsidian/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from retk import regex, const
from retk.core.files.saver import saver, File
from retk.utils import short_uuid
from retk.utils import short_uuid, get_at_node_md_link


@dataclass
Expand Down Expand Up @@ -142,7 +142,7 @@ async def replace_inner_link_and_upload(
except KeyError:
nid = short_uuid()
exist_path2nid[path] = nid
md = f"{md[: span[0]]}[@{filename}](/n/{nid}){md[span[1]:]}"
md = f"{md[: span[0]]}{get_at_node_md_link(filename, nid)}{md[span[1]:]}"
return md


Expand Down
2 changes: 1 addition & 1 deletion src/retk/core/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def post(
type_=type_,
disabled=False,
in_trash=False,
modified_at=_id.generation_time,
modified_at=datetime.datetime.now(tz=utc),
in_trash_at=None,
from_node_ids=from_nids,
to_node_ids=new_to_node_ids,
Expand Down
1 change: 1 addition & 0 deletions src/retk/core/scheduler/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def init_tasks():
func=tasks.notice.deliver_unscheduled_system_notices,
second=0,
)

# check unscheduled extend node every hour
run_every_at(
job_id="deliver_unscheduled_node_extend",
Expand Down
57 changes: 40 additions & 17 deletions src/retk/core/scheduler/tasks/extend_node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
import random
import time
from typing import List

from bson import ObjectId

from retk import const
from retk.core.ai.llm import knowledge
from retk.logger import logger
from retk.models.client import init_mongo
from retk.models.coll import CollNameEnum
from retk.models.tps.llm import NodeExtendQueue, ExtendedNode


Expand All @@ -22,45 +22,68 @@ def deliver_unscheduled_extend_nodes():
async def async_deliver_unscheduled_extend_nodes() -> str:
_, db = init_mongo(connection_timeout=5)
batch_size = 3
total_knowledge_extended = 0
total_success_count = 0
total_summary_time = 0
total_extend_time = 0
while True:
batch: List[NodeExtendQueue] = await db["llmExtendNodeQueue"].find().limit(batch_size).to_list(None)
done_id_list = []
batch: List[NodeExtendQueue] = await db[CollNameEnum.llm_extend_node_queue.value].find().limit(
batch_size).to_list(None)
if len(batch) == 0:
break

batch_result: List[ExtendedNode] = []
for item in batch:
req_id = "".join([str(random.randint(0, 9)) for _ in range(10)])
md = await db["node"].find_one({"id": item["nid"]})
node = await db[CollNameEnum.nodes.value].find_one({"id": item["nid"]})
# md = md[:int(8000 * 1.8)]
s0 = time.perf_counter()
_summary, code = await knowledge.summary(
llm_service=knowledge.LLM_SERVICES[item["summaryService"]],
model=item["summaryModel"],
md=md,
md=node["md"],
req_id=req_id,
)
s1 = time.perf_counter()
if code != const.CodeEnum.OK:
logger.error(f"knowledge summary error: {code}")
continue
logger.debug(f"summary: {_summary}")
e0 = time.perf_counter()
_extended, code = await knowledge.extend(
llm_service=knowledge.LLM_SERVICES[item["extendService"]],
model=item["extendModel"],
md=md,
md=_summary,
req_id=req_id,
)
e1 = time.perf_counter()
if code != const.CodeEnum.OK:
logger.error(f"knowledge extend error: {code}")
continue
batch_result.append(ExtendedNode(
_id=ObjectId(),
logger.debug(f"extended: {_extended}")
ext = ExtendedNode(
uid=item["uid"],
sourceNids=[item["nid"]],
sourceMd=[md],
sourceNid=item["nid"],
sourceMd=node["md"],
extendMd=_extended,
))
total_knowledge_extended += 1
)
await db[CollNameEnum.llm_extended_node.value].update_one(
{"uid": item["uid"], "sourceNid": item["nid"]},
{"$set": ext},
upsert=True
)
done_id_list.append(item["_id"])
total_summary_time += s1 - s0
total_extend_time += e1 - e0

if len(done_id_list) > 0:
res = await db[CollNameEnum.llm_extend_node_queue.value].delete_many({"_id": {"$in": done_id_list}})
total_success_count += res.deleted_count

if len(batch_result) > 0:
await db["llmExtendedNode"].insert_many(batch_result)
if total_success_count > 0:
logger.info(
f"llm extend knowledge task: "
f"avg_summary_time: {total_summary_time / total_success_count:.2f}s, "
f"avg_extend_time: {total_extend_time / total_success_count:.2f}s"
)

return f"successfully extent {total_knowledge_extended} node"
return f"successfully extent {len(done_id_list)} node"
Loading

0 comments on commit 71ae7c1

Please sign in to comment.