Skip to content

Latest commit

 

History

History
76 lines (65 loc) · 2.8 KB

3.8.md

File metadata and controls

76 lines (65 loc) · 2.8 KB

xobjdetect 下的 WaldBoost Detector

xobjectdetect 只有一个检测器,叫做 WaldBoost Detector,但是试了一下训练之后根本检测不了图片,网上也没找到用的代码,所以暂时就把 sample 代码放在这里,以后再考虑,原始 sample 代码在:opencv_contrib-4.x\modules\xobjdetect\tools\waldboost_detector 中。

#include "opencv2/xobjdetect.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <iostream>
#include <cstdio>
using namespace std;
using namespace cv;
using namespace cv::xobjdetect;

int main(int argc, char **argv)
{
    if (argc < 5) {
        cerr << "Usage: " << argv[0] << " train <model_filename> <pos_path> <neg_path>" << endl;
        cerr << "       " << argv[0] << " detect <model_filename> <img_filename> <out_filename> <labelling_filename>" << endl;
        return 0;
    }

    string mode = argv[1];
    Ptr<WBDetector> detector = WBDetector::create();
    if (mode == "train") {
        assert(argc == 5);
        detector->train(argv[3], argv[4]);
        FileStorage fs(argv[2], FileStorage::WRITE);
        fs << "waldboost";
        detector->write(fs);
    } else if (mode == "detect") {
        assert(argc == 6);
        vector<Rect> bboxes;
        vector<double> confidences;
        Mat img = imread(argv[3], IMREAD_GRAYSCALE);
        FileStorage fs(argv[2], FileStorage::READ);
        detector->read(fs.getFirstTopLevelNode());
        detector->detect(img, bboxes, confidences);

        FILE *fhandle = fopen(argv[5], "a");
        for (size_t i = 0; i < bboxes.size(); ++i) {
            Rect o = bboxes[i];
            fprintf(fhandle, "%s;%u;%u;%u;%u;%lf\n",
                argv[3], o.x, o.y, o.width, o.height, confidences[i]);
        }
        for (size_t i = 0; i < bboxes.size(); ++i) {
            rectangle(img, bboxes[i], Scalar(255, 0, 0));
        }
        imwrite(argv[4], img);
    }
}

其中 train 中说 pos 需要是 crop 后的图片,neg 无所谓,但是也没说 crop 多大,看代码又感觉无所谓,但 neg 是按照 (24, 24) 大小来裁剪的... 实现代码在 wbdetector.cpp 中:

void WBDetectorImpl::train( const string& pos_samples_path, const string& neg_imgs_path)
{
    // pos 读取图片,没啥特殊的,就 for 循环读取,没有任何处理
    vector<Mat> pos_imgs = read_imgs(pos_samples_path);
    // neg 读取图片,这里 24, 24 就是把 neg 的图片按照 24x24 进行裁剪
    vector<Mat> neg_imgs = sample_patches(neg_imgs_path, 24, 24, pos_imgs.size() * 10);

    CV_Assert(pos_imgs.size());
    CV_Assert(neg_imgs.size());

    int n_features;
    Mat pos_data, neg_data;

    Ptr<CvFeatureEvaluator> eval = CvFeatureEvaluator::create();
    eval->init(CvFeatureParams::create(), 1, Size(24, 24));
    n_features = eval->getNumFeatures();

    // ...
}