-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsolver_gridms_multiple_kps_covar2.py
3899 lines (3384 loc) · 253 KB
/
solver_gridms_multiple_kps_covar2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import datetime
import time
import os
import cv2
import copy
import logging
import datasets.dataset_utils
import datasets.transforms as mytransforms
from datasets.AnimalPoseDataset.animalpose_dataset import EpisodeGenerator, AnimalPoseDataset, KEYPOINT_TYPE_IDS, KEYPOINT_TYPES, save_episode_before_preprocess, kp_connections
from datasets.AnimalPoseDataset.animalpose_dataset import horizontal_swap_keypoints, get_symmetry_keypoints, HFLIP, FBFLIP, DFLIP, get_auxiliary_paths
from datasets.dataset_utils import draw_skeletons, draw_instance, draw_markers
from network.models_gridms2 import Encoder, EncoderMultiScaleType1, EncoderMultiScaleType2
from network.models_gridms2 import feature_modulator2, extract_representations, average_representations, average_representations2, feature_modulator3
from network.models_gridms2 import DescriptorNet
from network.models_gridms2 import RegressorVanila, GridBasedLocator, GridBasedLocator2, GridBasedLocator3, GridBasedLocator4, GridBasedLocatorX, GridBasedLocatorX2, GridBasedLocatorX3, GridBasedLocatorXCovar, GridBasedLocatorX2Covar
from network.models_gridms2 import CovarNet, OffsetLearningNet
from network.models_gridms2 import PatchUncertaintyModule
from network.models_gridms2 import ClassifierGRL
from network.functions import ReverseLayerF
from loss import masked_mse_loss, masked_l1_loss, masked_nll_gaussian, masked_nll_laplacian, instance_weighted_nllloss, instance_weighted_nlllossV0, masked_nll_gaussian_covar, masked_nll_gaussian_covar2
from coco_eval_funs import compute_recall_ap
from utils import print_weights, image_normalize, make_grid_images, make_uncertainty_map, compute_eigenvalues, mean_confidence_interval, mean_confidence_interval_multiple
# torch.autograd.set_detect_anomaly(True)
# import pdb
class FSLKeypointNet(object):
def __init__(self, episode_generator: EpisodeGenerator, opts: dict, logging = None, writer: SummaryWriter = None,
*episode_generator_test: EpisodeGenerator):
self.episode_generator = episode_generator
self.opts = opts
self.logging = logging
self.writer = writer
if len(episode_generator_test) == 2:
self.episode_generator_test = episode_generator_test[0]
self.episode_generator_test2 = episode_generator_test[1]
elif len(episode_generator_test) == 1:
self.episode_generator_test = episode_generator_test[0]
self.episode_generator_test2 = None
else:
print('error for the input of episode_generator_test!')
exit(0)
# B x 512 x 46 x 46 (1/8)
# B x 1024 x 23 x 23 (1/16)
# B x 2048 x 12 x 12 (1/32)
# B x 3072 x 8 x 8 (1/46)
# B x 4096 x 6 x 6 (1/64)
self.encoder_type = 0 # 0, Encoder, EncoderMultiScaleType1; 1, EncoderMultiScaleType2
self.encoder = Encoder(trunk='resnet50', layer_to_freezing=opts['layer_to_freezing'], downsize_factor=opts['downsize_factor'], use_ppm=False, ppm_remain_dim=True, specify_output_channel=None, lateral_output_mode_layer=(1, 6)) # default, 2048
print(self.encoder)
if opts['use_domain_confusion']:
self.auxiliary_classifier = ClassifierGRL(input_channel=2048, output_classes=4, architecture_type='type3')
self.auxiliary_classifier_kp = ClassifierGRL(input_channel=1, output_classes=opts['N_way'], architecture_type='type3')
print(self.auxiliary_classifier)
print(self.auxiliary_classifier_kp)
# used in feature modulator
# self.context_mode = 'hard_fiber'
# self.context_mode = 'soft_fiber_bilinear'
self.context_mode = 'soft_fiber_gaussian'
# self.context_mode = 'soft_fiber_gaussian2'
# input feature map: 1024 x 23 x 23 (1/16), 2048 x 12 x 12 (1/32), 3072 x 8 x 8 (1/46), 4096 x 6 x 6 (1/64)
self.descriptor_net = DescriptorNet(2048, 3, output_fiber=True, specify_fc_out_units=None, architecture_type='type3', net_config=(512, 1024))
print(self.descriptor_net)
self.regression_type = 'direct_regression_gridMSE'
if self.regression_type == 'direct_regression_gridMSE':
self.grid_length = opts['grid_length'] # [4, 8, 12]
modules = []
covar_branch_module = []
for grid_scale in self.grid_length:
# a) independent variance or covariance for each kp
modules += [GridBasedLocatorX2Covar(grid_length=grid_scale, reg_uncertainty=True, cls_uncertainty=False, ureg_type=1, covar_Q_size=(2,3), sample_times=opts['sample_times'], negative_noise=False, sigma_computing_mode=0, alpha=5, beta=0.2)]
# b) covariance for all kps
# covar_branch_module += [CovarNet(grid_length=grid_scale, covar_Q_size=(2, 2, 2, 3))] # covar for 2 keypoints
covar_branch_module += [CovarNet(grid_length=grid_scale, covar_Q_size=(3, 3, 2, 3))] # covar for 3 keypoints
self.regressor = nn.ModuleList(modules) # main regressor
self.covar_branch = nn.ModuleList(covar_branch_module)
print(self.regressor)
# if self.opts['use_pum'] == True:
# self.pum = PatchUncertaintyModule(input_channel=2048 * 2, conv_layers=2) # 512, 2048, 3072, 4096
# loading model based on the configurations said in self.opts
self.load_model()
self.loss_fun_mse = nn.MSELoss(reduction='sum') # Used in keypoint regression
# loss_fun_mse = nn.L1Loss(reduction='sum')
self.loss_fun_nll = nn.NLLLoss(ignore_index=-1) # Used in animal class classification
if torch.cuda.is_available():
self.loss_fun_mse = self.loss_fun_mse.cuda()
self.loss_fun_nll = self.loss_fun_nll.cuda()
self.optimizer_init(lr=0.0001, lr_auxiliary = 0.0001, weight_decay=0, optimization_algorithm='Adam')
# self.optimizer_alex = optim.Adam(self.alexnet.parameters(), lr=0.0001)
self.recall_stack=[0, 0]
self.recall_best = 0 # used to record the best recall
if self.opts['use_body_part_protos']:
if self.opts['load_proto']:
self.load_proto_memory()
else:
# compute the number of interpolated kps
if self.opts['use_interpolated_kps'] == True:
auxiliary_paths = get_auxiliary_paths(self.opts['auxiliary_path_mode'], self.episode_generator_test.support_kp_categories)
N_paths = len(auxiliary_paths)
N_knots = len(self.opts['interpolation_knots'])
T = N_paths * N_knots
else:
T = 0
self.init_proto_memory(stat_episode_num=100, fiber_dim=2048, part_num=len(self.episode_generator_test.support_kp_categories), auxiliary_kp_num=T, method=self.opts['proto_compute_method'])
if self.opts['memorize_fibers']:
self.memorized_episode_cnt = 0
torch.set_grad_enabled(False) # disable grad computation when recording fibers
print('Stop grad computation!')
self.init_fiber_memory()
def init_proto_memory(self, stat_episode_num=300, fiber_dim=2048, part_num=15, auxiliary_kp_num=0, method='ws'):
self.memory={}
self.memory['stat_episode_num'] = stat_episode_num
self.memory['fiber_dim'] = fiber_dim
self.memory['part_num'] = part_num
self.memory['auxiliary_kp_num'] = auxiliary_kp_num
self.memory['proto_compute_method'] = method # 'm': mean; 'ws': weighted_sum
self.memory['proto'] = torch.zeros(fiber_dim, part_num, requires_grad=False).cuda() # C x N
self.memory['proto_mask'] = torch.ones(part_num, requires_grad=False).cuda() # N
self.memory['aux_proto'] = torch.zeros(fiber_dim, auxiliary_kp_num, requires_grad=False).cuda() # C x T
self.memory['aux_proto_mask'] = torch.ones(auxiliary_kp_num, requires_grad=False).cuda() # T
def init_fiber_memory(self):
episode_num = self.memory['stat_episode_num']
fiber_dim = self.memory['fiber_dim']
part_num = self.memory['part_num']
auxiliary_kp_num = self.memory['auxiliary_kp_num']
self.memory['fibers'] = torch.zeros(episode_num, fiber_dim, part_num, requires_grad=False).cuda() # K x C x N, N parts
self.memory['distance'] = torch.zeros(episode_num, part_num, requires_grad=False).cuda() # K x N
self.memory['mask'] = torch.zeros(episode_num, part_num, requires_grad=False).cuda() # K x N
if auxiliary_kp_num > 0:
self.memory['aux_fibers'] = torch.zeros(episode_num, fiber_dim, auxiliary_kp_num, requires_grad=False).cuda() # K x C x N, N parts
self.memory['aux_distance'] = torch.zeros(episode_num, auxiliary_kp_num, requires_grad=False).cuda() # K x N
self.memory['aux_mask'] = torch.zeros(episode_num, auxiliary_kp_num, requires_grad=False).cuda() # K x N
def save_proto_memory(self, path=None):
if path == None:
# e.g., 'fiber_protos_and_mask_ws.pt', 'fiber_protos_and_mask_m.pt'
path = 'animal_fiber_protos_and_mask_' + self.memory['proto_compute_method'] + '.pt'
memory_dict = {}
memory_dict['stat_episode_num'] = self.memory['stat_episode_num']
memory_dict['proto_compute_method'] = self.memory['proto_compute_method'] # 'm': mean; 'ws': weighted_sum
memory_dict['proto'] = self.memory['proto'].cpu()
memory_dict['proto_mask'] = self.memory['proto_mask'].cpu()
memory_dict['aux_proto'] = self.memory['aux_proto'].cpu()
memory_dict['aux_proto_mask'] = self.memory['aux_proto_mask'].cpu()
torch.save(memory_dict, path)
def load_proto_memory(self, path=None):
if path == None:
# e.g., 'fiber_protos_and_mask_ws.pt', 'fiber_protos_and_mask_m.pt'
path = 'animal_fiber_protos_and_mask_' + self.opts['proto_compute_method'] + '.pt'
self.memory = {}
memory_dict = torch.load(path)
self.memory['proto'] = memory_dict['proto'].cuda()
self.memory['proto_mask'] = memory_dict['proto_mask'].cuda()
self.memory['aux_proto'] = memory_dict['aux_proto'].cuda()
self.memory['aux_proto_mask'] = memory_dict['aux_proto_mask'].cuda()
self.memory['stat_episode_num'] = memory_dict['stat_episode_num']
fiber_dim, part_num = (self.memory['proto']).shape
auxiliary_kp_num = len(self.memory['aux_proto_mask'])
self.memory['fiber_dim'] = fiber_dim
self.memory['part_num'] = part_num
self.memory['auxiliary_kp_num'] = auxiliary_kp_num
self.memory['proto_compute_method'] = self.opts['proto_compute_method'] # 'm': mean; 'ws': weighted_sum
def train(self, *multiple_episode_generators: EpisodeGenerator):
episode_i = 0
sample_failure_cnt = 0 # count failedly sampled episodes
using_multiple_episodes = False
sample_failure_cnt2 = 0 # count those wrongly labeled images
using_interpolated_kps = self.opts['use_interpolated_kps']
interpolation_knots = self.opts['interpolation_knots']
alpha = 0
loss_adapt_total, loss_adapt_kp_total = 0, 0
loss_symmetry_total = 0
loss_interpolation_total = 0
loss_total = 0
# t_count = np.array([0]*7, dtype=np.float)
if len(multiple_episode_generators) > 0: # if multiple_episode_generators is not empty
using_multiple_episodes = True
num_multiple_episode = len(multiple_episode_generators)
prob = np.array([len(multiple_episode_generators[i].samples) for i in range(num_multiple_episode)])
prob = prob / np.sum(prob)
while episode_i < self.opts['num_episodes']:
if episode_i % 1600 == 0 and episode_i >= 0:
eval_results = self.validate(self.episode_generator_test, eval_method=self.opts['eval_method'])
recall = eval_results[0] # parse eval results
# recall, _, recall_aux, _ = self.validate2(self.episode_generator_test, eval_method=self.opts['eval_method']) # testing for aux kps
if self.episode_generator_test2 != None:
eval_results2 = self.validate(self.episode_generator_test2, eval_method=self.opts['eval_method'])
recall2 = eval_results2[0] # parse eval results
if self.writer != None:
self.writer.add_scalar('accuracy', recall[0], episode_i) # recall is a list which is corresponding to different thresholds
if self.episode_generator_test2 != None:
self.writer.add_scalar('accuracy2', recall2[0], episode_i)
# save model based on the configurations said in self.opts
self.recall_stack[0] = self.recall_stack[1]
self.recall_stack[1] = recall2[0]
avg_recall = np.mean(self.recall_stack)
if avg_recall > self.recall_best:
if episode_i >= 400:
self.save_model()
self.recall_best = avg_recall
print('BEST:', self.recall_best)
if self.opts['use_body_part_protos'] and self.opts['memorize_fibers']:
# disable grad computation when recording fibers
# since each time the grad will be enabled after using self.validate, we have to put stop grad function here
torch.set_grad_enabled(False)
# roll-out an episode
if using_multiple_episodes == False: # training by using single episode generator, which is default case
episode_generator = self.episode_generator
else:
random_episode_ind = np.random.randint(0, num_multiple_episode, 1)
# random_episode_ind = np.random.choice(range(0, num_multiple_episode), size=1, p=prob)
# print('index: ', random_episode_ind)
episode_generator = multiple_episode_generators[random_episode_ind[0]]
while (False == episode_generator.episode_next()):
sample_failure_cnt += 1
if sample_failure_cnt % 500 == 0:
print('sample failure times: {}'.format(sample_failure_cnt))
continue
# print(episode_generator.support_kp_categories)
preprocess = mytransforms.Compose([
# color transform
# mytransforms.RandomApply(mytransforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), p=0.8),
# mytransforms.RandomGrayscale(p=0.01),
# geometry transform
mytransforms.RandomApply(mytransforms.HFlip(swap=horizontal_swap_keypoints), p=0.5), # 0.5
mytransforms.RandomApply(mytransforms.RandomRotation(max_rotate_degree=15), p=0.25), # 0.25
mytransforms.RelativeResize((0.75, 1.25)),
mytransforms.RandomCrop(crop_bbox=False),
# mytransforms.RandomApply(mytransforms.RandomTranslation(), p=0.5),
mytransforms.Resize(longer_length=self.opts['square_image_length']), # 368
mytransforms.CenterPad(target_size=self.opts['square_image_length']),
mytransforms.CoordinateNormalize(normalize_keypoints=True, normalize_bbox=False)
])
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# define a list containing the paths, each path is represented by kp index pair [index1, index2]
# our paths are subject to support keypoints; then we will interpolated kps for each path. The paths is possible to be empty list []
num_random_paths = self.opts['num_random_paths'] # only used when auxiliary_path_mode='random'
# path_mode: 'exhaust', 'predefined', 'random'
auxiliary_paths = get_auxiliary_paths(path_mode=self.opts['auxiliary_path_mode'], support_keypoint_categories=episode_generator.support_kp_categories, num_random_paths=num_random_paths)
support_dataset = AnimalPoseDataset(episode_generator.supports,
episode_generator.support_kp_categories,
using_auxiliary_keypoints=using_interpolated_kps,
interpolation_knots=interpolation_knots,
interpolation_mode=self.opts['interpolation_mode'],
auxiliary_path=auxiliary_paths,
hdf5_images_path=self.opts['hdf5_images_path'],
saliency_maps_root=self.opts['saliency_maps_root'],
output_saliency_map=self.opts['use_pum'],
preprocess=preprocess,
input_transform=image_transform
)
query_dataset = AnimalPoseDataset(episode_generator.queries,
episode_generator.support_kp_categories,
using_auxiliary_keypoints=using_interpolated_kps,
interpolation_knots=interpolation_knots,
interpolation_mode=self.opts['interpolation_mode'],
auxiliary_path=auxiliary_paths,
hdf5_images_path=self.opts['hdf5_images_path'],
saliency_maps_root=self.opts['saliency_maps_root'],
output_saliency_map=self.opts['use_pum'],
preprocess=preprocess,
input_transform=image_transform
)
support_loader = DataLoader(support_dataset, batch_size=self.opts['K_shot'], shuffle=False)
query_loader = DataLoader(query_dataset, batch_size=self.opts['M_query'], shuffle=False)
support_loader_iter = iter(support_loader)
query_loader_iter = iter(query_loader)
(supports, support_labels, support_kp_mask, _, support_aux_kps, support_aux_kp_mask, support_saliency, _, _) = support_loader_iter.next()
(queries, query_labels, query_kp_mask, _, query_aux_kps, query_aux_kp_mask, query_saliency, _, _) = query_loader_iter.next()
# print_weights(support_kp_mask)
# print_weights(query_kp_mask)
# make_grid_images(supports, denormalize=True, save_path='grid_image_s.jpg')
# make_grid_images(queries, denormalize=True, save_path='grid_image_q.jpg')
# make_grid_images(support_saliency.cuda(), denormalize=False, save_path='./ss.jpg')
# make_grid_images(query_saliency.cuda(), denormalize=False, save_path='./sq.jpg')
# print(episode_generator.supports)
## 'exhaust', 'predefined'
# save_episode_before_preprocess(episode_generator, episode_i, delete_old_files=False, draw_interpolated_kps=using_interpolated_kps, interpolation_knots=interpolation_knots, interpolation_mode=self.opts['interpolation_mode'], path_mode='predefined')
# show_save_episode(supports, support_labels, support_kp_mask, queries, query_labels, query_kp_mask, episode_generator, episode_i,
# support_aux_kps, support_aux_kp_mask, query_aux_kps, query_aux_kp_mask, is_show=False, is_save=True, delete_old_files=False)
# show_save_episode(supports, support_labels, support_kp_mask, queries, query_labels, query_kp_mask, episode_generator, episode_i,
# support_aux_kps=None, support_aux_kp_mask=None, query_aux_kps=None, query_aux_kp_mask=None, is_show=False, is_save=True, delete_old_files=False)
# print_weights(episode_generator.support_kp_mask)
# print_weights(episode_generator.query_kp_mask)
# support_kp_mask = episode_generator.support_kp_mask # B1 x N
# query_kp_mask = episode_generator.query_kp_mask # B2 x N
# if torch.cuda.is_available():
supports, queries = supports.cuda(), queries.cuda() # B1 x C x H x W, B2 x C x H x W
support_labels, query_labels = support_labels.float().cuda(), query_labels.float().cuda() # B1 x N x 2, B2 x N x 2
# support_labels.requires_grad = True
support_kp_mask = support_kp_mask.cuda() # B1 x N
query_kp_mask = query_kp_mask.cuda() # B2 x N
if using_interpolated_kps:
support_aux_kps = support_aux_kps.float().cuda() # B1 x T x 2, T = (N_paths * N_knots), total number of auxiliary keypoints
support_aux_kp_mask = support_aux_kp_mask.cuda() # B1 x T
query_aux_kps = query_aux_kps.float().cuda() # B2 x T x 2
query_aux_kp_mask = query_aux_kp_mask.cuda() # B2 x T
# compute the union of keypoint types in sampled images, N(union) <= N_way, tensor([True, False, True, ...])
union_support_kp_mask = torch.sum(support_kp_mask, dim=0) > 0 # N
# compute the valid query keypoints, using broadcast
valid_kp_mask = (query_kp_mask * union_support_kp_mask.reshape(1, -1)) # B2 x N
num_valid_kps = torch.sum(valid_kp_mask)
if using_interpolated_kps:
num_valid_support_aux_kps = torch.sum(support_aux_kp_mask) # only for support auxiliary kps
union_support_aux_kp_mask = torch.sum(support_aux_kp_mask, dim=0) > 0 # T
valid_aux_kp_mask = query_aux_kp_mask * union_support_aux_kp_mask.reshape(1, -1) # B x T
num_valid_aux_kps = torch.sum(valid_aux_kp_mask)
# print(num_valid_kps+num_valid_aux_kps)
num_valid_kps_for_samples = torch.sum(valid_kp_mask, dim=1) # B2
gp_valid_samples_mask = num_valid_kps_for_samples >= episode_generator.least_query_kps_num # B2
gp_num_valid_kps = torch.sum(num_valid_kps_for_samples * gp_valid_samples_mask)
gp_valid_kp_mask = valid_kp_mask * gp_valid_samples_mask.reshape(-1, 1) # B2 x N
# check #1 there may exist some wrongly labeled images where the keypoints are outside the boundary
if torch.any(support_labels > 1) or torch.any(support_labels < -1) or torch.any(query_labels > 1) or torch.any(query_labels < -1):
sample_failure_cnt2 += 1
if (sample_failure_cnt2 % 50 == 0):
print('count for wrongly labelled images: {}'.format(sample_failure_cnt))
continue # skip current episode directly
# check #2
if num_valid_kps == 0: # flip transform may lead zero intersecting keypoints between support and query, namely zero valid kps
continue # skip current episode directly
if self.encoder_type == 0: # Encoder, EncoderMultiScaleType1
# feature and semantic distinctiveness, note that x2 = support_saliency.cuda() or query_saliency.cuda()
support_features, support_lateral_out = self.encoder(x=supports, x2=None, enable_lateral_output=True) # B1 x C x H x W, B1 x 1 x H x W
query_features, query_lateral_out = self.encoder(x=queries, x2=None, enable_lateral_output=True) # B2 x C x H x W, B2 x 1 x H x W
if self.opts['use_pum']:
# computing_inv_w = True
w_computing_mode = 2
p_support_lateral_out = self.numerical_transformation(support_lateral_out, w_computing_mode=w_computing_mode) # B1 x 1 x H' x W', transform into positive number
p_query_lateral_out = self.numerical_transformation(query_lateral_out, w_computing_mode=w_computing_mode) # B2 x 1 x H' x W', transform into positive number
# B1 x C x N
support_repres, conv_query_features = extract_representations(support_features, support_labels, support_kp_mask, context_mode=self.context_mode,
sigma=self.opts['sigma'], downsize_factor=self.opts['downsize_factor'], image_length=self.opts['square_image_length'], together_trans_features=None)
avg_support_repres = average_representations2(support_repres, support_kp_mask) # C x N
attentive_features, attention_maps_l2, attention_maps_l1, similarity_map = feature_modulator3(avg_support_repres, query_features, \
fused_attention=self.opts['use_fused_attention'], output_attention_maps=False, compute_similarity=False)
# loss_similarity = torch.tensor(0) # self.compute_similarity_loss(similarity_map, query_labels, valid_kp_mask)
# show_save_attention_maps(attention_maps_l2, queries, support_kp_mask, query_kp_mask, episode_generator.support_kp_categories,
# episode_num=episode_i, is_show=False, is_save=True, delete_old=False, T=query_scale_trans)
# show_save_attention_maps(attention_maps_l1, queries, support_kp_mask, query_kp_mask, episode_generator.support_kp_categories,
# episode_num=episode_i, is_show=False, is_save=True, save_image_root='./attention_maps2', delete_old=False, T=query_scale_trans)
if using_interpolated_kps and num_valid_support_aux_kps > 0:
# B1 x C x M
support_repres_aux, _ = extract_representations(support_features, support_aux_kps, support_aux_kp_mask, context_mode=self.context_mode,
sigma=self.opts['sigma'], downsize_factor=self.opts['downsize_factor'], image_length=self.opts['square_image_length'])
avg_support_repres_aux = average_representations2(support_repres_aux, support_aux_kp_mask)
attentive_features_aux, _, _, _ = feature_modulator3(avg_support_repres_aux, query_features, \
fused_attention=self.opts['use_fused_attention'], output_attention_maps=False, compute_similarity=False)
if self.opts['use_pum']: # extracting semantic distinctiveness (SD) per keypoint (i.e., modeling semantic uncertainty)
# B1 x C x T
# support_saliency_repres_aux, _ = extract_representations(support_saliency_features, support_aux_kps, support_aux_kp_mask, context_mode=self.context_mode,
# sigma=self.opts['sigma'], downsize_factor=self.opts['downsize_factor'], image_length=self.opts['square_image_length'])
# support_patches_features = torch.cat([support_repres_aux, support_saliency_repres_aux], dim=1) # B1 x (2C) x T
# patches_rho_aux_origin = self.pum(support_patches_features) # B1 x T, T is the number of interpolated keypoints
# # patches_rho may need to do average to get T values, here we asssume the shot number is 1
# # patches_rho_aux = patches_rho_aux_origin.mean(dim=0) # T
# B1, T = patches_rho_aux_origin.shape
# patches_rho_aux = average_representations(patches_rho_aux_origin.view(B1, 1, T), support_aux_kp_mask).squeeze()
# regarding main training kps
# inv_w_patches, w_patches = self.get_distinctiveness_for_parts(p_support_lateral_out, p_query_lateral_out, support_labels, support_kp_mask, query_labels)
inv_w_patches = None
# p = float(episode_i) / self.opts['num_episodes']
# alpha = 2. / (1. + np.exp(-10 * p)) - 1 # ranges from 0 to 1
# # support_labels_for_pum = ReverseLayerF.apply(support_labels, alpha)
# # query_labels_for_pum = ReverseLayerF.apply(query_labels, alpha)
# support_labels_for_pum = support_labels.detach()
# query_labels_for_pum = query_labels.detach()
# inv_w_patches, w_patches = self.get_distinctiveness_for_parts(p_support_lateral_out, p_query_lateral_out, support_labels_for_pum, support_kp_mask, query_labels_for_pum)
# regarding auxiliary kps
if using_interpolated_kps and num_valid_support_aux_kps > 0:
inv_w_patches_aux, w_patches_aux = self.get_distinctiveness_for_parts(p_support_lateral_out, p_query_lateral_out, support_aux_kps, support_aux_kp_mask, query_aux_kps)
# # support_aux_kps_for_pum = ReverseLayerF.apply(support_aux_kps, alpha)
# # query_aux_kps_for_pum = ReverseLayerF.apply(query_aux_kps, alpha)
# support_aux_kps_for_pum = support_aux_kps.detach()
# query_aux_kps_for_pum = query_aux_kps.detach()
# inv_w_patches_aux, w_patches_aux = self.get_distinctiveness_for_parts(p_support_lateral_out, p_query_lateral_out, support_aux_kps_for_pum, support_aux_kp_mask, query_aux_kps_for_pum)
else:
inv_w_patches_aux, w_patches_aux = None, None
else:
# patches_rho_aux = None
inv_w_patches_aux, w_patches_aux = None, None
inv_w_patches, w_patches = None, None
elif self.encoder_type == 1: # can only converge when using feature_modulator2()
attentive_features = self.encoder(supports, queries, support_labels, support_kp_mask)
keypoint_descriptors = self.descriptor_net(attentive_features) # B2 x N x D or B2 x N x c x h x w
if using_interpolated_kps and num_valid_support_aux_kps > 0:
keypoint_descriptors_aux = self.descriptor_net(attentive_features_aux) # B2 x T x D or B2 x T x c x h x w
if self.regression_type == 'direct_regression_gridMSE':
B2 = self.opts['M_query']
N = self.opts['N_way']
w_dev = 1.0
if w_dev ==0 and self.opts['loss_weight'][2] == 0: # because cls uses w (semantic uncertainty) to weight loss
inv_w_patches = inv_w_patches_aux = None
# don't use interpolated kps or don't have interpolated kps
if using_interpolated_kps == False or (using_interpolated_kps == True and num_valid_support_aux_kps == 0):
loss_grid_class, loss_deviation, mean_predictions = self.multiscale_regression(keypoint_descriptors, query_labels, valid_kp_mask, self.grid_length, weight=None, inv_w_patches=inv_w_patches, w_patches=None)
loss = loss_grid_class + w_dev * loss_deviation
loss_interpolation = torch.tensor(0.).cuda()
loss_aux_grid, loss_aux_deviation = torch.tensor(0.).cuda(), torch.tensor(0.).cuda()
loss_covar_multiple_kps, loss_covar_multiple_kps_total = torch.tensor(0.).cuda(), torch.tensor(0.).cuda()
loss_covar_multiple_kps_aux = torch.tensor(0.).cuda()
loss_offset = torch.tensor(0.).cuda()
else: # using interpolated kps
T = valid_aux_kp_mask.shape[1] # T = (N_knots * N_paths)
N_curves = (int)(T / len(interpolation_knots))
# aux_kps_weights1 = torch.ones(2*3).cuda()
# aux_kps_weights2 = torch.tensor([0.75, 0.5, 0.75]).repeat(N_curves-2, 1).reshape(-1).cuda()
# aux_kps_weights2 = torch.tensor([0.75, 0.25, 0.75]).repeat(N_curves-2, 1).reshape(-1).cuda()
# aux_kps_weights = torch.cat([aux_kps_weights1, aux_kps_weights2], dim=0)
# aux_kps_weights = torch.tensor([0.75, 0.5, 0.75]).repeat(N_curves, 1).reshape(-1).cuda()
# aux_kps_weights = torch.tensor([0.75, 0.25, 0.75]).repeat(N_curves, 1).reshape(-1).cuda()
aux_kps_weights = None
loss_grid_class, loss_deviation, loss_aux_grid, loss_aux_deviation, loss_covar_multiple_kps, loss_covar_multiple_kps_aux, mean_predictions, mean_predictions_aux = \
self.multiscale_regression2(keypoint_descriptors, query_labels, valid_kp_mask, keypoint_descriptors_aux, query_aux_kps, valid_aux_kp_mask, self.grid_length, \
episode_generator.support_kp_categories, query_dataset.auxiliary_paths, weight=None, inv_w_patches_aux=inv_w_patches_aux, w_patches_aux=None, inv_w_patches=inv_w_patches, w_patches=None)
# ---------------main keypoints-------------------
loss = loss_grid_class + w_dev * loss_deviation
# ---------------auxiliary keypoints-------------------
loss_interpolation = loss_aux_grid + w_dev * loss_aux_deviation
loss_interpolation_total += loss_interpolation
loss_covar_multiple_kps_total = loss_covar_multiple_kps + loss_covar_multiple_kps_aux
# square_image_length = self.opts['square_image_length']
# show_save_predictions(queries, query_labels.cpu().detach(), valid_kp_mask, query_scale_trans, episode_generator, square_image_length, kp_var=None,
# version='original', is_show=False, is_save=True, folder_name='query_gt', save_file_prefix='eps{}'.format(episode_i))
# show_save_predictions(queries, mean_predictions.cpu().detach(), valid_kp_mask, query_scale_trans, episode_generator, square_image_length, kp_var=None, confidence_scale=3,
# version='original', is_show=False, is_save=True, folder_name='query_predict', save_file_prefix='eps{}'.format(episode_i))
if self.opts['use_body_part_protos']:
if self.opts['memorize_fibers']:
if episode_i <= self.memory['stat_episode_num'] - 1: # record fibers
# rectify the current body part order to the standard order in case "order_fixed = False", namely the dynamic support kp categories
order_index = [episode_generator.support_kp_categories.index(kp_type) for kp_type in KEYPOINT_TYPES]
self.memory['fibers'][episode_i] = (avg_support_repres[:, order_index]).detach().clone() # C x N
kp_weights = valid_kp_mask.sum(dim=0) # N
d_sum = torch.sum((mean_predictions.detach() - query_labels)**2, dim=2) # B x N
d = torch.sum(d_sum * valid_kp_mask, dim=0) / (kp_weights+1e-6) # N
self.memory['distance'][episode_i] = d[order_index]
self.memory['mask'][episode_i] = (kp_weights > 0)[order_index] # N
# here we suppose the aux kp order is fixed (should set the 'order_fixed = True' in main.py)
if self.memory['auxiliary_kp_num'] > 0 and num_valid_support_aux_kps > 0: # in case there is no aux kps
self.memory['aux_fibers'][episode_i] = avg_support_repres_aux.detach().clone() # C x T
kp_weights_aux = valid_aux_kp_mask.sum(dim=0) # T
d_sum_aux = torch.sum((mean_predictions_aux.detach() - query_aux_kps) ** 2, dim=2) # B x T
d_aux = torch.sum(d_sum_aux * valid_aux_kp_mask, dim=0) / (kp_weights_aux + 1e-6) # T
self.memory['aux_distance'][episode_i] = d_aux
self.memory['aux_mask'][episode_i] = (kp_weights_aux > 0) # T
if episode_i == self.memory['stat_episode_num'] - 1: # build body part prototypes which is subject the standard order
d_total = self.memory['distance'] # K x N, l2 distance
d_total = torch.sqrt(d_total) # K x N, square root
m_total = self.memory['mask'] # K x N
masked_exp_neg_d = torch.exp(-d_total) * m_total
p_total = masked_exp_neg_d / torch.sum(masked_exp_neg_d, dim=0).view(1, -1) # K x N
self.memory['proto_mask'] = m_total.sum(dim=0) > 0 # N
if self.memory['auxiliary_kp_num'] > 0 and num_valid_support_aux_kps > 0: # in case there is no aux kps
d_total_aux = self.memory['aux_distance'] # K x T, l2 distance
d_total_aux = torch.sqrt(d_total_aux) # K x T, square root
m_total_aux = self.memory['aux_mask'] # K x T
masked_exp_neg_d_aux = torch.exp(-d_total_aux) * m_total_aux
p_total_aux = masked_exp_neg_d_aux / torch.sum(masked_exp_neg_d_aux, dim=0).view(1, -1) # K x T
self.memory['aux_proto_mask'] = m_total_aux.sum(dim=0) > 0 # N
if self.memory['proto_compute_method'] == 'ws': # weighted sum
# ---
# Method 1, use prediction's deviation to GT to serve as weight
universal_body_part_protos = self.memory['fibers'] * p_total.view(self.memory['stat_episode_num'], 1, self.memory['part_num']) # K x C x N
self.memory['proto'] = torch.sum(universal_body_part_protos, dim=0) # C x N
if self.memory['auxiliary_kp_num'] > 0 and num_valid_support_aux_kps > 0: # in case there is no aux kps
universal_protos_aux = self.memory['aux_fibers'] * p_total_aux.view(self.memory['stat_episode_num'], 1, self.memory['auxiliary_kp_num']) # K x C x T
self.memory['aux_proto'] = torch.sum(universal_protos_aux, dim=0) # C x T
else: # mean
# ---
# Method 2, simple mean
universal_body_part_protos = self.memory['fibers'] * m_total.view(self.memory['stat_episode_num'], 1, self.memory['part_num']) # K x C x N
self.memory['proto'] = torch.sum(universal_body_part_protos, dim=0) / (m_total.sum(dim=0) + 1e-6).view(1, -1) # C x N
if self.memory['auxiliary_kp_num'] > 0 and num_valid_support_aux_kps > 0: # in case there is no aux kps
universal_protos_aux = self.memory['aux_fibers'] * m_total_aux.view(self.memory['stat_episode_num'], 1, self.memory['auxiliary_kp_num']) # K x C x T
self.memory['aux_proto'] = torch.sum(universal_protos_aux, dim=0) / (m_total_aux.sum(dim=0) + 1e-6).view(1, -1) # C x T
#---
self.save_proto_memory()
torch.set_grad_enabled(True) # enable grad computation when finishing recording fibers
print('Open grad computation!')
print('Initial universal body part prototypes built!')
exit(0) # exit when body part protos are built
print(episode_i)
episode_i += 1 # skip loss functions and backwards, jump to next episode
continue
loss_total += loss
if episode_i % 1 == (1 - 1):
loss_interpolation_total /= 1.0
loss_total /= 1.0
final_combined_loss = self.opts['loss_weight'][0]*loss_total + self.opts['loss_weight'][1]*loss_interpolation_total + self.opts['loss_weight'][2]*loss_covar_multiple_kps_total # + loss_similarity # + 0.1 * loss_symmetry_total + 0.01*loss_adapt_total + 0.01*loss_adapt_kp_total
#=================
# sometimes will happen due to sdm (semantic distinctiveness) when beta_semantic = 0
if torch.isnan(final_combined_loss):
loss_interpolation_total = 0
loss_total = 0
print('loss to be nan')
continue
#==================
self.optimizer_step(final_combined_loss)
self.lr_scheduler_step(episode_i)
loss_interpolation_total = 0
loss_total = 0
if episode_i % 8 == (8 - 1):
# print('time: {}'.format(datetime.datetime.now()))
if self.regression_type == 'direct_regression_gridMSE':
print('episode: {}, loss_kp: {:.5f} (G: {:.5f}/D: {:.5f}), loss_aux: {:.5f} (G: {:.5f}/D: {:.5f}), mcovar: {:.5f} (M: {:.5f}/A: {:.5f}), time: {}'.format(episode_i, loss.item(), loss_grid_class.item(),
loss_deviation.item(), loss_interpolation.item(), loss_aux_grid.item(), loss_aux_deviation.item(), loss_covar_multiple_kps_total.item(), loss_covar_multiple_kps.item(), loss_covar_multiple_kps_aux.item(),\
datetime.datetime.now()))
if self.writer != None:
self.writer.add_scalar('loss', loss.cpu().detach().numpy(), episode_i)
# increment in episode_i
episode_i += 1
def multiscale_regression(self, keypoint_descriptors, query_kp_label, valid_kp_mask, grid_length_list, weight=None, inv_w_patches=None, w_patches=None):
B2 = query_kp_label.shape[0] # B2 x N x 2, N keypoints for each image
N = query_kp_label.shape[1]
num_valid_kps = torch.sum(valid_kp_mask)
loss_grid_class = 0
loss_deviation = 0
mean_predictions = 0
scale_num = len(grid_length_list)
for scale_i, grid_length in enumerate(grid_length_list):
# compute grid groundtruth and deviation
gridxy = (query_kp_label /2 + 0.5) * grid_length # coordinate -1~1 --> 0~self.grid_length, B2 x N x 2
gridxy_quantized = gridxy.long().clamp(0, grid_length - 1) # B2 x N x 2
# Method 1, deviation range: -1~1
label_deviations = (gridxy - (gridxy_quantized + 0.5)) * 2 # we hope the deviation ranges -1~1, B2 x N x 2
# Method 2, deviation range: 0~1
# label_deviations = (gridxy - gridxy_quantized) # we hope the deviation ranges 0~1, B2 x N x 2
label_grids = gridxy_quantized[:, :, 1] * grid_length + gridxy_quantized[:, :, 0] # 0 ~ grid_length * grid_length - 1, B2 x N
one_hot_grid_label = torch.zeros(B2, N, grid_length**2).cuda() # B2 x N x (grid_length * grid_length)
one_hot_grid_label = one_hot_grid_label.scatter(dim=2, index=torch.unsqueeze(label_grids, dim=2), value=1) # B2 x N x (grid_length * grid_length)
# 1) global deviation
# predict_grids, predict_deviations, rho, _ = self.regressor[scale_i](keypoint_descriptors, training_phase=True, one_hot_grid_label=one_hot_grid_label) # B2 x N x (grid_length ** 2), B2 x N x 2, B2 x N x 2 (or B2 x N x 2d)
# 2) local deviation
predict_grids, predict_deviations, rho, _ = self.regressor[scale_i](keypoint_descriptors, training_phase=True, one_hot_grid_label=label_grids) # B2 x N x (grid_length ** 2), B2 x N x 2, B2 x N x 2 (or B2 x N x 2d)
# compute grid classification loss and deviation loss
# predict_grids2 = predict_grids * valid_kp_mask.view(B2, N, 1)
# predict_grids2 = predict_grids2.view(B2 * N, -1)
# label_grids2 = (label_grids * valid_kp_mask).long()
# label_grids2 = label_grids2.view(-1)
predict_grids2 = predict_grids.view(B2 * N, -1) # (B2 * N) * (grid_length * grid_length)
label_grids2 = label_grids
for i in range(B2):
for j in range(N):
if valid_kp_mask[i, j] < 1:
label_grids2[i, j] = -1 # set ignore index for nllloss
label_grids2 = label_grids2.view(-1) # (B2 * N)
# loss_grid_class += self.loss_fun_nll(predict_grids2, label_grids2)
# # type1, direct regressing deviation
# predict_deviations2 = predict_deviations * valid_kp_mask.view(B2, N, 1)
# label_deviations2 = label_deviations * valid_kp_mask.view(B2, N, 1)
# loss_deviation += self.loss_fun_mse(predict_deviations2, label_deviations2)
if weight is None:
# A) main training kps, grid classification
if inv_w_patches is None:
loss_grid_class += self.loss_fun_nll(predict_grids2, label_grids2)
else:
# loss_grid_class += self.loss_fun_nll(predict_grids2, label_grids2)
square_root_for_w = True
# # using patches uncertainty into consideration
if square_root_for_w == False:
if len(inv_w_patches.shape) == 1: # only use support patches to compute inv_w, size is N
instance_weight = inv_w_patches.repeat(B2, 1).reshape(-1) # (B2*N)
elif len(inv_w_patches.shape) == 2: # use both support & query patches, size is B2 x N
instance_weight = inv_w_patches.reshape(-1) # (B2*N)
else:
if len(inv_w_patches.shape) == 1: # only use support patches to compute inv_w, size is N
instance_weight = torch.sqrt(inv_w_patches).repeat(B2, 1).reshape(-1)
elif len(inv_w_patches.shape) == 2: # use both support & query patches, size is B2 x N
instance_weight = torch.sqrt(inv_w_patches).reshape(-1)
loss_grid_class += instance_weighted_nllloss(predict_grids2, label_grids2, instance_weight=instance_weight, ignore_index=-1) # weight: (B2 * N)
# compute penatly for weights
# penalty = torch.sum(log_w_patches.repeat(B2, 1) * valid_kp_mask_aux) # use log(w) as penalty
# if num_valid_kps_aux > 0:
# penalty /= num_valid_kps_aux
# # print(penalty)
# loss_grid_class_aux += penalty
# ---------------------------------------
else:
instance_weight = weight.repeat(B2, 1).reshape(-1) # construct instance_weights which has (B*N) elements
loss_grid_class += instance_weighted_nllloss(predict_grids2, label_grids2, instance_weight=instance_weight, ignore_index=-1) # weight: N
# ------------------
if weight is None:
if inv_w_patches is None:
# type1, direct regressing deviation
# independent variance
# loss_deviation += masked_nll_gaussian(predict_deviations, label_deviations, rho, valid_kp_mask.view(B2, N, 1))
# loss_deviation += masked_nll_laplacian(predict_deviations, label_deviations, rho, valid_kp_mask.view(B2, N, 1))
# covariance for each keypoint
loss_deviation += masked_nll_gaussian_covar(predict_deviations, label_deviations, rho, valid_kp_mask)
# covariance for all keypoints
# loss_deviation += masked_nll_gaussian_covar2(predict_deviations, label_deviations, rho, valid_kp_mask, computing_mode=1)
# no variance
# loss_deviation += masked_mse_loss(predict_deviations, label_deviations, valid_kp_mask.view(B2, N, 1))
# loss_deviation += masked_l1_loss(predict_deviations, label_deviations, valid_kp_mask.view(B2, N, 1))
else:
# loss_deviation += masked_l1_loss(predict_deviations, label_deviations, valid_kp_mask.view(B2, N, 1))
# loss_deviation += masked_nll_laplacian(predict_deviations, label_deviations, rho, valid_kp_mask.view(B2, N, 1))
# covariance for each keypoint
loss_fun_mode = 0
penalty_mode = 0 # 0 or 2 are better, namely log(det(W^-1)) or 1/ W^-1 - 1
beta = 1.0
beta_loc_uc = 1.0
loss_deviation += masked_nll_gaussian_covar(predict_deviations, label_deviations, rho, valid_kp_mask, patches_rho=inv_w_patches, loss_fun_mode=loss_fun_mode, penalty_mode=penalty_mode, beta=beta, beta_loc_uc=beta_loc_uc)
# using patches uncertainty into consideration
# loss_deviation += masked_nll_laplacian(predict_deviations, label_deviations, patches_rho.view(1, N, 1), valid_kp_mask.view(B2, N, 1), beta=0.5, computing_mode=0, offset1=6, offset2=0.5)
# loss_deviation += masked_nll_gaussian(predict_deviations, label_deviations, patches_rho.view(1, N, 1), valid_kp_mask.view(B2, N, 1), beta=0.5, computing_mode=0, offset1=6, offset2=0.5)
# loss_deviation += masked_nll_laplacian2(predict_deviations, label_deviations, rho, patches_rho.view(1, N, 1), valid_kp_mask.view(B2, N, 1), beta=0.5, gamma=0.98)
# loss_deviation += masked_nll_gaussian2(predict_deviations, label_deviations, rho, patches_rho.view(1, N, 1), valid_kp_mask.view(B2, N, 1), beta=0.5, gamma=0.95)
else:
# weight: N
combine_mask_weight = (valid_kp_mask * weight.reshape(1, -1)).view(B2, N, 1)
# loss_deviation += masked_nll_gaussian(predict_deviations, label_deviations, rho, combine_mask_weight)
# loss_deviation += masked_nll_laplacian(predict_deviations, label_deviations, rho, combine_mask_weight)
# loss_deviation += masked_mse_loss(predict_deviations, label_deviations, combine_mask_weight)
loss_deviation += masked_l1_loss(predict_deviations, label_deviations, combine_mask_weight)
# ------------------
# type2, transform to original location
# mean_predictions += (((gridxy_quantized + 0.5) + predict_deviations / 2.0) / grid_length - 0.5) * 2
# compute predicted keypoint locations using grids and deviations
out_predict_grids = torch.max(predict_grids, dim=2)[1] # B2 x N, 0 ~ grid_length * grid_length - 1
out_predict_gridxy = torch.FloatTensor(B2, N, 2).cuda()
out_predict_gridxy[:, :, 0] = out_predict_grids % grid_length # grid x
out_predict_gridxy[:, :, 1] = out_predict_grids // grid_length # grid y
# Method 1, deviation range: -1~1
predictions = (((out_predict_gridxy + 0.5) + predict_deviations / 2.0) / grid_length - 0.5) * 2 # deviation -1~1
# Method 2, deviation range: 0~1
# predictions = ((predict_gridxy + predict_deviations) / grid_length - 0.5) * 2
mean_predictions += predictions
# ------------------
# ------------------
# type2, transform to original location
mean_predictions /= scale_num
# query_labels2 = query_labels * valid_kp_mask.view(B2, N, 1)
# mean_predictions2 = mean_predictions * valid_kp_mask.view(B2, N, 1)
# loss_deviation = self.loss_fun_mse(mean_predictions2, query_labels2)
# loss_deviation = loss_deviation / num_valid_kps * 15
# ------------------
loss_grid_class /= scale_num # multi-scale
# ------------------
# type1, direct regressing deviation
loss_deviation /= scale_num # multi-scale
# if weight is None:
# num_valid_kps = torch.sum(valid_kp_mask).item()
# else:
# num_valid_kps = torch.sum(valid_kp_mask * weight.reshape(1, -1)).item()
# if num_valid_kps == 0: # for symmetry case, it may not have the valid symmetric keypoints
# pass # no need to divide symmetry_num_valid_kps since all are zero
# else:
# loss_deviation = loss_deviation / num_valid_kps
# ------------------
# loss = loss_grid_class + loss_deviation
# print('loss_d: ', loss_deviation.cpu().detach().numpy())
return loss_grid_class, loss_deviation, mean_predictions
def construct_patches_grids(self, kp_labels, kp_mask, patch_size=3, interval_control=(2 / 12)):
'''
:param kp_labels: B x T x 2
:param kp_mask: B x T
:param patch_size:
:param interval_control:
:return patch_grids, B x (T*patch_size*patch_size) x 2, e.g., B x (T*9) x 2
:return patch_grids_mask, B x (T*patch_size*patch_size), e.g., B x (T*9)
'''
B, T = kp_labels.shape[:2]
per_patch_grids_num = patch_size ** 2
patch_grids = torch.zeros(B, T * per_patch_grids_num, 2, requires_grad=False).cuda()
# patch_grids_mask = torch.zeros(B, T * per_patch_grids_num, requires_grad=False).cuda()
half_patch_size = patch_size // 2
for i in range(B):
for j in range(T):
if kp_mask[i, j] == 0:
continue
cnt = 0 # 0~per_patch_grids_num-1
for r in range(-half_patch_size, half_patch_size + 1, 1):
for c in range(-half_patch_size, half_patch_size + 1, 1):
y = kp_labels[i, j, 1] + r * interval_control
x = kp_labels[i, j, 0] + c * interval_control
patch_grids[i, j * per_patch_grids_num + cnt, 0] = x
patch_grids[i, j * per_patch_grids_num + cnt, 1] = y
cnt += 1
return patch_grids
def construct_patches_grids2(self, kp_labels, kp_mask, patch_size=3, interval_control=(2/12), label_max=1, label_min=-1):
'''
:param kp_labels: B x T x 2
:param kp_mask: B x T
:param patch_size:
:param interval_control:
:return patch_grids, B x (T*patch_size*patch_size) x 2, e.g., B x (T*9) x 2
:return patch_grids_mask, B x (T*patch_size*patch_size), e.g., B x (T*9)
'''
B, T = kp_labels.shape[:2]
per_patch_grids_num = patch_size ** 2
patch_grids = torch.zeros(B, T*per_patch_grids_num, 2, requires_grad=False).cuda()
patch_grids_mask = torch.zeros(B, T*per_patch_grids_num, requires_grad=False).cuda()
half_patch_size = patch_size // 2
for i in range(B):
for j in range(T):
if kp_mask[i, j] == 0:
continue
cnt = 0 # 0~per_patch_grids_num-1
for r in range(-half_patch_size, half_patch_size+1, 1):
for c in range(-half_patch_size, half_patch_size+1, 1):
y = kp_labels[i, j, 1] + r * interval_control
x = kp_labels[i, j, 0] + c * interval_control
patch_grids[i, j * per_patch_grids_num + cnt, 0] = x
patch_grids[i, j * per_patch_grids_num + cnt, 1] = y
# build mask for generated grids
if x >= label_min and x <= label_max and y >= label_min and y <= label_max:
patch_grids_mask[i, j * per_patch_grids_num + cnt] = 1
cnt += 1
return patch_grids, patch_grids_mask
def compute_average_for_patch(self, map, points_set, keypoint_num, patch_size_square=9):
'''
:param points_set: B x (keypoint_num * patch_size_square) x 2
:param map: B x 1 x H x W
:return:
'''
values_set, _ = extract_representations(map, points_set, None, context_mode='soft_fiber_bilinear') # B x 1 x (keypoint_num * patch_size_square)
average = torch.mean(values_set.view(-1, keypoint_num, patch_size_square), dim=2) # B x keypoint_num
return average
def find_valid_pair_kps(self, mask):
'''
Given kp mask, find valid pair of indices for valid keypoints
:param mask: B x N
:return: valid_inds, k x 3, each row is something like (0, 1, 3), which means kp indices (1, 3) in image 0 is a valid pair
'''
valid_inds_list = []
B, N = mask.shape
for i in range(B):
for j in range(0, N-1, 1):
if mask[i, j] == 0:
continue
for k in range(j+1, N, 1):
if mask[i, k] == 0:
continue
valid_inds_list.append([i, j, k])
valid_inds = torch.Tensor(valid_inds_list).long().cuda()
return valid_inds # k x 3
def find_valid_pair_neighbor_kps(self, main_kp_mask, auxiliary_kp_mask, support_kp_categories, auxiliary_path):
'''
Given kp mask, find valid pairs of neighboring kps
:param main_kp_mask: B x N
:param auxiliary_kp_mask: B x T, T = (N_paths * N_knots)
:param support_kp_categories:
:param auxiliary_path: N_paths x 2, each row is a pair of body parts, original kp ids in KEYPOINT_TYPE
:return: valid inds, k x 3
'''
B, N = main_kp_mask.shape
_, T = auxiliary_kp_mask.shape
interpolation_knots = self.opts['interpolation_knots']
N_knots = len(interpolation_knots) # interpolated knots
N_paths = int(T / N_knots) # the number of body part pairs served as path for interpolation
# num_valid_kps_aux = torch.sum(auxiliary_kp_mask).long().item()
valid_inds_list = [] # Each path has N_knots+1 valid pairs when this path is valid
auxiliary_path_in_support = []
for path_ind in range(N_paths):
start_kp_id, end_kp_id = auxiliary_path[path_ind]
start_kp_type, end_kp_type = KEYPOINT_TYPES[start_kp_id], KEYPOINT_TYPES[end_kp_id]
start_kp_id_in_support, end_kp_id_in_support = support_kp_categories.index(start_kp_type), support_kp_categories.index(end_kp_type)
auxiliary_path_in_support.append([start_kp_id_in_support, end_kp_id_in_support])
for i in range(B):
for j in range(T):
if auxiliary_kp_mask[i, j] == 0:
continue
knot_ind = j % N_knots
if knot_ind == 0: # first knot
path_ind = j // N_knots # 0 ~ N_paths-1
start_kp_id_in_support, end_kp_id_in_support = auxiliary_path_in_support[path_ind]
valid_inds_list.append([i, start_kp_id_in_support, N+j])
if N_knots >= 2:
valid_inds_list.append([i, N + j, N + j + 1])
else: # N_knots == 1, for this case the first knot is also the last knot
valid_inds_list.append([i, N + j, end_kp_id_in_support])
elif knot_ind == N_knots-1: # last knot
path_ind = j // N_knots # 0 ~ N_paths-1
start_kp_id_in_support, end_kp_id_in_support = auxiliary_path_in_support[path_ind]
valid_inds_list.append([i, N+j, end_kp_id_in_support])
else: # middle knots
valid_inds_list.append([i, N+j, N+j+1])
valid_inds = torch.Tensor(valid_inds_list).long().cuda()
return valid_inds # k x 3
def find_valid_triplet_kps(self, main_kp_mask, auxiliary_kp_mask, support_kp_categories, auxiliary_path):
'''
Given kp mask, find valid triplets of indices
:param main_kp_mask: B x N
:param auxiliary_kp_mask: B x T, T = (N_paths * N_knots)
:param auxiliary_path: N_paths x 2, each row is a pair of body parts, original kp ids in KEYPOINT_TYPE
:return: valid inds, k x 4, each row is something like (0, 1, 3, N+j), which means main kp indices (1, 3) and
auxiliary kp index j in image 0 is a valid triplet
'''
B, N = main_kp_mask.shape
_, T = auxiliary_kp_mask.shape
interpolation_knots = self.opts['interpolation_knots']
N_knots = len(interpolation_knots) # interpolated knots
N_paths = int(T / N_knots) # the number of body part pairs served as path for interpolation
num_valid_kps_aux = torch.sum(auxiliary_kp_mask).long().item()
valid_inds = torch.zeros(num_valid_kps_aux, 4).long()
auxiliary_path_in_support = []
for path_ind in range(N_paths):
start_kp_id, end_kp_id = auxiliary_path[path_ind]
start_kp_type, end_kp_type = KEYPOINT_TYPES[start_kp_id], KEYPOINT_TYPES[end_kp_id]
start_kp_id_in_support, end_kp_id_in_support = support_kp_categories.index(start_kp_type), support_kp_categories.index(end_kp_type)
auxiliary_path_in_support.append([start_kp_id_in_support, end_kp_id_in_support])
cnt = 0
for i in range(B):
for j in range(T):
if auxiliary_kp_mask[i, j] == 0:
continue
path_ind = j // N_knots # 0 ~ N_paths-1
start_kp_id_in_support, end_kp_id_in_support = auxiliary_path_in_support[path_ind]
valid_inds[cnt] = torch.tensor([i, start_kp_id_in_support, N + j, end_kp_id_in_support])
cnt += 1
# for i in range(B):
# for j in range(T):
# if auxiliary_kp_mask[i, j] == 0:
# continue
# path_ind = j // N_knots # 0 ~ N_paths-1
# start_kp_id, end_kp_id = auxiliary_path[path_ind]
# start_kp_type, end_kp_type = KEYPOINT_TYPES[start_kp_id], KEYPOINT_TYPES[end_kp_id]
# start_kp_id_in_support, end_kp_id_in_support = support_kp_categories.index(start_kp_type), support_kp_categories.index(end_kp_type)
# valid_inds[cnt] = torch.tensor([i, start_kp_id_in_support, N + j, end_kp_id_in_support])
#
# cnt += 1
return valid_inds # k x 4
def numerical_transformation(self, lateral_out, w_computing_mode=2, offset1=1.0, offset2=0.5):
## transform patches_rho into positive number
# Method 0
if w_computing_mode == 0:
positive_lateral_out = torch.exp(lateral_out) # patches_rho = log(w^-1)
# ---------------------------------------
# Method 1
elif w_computing_mode == 1:
# offset1 = 1 # or 6
# offset2 = 0.5 # or 0.5
positive_lateral_out = offset1 * torch.sigmoid(lateral_out) + offset2 # w = a*sigmoid(patches_rho)+b
# ---------------------------------------
# Method 2
elif w_computing_mode == 2:
parameter_a = 4
positive_lateral_out = (lateral_out + torch.sqrt(lateral_out ** 2 + parameter_a)) / 2 # f(x) = (x + sqrt(x*x + a)) / 2
# log_w_patches = -torch.log(inv_w_patches)
# ---------------------------------------
# Method 3
elif w_computing_mode == 3:
positive_lateral_out = lateral_out ** 2 # inv_w = rho * rho
# ---------------------------------------
# Method 4
elif w_computing_mode == 4: