Skip to content

Commit

Permalink
Increase dropout and add stuff to analyze noise resistance
Browse files Browse the repository at this point in the history
  • Loading branch information
DaWelter committed Oct 16, 2023
1 parent 3152c3e commit 3d9a3bd
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 45 deletions.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_wlp+synface+lapa_megaface_lp+wflw_lp" --auglevel 2 \
--save-plot train.pdf \
--with-swa \
--backbone mobilenetv1
--backbone resnet18
21 changes: 1 addition & 20 deletions scripts/add_pose_pseudolabels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import gc

from trackertraincode.datasets.dshdf5pose import Hdf5PoseDataset
from trackertraincode.neuralnets.torchquaternion import quat_average
import trackertraincode.vis as vis
import trackertraincode.datatransformation as dtr
import trackertraincode.utils as utils
Expand Down Expand Up @@ -60,26 +61,6 @@ def fit_batch(net : InferenceNetwork, batch : List[Batch]):
return out


def quat_average(quats):
quats = np.asarray(quats)
# Ensemble size, number of samples, dimensionality
E, N, D = quats.shape
assert D==4
# Sum over ensemble to get an idea of the largest axis on average.
# Then find the actual longest axis, i.e. i,j,k or w.
pivot_axes = np.argmax(np.sum(np.abs(quats), axis=0), axis=-1)
assert pivot_axes.shape == (N,)
mask = np.take_along_axis(quats, pivot_axes[None,:,None], axis=-1) < 0.
mask = mask[...,0] # Skip quaternion dimension
quats[mask,:] *= -1
quats = np.average(quats, axis=0)
norms = np.linalg.norm(quats, axis=-1, keepdims=True)
if not np.all(norms > 0.5):
print("Oh oh either quat_average is bugged or rotations predictions differ wildly")
quats /= norms
return quats


def test_quats_average():
def positivereal(q):
s = np.sign(q[...,3])
Expand Down
78 changes: 60 additions & 18 deletions scripts/evaluate_pose_network.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
#!/usr/bin/env python
# coding: utf-8

from typing import Any, List
# Seems to run a bit faster than with default settings and less bugged
# See https://github.com/pytorch/pytorch/issues/67864
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from typing import Any, List, NamedTuple
import numpy as np
import argparse
import tqdm
import tabulate
from numpy.typing import NDArray
from matplotlib import pyplot
from os.path import basename
import functools
import torch

from trackertraincode.datasets.batch import Batch
import trackertraincode.datatransformation as dtr
Expand All @@ -19,15 +26,40 @@

from trackertraincode.eval import load_pose_network, predict

load_pose_network = functools.lru_cache(maxsize=1)(load_pose_network)

class RoiConfig(NamedTuple):
expansion_factor : float = 1.2
center_crop : bool = False

def __str__(self):
crop = ['ROI','CC'][self.center_crop]
return f'{crop}{self.expansion_factor:0.1f}'

normal_roi_configs = [ RoiConfig() ]
comprehensive_roi_configs = [ RoiConfig(*x) for x in [(1.2, False), (1.0, False), (0.6, True), (0.8, True) ] ]


def compute_predictions_and_targets(loader, net, keys):
def determine_roi(batch : Batch, use_center_crop : bool):
if not use_center_crop:
return batch['roi']
w,h = batch.meta.image_wh
b = batch.meta.batchsize
return torch.tensor([0,0,h,w], dtype=torch.float32).expand((b,4))


def compute_predictions_and_targets(loader, net, keys, roi_config : RoiConfig):
preds = []
targets = []
first = True
bar = tqdm.tqdm(total = len(loader.dataset))
for batch in utils.iter_batched(loader, 32):
batch = Batch.collate(batch)
pred = predict(net, batch['image'], rois=batch['roi'])
pred = predict(
net,
batch['image'],
rois=determine_roi(batch, roi_config.center_crop),
focus_roi_expansion_factor=roi_config.expansion_factor)
if first:
keys = list(frozenset(pred.keys()).intersection(frozenset(keys)))
first = False
Expand Down Expand Up @@ -55,67 +87,77 @@ def interleaved(a,b):


class TableBuilder:
data_name_table = {
'aflw2k3d' : 'AFLW 2k 3d',
'aflw2k3d_grimaces' : 'grimaces'
}

def __init__(self):
self._header = [ 'Model', 'Data', 'Yaw°', 'Pitch°', 'Roll°', 'Mean°', 'Geodesic°', 'XY%', 'S%' ]
self._entries = []

def add_row(self, model : str, data : str, euler_angles : List[float], geodesic : float, rmse_pos : float, rmse_size : float):
def add_row(self, model : str, data : str, euler_angles : List[float], geodesic : float, rmse_pos : float, rmse_size : float, data_aux_string = None):
maxlen = 30
if len(model) > maxlen+3:
model = '...'+model[-maxlen:]
data = self.data_name_table[data] + (data_aux_string if data_aux_string is not None else '')
self._entries.append([model, data] + euler_angles + [ np.average(euler_angles).tolist(), geodesic, rmse_pos, rmse_size] )

def build(self) -> str:
return tabulate.tabulate(self._entries, self._header, tablefmt='github', floatfmt=".2f")


def report(name, net, args, builder : TableBuilder):
loader = trackertraincode.pipelines.make_validation_loader(name)
preds, targets = compute_predictions_and_targets(loader, net, ['coord','pose', 'roi', 'pt3d_68'])

def report(net_filename, data_name, roi_config, args, builder : TableBuilder):
loader = trackertraincode.pipelines.make_validation_loader(data_name)
net = load_pose_network(net_filename, args.device)
preds, targets = compute_predictions_and_targets(loader, net, ['coord','pose', 'roi', 'pt3d_68'], roi_config)
# Position and size errors are measured relative to the ROI size. Hence in percent.
poseerrs = trackertraincode.eval.PoseErr()(preds, targets)
eulererrs = trackertraincode.eval.EulerAngleErrors()(preds, targets)
e_rot, e_posx, e_posy, e_size = np.array(poseerrs).T
rmse_pos = np.sqrt(np.average(np.sum(np.square(np.vstack([e_posx, e_posy]).T), axis=1), axis=0))
rmse_size = np.sqrt(np.average(np.square(e_size)))
builder.add_row(
model=basename(args.filename),
data=name,
model=basename(net_filename),
data=data_name,
euler_angles=(np.average(np.abs(eulererrs), axis=0)*utils.rad2deg).tolist(),
geodesic=(np.average(e_rot)*utils.rad2deg).tolist(),
rmse_pos=(rmse_pos*100.).tolist(),
rmse_size=(rmse_size*100.).tolist()
rmse_size=(rmse_size*100.).tolist(),
data_aux_string=' / ' + str(roi_config)
)

if args.vis:
order = interleaved(np.argsort(e_rot)[::-1], np.argsort(e_size)[::-1])
loader = trackertraincode.pipelines.make_validation_loader(name, order=order)
loader = trackertraincode.pipelines.make_validation_loader(data_name, order=order)
new_preds = Batch(preds.meta, **{k:v[order] for k,v in preds.items()})
new_preds.meta.batchsize = len(order)
worst_rot_iter = iterate_predictions(loader, new_preds)
fig, btn = vis.matplotlib_plot_iterable(worst_rot_iter, vis.draw_prediction)
fig.suptitle(name)
fig.suptitle(data_name + ' / ' + net_filename)
return [fig, btn]
else:
return []


def run(args):
net = load_pose_network(args.filename, args.device)
gui = []
table_builder = TableBuilder()
for name in [ 'aflw2k3d', 'aflw2k3d_grimaces']:
gui += report(name, net, args, table_builder)
roi_configs = comprehensive_roi_configs if args.comprehensive_roi else normal_roi_configs
for net_filename in args.filenames:
for name in [ 'aflw2k3d', 'aflw2k3d_grimaces']:
for roi_config in roi_configs:
gui += report(net_filename, name, roi_config, args, table_builder)
print (table_builder.build())
pyplot.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Evaluate pose networks")
parser.add_argument('filename', help='filename of checkpoint or onnx model file', type=str)
parser.add_argument('filenames', help='filenames of checkpoint or onnx model file', type=str, nargs='*')
parser.add_argument('--no-vis', dest='vis', help='disable visualization', default=True, action='store_false')
parser.add_argument('--device', help='select device: cpu or cuda', default='cuda', type=str)
parser.add_argument('--res', dest='input_resolution', help='input resolution for loaded models where it is not clear', default=129, type=int)
parser.add_argument('--auto-level', dest='auto_level', help='automatically adjust brightness levels for maximum contrast within the roi', default=False, action='store_true')
parser.add_argument('--comprehensive-roi', action='store_true', default=False)
args = parser.parse_args()
run(args)
60 changes: 58 additions & 2 deletions scripts/evaluate_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import tqdm
from typing import NamedTuple, Optional, List, Tuple
from scipy.spatial.transform import Rotation
import tabulate
from matplotlib import pyplot
import itertools
from collections import defaultdict
import torch
from torchvision.transforms import Compose
from trackertraincode.neuralnets.torchquaternion import quat_average, geodesicdistance

from trackertraincode.datasets.batch import Batch
import trackertraincode.datatransformation as dtr
Expand Down Expand Up @@ -247,19 +250,72 @@ def predict_all_nets(loader):
pyplot.show()


def main_analyze_noise_resist(checkpoints : List[str]):
def noisify(img : torch.Tensor, noiselevel : float):
return (img+noiselevel*torch.randn_like(img, dtype=torch.float32)).clip(0., 255.).to(torch.uint8)

def predict_noisy_dataset(net, loader, noiselevel, predictions_to_keep, rounds = 10) -> Poses:
bar = tqdm.tqdm(total = len(loader.dataset))
def predict_sample_list(samples : List[Batch]):
rois = torch.stack([ s['roi'] for s in samples ])
outputs = defaultdict(list)
for _ in range(rounds):
images = [ noisify(s['image'], noiselevel) for s in samples ]
out = predict(net, images, rois, focus_roi_expansion_factor=1.1)
for k in predictions_to_keep:
outputs[k].append(out[k])
preds_out = Batch(meta=out.meta, **{
k:torch.stack(v,dim=1) for k,v in outputs.items()
})
gt_outputs = { k:torch.stack([s[k] for s in samples]) for k in predictions_to_keep }
bar.update(len(samples))
return preds_out, gt_outputs
preds, gt = zip(*(predict_sample_list(batch) for batch in utils.iter_batched(loader,32)))
preds = utils.list_of_dicts_to_dict_of_lists(preds)
for k,v in preds.items():
preds[k] = torch.concat(v)
gt = utils.list_of_dicts_to_dict_of_lists(gt)
for k,v in gt.items():
gt[k] = torch.concat(v)
return preds, gt

def compute_metrics_for_quats(gt : torch.Tensor, preds : torch.Tensor):
preds = preds.swapaxes(0,1) # batch <-> noise sample
mean_preds = torch.from_numpy(quat_average(preds))
preds_spread = geodesicdistance(mean_preds[None,...], preds).square().mean(dim=0).sqrt().mean(dim=0)
preds_error = geodesicdistance(mean_preds, gt).mean(dim=0)
return preds_error, preds_spread

loader = trackertraincode.pipelines.make_validation_loader('aflw2k3d', order = np.random.choice(1900,size=100))

rad2deg = 180./np.pi

for checkpoint in checkpoints:
net = load_pose_network(checkpoint, 'cuda')
table = [[ 'noise', 'err', 'spread' ]]
for noiselevel in [ 1., 8., 16., 32. ]:
preds, gt = predict_noisy_dataset(net, loader, noiselevel, ('coord','pose','roi'))
err, spread = compute_metrics_for_quats(gt['pose'], preds['pose'])
table.append([noiselevel, err*rad2deg, spread*rad2deg])
print (f"Checkpoint: {checkpoint}")
print (tabulate.tabulate(table[1:], table[0]))

if __name__ == '__main__':
np.seterr(all='raise')
parser = argparse.ArgumentParser(description="Trains the model")
parser.add_argument('filename', nargs='+', help='filename of checkpoint or onnx model file', type=str)
parser.add_argument('--closed-loop', action='store_true', default=False)
parser.add_argument('--pitch-yaw', action='store_true', default=False)
parser.add_argument('--open-loop', action='store_true', default=False)
parser.add_argument('--noise-resist', action='store_true', default=False)
args = parser.parse_args()
if not (args.closed_loop or args.pitch_yaw):
if not (args.closed_loop or args.pitch_yaw or args.noise_resist):
args.open_loop = True
if args.open_loop:
main_open_loop(args.filename, 'cuda')
if args.closed_loop:
main_closed_loop(args.filename, 'cpu')
if args.pitch_yaw:
main_analyze_pitch_vs_yaw(args.filename)
main_analyze_pitch_vs_yaw(args.filename)
if args.noise_resist:
main_analyze_noise_resist(args.filename)
1 change: 1 addition & 0 deletions trackertraincode/datatransformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RandomGaussianNoiseWithClipping,
RandomContrast,
RandomBrightness,
RandomGaussianNoise
)

from trackertraincode.datatransformation.loader import (
Expand Down
2 changes: 1 addition & 1 deletion trackertraincode/neuralnets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(

self.convnet = create_pose_estimator_backbone(config)
num_features = self.convnet.num_features
self.dropout = nn.Dropout(0.1)
self.dropout = nn.Dropout(0.5)

self.boxnet = BoundingBox(num_features, enable_uncertainty)
self.posnet = PositionSizeOutput(num_features, enable_uncertainty)
Expand Down
23 changes: 22 additions & 1 deletion trackertraincode/neuralnets/torchquaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
pytorch3d and kornia are different.
'''
from typing import Union, Final
import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -147,4 +148,24 @@ def distance(a, b):
#return torch.min(torch.norm(a-b,p=2,dim=-1), torch.norm(a+b,p=2,dim=-1))

def geodesicdistance(a,b):
return 2.*torch.acos(torch.sum(a * b, dim=-1).abs().min(torch.as_tensor(1.,dtype=a.dtype)))
return 2.*torch.acos(torch.sum(a * b, dim=-1).abs().min(torch.as_tensor(1.,dtype=a.dtype)))


def quat_average(quats):
quats = np.asarray(quats)
# Ensemble size, number of samples, dimensionality
E, N, D = quats.shape
assert D==4
# Sum over ensemble to get an idea of the largest axis on average.
# Then find the actual longest axis, i.e. i,j,k or w.
pivot_axes = np.argmax(np.sum(np.abs(quats), axis=0), axis=-1)
assert pivot_axes.shape == (N,)
mask = np.take_along_axis(quats, pivot_axes[None,:,None], axis=-1) < 0.
mask = mask[...,0] # Skip quaternion dimension
quats[mask,:] *= -1
quats = np.average(quats, axis=0)
norms = np.linalg.norm(quats, axis=-1, keepdims=True)
if not np.all(norms > 0.5):
print("Oh oh either quat_average is bugged or rotations predictions differ wildly")
quats /= norms
return quats
7 changes: 5 additions & 2 deletions trackertraincode/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,21 +309,24 @@ def make_pose_estimation_loaders(inputsize, batchsize, datasets : Sequence[Id],
image_augs = [
dtr.RandomEqualize(p=0.2),
dtr.RandomPosterize((4.,6.), p=0.01),
dtr.RandomInvert(p=0.1),
dtr.RandomGamma((0.5, 2.0), p = 0.2),
dtr.RandomContrast((0.7, 1.5), p = 0.2),
dtr.RandomBrightness((0.7, 1.5), p = 0.2),
]
if auglevel in (2, 1, 3):
image_augs += [
dtr.RandomGaussianBlur(p=0.1, kernel_size=(5,5), sigma=(1.5,1.5)),
dtr.RandomGaussianNoiseWithClipping(std=4./255., p=0.1)
#dtr.RandomGaussianNoiseWithClipping(std=4./255., p=0.1)
]

loader_trafo_train = [
partial(dtr.normalize_batch, align_corners=False),
partial(dtr.to_device, 'cuda'),
dtr.KorniaImageDistortions(*image_augs, random_apply = 4),
dtr.KorniaImageDistortions(
dtr.RandomGaussianNoise(std=4./255., p=0.5),
dtr.RandomGaussianNoise(std=16./255., p=0.1),
),
whiten_batch
]

Expand Down

0 comments on commit 3d9a3bd

Please sign in to comment.