Skip to content

Commit

Permalink
chore: refactor bulky add_cve_items method
Browse files Browse the repository at this point in the history
which got even larger because of the RAM fix
  • Loading branch information
jstucke committed Nov 27, 2024
1 parent 091876f commit d6ae1cd
Showing 1 changed file with 48 additions and 39 deletions.
87 changes: 48 additions & 39 deletions src/plugins/analysis/cve_lookup/internal/database/db_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import re
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable

from ..helper_functions import CveEntry, replace_wildcards
from .schema import Association, Cpe, Cve
from .schema import Association, Base, Cpe, Cve

if TYPE_CHECKING:
from .db_connection import DbConnection
Expand All @@ -20,6 +20,8 @@ def __init__(self, connection: DbConnection):
self.connection = connection
self.connection.create_tables()
self.session = self.connection.create_session()
self.existing_cve_ids = set()
self.existing_cpe_ids = set()

def create_cve(self, cve_item: CveEntry) -> Cve:
"""
Expand All @@ -33,8 +35,8 @@ def create_cve(self, cve_item: CveEntry) -> Cve:
cve_id=cve_item.cve_id,
year=year,
summary=cve_item.summary,
cvss_v2_score=score_v2,
cvss_v3_score=score_v3,
cvss_v2_score=str(score_v2),
cvss_v3_score=str(score_v3),
)

def create_cpe(self, cpe_id: str):
Expand All @@ -55,44 +57,51 @@ def create_cpe(self, cpe_id: str):
update=update,
)

def add_cve_items(self, cve_list: list[CveEntry], chunk_size: int = 2**12):
def create_association(self, cve_id: str, cpe_entry: tuple[str, str, str, str, str]) -> Association:
"""
Add CVE items to the database.
Create an Association object from a CVE ID and a CPE entry.
"""
existing_cve_ids = set()
existing_cpe_ids = set()
(
cpe_id,
version_start_including,
version_start_excluding,
version_end_including,
version_end_excluding,
) = cpe_entry
return Association(
cve_id=cve_id,
cpe_id=cpe_id,
version_start_including=version_start_including,
version_start_excluding=version_start_excluding,
version_end_including=version_end_including,
version_end_excluding=version_end_excluding,
)

db_objects = []
def add_cve_items(self, cve_list: Iterable[CveEntry], chunk_size: int = 2**12):
"""
Add CVE items to the database chunk-wise.
"""

db_objects: list[Base] = []

for cve_item in cve_list:
if cve_item.cve_id not in existing_cve_ids:
db_objects.append(self.create_cve(cve_item))
for cpe_entry in cve_item.cpe_entries:
(
cpe_id,
version_start_including,
version_start_excluding,
version_end_including,
version_end_excluding,
) = cpe_entry
if cpe_id not in existing_cpe_ids:
db_objects.append(self.create_cpe(cpe_id))
existing_cpe_ids.add(cpe_id)
db_objects.append(
Association(
cve_id=cve_item.cve_id,
cpe_id=cpe_id,
version_start_including=version_start_including,
version_start_excluding=version_start_excluding,
version_end_including=version_end_including,
version_end_excluding=version_end_excluding,
)
)
existing_cve_ids.add(cve_item.cve_id)
if len(db_objects) >= chunk_size:
self.session.bulk_save_objects(db_objects)
self.session.commit()
db_objects.clear()
if cve_item.cve_id not in self.existing_cve_ids:
db_objects.extend(self._create_db_objects_for_cve(cve_item))
if len(db_objects) >= chunk_size:
self._save_objects(db_objects)
db_objects.clear()
if db_objects:
self.session.bulk_save_objects(db_objects)
self.session.commit()
self._save_objects(db_objects)

def _create_db_objects_for_cve(self, cve_item: CveEntry) -> Iterable[Base]:
yield self.create_cve(cve_item)
for cpe_entry in cve_item.cpe_entries:
if (cpe_id := cpe_entry[0]) not in self.existing_cpe_ids:
yield self.create_cpe(cpe_id)
self.existing_cpe_ids.add(cpe_id)
yield self.create_association(cve_item.cve_id, cpe_entry)
self.existing_cve_ids.add(cve_item.cve_id)

def _save_objects(self, objects: list[Base]):
self.session.bulk_save_objects(objects)
self.session.commit()

0 comments on commit d6ae1cd

Please sign in to comment.