-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmodels.py
176 lines (151 loc) · 5.47 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Models
Fred Zhang <frederic.zhang@anu.edu.au>
The Australian National University
Australian Centre for Robotic Vision
"""
import torch
import torchvision.ops.boxes as box_ops
from torch import nn, Tensor
from torchvision.ops._utils import _cat
from typing import Optional, List, Tuple
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import transform
import pocket.models as models
from transforms import HOINetworkTransform
from interaction_head import InteractionHead, GraphHead
class GenericHOINetwork(nn.Module):
"""A generic architecture for HOI classification
Parameters:
-----------
backbone: nn.Module
interaction_head: nn.Module
transform: nn.Module
postprocess: bool
If True, rescale bounding boxes to original image size
"""
def __init__(self,
backbone: nn.Module, interaction_head: nn.Module,
transform: nn.Module, postprocess: bool = True
) -> None:
super().__init__()
self.backbone = backbone
self.interaction_head = interaction_head
self.transform = transform
self.postprocess = postprocess
def preprocess(self,
images: List[Tensor],
detections: List[dict],
targets: Optional[List[dict]] = None
) -> Tuple[
List[Tensor], List[dict],
List[dict], List[Tuple[int, int]]
]:
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
for det, o_im_s, im_s in zip(
detections, original_image_sizes, images.image_sizes
):
boxes = det['boxes']
boxes = transform.resize_boxes(boxes, o_im_s, im_s)
det['boxes'] = boxes
return images, detections, targets, original_image_sizes
def forward(self,
images: List[Tensor],
detections: List[dict],
targets: Optional[List[dict]] = None
) -> List[dict]:
"""
Parameters:
-----------
images: List[Tensor]
detections: List[dict]
targets: List[dict]
Returns:
--------
results: List[dict]
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
images, detections, targets, original_image_sizes = self.preprocess(
images, detections, targets)
features = self.backbone(images.tensors)
results = self.interaction_head(features, detections,
images.image_sizes, targets)
if self.postprocess and results is not None:
return self.transform.postprocess(
results,
images.image_sizes,
original_image_sizes
)
else:
return results
class SpatiallyConditionedGraph(GenericHOINetwork):
def __init__(self,
object_to_action: List[list],
human_idx: int,
# Backbone parameters
backbone_name: str = "resnet50",
pretrained: bool = True,
# Pooler parameters
output_size: int = 7,
sampling_ratio: int = 2,
# Box pair head parameters
node_encoding_size: int = 1024,
representation_size: int = 1024,
num_classes: int = 117,
box_score_thresh: float = 0.2,
fg_iou_thresh: float = 0.5,
num_iterations: int = 2,
distributed: bool = False,
# Transformation parameters
min_size: int = 800, max_size: int = 1333,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
postprocess: bool = True,
# Preprocessing parameters
box_nms_thresh: float = 0.5,
max_human: int = 15,
max_object: int = 15
) -> None:
detector = models.fasterrcnn_resnet_fpn(backbone_name,
pretrained=pretrained)
backbone = detector.backbone
box_roi_pool = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=output_size,
sampling_ratio=sampling_ratio
)
box_pair_head = GraphHead(
out_channels=backbone.out_channels,
roi_pool_size=output_size,
node_encoding_size=node_encoding_size,
representation_size=representation_size,
num_cls=num_classes,
human_idx=human_idx,
object_class_to_target_class=object_to_action,
fg_iou_thresh=fg_iou_thresh,
num_iter=num_iterations
)
box_pair_predictor = nn.Linear(representation_size * 2, num_classes)
box_pair_suppressor = nn.Linear(representation_size * 2, 1)
interaction_head = InteractionHead(
box_roi_pool=box_roi_pool,
box_pair_head=box_pair_head,
box_pair_suppressor=box_pair_suppressor,
box_pair_predictor=box_pair_predictor,
num_classes=num_classes,
human_idx=human_idx,
box_nms_thresh=box_nms_thresh,
box_score_thresh=box_score_thresh,
max_human=max_human,
max_object=max_object,
distributed=distributed
)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = HOINetworkTransform(min_size, max_size,
image_mean, image_std)
super().__init__(backbone, interaction_head, transform, postprocess)