Skip to content

Commit

Permalink
adjust confidence threshold of nudenet detector
Browse files Browse the repository at this point in the history
  • Loading branch information
ladaapp committed Jan 13, 2025
1 parent b1bbbf1 commit 09b623f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
13 changes: 6 additions & 7 deletions lada/lib/nudenet_nsfw_detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from operator import itemgetter
from typing import Optional

from lada.lib.ultralytics_utils import disable_ultralytics_telemetry, convert_yolo_boxes
Expand All @@ -18,16 +17,16 @@ def __init__(self, model: YOLO, device):
self.model = model
self.device = device
self.batch_size = 4
self.min_confidence = 0.2
self.min_positive_detections = 4
self.min_confidence = 0.15
self.min_positive_detections = 6
self.sampling_rate = 0.3

def detect(self, images:list[Image], boxes:Optional[list[Box]]=None) -> tuple[bool, bool, bool]:
num_samples = min(len(images), int(len(images)*self.sampling_rate))
num_samples = min(len(images), max(1, int(len(images)*self.sampling_rate)))
indices_step_size = len(images) // num_samples
indices_of_nsfw_elements = list(range(0, num_samples*indices_step_size, indices_step_size))
samples = itemgetter(*indices_of_nsfw_elements)(images)
samples_boxes = itemgetter(*indices_of_nsfw_elements)(boxes) if boxes else None
samples = [images[i] for i in indices_of_nsfw_elements]
samples_boxes = [boxes[i] for i in indices_of_nsfw_elements] if boxes else None

batches = [samples[i:i + self.batch_size] for i in range(0, len(samples), self.batch_size)]
positive_detections = 0
Expand Down Expand Up @@ -62,7 +61,7 @@ def detect(self, images:list[Image], boxes:Optional[list[Box]]=None) -> tuple[bo
positive_male_detections += 1
if single_image_nsfw_female_detected:
positive_female_detections += 1
nsfw_detected = positive_detections > self.min_positive_detections
nsfw_detected = positive_detections >= self.min_positive_detections
nsfw_male_detected = positive_male_detections > self.min_positive_detections
nsfw_female_detected = positive_female_detections > self.min_positive_detections
#print(f"nudenet nsfw detector: nsfw {nsfw_detected}, detected {positive_detections}/{len(samples)}")
Expand Down
9 changes: 4 additions & 5 deletions lada/lib/watermark_detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from operator import itemgetter
from typing import Optional

from lada.lib.ultralytics_utils import disable_ultralytics_telemetry, convert_yolo_boxes
Expand All @@ -18,11 +17,11 @@ def __init__(self, model: YOLO, device):
self.sampling_rate = 0.3

def detect(self, images:list[Image], boxes:Optional[list[Box]]=None) -> bool:
num_samples = min(len(images), int(len(images)*self.sampling_rate))
num_samples = min(len(images), max(1, int(len(images) * self.sampling_rate)))
indices_step_size = len(images) // num_samples
indices = list(range(0, num_samples*indices_step_size, indices_step_size))
samples = itemgetter(*indices)(images)
samples_boxes = itemgetter(*indices)(boxes) if boxes else None
samples = [images[i] for i in indices]
samples_boxes = [boxes[i] for i in indices] if boxes else None

batches = [samples[i:i + self.batch_size] for i in range(0, len(samples), self.batch_size)]
positive_detections = 0
Expand All @@ -39,6 +38,6 @@ def detect(self, images:list[Image], boxes:Optional[list[Box]]=None) -> bool:
single_image_watermark_detected = any(conf > self.min_confidence and (not samples_boxes or box_overlap(detection_boxes[i], samples_boxes[sample_idx])) for i, conf in enumerate(detection_confidences))
if single_image_watermark_detected:
positive_detections += 1
watermark_detected = positive_detections > self.min_positive_detections
watermark_detected = positive_detections >= self.min_positive_detections
#print(f"watermark detector: watermark {watermark_detected}, detected {positive_detections}/{len(samples)}")
return watermark_detected

0 comments on commit 09b623f

Please sign in to comment.