Skip to content

Commit

Permalink
Change format of detections to work with MOT matlab devkit
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelmurray committed May 12, 2017
1 parent c27b102 commit ffa1172
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 36 deletions.
5 changes: 3 additions & 2 deletions cpp/src/examples/DetectAndTrackDemo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ std::pair<std::chrono::duration<double, std::milli>, int> detectAndTrack(const s

for (auto trackingIt = trackings.begin(); trackingIt != trackings.end(); ++trackingIt) {
outputStream << frame << ","
<< trackingIt->label << ","
<< trackingIt->ID << ","
<< trackingIt->bb.x1() << ","
<< trackingIt->bb.y1() << ","
<< trackingIt->bb.width << ","
<< trackingIt->bb.height << ","
<< "1,-1,-1,-1\n";
<< "1" << "," // Confidence
<< "-1,-1,-1" << "," // Unused
<< trackingIt->label << "\n";
}
if (realTime) {
int framesToSkip = int((originalFrameRate * duration.count()) / 1000) - 1;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/examples/DetectDemo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ std::pair<std::chrono::duration<double, std::milli>, int> detect(const std::shar

for (auto detectionIt = detections.begin(); detectionIt != detections.end(); ++detectionIt) {
outputStream << frameCount << ","
<< detectionIt->label << ","
<< "-1,"
<< "-1" << "," // ID
<< detectionIt->bb.x1() << ","
<< detectionIt->bb.y1() << ","
<< detectionIt->bb.width << ","
<< detectionIt->bb.height << ","
<< detectionIt->confidence << ","
<< "-1,-1,-1\n";
<< "-1,-1,-1" << "," // Unused
<< detectionIt->label << "\n";
}
++frameCount;
if (frameCount % 100 == 0) {
Expand Down
31 changes: 19 additions & 12 deletions cpp/src/examples/TrackDemo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@ const char *USAGE_MESSAGE = "Usage: %s "
"-s sequenceMap "
"-m modelType "
"-f detectionFormat "
"[-i frameInterval (default 1)]\n";
"[-i frameInterval (default 1)] "
"[-x (overwrites result)]\n";
const char *OPEN_FILE_MESSAGE = "Could not open file %s\n";
const char *FILE_EXISTS_MESSAGE = "Output file %s already exists; won't overwrite\n";

std::pair<std::chrono::duration<double, std::milli>, int> track(const boost::filesystem::path &sequencePath,
const std::string &modelType,
const std::string &detectionFormat,
const int frameInterval) {
const int frameInterval,
const bool doOverwriteResult) {
typedef std::chrono::duration<double, std::milli> msduration;

// Make sure input file exists
boost::filesystem::path inputPath = dataDirPath / sequencePath / modelType / "det.txt";
std::ifstream inputStream(inputPath.string());
if (!inputStream.is_open()) {
fprintf(stderr, OPEN_FILE_MESSAGE, inputPath.c_str());
exit(EXIT_FAILURE);
return std::pair<msduration, int>(msduration(0), 0);
}

// Create output directory if not exists
Expand All @@ -42,17 +44,17 @@ std::pair<std::chrono::duration<double, std::milli>, int> track(const boost::fil

// Make sure output file does not exist
boost::filesystem::path outputPath = outputDirPath / (sequencePath.filename().string() + ".txt");
if (boost::filesystem::exists(outputPath)) {
fprintf(stderr, FILE_EXISTS_MESSAGE, outputPath.c_str()); // FIXME:
//return std::pair<msduration, int>(msduration(0), 0);
if (boost::filesystem::exists(outputPath) && !doOverwriteResult) {
fprintf(stderr, FILE_EXISTS_MESSAGE, outputPath.c_str());
return std::pair<msduration, int>(msduration(0), 0);
}

// Make sure output file can be opened
std::ofstream outputStream;
outputStream.open(outputPath.string());
if (!outputStream.is_open()) {
fprintf(stderr, OPEN_FILE_MESSAGE, outputPath.c_str());
exit(EXIT_FAILURE);
return std::pair<msduration, int>(msduration(0), 0);
}

PAOT tracker;
Expand All @@ -70,7 +72,6 @@ std::pair<std::chrono::duration<double, std::milli>, int> track(const boost::fil
msduration cumulativeDuration = std::chrono::milliseconds::zero();
int frameCount = 0;
for (int frame = 0; frame < frameToDetections.rbegin()->first; ++frame) {
std::cout << frame << std::endl;
if (frame % frameInterval == 0 && frameToDetections.find(frame) != frameToDetections.end()) {
auto startTime = std::chrono::high_resolution_clock::now();
std::vector<Tracking> trackings = tracker.track(frameToDetections.at(frame));
Expand All @@ -81,13 +82,14 @@ std::pair<std::chrono::duration<double, std::milli>, int> track(const boost::fil

for (auto trackingIt = trackings.begin(); trackingIt != trackings.end(); ++trackingIt) {
outputStream << frame << ","
<< trackingIt->label << ","
<< trackingIt->ID << ","
<< trackingIt->bb.x1() << ","
<< trackingIt->bb.y1() << ","
<< trackingIt->bb.width << ","
<< trackingIt->bb.height << ","
<< "1,-1,-1,-1\n";
<< "1" << "," // Confidence
<< "-1,-1,-1" << "," // Unused
<< trackingIt->label << "\n";
}
++frameCount;
}
Expand All @@ -102,9 +104,10 @@ int main(int argc, char **argv) {
std::string modelType;
std::string detectionFormat;
int frameInterval = 1;
bool doOverwriteResult = false;

int opt;
while ((opt = getopt(argc, argv, "s:m:f:i:")) != -1) {
while ((opt = getopt(argc, argv, "s:m:f:i:x")) != -1) {
switch (opt) {
case 's':
sequenceMapName = optarg;
Expand All @@ -118,6 +121,9 @@ int main(int argc, char **argv) {
case 'i':
frameInterval = atoi(optarg);
break;
case 'x':
doOverwriteResult = true;
break;
default:
fprintf(stderr, USAGE_MESSAGE, argv[0]);
exit(EXIT_FAILURE);
Expand All @@ -143,7 +149,8 @@ int main(int argc, char **argv) {
std::string sequencePathString;
while (getline(sequenceMap, sequencePathString)) {
std::cout << "Sequence: " << sequencePathString << std::endl;
auto durationFrameCount = track(sequencePathString, modelType, detectionFormat, frameInterval);
auto durationFrameCount = track(sequencePathString, modelType, detectionFormat, frameInterval,
doOverwriteResult);
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(durationFrameCount.first).count();
std::cout << "Duration: " << duration << "ms"
<< " (" << double(durationFrameCount.second * 1000) / duration << "fps)\n";
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/util/DetectionFileParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ std::pair<int, Detection> DetectionFileParser::parseOkutamaLine(const std::strin
double x1, y1, width, height, confidence;
double x, y, z; // Unused
if (!(is >> frame && is.ignore() &&
is >> label && is.ignore() &&
is >> id && is.ignore() &&
is >> x1 && is.ignore() &&
is >> y1 && is.ignore() &&
Expand All @@ -68,10 +67,11 @@ std::pair<int, Detection> DetectionFileParser::parseOkutamaLine(const std::strin
is >> confidence && is.ignore() &&
is >> x && is.ignore() &&
is >> y && is.ignore() &&
is >> z)) {
is >> z && is.ignore() &&
is >> label)) {
throw std::invalid_argument(
"Each line must be on following format: "
"<frame>,<label>,<id>,<x_topleft>,<y_topleft>,<width>,<height>,<confidence>,<x>,<y>,<z>");
"<frame>,<id>,<x_topleft>,<y_topleft>,<width>,<height>,<confidence>,<x>,<y>,<z>,<label>");
}
return std::pair<int, Detection>(frame, Detection(label, confidence,
BoundingBox(x1 + width / 2.0, y1 + height / 2.0, width, height)));
Expand Down
24 changes: 8 additions & 16 deletions python/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import argparse

CONFIDENCE_THRESHOLD = 0.5
LABEL_ACTION_MAP = {
0 : "Background",
1 : "Walking",
Expand All @@ -22,7 +21,7 @@
}


def generate(sequence_path, model_type, tracking_format, frame_interval):
def generate(sequence_path, model_type, frame_interval, confidence_threshold):
from PIL import Image, ImageDraw, ImageFont
path_until_sequence, sequence_name = os.path.split(sequence_path)
tracking_file_path = "../data/results/{}/{}/{}.txt".format(path_until_sequence, model_type, sequence_name)
Expand All @@ -41,26 +40,19 @@ def generate(sequence_path, model_type, tracking_format, frame_interval):

with Image.open(image_path) as image:
draw = ImageDraw.Draw(image, 'RGBA')
# Format: [frame, ID, x1, y1, width, height, confidence, _, _, _, (label)]
dets = detections[detections[:,0]==frame, 1:7]
labels = None
dets = None
if tracking_format == "okutama":
# Format: [frame, label, ID, x1, y1, width, height, confidence, _, _, _]
if model_type == "action-detection":
labels = detections[detections[:,0]==frame, 1]
dets = detections[detections[:,0]==frame, 2:8]
elif tracking_format == "mot":
# Format: [frame, ID, x1, y1, width, height, confidence, _, _, _]
dets = detections[detections[:,0]==frame, 1:7]
pass
else:
return
dets[:, 3:5] += dets[:, 1:3] # Convert from [x1, y1, width, height] to [x1, y1, x2, y2]
for i, d in enumerate(dets):
if d[5] < CONFIDENCE_THRESHOLD:
continue
d = d.astype(np.int32)
c = tuple(colours[d[0]%32, :])
draw.rectangle([d[1], d[2], d[3], d[4]], fill=(c + (100,)), outline=c)
if tracking_format == "okutama" and model_type == "action-detection":
if model_type == "action-detection":
draw.text((d[1], d[2] + 4), LABEL_ACTION_MAP[labels[i]], fill=(255,255,255))
image.save("{}/{}".format(output_dir_path, image_name))

Expand All @@ -71,11 +63,11 @@ def generate(sequence_path, model_type, tracking_format, frame_interval):
help='a textfile specifying paths to sequences')
parser.add_argument('-m', '--modelType', metavar='modelType', required=True,
help='name of the type of model')
parser.add_argument('-f', '--trackingFormat', metavar='trackingFormat', required=True, choices=['okutama', 'mot'],
help='format of the tracking files')
parser.add_argument('-i', '--frameInterval', metavar='frameInterval', required=False, type=int, default=1,
help='number of frames to skip to skip between each produced frame')
parser.add_argument('-c', '--confidenceThreshold', metavar='confidenceThreshold', required=False, type=float, default=0.5,
help='don\'t display detections with confidence below this value')
args = parser.parse_args()
with open("../data/seqmaps/{}".format(args.sequenceMap)) as f:
for sequence_path in f:
generate(sequence_path, args.modelType, args.trackingFormat, args.frameInterval)
generate(sequence_path, args.modelType, args.frameInterval, args.confidenceThreshold)

0 comments on commit ffa1172

Please sign in to comment.