-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,850 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,5 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
.idea/ | ||
weights/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Author: Zylo117 | ||
|
||
import math | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet | ||
from efficientdet.utils import Anchors | ||
|
||
|
||
|
||
class EfficientDetBackbone(nn.Module): | ||
def __init__(self, num_anchors=9, num_classes=80, compound_coef=0, load_weights=False, **kwargs): | ||
super(EfficientDetBackbone, self).__init__() | ||
self.compound_coef = compound_coef | ||
|
||
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384] | ||
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8] | ||
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] | ||
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5] | ||
self.anchor_scale = [4, 4, 3, 4, 4, 4, 4, 5] | ||
self.aspect_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)] | ||
self.num_scales = 3 | ||
self.anchor_scale = 4.0 | ||
conv_channel_coef = { | ||
# TODO: I have only tested on D0/D2, if you want to try it on other coefficients, | ||
# fill it in with the channels of P3/P4/P5 like this. | ||
2: [48, 120, 352], | ||
} | ||
|
||
new_num_anchors = len(kwargs.get('ratios', [])) * len(kwargs.get('scales', [])) | ||
if new_num_anchors > 0: | ||
num_anchors = new_num_anchors | ||
else: | ||
num_anchors = len(self.aspect_ratios) * self.num_scales | ||
|
||
self.bifpn = nn.Sequential( | ||
*[BiFPN(self.fpn_num_filters[self.compound_coef], | ||
conv_channel_coef[compound_coef], | ||
True if _ == 0 else False) for _ in range(self.fpn_cell_repeats[compound_coef])]) | ||
|
||
self.num_classes = num_classes | ||
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, | ||
num_layers=self.box_class_repeats[self.compound_coef]) | ||
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, | ||
num_classes=num_classes, | ||
num_layers=self.box_class_repeats[self.compound_coef]) | ||
|
||
self.anchors = Anchors(image_size=self.input_sizes[compound_coef], **kwargs) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
self.backbone_net = EfficientNet(compound_coef, load_weights) | ||
|
||
def freeze_bn(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.BatchNorm2d): | ||
m.eval() | ||
|
||
def forward(self, inputs): | ||
max_size = inputs.shape[-1] | ||
|
||
_, p3, p4, p5 = self.backbone_net(inputs) | ||
|
||
features = (p3, p4, p5) | ||
features = self.bifpn(features) | ||
|
||
regression = self.regressor(features) | ||
classification = self.classifier(features) | ||
anchors = self.anchors(inputs, inputs.dtype) | ||
|
||
return features, regression, classification, anchors | ||
|
||
def init_backbone(self, path): | ||
state_dict = torch.load(path) | ||
try: | ||
ret = self.load_state_dict(state_dict, strict=False) | ||
print(ret) | ||
except RuntimeError as e: | ||
print('Ignoring ' + str(e) + '"') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", | ||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", | ||
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", | ||
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", | ||
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", | ||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", | ||
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", | ||
"bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", | ||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", | ||
"teddy bear", "hair drier", "toothbrush"] | ||
|
||
colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122), | ||
(80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17), | ||
(82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60), | ||
(16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34), | ||
(53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181), | ||
(42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108), | ||
(52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26), | ||
(122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50), | ||
(56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33), | ||
(105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91), | ||
(31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130), | ||
(31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3), | ||
(129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108), | ||
(81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95), | ||
(2, 20, 184), (122, 37, 185)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import os | ||
import torch | ||
import numpy as np | ||
|
||
from torch.utils.data import Dataset, DataLoader | ||
from pycocotools.coco import COCO | ||
import cv2 | ||
|
||
|
||
class CocoDataset(Dataset): | ||
def __init__(self, root_dir, set='train2017', transform=None): | ||
|
||
self.root_dir = root_dir | ||
self.set_name = set | ||
self.transform = transform | ||
|
||
self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json')) | ||
self.image_ids = self.coco.getImgIds() | ||
|
||
self.load_classes() | ||
|
||
def load_classes(self): | ||
|
||
# load class names (name -> label) | ||
categories = self.coco.loadCats(self.coco.getCatIds()) | ||
categories.sort(key=lambda x: x['id']) | ||
|
||
self.classes = {} | ||
self.coco_labels = {} | ||
self.coco_labels_inverse = {} | ||
for c in categories: | ||
self.coco_labels[len(self.classes)] = c['id'] | ||
self.coco_labels_inverse[c['id']] = len(self.classes) | ||
self.classes[c['name']] = len(self.classes) | ||
|
||
# also load the reverse (label -> name) | ||
self.labels = {} | ||
for key, value in self.classes.items(): | ||
self.labels[value] = key | ||
|
||
def __len__(self): | ||
return len(self.image_ids) | ||
|
||
def __getitem__(self, idx): | ||
|
||
img = self.load_image(idx) | ||
annot = self.load_annotations(idx) | ||
sample = {'img': img, 'annot': annot} | ||
if self.transform: | ||
sample = self.transform(sample) | ||
return sample | ||
|
||
def load_image(self, image_index): | ||
image_info = self.coco.loadImgs(self.image_ids[image_index])[0] | ||
path = os.path.join(self.root_dir, self.set_name, image_info['file_name']) | ||
img = cv2.imread(path) | ||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
|
||
return img.astype(np.float32) / 255. | ||
|
||
def load_annotations(self, image_index): | ||
# get ground truth annotations | ||
annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) | ||
annotations = np.zeros((0, 5)) | ||
|
||
# some images appear to miss annotations | ||
if len(annotations_ids) == 0: | ||
return annotations | ||
|
||
# parse annotations | ||
coco_annotations = self.coco.loadAnns(annotations_ids) | ||
for idx, a in enumerate(coco_annotations): | ||
|
||
# some annotations have basically no width / height, skip them | ||
if a['bbox'][2] < 1 or a['bbox'][3] < 1: | ||
continue | ||
|
||
annotation = np.zeros((1, 5)) | ||
annotation[0, :4] = a['bbox'] | ||
annotation[0, 4] = self.coco_label_to_label(a['category_id']) | ||
annotations = np.append(annotations, annotation, axis=0) | ||
|
||
# transform from [x, y, w, h] to [x1, y1, x2, y2] | ||
annotations[:, 2] = annotations[:, 0] + annotations[:, 2] | ||
annotations[:, 3] = annotations[:, 1] + annotations[:, 3] | ||
|
||
return annotations | ||
|
||
def coco_label_to_label(self, coco_label): | ||
return self.coco_labels_inverse[coco_label] | ||
|
||
def label_to_coco_label(self, label): | ||
return self.coco_labels[label] | ||
|
||
def num_classes(self): | ||
return 80 | ||
|
||
|
||
def collater(data): | ||
imgs = [s['img'] for s in data] | ||
annots = [s['annot'] for s in data] | ||
scales = [s['scale'] for s in data] | ||
|
||
imgs = torch.from_numpy(np.stack(imgs, axis=0)) | ||
|
||
max_num_annots = max(annot.shape[0] for annot in annots) | ||
|
||
if max_num_annots > 0: | ||
|
||
annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 | ||
|
||
if max_num_annots > 0: | ||
for idx, annot in enumerate(annots): | ||
if annot.shape[0] > 0: | ||
annot_padded[idx, :annot.shape[0], :] = annot | ||
else: | ||
annot_padded = torch.ones((len(annots), 1, 5)) * -1 | ||
|
||
imgs = imgs.permute(0, 3, 1, 2) | ||
|
||
return {'img': imgs, 'annot': annot_padded, 'scale': scales} | ||
|
||
|
||
class Resizer(object): | ||
"""Convert ndarrays in sample to Tensors.""" | ||
|
||
def __call__(self, sample, common_size=512): | ||
image, annots = sample['img'], sample['annot'] | ||
height, width, _ = image.shape | ||
if height > width: | ||
scale = common_size / height | ||
resized_height = common_size | ||
resized_width = int(width * scale) | ||
else: | ||
scale = common_size / width | ||
resized_height = int(height * scale) | ||
resized_width = common_size | ||
|
||
image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR) | ||
|
||
new_image = np.zeros((common_size, common_size, 3)) | ||
new_image[0:resized_height, 0:resized_width] = image | ||
|
||
annots[:, :4] *= scale | ||
|
||
return {'img': torch.from_numpy(new_image), 'annot': torch.from_numpy(annots), 'scale': scale} | ||
|
||
|
||
class Augmenter(object): | ||
"""Convert ndarrays in sample to Tensors.""" | ||
|
||
def __call__(self, sample, flip_x=0.5): | ||
if np.random.rand() < flip_x: | ||
image, annots = sample['img'], sample['annot'] | ||
image = image[:, ::-1, :] | ||
|
||
rows, cols, channels = image.shape | ||
|
||
x1 = annots[:, 0].copy() | ||
x2 = annots[:, 2].copy() | ||
|
||
x_tmp = x1.copy() | ||
|
||
annots[:, 0] = cols - x2 | ||
annots[:, 2] = cols - x_tmp | ||
|
||
sample = {'img': image, 'annot': annots} | ||
|
||
return sample | ||
|
||
|
||
class Normalizer(object): | ||
|
||
def __init__(self): | ||
self.mean = np.array([[[0.485, 0.456, 0.406]]]) | ||
self.std = np.array([[[0.229, 0.224, 0.225]]]) | ||
|
||
def __call__(self, sample): | ||
image, annots = sample['img'], sample['annot'] | ||
|
||
return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots} |
Oops, something went wrong.