Skip to content

Commit

Permalink
Implemented Knowledge component basic functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jun 17, 2024
1 parent f5415ef commit e76fe6b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/agent/knowledge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.agent.knowledge.store import Store
from src.agent.knowledge.chunker import chunk, chunk_str
from src.agent.knowledge.collections import Collection, Document, Topic
37 changes: 37 additions & 0 deletions src/agent/knowledge/chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json

import spacy

from src.agent.knowledge.collections import Document


nlp = spacy.load("en_core_web_lg")


def chunk_str(document: str):
"""Chunks a text string"""
doc = nlp(document)
sentences = [sent for sent in list(doc.sents) if str(sent).strip() not in ['*']]

similarities = []
for i in range(1, len(sentences)):
sim = sentences[i-1].similarity(sentences[i])
similarities.append(sim)

threshold = 0.5
max_sent = 4
sentences = [str(sent) for sent in sentences]
groups = [[sentences[0]]]
for i in range(1, len(sentences)):
if len(groups[-1]) > max_sent or similarities[i-1] < threshold:
groups.append([sentences[i]])
else:
groups[-1].append(sentences[i])

_chunks = [" ".join(g) for g in groups]
return [_ch for _ch in _chunks if len(_ch) > 100]


def chunk(document: Document):
"""Return chunks of a Document that will be added to the Vector Database"""
return chunk_str(document.content)
40 changes: 40 additions & 0 deletions src/agent/knowledge/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import List, Optional


class Topic(StrEnum):
"""One of the possible Penetration Testing topics, used to choose a collection
and to filter documents"""
WebPenetrationTesting = 'Web Penetration Testing'


@dataclass
class Document:
"""Represents a processed data source such as HTML or PDF documents; it will
be chunked and added to a Vector Database"""
name: str
content: str
topic: Optional[Topic]

def __str__(self):
return f'{self.name} [{str(self.topic)}]\n{self.content}'


@dataclass
class Collection:
"""Represents a Qdrant collection"""
id: int
title: str
documents: List[Document]
topics: List[Topic]
size: Optional[int] = 0

def __str__(self):
docs = "| - Documents\n"
for doc in self.documents:
docs += f' | - {doc.name}\n'
return f'Title: {self.title} ({self.id})\n| - Topics: {", ".join(self.topics)}\n{docs}'



95 changes: 95 additions & 0 deletions src/agent/knowledge/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Dict

import ollama
from qdrant_client import QdrantClient, models

from src.agent.knowledge.chunker import chunk
from src.agent.knowledge.collections import Document, Collection


class Store:
"""Act as interface for Qdrant database"""

def __init__(self):
self._connection = QdrantClient(":memory:")
self._collections: Dict[str: Collection] = {}

self._encoder = ollama.embeddings
self._embedding_model: str = 'nomic-embed-text'
self._embedding_size: int = len(
self._encoder(
self._embedding_model,
prompt='init'
)['embedding']
)

def create_collection(self, collection: Collection):
"""Creates a new Qdrant collection"""
done = self._connection.create_collection(
collection_name=collection.title,
vectors_config=models.VectorParams(
size=self._embedding_size,
distance=models.Distance.COSINE
)
)
if done:
self._collections[collection.title] = collection

# add documents if already added to Collection
if len(collection.documents) > 0:
for document in collection.documents:
self.upload(document, collection.title)

def upload(self, document: Document, collection_name: str):
"""Performs chunking and embedding of a document and uploads it to the specified collection"""
if collection_name not in self._collections:
raise ValueError('Collection does not exist')

# create the Qdrant data points
doc_chunks = chunk(document)
emb_chunks = [{
'title': document.name,
# 'topic': str(document.topic)
'text': ch,
'embedding': self._encoder(self._embedding_model, ch)
} for ch in doc_chunks]
current_len = self._collections[collection_name].size

points = [
models.PointStruct(
id=current_len + i,
vector=item['embedding']['embedding'],
payload={'text': item['text'], 'title': item['title']}
)
for i, item in enumerate(emb_chunks)
]

# upload Points to Qdrant and update Collection metadata
self._connection.upload_points(
collection_name=collection_name,
points=points
)

self._collections[collection_name].documents.append(document)
self._collections[collection_name].size = current_len + len(emb_chunks)

def retrieve(self, query: str, collection_name: str, limit: int = 1):
"""Performs retrieval of chunks from the vector database"""
if len(query) < 3:
return None

hits = self._connection.search(
collection_name=collection_name,
query_vector=self._encoder(self._embedding_model, query)['embedding'],
limit=limit
)
return hits

@property
def collections(self):
return self._collections

def get_collection(self, name):
if name not in self.collections:
return None
return self._collections[name]

0 comments on commit e76fe6b

Please sign in to comment.