From 71ae7c14f20dbf4f84107133a931284108e91069 Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Wed, 10 Jul 2024 21:27:05 +0800 Subject: [PATCH] feat(llm): extend node from llm --- src/retk/application.py | 2 + src/retk/controllers/ai/__init__.py | 0 src/retk/controllers/ai/knowledge.py | 35 ++++++++++++ src/retk/controllers/node/node_ops.py | 8 +-- src/retk/controllers/schemas/__init__.py | 1 + src/retk/controllers/schemas/ai.py | 14 +++++ src/retk/core/__init__.py | 1 + src/retk/core/ai/llm/knowledge/__init__.py | 3 +- src/retk/core/ai/llm/knowledge/extended.py | 45 +++++++++++++++ .../llm/knowledge/{db_ops.py => extending.py} | 24 ++++---- .../importing/async_tasks/obsidian/ops.py | 4 +- src/retk/core/node/node.py | 2 +- src/retk/core/scheduler/schedule.py | 1 + src/retk/core/scheduler/tasks/extend_node.py | 57 +++++++++++++------ src/retk/core/scheduler/tasks/notice.py | 19 ++++--- src/retk/models/client.py | 20 +++---- src/retk/models/coll.py | 16 ++++++ src/retk/models/indexing.py | 12 +++- src/retk/models/tps/llm.py | 7 ++- src/retk/routes/ai.py | 44 ++++++++++++++ src/retk/utils.py | 6 +- tests/test_api.py | 53 +++++++++++++++++ tests/test_core_local.py | 36 ++++++++++++ tests/test_core_remote.py | 53 +++++++++++++++++ 24 files changed, 404 insertions(+), 59 deletions(-) create mode 100644 src/retk/controllers/ai/__init__.py create mode 100644 src/retk/controllers/ai/knowledge.py create mode 100644 src/retk/controllers/schemas/ai.py create mode 100644 src/retk/core/ai/llm/knowledge/extended.py rename src/retk/core/ai/llm/knowledge/{db_ops.py => extending.py} (77%) create mode 100644 src/retk/routes/ai.py diff --git a/src/retk/application.py b/src/retk/application.py index 0223b6d..1f97dc8 100644 --- a/src/retk/application.py +++ b/src/retk/application.py @@ -18,6 +18,7 @@ manager, statistic, notice, + ai, ) from .routes.utils import on_shutdown, on_startup @@ -59,6 +60,7 @@ async def lifespan(app: FastAPI): manager, statistic, notice, + ai, ]: app.include_router(r.router) diff --git a/src/retk/controllers/ai/__init__.py b/src/retk/controllers/ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retk/controllers/ai/knowledge.py b/src/retk/controllers/ai/knowledge.py new file mode 100644 index 0000000..84b8ae5 --- /dev/null +++ b/src/retk/controllers/ai/knowledge.py @@ -0,0 +1,35 @@ +from retk import core +from retk.controllers import schemas +from retk.controllers.node.node_ops import get_node_data +from retk.controllers.utils import maybe_raise_json_exception +from retk.models.tps import AuthedUser + + +async def get_extended_nodes( + au: AuthedUser, +) -> schemas.ai.GetExtendedNodesResponse: + docs = await core.ai.llm.knowledge.extended.get_extended_nodes(uid=au.u.id) + return schemas.ai.GetExtendedNodesResponse( + requestId=au.request_id, + nodes=[schemas.ai.GetExtendedNodesResponse.Node( + id=str(doc["_id"]), + sourceNid=doc["sourceNid"], + sourceTitle=doc["sourceMd"].split("\n", 1)[0].strip(), + md=doc["extendMd"], + ) for doc in docs] + ) + + +async def accept_extended_node( + au: AuthedUser, + eid: str, +) -> schemas.node.NodeResponse: + n, code = await core.ai.llm.knowledge.extended.accept_extended_node( + au=au, + eid=eid, + ) + maybe_raise_json_exception(au=au, code=code) + return schemas.node.NodeResponse( + requestId=au.request_id, + node=get_node_data(n), + ) diff --git a/src/retk/controllers/node/node_ops.py b/src/retk/controllers/node/node_ops.py index 0b0d671..d578861 100644 --- a/src/retk/controllers/node/node_ops.py +++ b/src/retk/controllers/node/node_ops.py @@ -7,7 +7,7 @@ from retk.utils import contain_only_http_link, get_title_description_from_link, datetime2str -def __get_node_data(n: Node) -> schemas.node.NodeData: +def get_node_data(n: Node) -> schemas.node.NodeData: from_nodes: List[schemas.node.NodeData.LinkedNode] = [] to_nodes: List[schemas.node.NodeData.LinkedNode] = [] for nodes, n_nodes in zip( @@ -62,7 +62,7 @@ async def post_node( ) return schemas.node.NodeResponse( requestId=au.request_id, - node=__get_node_data(n), + node=get_node_data(n), ) @@ -102,7 +102,7 @@ async def get_node( return schemas.node.NodeResponse( requestId=au.request_id, - node=__get_node_data(n), + node=get_node_data(n), ) @@ -120,7 +120,7 @@ async def update_md( return schemas.node.NodeResponse( requestId=au.request_id, - node=__get_node_data(n), + node=get_node_data(n), ) diff --git a/src/retk/controllers/schemas/__init__.py b/src/retk/controllers/schemas/__init__.py index e6d0c8f..2e38646 100644 --- a/src/retk/controllers/schemas/__init__.py +++ b/src/retk/controllers/schemas/__init__.py @@ -12,6 +12,7 @@ manager, statistic, notice, + ai, ) diff --git a/src/retk/controllers/schemas/ai.py b/src/retk/controllers/schemas/ai.py new file mode 100644 index 0000000..317998f --- /dev/null +++ b/src/retk/controllers/schemas/ai.py @@ -0,0 +1,14 @@ +from typing import List + +from pydantic import BaseModel + + +class GetExtendedNodesResponse(BaseModel): + class Node(BaseModel): + id: str + sourceNid: str + sourceTitle: str + md: str + + requestId: str + nodes: List[Node] diff --git a/src/retk/core/__init__.py b/src/retk/core/__init__.py index eea0823..a91d235 100644 --- a/src/retk/core/__init__.py +++ b/src/retk/core/__init__.py @@ -9,4 +9,5 @@ self_hosted, notice, analysis, + ai, ) diff --git a/src/retk/core/ai/llm/knowledge/__init__.py b/src/retk/core/ai/llm/knowledge/__init__.py index 0073131..9c81001 100644 --- a/src/retk/core/ai/llm/knowledge/__init__.py +++ b/src/retk/core/ai/llm/knowledge/__init__.py @@ -2,7 +2,8 @@ from typing import Tuple from retk import const -from .db_ops import extend_on_node_update, extend_on_node_post, LLM_SERVICES +from . import extended +from .extending import extend_on_node_update, extend_on_node_post, LLM_SERVICES from ..api.base import BaseLLMService, MessagesType system_summary_prompt = (Path(__file__).parent / "system_summary.md").read_text(encoding="utf-8") diff --git a/src/retk/core/ai/llm/knowledge/extended.py b/src/retk/core/ai/llm/knowledge/extended.py new file mode 100644 index 0000000..1b5630c --- /dev/null +++ b/src/retk/core/ai/llm/knowledge/extended.py @@ -0,0 +1,45 @@ +from typing import List, Tuple + +from bson import ObjectId + +from retk import core +from retk.config import is_local_db +from retk.const import CodeEnum +from retk.models.client import client +from retk.models.tps import AuthedUser +from retk.models.tps.llm import ExtendedNode +from retk.models.tps.node import Node +from retk.utils import get_at_node_md_link + + +async def get_extended_nodes( + uid: str, +) -> List[ExtendedNode]: + docs = await client.coll.llm_extended_node.find({"uid": uid}).to_list(None) + return docs + + +async def accept_extended_node( + au: AuthedUser, + eid: str, +) -> Tuple[Node, CodeEnum]: + if not is_local_db(): + doc = await client.coll.llm_extended_node.find_one_and_delete( + {"_id": ObjectId(eid)}, + ) + else: + doc = await client.coll.llm_extended_node.find_one( + {"_id": ObjectId(eid)}, + ) + await client.coll.llm_extended_node.delete_one( + {"_id": ObjectId(eid)}, + ) + title = doc["sourceMd"].split("\n", 1)[0].strip() + at_node = get_at_node_md_link(title, doc["sourceNid"]) + md = doc["extendMd"] + "\n\n" + at_node + n, code = await core.node.post( + au=au, + md=md, + from_nid=doc["sourceNid"], + ) + return n, code diff --git a/src/retk/core/ai/llm/knowledge/db_ops.py b/src/retk/core/ai/llm/knowledge/extending.py similarity index 77% rename from src/retk/core/ai/llm/knowledge/db_ops.py rename to src/retk/core/ai/llm/knowledge/extending.py index 1f55ba3..f95e18f 100644 --- a/src/retk/core/ai/llm/knowledge/db_ops.py +++ b/src/retk/core/ai/llm/knowledge/extending.py @@ -1,6 +1,7 @@ from datetime import timedelta from bson import ObjectId +from bson.tz_util import utc from retk.models.client import client from retk.models.tps.llm import NodeExtendQueue @@ -26,42 +27,45 @@ async def extend_on_node_post(data: Node): _id=ObjectId(), uid=data["uid"], nid=data["id"], + modifiedAt=int(data["modifiedAt"].replace(tzinfo=utc).timestamp()), summaryService="tencent", summaryModel=api.TencentModelEnum.HUNYUAN_LITE.value, - extendService="ali", - extendModel=api.AliyunModelEnum.QWEN_PLUS.value, + extendService="tencent", + extendModel=api.TencentModelEnum.HUNYUAN_LITE.value, ) # sort by _id desc docs = await client.coll.llm_extend_node_queue.find( filter={"uid": data["uid"]} - ).sort("_id", -1).to_list(None) + ).sort("modifiedAt", -1).to_list(None) has_q = False for doc in docs: if doc["nid"] == data["id"]: has_q = True - q["_id"] = doc["_id"] # renew the creating time - await client.coll.llm_extend_knowledge_queue.update_one( + await client.coll.llm_extend_node_queue.update_one( filter={"_id": doc["_id"]}, - update={"_id": q["_id"]}, + update={"$set": {"modifiedAt": q["modifiedAt"]}}, ) break max_keep = 5 if not has_q: + # this is a new node in queue if len(docs) >= max_keep: # remove the oldest and only keep the latest 5 await client.coll.llm_extend_node_queue.delete_many( {"_id": {"$in": [doc["_id"] for doc in docs[max_keep:]]}} ) - await client.coll.llm_extend_node_queue.insert_one(q) -async def extend_on_node_update(old_data: Node, new_data: Node): +async def extend_on_node_update( + old_data: Node, + new_data: Node, + cooling_time: int = 60, +): # filter out frequent updates - if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=60): + if new_data["modifiedAt"] - old_data["modifiedAt"] < timedelta(seconds=cooling_time): return - await extend_on_node_post(new_data) diff --git a/src/retk/core/files/importing/async_tasks/obsidian/ops.py b/src/retk/core/files/importing/async_tasks/obsidian/ops.py index 96ec668..fdd8c14 100644 --- a/src/retk/core/files/importing/async_tasks/obsidian/ops.py +++ b/src/retk/core/files/importing/async_tasks/obsidian/ops.py @@ -8,7 +8,7 @@ from retk import regex, const from retk.core.files.saver import saver, File -from retk.utils import short_uuid +from retk.utils import short_uuid, get_at_node_md_link @dataclass @@ -142,7 +142,7 @@ async def replace_inner_link_and_upload( except KeyError: nid = short_uuid() exist_path2nid[path] = nid - md = f"{md[: span[0]]}[@{filename}](/n/{nid}){md[span[1]:]}" + md = f"{md[: span[0]]}{get_at_node_md_link(filename, nid)}{md[span[1]:]}" return md diff --git a/src/retk/core/node/node.py b/src/retk/core/node/node.py index 874561b..0970a1b 100644 --- a/src/retk/core/node/node.py +++ b/src/retk/core/node/node.py @@ -57,7 +57,7 @@ async def post( type_=type_, disabled=False, in_trash=False, - modified_at=_id.generation_time, + modified_at=datetime.datetime.now(tz=utc), in_trash_at=None, from_node_ids=from_nids, to_node_ids=new_to_node_ids, diff --git a/src/retk/core/scheduler/schedule.py b/src/retk/core/scheduler/schedule.py index 4551aa9..acd866d 100644 --- a/src/retk/core/scheduler/schedule.py +++ b/src/retk/core/scheduler/schedule.py @@ -92,6 +92,7 @@ def init_tasks(): func=tasks.notice.deliver_unscheduled_system_notices, second=0, ) + # check unscheduled extend node every hour run_every_at( job_id="deliver_unscheduled_node_extend", diff --git a/src/retk/core/scheduler/tasks/extend_node.py b/src/retk/core/scheduler/tasks/extend_node.py index 252c580..e570049 100644 --- a/src/retk/core/scheduler/tasks/extend_node.py +++ b/src/retk/core/scheduler/tasks/extend_node.py @@ -1,13 +1,13 @@ import asyncio import random +import time from typing import List -from bson import ObjectId - from retk import const from retk.core.ai.llm import knowledge from retk.logger import logger from retk.models.client import init_mongo +from retk.models.coll import CollNameEnum from retk.models.tps.llm import NodeExtendQueue, ExtendedNode @@ -22,45 +22,68 @@ def deliver_unscheduled_extend_nodes(): async def async_deliver_unscheduled_extend_nodes() -> str: _, db = init_mongo(connection_timeout=5) batch_size = 3 - total_knowledge_extended = 0 + total_success_count = 0 + total_summary_time = 0 + total_extend_time = 0 while True: - batch: List[NodeExtendQueue] = await db["llmExtendNodeQueue"].find().limit(batch_size).to_list(None) + done_id_list = [] + batch: List[NodeExtendQueue] = await db[CollNameEnum.llm_extend_node_queue.value].find().limit( + batch_size).to_list(None) if len(batch) == 0: break - batch_result: List[ExtendedNode] = [] for item in batch: req_id = "".join([str(random.randint(0, 9)) for _ in range(10)]) - md = await db["node"].find_one({"id": item["nid"]}) + node = await db[CollNameEnum.nodes.value].find_one({"id": item["nid"]}) # md = md[:int(8000 * 1.8)] + s0 = time.perf_counter() _summary, code = await knowledge.summary( llm_service=knowledge.LLM_SERVICES[item["summaryService"]], model=item["summaryModel"], - md=md, + md=node["md"], req_id=req_id, ) + s1 = time.perf_counter() if code != const.CodeEnum.OK: logger.error(f"knowledge summary error: {code}") continue + logger.debug(f"summary: {_summary}") + e0 = time.perf_counter() _extended, code = await knowledge.extend( llm_service=knowledge.LLM_SERVICES[item["extendService"]], model=item["extendModel"], - md=md, + md=_summary, req_id=req_id, ) + e1 = time.perf_counter() if code != const.CodeEnum.OK: logger.error(f"knowledge extend error: {code}") continue - batch_result.append(ExtendedNode( - _id=ObjectId(), + logger.debug(f"extended: {_extended}") + ext = ExtendedNode( uid=item["uid"], - sourceNids=[item["nid"]], - sourceMd=[md], + sourceNid=item["nid"], + sourceMd=node["md"], extendMd=_extended, - )) - total_knowledge_extended += 1 + ) + await db[CollNameEnum.llm_extended_node.value].update_one( + {"uid": item["uid"], "sourceNid": item["nid"]}, + {"$set": ext}, + upsert=True + ) + done_id_list.append(item["_id"]) + total_summary_time += s1 - s0 + total_extend_time += e1 - e0 + + if len(done_id_list) > 0: + res = await db[CollNameEnum.llm_extend_node_queue.value].delete_many({"_id": {"$in": done_id_list}}) + total_success_count += res.deleted_count - if len(batch_result) > 0: - await db["llmExtendedNode"].insert_many(batch_result) + if total_success_count > 0: + logger.info( + f"llm extend knowledge task: " + f"avg_summary_time: {total_summary_time / total_success_count:.2f}s, " + f"avg_extend_time: {total_extend_time / total_success_count:.2f}s" + ) - return f"successfully extent {total_knowledge_extended} node" + return f"successfully extent {len(done_id_list)} node" diff --git a/src/retk/core/scheduler/tasks/notice.py b/src/retk/core/scheduler/tasks/notice.py index fd31f8f..378afdc 100644 --- a/src/retk/core/scheduler/tasks/notice.py +++ b/src/retk/core/scheduler/tasks/notice.py @@ -7,19 +7,20 @@ from retk import const, config from retk.models.client import init_mongo +from retk.models.coll import CollNameEnum async def __get_users_in_batches(db, batch_size=100): # Get the total number of users - total_users = await db["users"].count_documents({}) + total_users = await db[CollNameEnum.users.value].count_documents({}) if config.is_local_db(): - fn = db["users"].find( + fn = db[CollNameEnum.users.value].find( {}, ).sort( [("_id", -1)] ) else: - fn = db["users"].find( + fn = db[CollNameEnum.users.value].find( {}, projection=["id"] ).sort( [("_id", -1)] @@ -47,7 +48,7 @@ async def __deliver_scheduled_system_notices_batch( "readTime": None, } for user in users] # Insert all notices at once - doc = await db["noticeSystem"].insert_many(notices) + doc = await db[CollNameEnum.notice_system.value].insert_many(notices) return len(doc.inserted_ids) @@ -61,7 +62,7 @@ def deliver_unscheduled_system_notices(): async def async_deliver_unscheduled_system_notices(): _, db = init_mongo(connection_timeout=5) - unscheduled = await db["noticeManagerDelivery"].find({ + unscheduled = await db[CollNameEnum.notice_manager_delivery.value].find({ "scheduled": False, }).sort("publishAt", -1).to_list(None) total_users = 0 @@ -96,12 +97,12 @@ async def async_deliver_unscheduled_system_notices(): } for user_id in batch_type_ids] # Insert all notices at once if len(notices) > 0: - docs = await db["noticeSystem"].insert_many(notices) + docs = await db[CollNameEnum.notice_system.value].insert_many(notices) success_users += len(docs.inserted_ids) total_users += len(batch_type_ids) elif recipient_type == const.notice.RecipientTypeEnum.ADMIN.value: # Get all admins - admins = await db["users"].find( + admins = await db[CollNameEnum.users.value].find( {"type": const.USER_TYPE.ADMIN.id}, {"id", 1}).to_list(None) success_users_count = await __deliver_scheduled_system_notices_batch( db=db, @@ -113,7 +114,7 @@ async def async_deliver_unscheduled_system_notices(): success_users += success_users_count elif recipient_type == const.notice.RecipientTypeEnum.MANAGER.value: # Get all managers - managers = await db["users"].find( + managers = await db[CollNameEnum.users.value].find( {"type": const.USER_TYPE.MANAGER.id}, projection=["id"]).to_list(None) success_users_count = await __deliver_scheduled_system_notices_batch( db=db, @@ -127,7 +128,7 @@ async def async_deliver_unscheduled_system_notices(): raise ValueError(f"Unknown recipient type: {recipient_type}") # Update the notice to indicate that it has been scheduled - await db["noticeManagerDelivery"].update_one( + await db[CollNameEnum.notice_manager_delivery.value].update_one( {"_id": notice_id}, {"$set": {"scheduled": True}} ) diff --git a/src/retk/models/client.py b/src/retk/models/client.py index 88738f8..d1e8d62 100644 --- a/src/retk/models/client.py +++ b/src/retk/models/client.py @@ -12,7 +12,7 @@ from retk.logger import logger from retk.models.search_engine.engine import BaseEngine, SearchDoc, RestoreSearchDoc from retk.models.search_engine.engine_local import LocalSearcher -from .coll import Collections +from .coll import Collections, CollNameEnum from .indexing import remote_try_build_index from .tps import UserFile, ImportData, UserMeta, Node, AuthedUser, convert_user_dict_to_authed_user @@ -73,15 +73,15 @@ async def init(self): def init_mongo(self): self.mongo, db = init_mongo(self.connection_timeout) - self.coll.users = db["users"] - self.coll.nodes = db["nodes"] - self.coll.import_data = db["importData"] - self.coll.user_file = db["userFile"] - self.coll.user_behavior = db["userBehavior"] - self.coll.notice_manager_delivery = db["noticeManagerDelivery"] - self.coll.notice_system = db["noticeSystem"] - self.coll.llm_extend_node_queue = db["llmExtendNodeQueue"] - self.coll.llm_extended_node = db["llmExtendedNode"] + self.coll.users = db[CollNameEnum.users.value] + self.coll.nodes = db[CollNameEnum.nodes.value] + self.coll.import_data = db[CollNameEnum.import_data.value] + self.coll.user_file = db[CollNameEnum.user_file.value] + self.coll.user_behavior = db[CollNameEnum.user_behavior.value] + self.coll.notice_manager_delivery = db[CollNameEnum.notice_manager_delivery.value] + self.coll.notice_system = db[CollNameEnum.notice_system.value] + self.coll.llm_extend_node_queue = db[CollNameEnum.llm_extend_node_queue.value] + self.coll.llm_extended_node = db[CollNameEnum.llm_extended_node.value] async def init_search(self): conf = config.get_settings() diff --git a/src/retk/models/coll.py b/src/retk/models/coll.py index 74ad0c4..11c567f 100644 --- a/src/retk/models/coll.py +++ b/src/retk/models/coll.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import Enum from typing import Union, TYPE_CHECKING from retk.depend.mongita.collection import Collection @@ -25,3 +26,18 @@ class Collections: # llm llm_extend_node_queue: Union[Collection, "AsyncIOMotorCollection"] = None llm_extended_node: Union[Collection, "AsyncIOMotorCollection"] = None + + +class CollNameEnum(str, Enum): + users = "users" + nodes = "nodes" + import_data = "importData" + user_file = "userFile" + notice_manager_delivery = "noticeManagerDelivery" + notice_system = "noticeSystem" + user_behavior = "userBehavior" + llm_extend_node_queue = "llmExtendNodeQueue" + llm_extended_node = "llmExtendedNode" + + def __str__(self): + return self.value diff --git a/src/retk/models/indexing.py b/src/retk/models/indexing.py index 1be33e7..7640230 100644 --- a/src/retk/models/indexing.py +++ b/src/retk/models/indexing.py @@ -136,6 +136,16 @@ async def llm_extend_node_queue_coll(coll: "AsyncIOMotorCollection"): await not_in_and_create_index( coll=coll, index_info=index_info, - keys=["uid"], + keys=["uid", "modifiedAt"], unique=False, ) + + +async def llm_extended_node_coll(coll: "AsyncIOMotorCollection"): + index_info = await coll.index_information() + await not_in_and_create_index( + coll=coll, + index_info=index_info, + keys=["uid", "sourceNid"], + unique=True, + ) diff --git a/src/retk/models/tps/llm.py b/src/retk/models/tps/llm.py index 00c41e2..38be61a 100644 --- a/src/retk/models/tps/llm.py +++ b/src/retk/models/tps/llm.py @@ -1,4 +1,4 @@ -from typing import TypedDict, List +from typing import TypedDict from bson import ObjectId @@ -7,6 +7,7 @@ class NodeExtendQueue(TypedDict): _id: ObjectId uid: str nid: str + modifiedAt: int summaryService: str summaryModel: str extendService: str @@ -16,6 +17,6 @@ class NodeExtendQueue(TypedDict): class ExtendedNode(TypedDict): _id: ObjectId uid: str - sourceNids: List[str] - sourceMd: List[str] + sourceNid: str + sourceMd: str extendMd: str diff --git a/src/retk/routes/ai.py b/src/retk/routes/ai.py new file mode 100644 index 0000000..71f6900 --- /dev/null +++ b/src/retk/routes/ai.py @@ -0,0 +1,44 @@ +from typing import Optional + +from fastapi import APIRouter + +from retk.controllers import schemas +from retk.controllers.ai import knowledge +from retk.routes import utils + +router = APIRouter( + prefix="/api/ai", + tags=["node"], + responses={404: {"description": "Not found"}}, +) + + +@router.get( + path="/extended-nodes", + status_code=200, + response_model=schemas.ai.GetExtendedNodesResponse, +) +@utils.measure_time_spend +async def get_extended_nodes( + au: utils.ANNOTATED_AUTHED_USER, + referer: Optional[str] = utils.DEPENDS_REFERER, +) -> schemas.ai.GetExtendedNodesResponse: + return await knowledge.get_extended_nodes( + au=au, + ) + + +@router.post( + path="/extended-nodes/accept/{eid}", + status_code=201, + response_model=schemas.node.NodeResponse, +) +async def accept_extended_node( + au: utils.ANNOTATED_AUTHED_USER, + eid: str, + referer: Optional[str] = utils.DEPENDS_REFERER, +) -> schemas.node.NodeResponse: + return await knowledge.accept_extended_node( + au=au, + eid=eid, + ) diff --git a/src/retk/utils.py b/src/retk/utils.py index 03ea7da..f214f82 100644 --- a/src/retk/utils.py +++ b/src/retk/utils.py @@ -156,10 +156,14 @@ def md2html(md: str) -> str: return _html +def get_at_node_md_link(title: str, nid: str) -> str: + return f"[@{title}](/n/{nid})" + + def change_link_title(md: str, nid: str, new_title: str) -> str: new_md = re.sub( r"\[@[^].]*?]\(/n/{}/?\)".format(nid), - f"[@{new_title}](/n/{nid})", + get_at_node_md_link(new_title, nid), md, ) return new_md diff --git a/tests/test_api.py b/tests/test_api.py index 335fe29..ba69bb8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,6 +10,7 @@ from zipfile import ZipFile from PIL import Image +from bson import ObjectId from bson.tz_util import utc from fastapi.testclient import TestClient from httpx import Response @@ -19,6 +20,7 @@ from retk.core import account, scheduler from retk.models.client import client from retk.models.tps import convert_user_dict_to_authed_user +from retk.models.tps.llm import ExtendedNode from retk.plugins.register import register_official_plugins, unregister_official_plugins from retk.utils import jwt_decode from . import utils @@ -1564,3 +1566,54 @@ async def test_user_notice(self): self.assertIsNotNone(sn["readTime"]) await self.clear_default_manager(admin_uid) + + async def test_node_extend(self): + resp = self.client.post( + "/api/nodes", + json={ + "md": "node1\ntext", + "type": const.NodeTypeEnum.MARKDOWN.value, + }, + headers=self.default_headers, + ) + rj = self.check_ok_response(resp, 201) + node = rj["node"] + self.assertFalse(node["favorite"]) + uid = (await client.coll.users.find_one({"email": const.DEFAULT_USER["email"]})).get("id") + + await client.coll.llm_extended_node.insert_one( + ExtendedNode( + _id=ObjectId(), + uid=uid, + sourceNid=node["id"], + sourceMd=node["md"], + extendMd="this is extended md", + ), + ) + + resp = self.client.get( + f"/api/ai/extended-nodes", + headers=self.default_headers, + ) + rj = self.check_ok_response(resp, 200) + self.assertEqual(1, len(rj["nodes"])) + n = rj["nodes"][0] + self.assertEqual("this is extended md", n["md"]) + self.assertEqual(node["id"], n["sourceNid"]) + + resp = self.client.post( + f"/api/ai/extended-nodes/accept/{n['id']}", + headers=self.default_headers, + ) + rj = self.check_ok_response(resp, 201) + self.assertEqual( + f"this is extended md\n\n[@node1](/n/{n['sourceNid']})", + rj["node"]["md"] + ) + + resp = self.client.get( + f"/api/ai/extended-nodes", + headers=self.default_headers, + ) + rj = self.check_ok_response(resp, 200) + self.assertEqual(0, len(rj["nodes"])) diff --git a/tests/test_core_local.py b/tests/test_core_local.py index 755193f..6c32ab0 100644 --- a/tests/test_core_local.py +++ b/tests/test_core_local.py @@ -17,6 +17,7 @@ from retk import const, core, config from retk.controllers.schemas.user import PatchUserRequest +from retk.core.ai.llm.knowledge.extending import extend_on_node_update from retk.core.files.importing.async_tasks.utils import update_process from retk.core.scheduler import tasks from retk.models import db_ops @@ -213,6 +214,41 @@ async def test_node(self, mock_send): self.assertEqual(2, len(nodes)) self.assertEqual(2, total) + async def test_node_in_ai_extend_queue(self, mock_send): + node, code = await core.node.post( + au=self.au, md="knowledge test\nthis is a knowledge test" + ) + self.assertEqual(const.CodeEnum.OK, code) + self.assertIsNotNone(node) + + q = await client.coll.llm_extend_node_queue.find().to_list(None) + self.assertEqual(1, len(q)) + self.assertEqual(node["id"], q[0]["nid"]) + q_time = q[0]["modifiedAt"] + + time.sleep(0.1) + + new_node, old_node, code = await core.node.update_md( + au=self.au, nid=node["id"], md="knowledge test\nthis is a knowledge test\n\nnew line", + ) + self.assertEqual(const.CodeEnum.OK, code) + self.assertEqual(node["id"], new_node["id"]) + self.assertEqual("knowledge test", new_node["title"]) + self.assertEqual("knowledge test\nthis is a knowledge test\n\nnew line", new_node["md"]) + + q_ = await client.coll.llm_extend_node_queue.find().to_list(None) + self.assertEqual(1, len(q)) + self.assertEqual(node["id"], q[0]["nid"]) + self.assertEqual(q_[0]["modifiedAt"], q_time) + time.sleep(1) + new_node, old_node, code = await core.node.update_md( + au=self.au, nid=node["id"], md="knowledge test\nthis is a knowledge test\n\nnew line\n\nnewline2", + ) + await extend_on_node_update(old_node, new_node, cooling_time=1) + q_ = await client.coll.llm_extend_node_queue.find().to_list(None) + self.assertEqual(1, len(q)) + self.assertGreater(q_[0]["modifiedAt"], q_time) + async def test_parse_at(self, mock_send): nid1, _ = await core.node.post( au=self.au, md="c", type_=const.NodeTypeEnum.MARKDOWN.value, diff --git a/tests/test_core_remote.py b/tests/test_core_remote.py index 559bbf3..f364ebf 100644 --- a/tests/test_core_remote.py +++ b/tests/test_core_remote.py @@ -12,6 +12,7 @@ from retk import const, config, core from retk.controllers.schemas.user import PatchUserRequest from retk.core.account.manager import signup +from retk.core.ai.llm.knowledge.extending import extend_on_node_update from retk.core.scheduler import tasks from retk.models import db_ops from retk.models.client import client @@ -278,6 +279,58 @@ async def test_node( self.assertEqual(2, len(nodes)) self.assertEqual(2, total) + @utils.skip_no_connect + @patch("retk.core.node.backup.__remove_md_all_versions_from_cos") + @patch("retk.core.node.backup.__remove_md_from_cos") + @patch("retk.core.node.backup.__get_md_from_cos") + @patch("retk.core.node.backup.__save_md_to_cos") + async def test_node_in_ai_extend_queue( + self, + mock_send, + mock_save_md_to_cos, + mock_get_md_from_cos, + mock_remove_md_from_cos, + mock_remove_md_all_versions_from_cos, + ): + mock_save_md_to_cos.return_value = const.CodeEnum.OK + mock_get_md_from_cos.return_value = ("", const.CodeEnum.OK) + mock_remove_md_from_cos.return_value = const.CodeEnum.OK + mock_remove_md_all_versions_from_cos.return_value = const.CodeEnum.OK + + node, code = await core.node.post( + au=self.au, md="knowledge test\nthis is a knowledge test" + ) + self.assertEqual(const.CodeEnum.OK, code) + self.assertIsNotNone(node) + + qs = await client.coll.llm_extend_node_queue.find().to_list(None) + q = [q for q in qs if q["nid"] == node["id"]] + self.assertEqual(node["id"], q[0]["nid"]) + q_time = q[0]["modifiedAt"] + time.sleep(0.1) + + new_node, old_node, code = await core.node.update_md( + au=self.au, nid=node["id"], md="knowledge test\nthis is a knowledge test\n\nnew line", + ) + self.assertEqual(const.CodeEnum.OK, code) + self.assertEqual(node["id"], new_node["id"]) + self.assertEqual("knowledge test", new_node["title"]) + self.assertEqual("knowledge test\nthis is a knowledge test\n\nnew line", new_node["md"]) + + qs_ = await client.coll.llm_extend_node_queue.find().to_list(None) + q_ = [q for q in qs_ if q["nid"] == node["id"]] + self.assertEqual(node["id"], q_[0]["nid"]) + self.assertEqual(q_[0]["modifiedAt"], q_time) + + time.sleep(1) + new_node, old_node, code = await core.node.update_md( + au=self.au, nid=node["id"], md="knowledge test\nthis is a knowledge test\n\nnew line\n\nnewline2", + ) + await extend_on_node_update(old_data=old_node, new_data=new_node, cooling_time=1) + qs_ = await client.coll.llm_extend_node_queue.find().to_list(None) + q_ = [q for q in qs_ if q["nid"] == node["id"]] + self.assertGreater(q_[0]["modifiedAt"], q_time) + @utils.skip_no_connect @patch("retk.core.node.backup.__remove_md_all_versions_from_cos") @patch("retk.core.node.backup.__remove_md_from_cos")