Skip to content

Commit

Permalink
Merge pull request #81 from talebolano/feature/fix_ddp
Browse files Browse the repository at this point in the history
Fix: fix ddp
  • Loading branch information
Peterande authored Nov 27, 2024
2 parents aeebefe + 9b949c6 commit b9da032
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/zoo/dfine/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def get_contrastive_denoising_training_group(targets,

max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
dn_meta = {
"dn_positive_idx": None,
"dn_num_group": 0,
"dn_num_split": [0, num_queries]
}
return None, None, None, dn_meta

num_group = num_denoising // max_gt_num
num_group = 1 if num_group == 0 else num_group
Expand Down
5 changes: 4 additions & 1 deletion src/zoo/dfine/dfine_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def loss_local(self, outputs, targets, indices, num_boxes, T=5):
if 'teacher_corners' in outputs:
pred_corners = outputs['pred_corners'].reshape(-1, (self.reg_max+1))
target_corners = outputs['teacher_corners'].reshape(-1, (self.reg_max+1))
if not torch.equal(pred_corners, target_corners):
if torch.equal(pred_corners, target_corners):
losses['loss_ddf'] = pred_corners.sum() * 0
else:
weight_targets_local = outputs['teacher_logits'].sigmoid().max(dim=-1)[0]

mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
Expand Down Expand Up @@ -332,6 +334,7 @@ def forward(self, outputs, targets, **kwargs):
assert 'dn_meta' in outputs, ''
indices_dn = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
dn_num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
dn_num_boxes = dn_num_boxes if dn_num_boxes > 0 else 1

for i, aux_outputs in enumerate(outputs['dn_outputs']):
aux_outputs['is_dn'] = True
Expand Down

0 comments on commit b9da032

Please sign in to comment.