Skip to content

Commit

Permalink
Implement face recognition training in UI (blakeblackshear#15786)
Browse files Browse the repository at this point in the history
* Rename debug to train

* Add api to train image as person

* Cleanup model running

* Formatting

* Fix

* Set face recognition page title
  • Loading branch information
NickM-27 authored Jan 2, 2025
1 parent b307208 commit 05f9ae5
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 72 deletions.
44 changes: 43 additions & 1 deletion frigate/api/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import logging
import os
import random
import shutil
import string

from fastapi import APIRouter, Request, UploadFile
from fastapi.responses import JSONResponse
Expand All @@ -22,7 +25,13 @@ def get_faces():

for name in os.listdir(FACE_DIR):
face_dict[name] = []
for file in os.listdir(os.path.join(FACE_DIR, name)):

face_dir = os.path.join(FACE_DIR, name)

if not os.path.isdir(face_dir):
continue

for file in os.listdir(face_dir):
face_dict[name].append(file)

return JSONResponse(status_code=200, content=face_dict)
Expand All @@ -38,6 +47,39 @@ async def register_face(request: Request, name: str, file: UploadFile):
)


@router.post("/faces/train/{name}/classify")
def train_face(name: str, body: dict = None):
json: dict[str, any] = body or {}
training_file = os.path.join(
FACE_DIR, f"train/{sanitize_filename(json.get('training_file', ''))}"
)

if not training_file or not os.path.isfile(training_file):
return JSONResponse(
content=(
{
"success": False,
"message": f"Invalid filename or no file exists: {training_file}",
}
),
status_code=404,
)

rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
new_name = f"{name}-{rand_id}.webp"
new_file = os.path.join(FACE_DIR, f"{name}/{new_name}")
shutil.move(training_file, new_file)
return JSONResponse(
content=(
{
"success": True,
"message": f"Successfully saved {training_file} as {new_name}.",
}
),
status_code=200,
)


@router.post("/faces/{name}/delete")
def deregister_faces(request: Request, name: str, body: dict = None):
json: dict[str, any] = body or {}
Expand Down
2 changes: 1 addition & 1 deletion frigate/embeddings/maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:

if self.config.face_recognition.save_attempts:
# write face to library
folder = os.path.join(FACE_DIR, "debug")
folder = os.path.join(FACE_DIR, "train")
file = os.path.join(folder, f"{id}-{sub_label}-{score}-{face_score}.webp")
os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, face_frame)
Expand Down
19 changes: 16 additions & 3 deletions frigate/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def __init__(self, config: FaceRecognitionConfig, db: SqliteQueueDatabase):
self.config = config
self.db = db
self.landmark_detector = cv2.face.createFacemarkLBF()
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")

if os.path.isfile("/config/model_cache/facedet/landmarkdet.yaml"):
self.landmark_detector.loadModel(
"/config/model_cache/facedet/landmarkdet.yaml"
)

self.recognizer: cv2.face.LBPHFaceRecognizer = (
cv2.face.LBPHFaceRecognizer_create(
radius=2, threshold=(1 - config.min_score) * 1000
Expand All @@ -178,13 +183,21 @@ def __build_classifier(self) -> None:

dir = "/media/frigate/clips/faces"
for idx, name in enumerate(os.listdir(dir)):
if name == "debug":
if name == "train":
continue

self.label_map[idx] = name
face_folder = os.path.join(dir, name)

if not os.path.isdir(face_folder):
continue

self.label_map[idx] = name
for image in os.listdir(face_folder):
img = cv2.imread(os.path.join(face_folder, image))

if img is None:
continue

img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = self.__align_face(img, img.shape[1], img.shape[0])
faces.append(img)
Expand Down
25 changes: 25 additions & 0 deletions web/src/components/icons/AddFaceIcon.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { forwardRef } from "react";
import { LuPlus, LuScanFace } from "react-icons/lu";
import { cn } from "@/lib/utils";

type AddFaceIconProps = {
className?: string;
onClick?: () => void;
};

const AddFaceIcon = forwardRef<HTMLDivElement, AddFaceIconProps>(
({ className, onClick }, ref) => {
return (
<div
ref={ref}
className={cn("relative flex items-center", className)}
onClick={onClick}
>
<LuScanFace className="size-full" />
<LuPlus className="absolute size-4 translate-x-3 translate-y-3" />
</div>
);
},
);

export default AddFaceIcon;
Loading

0 comments on commit 05f9ae5

Please sign in to comment.