diff --git a/examples/efficientpose/debugger.py b/examples/efficientpose/debugger.py index d29838c4c..880853989 100644 --- a/examples/efficientpose/debugger.py +++ b/examples/efficientpose/debugger.py @@ -131,8 +131,7 @@ def __init__(self, score_thresh=0.60, nms_thresh=0.45, show_boxes2D=False, show_poses6D=True): names = get_class_names('LINEMOD_EFFICIENTPOSE_DRILLER') model = EFFICIENTPOSEA(num_classes=len(names), base_weights='COCO', - head_weights=None, momentum=0.99, - epsilon=0.001, activation='softmax') + head_weights=None) super(EFFICIENTPOSEALINEMODDEBUG, self).__init__( model, names, score_thresh, nms_thresh, LINEMOD_CAMERA_MATRIX, LINEMOD_OBJECT_SIZES, diff --git a/examples/efficientpose/demo.py b/examples/efficientpose/demo.py index e479b0c9f..54648ed9a 100644 --- a/examples/efficientpose/demo.py +++ b/examples/efficientpose/demo.py @@ -1,12 +1,13 @@ from paz.backend.image import load_image, show_image -from pose import EFFICIENTPOSEALINEMOD +from pose import EFFICIENTPOSEALINEMODDRILLER -IMAGE_PATH = ('/home/manummk95/Desktop/ybkscht_efficientpose/EfficientPose/' - 'Datasets/Linemod_preprocessed/data/02/rgb/0000.png') +IMAGE_PATH = ('/home/manummk95/Desktop/paz/paz/examples/efficientpose/' + 'Linemod_preprocessed/data/08/rgb/0002.png') -detect = EFFICIENTPOSEALINEMOD(score_thresh=0.60, nms_thresh=0.45, - show_boxes2D=False, show_poses6D=True) +detect = EFFICIENTPOSEALINEMODDRILLER(score_thresh=0.60, nms_thresh=0.45, + show_boxes2D=True, show_poses6D=True) +detect.model.load_weights('weights.4583-1.55.hdf5') image = load_image(IMAGE_PATH) inferences = detect(image) show_image(inferences['image']) diff --git a/examples/efficientpose/demo_1.py b/examples/efficientpose/demo_1.py deleted file mode 100644 index f127fc0c7..000000000 --- a/examples/efficientpose/demo_1.py +++ /dev/null @@ -1,13 +0,0 @@ -from paz.backend.image import load_image, show_image -from pose import EFFICIENTPOSEALINEMODDRILLER - -IMAGE_PATH = ('/home/manummk95/Desktop/paz/paz/examples/efficientpose/' - 'Linemod_preprocessed/data/08/rgb/0002.png') - - -detect = EFFICIENTPOSEALINEMODDRILLER(score_thresh=0.60, nms_thresh=0.45, - show_boxes2D=False, show_poses6D=True) -detect.model.load_weights('weights.100-2.38.hdf5') -image = load_image(IMAGE_PATH) -inferences = detect(image) -show_image(inferences['image']) diff --git a/examples/efficientpose/efficientdet_blocks.py b/examples/efficientpose/efficientdet_blocks.py deleted file mode 100644 index 34edda74a..000000000 --- a/examples/efficientpose/efficientdet_blocks.py +++ /dev/null @@ -1,289 +0,0 @@ -import numpy as np -import tensorflow as tf -import tensorflow.keras.backend as K -from tensorflow.keras.layers import Activation, Concatenate, Reshape -from tensorflow.keras.layers import (BatchNormalization, Conv2D, Flatten, - MaxPooling2D, SeparableConv2D, - UpSampling2D) -from paz.models.detection.efficientdet.layers import ( - FuseFeature, GetDropConnect) - - -def build_detector_head(middles, num_classes, num_dims, aspect_ratios, - num_scales, FPN_num_filters, box_class_repeats, - survival_rate, momentum=0.99, epsilon=0.001, - activation='softmax'): - """Builds EfficientDet object detector's head. - The built head includes ClassNet and BoxNet for classification and - regression respectively. - - # Arguments - middles: List, BiFPN layer output. - num_classes: Int, number of object classes. - num_dims: Int, number of output dimensions to regress. - aspect_ratios: List, anchor boxes aspect ratios. - num_scales: Int, number of anchor box scales. - FPN_num_filters: Int, number of FPN filters. - box_class_repeats: Int, Number of regression - and classification blocks. - survival_rate: Float, used in drop connect. - - # Returns - outputs: Tensor of shape `[num_boxes, num_classes+num_dims]` - """ - num_anchors = len(aspect_ratios) * num_scales - args = (middles, momentum, epsilon, num_anchors, - FPN_num_filters, box_class_repeats, survival_rate) - class_outputs = ClassNet(*args, num_classes) - boxes_outputs = BoxesNet(*args, num_dims) - classes = Concatenate(axis=1)(class_outputs) - regressions = Concatenate(axis=1)(boxes_outputs) - num_boxes = K.int_shape(regressions)[-1] // num_dims - classes = Reshape((num_boxes, num_classes))(classes) - classes = Activation(activation)(classes) - regressions = Reshape((num_boxes, num_dims))(regressions) - outputs = Concatenate(axis=2, name='boxes')([regressions, classes]) - return outputs - - -def ClassNet(features, momentum, epsilon, num_anchors=9, num_filters=32, - num_blocks=4, survival_rate=None, num_classes=90): - """Initializes ClassNet. - - # Arguments - features: List, input features. - num_anchors: Int, number of anchors. - num_filters: Int, number of intermediate layer filters. - num_blocks: Int, Number of intermediate layers. - survival_rate: Float, used in drop connect. - num_classes: Int, number of object classes. - - # Returns - class_outputs: List, ClassNet outputs per level. - """ - bias_initializer = tf.constant_initializer(-np.log((1 - 0.01) / 0.01)) - num_filters = [num_filters, num_classes * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer, momentum, epsilon) - - -def BoxesNet(features, momentum, epsilon, num_anchors=9, num_filters=32, - num_blocks=4, survival_rate=None, num_dims=4): - """Initializes BoxNet. - - # Arguments - features: List, input features. - num_anchors: Int, number of anchors. - num_filters: Int, number of intermediate layer filters. - num_blocks: Int, Number of intermediate layers. - survival_rate: Float, used by drop connect. - num_dims: Int, number of output dimensions to regress. - - # Returns - boxes_outputs: List, BoxNet outputs per level. - """ - bias_initializer = tf.zeros_initializer() - num_filters = [num_filters, num_dims * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer, momentum, epsilon) - - -def build_head(middle_features, num_blocks, num_filters, - survival_rate, bias_initializer, momentum, epsilon): - """Builds ClassNet/BoxNet head. - - # Arguments - middle_features: Tuple. input features. - num_blocks: Int, number of intermediate layers. - num_filters: Int, number of intermediate layer filters. - survival_rate: Float, used by drop connect. - bias_initializer: Callable, bias initializer. - - # Returns - head_outputs: List, with head outputs. - """ - conv_blocks = build_head_conv2D( - num_blocks, num_filters[0], tf.zeros_initializer()) - final_head_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - head_outputs = [] - for x in middle_features: - for block_arg in range(num_blocks): - x = conv_blocks[block_arg](x) - x = BatchNormalization(momentum=momentum, epsilon=epsilon)(x) - x = tf.nn.swish(x) - if block_arg > 0 and survival_rate: - x = x + GetDropConnect(survival_rate=survival_rate)(x) - x = final_head_conv(x) - x = Flatten()(x) - head_outputs.append(x) - return head_outputs - - -def build_head_conv2D(num_blocks, num_filters, bias_initializer): - """Builds head convolutional blocks. - - # Arguments - num_blocks: Int, number of intermediate layers. - num_filters: Int, number of intermediate layer filters. - bias_initializer: Callable, bias initializer. - - # Returns - conv_blocks: List, head convolutional blocks. - """ - conv_blocks = [] - args_1 = (num_filters, 3, (1, 1), 'same', 'channels_last', (1, 1), - 1, None, True) - for _ in range(num_blocks): - args_2 = (tf.initializers.variance_scaling(), - tf.initializers.variance_scaling(), bias_initializer) - conv_blocks.append(SeparableConv2D(*args_1, *args_2)) - return conv_blocks - - -def EfficientNet_to_BiFPN(branches, num_filters, - momentum=0.99, epsilon=0.001): - """Preprocess EfficientNet branches prior to feeding BiFPN block. - The branches generated by the EfficientNet backbone consists of - features P1, P2, P3, P4, and P5. However, the BiFPN block requires - features P3, P4, P5, P6, and P7. This function generates features - P3 to P7 from EfficientNet branches that can be fed to the BiFPN - block. - - # Arguments - branches: List, EfficientNet feature maps. - num_filters: Int, number of intermediate layer filters. - - # Returns - branches, middles, skips: List, extended branch - and preprocessed feature maps. - """ - args = num_filters, momentum, epsilon - branches = extend_branch(branches, *args) - P3, P4, P5, P6, P7 = branches - P3_middle = conv_batchnorm_block(P3, *args) - P4_middle = conv_batchnorm_block(P4, *args) - P5_middle = conv_batchnorm_block(P5, *args) - middles = [P3_middle, P4_middle, P5_middle, P6, P7] - - P4_skip = conv_batchnorm_block(P4, *args) - P5_skip = conv_batchnorm_block(P5, *args) - skips = [None, P4_skip, P5_skip, P6, None] - return [branches, middles, skips] - - -def extend_branch(branches, num_filters, momentum, epsilon): - """Extends branches to comply with BiFPN. - The input branchs includes features P1-P5. This function extends the - EfficientNet backbone generated branch. The extended branch contains - features P3-P7. - - # Arguments - branches: List, EfficientNet feature maps. - num_filters: Int, number of intermediate layer filters. - - # Returns - middles, skips: List, modified branch. - """ - _, _, P3, P4, P5 = branches - P6, P7 = build_branch(P5, num_filters, momentum, epsilon) - branches_extended = [P3, P4, P5, P6, P7] - return branches_extended - - -def build_branch(P5, num_filters, momentum, epsilon): - """Builds feature maps P6 and P7 from P5. - - # Arguments - P5: Tensor of shape `(batch_size, 16, 16, 320)`, - EfficientNet's 5th layer output. - num_filters: Int, number of intermediate layer filters. - - # Returns - P6, P7: List, EfficientNet's 6th and 7th layer output. - """ - P6 = conv_batchnorm_block(P5, num_filters, momentum, epsilon) - P6 = MaxPooling2D(3, 2, 'same')(P6) - P7 = MaxPooling2D(3, 2, 'same')(P6) - return [P6, P7] - - -def conv_batchnorm_block(x, num_filters, momentum, epsilon): - """Builds 2D convolution and batch normalization layers. - - # Arguments - x: Tensor, input feature map. - num_filters: Int, number of intermediate layer filters. - - # Returns - x: Tensor. Feature after convolution and batch normalization. - """ - x = Conv2D(num_filters, 1, 1, 'same')(x) - x = BatchNormalization(momentum=momentum, epsilon=epsilon)(x) - return x - - -def BiFPN(middles, skips, num_filters, fusion, momentum=0.99, epsilon=0.001): - """BiFPN block. - BiFPN stands for Bidirectional Feature Pyramid Network. - - # Arguments - middles: List, BiFPN layer output. - skips: List, skip feature map from BiFPN node. - num_filters: Int, number of intermediate layer filters. - fusion: Str, feature fusion method. - - # Returns - middles, middles: List, BiFPN block output. - """ - P3_middle, P4_middle, P5_middle, P6_middle, P7_middle = middles - _, P4_skip, P5_skip, P6_skip, _ = skips - - # Downpropagation --------------------------------------------------------- - args = (num_filters, fusion, momentum, epsilon) - P7_up = UpSampling2D()(P7_middle) - P6_top_down = node_BiFPN(P7_up, P6_middle, None, None, *args) - P6_up = UpSampling2D()(P6_top_down) - P5_top_down = node_BiFPN(P6_up, P5_middle, None, None, *args) - P5_up = UpSampling2D()(P5_top_down) - P4_top_down = node_BiFPN(P5_up, P4_middle, None, None, *args) - P4_up = UpSampling2D()(P4_top_down) - P3_out = node_BiFPN(P4_up, P3_middle, None, None, *args) - - # Upward propagation ------------------------------------------------------ - P3_down = MaxPooling2D(3, 2, 'same')(P3_out) - P4_out = node_BiFPN(None, P4_top_down, P3_down, P4_skip, *args) - P4_down = MaxPooling2D(3, 2, 'same')(P4_out) - P5_out = node_BiFPN(None, P5_top_down, P4_down, P5_skip, *args) - P5_down = MaxPooling2D(3, 2, 'same')(P5_out) - P6_out = node_BiFPN(None, P6_top_down, P5_down, P6_skip, *args) - P6_down = MaxPooling2D(3, 2, 'same')(P6_out) - P7_out = node_BiFPN(None, P7_middle, P6_down, None, *args) - - middles = [P3_out, P4_out, P5_out, P6_out, P7_out] - return [middles, middles] - - -def node_BiFPN(up, middle, down, skip, num_filters, fusion, momentum, epsilon): - """Simulates a single node of BiFPN block. - - # Arguments - up: Tensor, upsampled feature map. - middle: Tensor, preprocessed feature map. - down: Tensor, downsampled feature map. - skip: Tensor, skip feature map. - num_filters: Int, number of intermediate layer filters. - fusion: Str, feature fusion method. - - # Returns - middle: Tensor, BiFPN node output. - """ - is_layer_one = down is None - if is_layer_one: - to_fuse = [middle, up] - else: - to_fuse = [middle, down] if skip is None else [skip, middle, down] - middle = FuseFeature(fusion=fusion)(to_fuse, fusion) - middle = tf.nn.swish(middle) - middle = SeparableConv2D(num_filters, 3, 1, 'same', use_bias=True)(middle) - middle = BatchNormalization(momentum=momentum, epsilon=epsilon)(middle) - return middle diff --git a/examples/efficientpose/efficientdet_blocks_with_bug.py b/examples/efficientpose/efficientdet_blocks_with_bug.py deleted file mode 100644 index a9c44c35a..000000000 --- a/examples/efficientpose/efficientdet_blocks_with_bug.py +++ /dev/null @@ -1,289 +0,0 @@ -import numpy as np -import tensorflow as tf -import tensorflow.keras.backend as K -from tensorflow.keras.layers import Activation, Concatenate, Reshape -from tensorflow.keras.layers import (BatchNormalization, Conv2D, Flatten, - MaxPooling2D, SeparableConv2D, - UpSampling2D) -from paz.models.detection.efficientdet.layers import ( - FuseFeature, GetDropConnect) - - -def build_detector_head(middles, num_classes, num_dims, aspect_ratios, - num_scales, FPN_num_filters, box_class_repeats, - survival_rate, momentum=0.99, epsilon=0.001, - activation='softmax'): - """Builds EfficientDet object detector's head. - The built head includes ClassNet and BoxNet for classification and - regression respectively. - - # Arguments - middles: List, BiFPN layer output. - num_classes: Int, number of object classes. - num_dims: Int, number of output dimensions to regress. - aspect_ratios: List, anchor boxes aspect ratios. - num_scales: Int, number of anchor box scales. - FPN_num_filters: Int, number of FPN filters. - box_class_repeats: Int, Number of regression - and classification blocks. - survival_rate: Float, used in drop connect. - - # Returns - outputs: Tensor of shape `[num_boxes, num_classes+num_dims]` - """ - num_anchors = len(aspect_ratios) * num_scales - args = (middles, momentum, epsilon, num_anchors, - FPN_num_filters, box_class_repeats, survival_rate) - class_outputs = ClassNet(*args, num_classes) - boxes_outputs = BoxesNet(*args, num_dims) - classes = Concatenate(axis=1)(class_outputs) - regressions = Concatenate(axis=1)(boxes_outputs) - num_boxes = K.int_shape(regressions)[-1] // num_dims - classes = Reshape((num_boxes, num_classes))(classes) - classes = Activation(activation)(classes) - regressions = Reshape((num_boxes, num_dims))(regressions) - outputs = Concatenate(axis=2, name='boxes')([regressions, classes]) - return outputs - - -def ClassNet(features, momentum, epsilon, num_anchors=9, num_filters=32, - num_blocks=4, survival_rate=None, num_classes=90): - """Initializes ClassNet. - - # Arguments - features: List, input features. - num_anchors: Int, number of anchors. - num_filters: Int, number of intermediate layer filters. - num_blocks: Int, Number of intermediate layers. - survival_rate: Float, used in drop connect. - num_classes: Int, number of object classes. - - # Returns - class_outputs: List, ClassNet outputs per level. - """ - bias_initializer = tf.constant_initializer(-np.log((1 - 0.01) / 0.01)) - num_filters = [num_filters, num_classes * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer, momentum, epsilon) - - -def BoxesNet(features, momentum, epsilon, num_anchors=9, num_filters=32, - num_blocks=4, survival_rate=None, num_dims=4): - """Initializes BoxNet. - - # Arguments - features: List, input features. - num_anchors: Int, number of anchors. - num_filters: Int, number of intermediate layer filters. - num_blocks: Int, Number of intermediate layers. - survival_rate: Float, used by drop connect. - num_dims: Int, number of output dimensions to regress. - - # Returns - boxes_outputs: List, BoxNet outputs per level. - """ - bias_initializer = tf.zeros_initializer() - num_filters = [num_filters, num_dims * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer, momentum, epsilon) - - -def build_head(middle_features, num_blocks, num_filters, - survival_rate, bias_initializer, momentum, epsilon): - """Builds ClassNet/BoxNet head. - - # Arguments - middle_features: Tuple. input features. - num_blocks: Int, number of intermediate layers. - num_filters: Int, number of intermediate layer filters. - survival_rate: Float, used by drop connect. - bias_initializer: Callable, bias initializer. - - # Returns - head_outputs: List, with head outputs. - """ - conv_blocks = build_head_conv2D( - num_blocks, num_filters[0], tf.zeros_initializer()) - final_head_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - head_outputs = [] - for x in middle_features: - for block_arg in range(num_blocks): - x = conv_blocks[block_arg](x) - x = BatchNormalization(momentum=momentum, epsilon=epsilon)(x) - x = tf.nn.swish(x) - if block_arg > 0 and survival_rate: - x = x + GetDropConnect(survival_rate=survival_rate)(x) - x = final_head_conv(x) - x = Flatten()(x) - head_outputs.append(x) - return head_outputs - - -def build_head_conv2D(num_blocks, num_filters, bias_initializer): - """Builds head convolutional blocks. - - # Arguments - num_blocks: Int, number of intermediate layers. - num_filters: Int, number of intermediate layer filters. - bias_initializer: Callable, bias initializer. - - # Returns - conv_blocks: List, head convolutional blocks. - """ - conv_blocks = [] - args_1 = (num_filters, 3, (1, 1), 'same', 'channels_last', (1, 1), - 1, None, True) - for _ in range(num_blocks): - args_2 = (tf.initializers.variance_scaling(), - tf.initializers.variance_scaling(), bias_initializer) - conv_blocks.append(SeparableConv2D(*args_1, *args_2)) - return conv_blocks - - -def EfficientNet_to_BiFPN(branches, num_filters, - momentum=0.99, epsilon=0.001): - """Preprocess EfficientNet branches prior to feeding BiFPN block. - The branches generated by the EfficientNet backbone consists of - features P1, P2, P3, P4, and P5. However, the BiFPN block requires - features P3, P4, P5, P6, and P7. This function generates features - P3 to P7 from EfficientNet branches that can be fed to the BiFPN - block. - - # Arguments - branches: List, EfficientNet feature maps. - num_filters: Int, number of intermediate layer filters. - - # Returns - branches, middles, skips: List, extended branch - and preprocessed feature maps. - """ - args = num_filters, momentum, epsilon - branches = extend_branch(branches, *args) - P3, P4, P5, P6, P7 = branches - P3_middle = conv_batchnorm_block(P3, *args) - P4_middle = conv_batchnorm_block(P4, *args) - P5_middle = conv_batchnorm_block(P5, *args) - middles = [P3_middle, P4_middle, P5_middle, P6, P7] - - P4_skip = conv_batchnorm_block(P4, *args) - P5_skip = conv_batchnorm_block(P5, *args) - skips = [None, P4_skip, P5_skip, P6, None] - return [branches, middles, skips] - - -def extend_branch(branches, num_filters, momentum, epsilon): - """Extends branches to comply with BiFPN. - The input branchs includes features P1-P5. This function extends the - EfficientNet backbone generated branch. The extended branch contains - features P3-P7. - - # Arguments - branches: List, EfficientNet feature maps. - num_filters: Int, number of intermediate layer filters. - - # Returns - middles, skips: List, modified branch. - """ - _, _, P3, P4, P5 = branches - P6, P7 = build_branch(P5, num_filters, momentum, epsilon) - branches_extended = [P3, P4, P5, P6, P7] - return branches_extended - - -def build_branch(P5, num_filters, momentum, epsilon): - """Builds feature maps P6 and P7 from P5. - - # Arguments - P5: Tensor of shape `(batch_size, 16, 16, 320)`, - EfficientNet's 5th layer output. - num_filters: Int, number of intermediate layer filters. - - # Returns - P6, P7: List, EfficientNet's 6th and 7th layer output. - """ - P6 = conv_batchnorm_block(P5, num_filters, momentum, epsilon) - P6 = MaxPooling2D(3, 2, 'same')(P6) - P7 = MaxPooling2D(3, 2, 'same')(P6) - return [P6, P7] - - -def conv_batchnorm_block(x, num_filters, momentum, epsilon): - """Builds 2D convolution and batch normalization layers. - - # Arguments - x: Tensor, input feature map. - num_filters: Int, number of intermediate layer filters. - - # Returns - x: Tensor. Feature after convolution and batch normalization. - """ - x = Conv2D(num_filters, 1, 1, 'same')(x) - x = BatchNormalization(momentum=momentum, epsilon=epsilon)(x) - return x - - -def BiFPN(middles, skips, num_filters, fusion, momentum=0.99, epsilon=0.001): - """BiFPN block. - BiFPN stands for Bidirectional Feature Pyramid Network. - - # Arguments - middles: List, BiFPN layer output. - skips: List, skip feature map from BiFPN node. - num_filters: Int, number of intermediate layer filters. - fusion: Str, feature fusion method. - - # Returns - middles, middles: List, BiFPN block output. - """ - P3_middle, P4_middle, P5_middle, P6_middle, P7_middle = middles - _, P4_skip, P5_skip, P6_skip, _ = skips - - # Downpropagation --------------------------------------------------------- - args = (num_filters, fusion, momentum, epsilon) - P7_up = UpSampling2D()(P7_middle) - P6_top_down = node_BiFPN(P7_up, P6_middle, None, None, *args) - P6_up = UpSampling2D()(P6_top_down) - P5_top_down = node_BiFPN(P6_up, P5_middle, None, None, *args) - P5_up = UpSampling2D()(P5_top_down) - P4_top_down = node_BiFPN(P5_up, P4_middle, None, None, *args) - P4_up = UpSampling2D()(P4_top_down) - P3_out = node_BiFPN(P4_up, P3_middle, None, None, *args) - - # Upward propagation ------------------------------------------------------ - P3_down = MaxPooling2D(3, 2, 'same')(P3_out) - P4_out = node_BiFPN(None, P4_top_down, P3_down, P4_skip, *args) - P4_down = MaxPooling2D(3, 2, 'same')(P4_out) - P5_out = node_BiFPN(None, P5_top_down, P4_down, P5_skip, *args) - P5_down = MaxPooling2D(3, 2, 'same')(P5_out) - P6_out = node_BiFPN(None, P6_top_down, P5_down, P6_skip, *args) - P6_down = MaxPooling2D(3, 2, 'same')(P6_out) - P7_out = node_BiFPN(None, P7_middle, P6_down, None, *args) - - middles = [P3_out, P4_top_down, P5_top_down, P6_top_down, P7_out] - return [middles, middles] - - -def node_BiFPN(up, middle, down, skip, num_filters, fusion, momentum, epsilon): - """Simulates a single node of BiFPN block. - - # Arguments - up: Tensor, upsampled feature map. - middle: Tensor, preprocessed feature map. - down: Tensor, downsampled feature map. - skip: Tensor, skip feature map. - num_filters: Int, number of intermediate layer filters. - fusion: Str, feature fusion method. - - # Returns - middle: Tensor, BiFPN node output. - """ - is_layer_one = down is None - if is_layer_one: - to_fuse = [middle, up] - else: - to_fuse = [middle, down] if skip is None else [skip, middle, down] - middle = FuseFeature(fusion=fusion)(to_fuse, fusion) - middle = tf.nn.swish(middle) - middle = SeparableConv2D(num_filters, 3, 1, 'same', use_bias=True)(middle) - middle = BatchNormalization(momentum=momentum, epsilon=epsilon)(middle) - return middle diff --git a/examples/efficientpose/efficientpose.py b/examples/efficientpose/efficientpose.py index 409ad6f6d..42aea652d 100644 --- a/examples/efficientpose/efficientpose.py +++ b/examples/efficientpose/efficientpose.py @@ -3,12 +3,11 @@ from paz.backend.anchors import build_anchors from paz.models.detection.efficientdet.efficientnet import EFFICIENTNET from anchors import build_translation_anchors -from efficientdet_blocks import build_detector_head, EfficientNet_to_BiFPN -from efficientdet_blocks import BiFPN +from paz.models.detection.efficientdet.efficientdet_blocks import ( + build_detector_head, EfficientNet_to_BiFPN, BiFPN) from efficientpose_blocks import build_pose_estimator_head -WEIGHT_PATH = ( - '/home/manummk95/Desktop/paz/paz/examples/efficientpose/weights/') +WEIGHT_PATH = 'weights/' def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, @@ -16,10 +15,9 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, anchor_scale, fusion, return_base, model_name, EfficientNet, subnet_iterations=1, subnet_repeats=3, num_scales=3, aspect_ratios=[1.0, 2.0, 0.5], survival_rate=None, - num_dims=4, momentum=0.99, epsilon=0.001, - activation='softmax', num_anchors=9, num_filters=64, + num_dims=4, num_anchors=9, num_filters=64, num_pose_dims=3): - """Creates EfficientPose model. + """Builds EfficientPose model. # Arguments image: Tensor of shape `(batch_size, input_shape)`. @@ -35,19 +33,15 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, return_base: Bool, whether to return base or not. model_name: Str, EfficientDet model name. EfficientNet: List, containing branch tensors. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. num_scales: Int, number of anchor box scales. aspect_ratios: List, anchor boxes aspect ratios. survival_rate: Float, specifying survival probability. num_dims: Int, number of output dimensions to regress. - momentum: Float, batch normalization moving average momentum. - epsilon: Float, small float added to - variance to avoid division by zero. - activation: Str, activation function for classes. - num_anchors: List, number of combinations of - anchor box's scale and aspect ratios. + num_anchors: List, number of combinations of anchor box's scale + and aspect ratios. num_filters: Int, number of subnet filters. num_pose_dims: Int, number of pose dimensions. @@ -55,8 +49,9 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, model: EfficientPose model. # References - [ybkscht repository implementation of EfficientPose]( - https://github.com/ybkscht/EfficientPose) + [EfficientPose: An efficient, accurate and scalable end-to-end + 6D multi object pose estimation approach]( + https://arxiv.org/pdf/2011.04307.pdf) """ if base_weights not in ['COCO', None]: raise ValueError('Invalid base_weights: ', base_weights) @@ -66,18 +61,16 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, raise NotImplementedError('Invalid `base_weights` with head_weights') branches, middles, skips = EfficientNet_to_BiFPN( - EfficientNet, FPN_num_filters, momentum, epsilon) + EfficientNet, FPN_num_filters) for _ in range(FPN_cell_repeats): - middles, skips = BiFPN(middles, skips, FPN_num_filters, - fusion, momentum, epsilon) + middles, skips = BiFPN(middles, skips, FPN_num_filters, fusion) if return_base: outputs = middles else: detection_outputs = build_detector_head( middles, num_classes, num_dims, aspect_ratios, num_scales, - FPN_num_filters, box_class_repeats, survival_rate, - momentum, epsilon, activation) + FPN_num_filters, box_class_repeats, survival_rate) pose_outputs = build_pose_estimator_head( middles, subnet_iterations, subnet_repeats, @@ -97,7 +90,14 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, if not ((base_weights is None) and (head_weights is None)): weights_path = WEIGHT_PATH + model_filename - finetunning_model_names = ['efficientpose-a-COCO-None_weights.hdf5'] + finetunning_model_names = ['efficientpose-a-COCO-None_weights.hdf5', + 'efficientpose-b-COCO-None_weights.hdf5', + 'efficientpose-c-COCO-None_weights.hdf5', + 'efficientpose-d-COCO-None_weights.hdf5', + 'efficientpose-e-COCO-None_weights.hdf5', + 'efficientpose-f-COCO-None_weights.hdf5', + 'efficientpose-g-COCO-None_weights.hdf5', + 'efficientpose-h-COCO-None_weights.hdf5'] by_name = True if model_filename in finetunning_model_names else False print('Loading %s model weights' % weights_path) model.load_weights(weights_path, by_name=by_name) @@ -113,11 +113,11 @@ def EFFICIENTPOSE(image, num_classes, base_weights, head_weights, def EFFICIENTPOSEA(num_classes=8, base_weights='COCO', head_weights='LINEMOD_OCCLUDED', input_shape=(512, 512, 3), - FPN_num_filters=64, FPN_cell_repeats=3, subnet_repeats=3, + FPN_num_filters=64, FPN_cell_repeats=3, subnet_repeats=2, subnet_iterations=1, box_class_repeats=3, anchor_scale=4.0, fusion='fast', return_base=False, - model_name='efficientpose-a', momentum=0.99, epsilon=0.001, - activation='softmax', scaling_coefficients=(1.0, 1.0, 0.8)): + model_name='efficientpose-a', + scaling_coefficients=(1.0, 1.0, 0.8)): """Instantiates EfficientPose-A model. # Arguments @@ -128,18 +128,14 @@ def EFFICIENTPOSEA(num_classes=8, base_weights='COCO', FPN_num_filters: Int, number of FPN filters. FPN_cell_repeats: Int, number of FPN blocks. subnet_repeats: Int, number of layers used in subnetworks. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. box_class_repeats: Int, Number of regression and classification blocks. anchor_scale: Int, number of anchor scales. fusion: Str, feature fusion weighting method. return_base: Bool, whether to return base or not. model_name: Str, EfficientDet model name. - momentum: Float, batch normalization moving average momentum. - epsilon: Float, small float added to - variance to avoid division by zero. - activation: Str, activation function for classes. scaling_coefficients: Tuple, EfficientNet scaling coefficients. # Returns @@ -150,7 +146,278 @@ def EFFICIENTPOSEA(num_classes=8, base_weights='COCO', model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, FPN_num_filters, FPN_cell_repeats, box_class_repeats, anchor_scale, fusion, return_base, model_name, - EfficientNetb0, subnet_iterations, subnet_repeats, - momentum=momentum, epsilon=epsilon, - activation=activation) + EfficientNetb0, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEB(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', input_shape=(640, 640, 3), + FPN_num_filters=88, FPN_cell_repeats=4, subnet_repeats=2, + subnet_iterations=1, box_class_repeats=3, anchor_scale=4.0, + fusion='fast', return_base=False, + model_name='efficientpose-b', + scaling_coefficients=(1.0, 1.0, 0.8)): + """Instantiates EfficientPose-B model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-B model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb1 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb1, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEC(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', input_shape=(768, 768, 3), + FPN_num_filters=112, FPN_cell_repeats=5, subnet_repeats=2, + subnet_iterations=1, box_class_repeats=3, anchor_scale=4.0, + fusion='fast', return_base=False, + model_name='efficientpose-c', + scaling_coefficients=(1.1, 1.2, 0.7)): + """Instantiates EfficientPose-C model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-C model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb2 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb2, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSED(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', input_shape=(896, 896, 3), + FPN_num_filters=160, FPN_cell_repeats=6, subnet_repeats=3, + subnet_iterations=2, box_class_repeats=4, anchor_scale=4.0, + fusion='fast', return_base=False, + model_name='efficientpose-d', + scaling_coefficients=(1.2, 1.4, 0.7)): + """Instantiates EfficientPose-D model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-D model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb3 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb3, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEE(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', + input_shape=(1024, 1024, 3), FPN_num_filters=224, + FPN_cell_repeats=7, subnet_repeats=3, subnet_iterations=2, + box_class_repeats=4, anchor_scale=4.0, fusion='fast', + return_base=False, model_name='efficientpose-e', + scaling_coefficients=(1.2, 1.4, 0.7)): + """Instantiates EfficientPose-E model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-E model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb4 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb4, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEF(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', + input_shape=(1280, 1280, 3), FPN_num_filters=288, + FPN_cell_repeats=7, subnet_repeats=3, subnet_iterations=2, + box_class_repeats=4, anchor_scale=4.0, fusion='fast', + return_base=False, model_name='efficientpose-f', + scaling_coefficients=(1.6, 2.2, 0.6)): + """Instantiates EfficientPose-F model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-F model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb5 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb5, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEG(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', + input_shape=(1280, 1280, 3), FPN_num_filters=384, + FPN_cell_repeats=8, subnet_repeats=4, subnet_iterations=3, + box_class_repeats=5, anchor_scale=5.0, fusion='sum', + return_base=False, model_name='efficientpose-g', + scaling_coefficients=(1.8, 2.6, 0.5)): + """Instantiates EfficientPose-G model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-G model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb6 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb6, subnet_iterations, subnet_repeats) + return model + + +def EFFICIENTPOSEH(num_classes=8, base_weights='COCO', + head_weights='LINEMOD_OCCLUDED', + input_shape=(1536, 1536, 3), FPN_num_filters=384, + FPN_cell_repeats=8, subnet_repeats=4, subnet_iterations=3, + box_class_repeats=5, anchor_scale=5.0, fusion='sum', + return_base=False, model_name='efficientpose-h', + scaling_coefficients=(1.8, 2.6, 0.5)): + """Instantiates EfficientPose-H model. + + # Arguments + num_classes: Int, number of object classes. + base_weights: Str, base weights name. + head_weights: Str, head weights name. + input_shape: Tuple, holding input image size. + FPN_num_filters: Int, number of FPN filters. + FPN_cell_repeats: Int, number of FPN blocks. + subnet_repeats: Int, number of layers used in subnetworks. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. + box_class_repeats: Int, Number of regression + and classification blocks. + anchor_scale: Int, number of anchor scales. + fusion: Str, feature fusion weighting method. + return_base: Bool, whether to return base or not. + model_name: Str, EfficientDet model name. + scaling_coefficients: Tuple, EfficientNet scaling coefficients. + + # Returns + model: EfficientPose-H model. + """ + image = Input(shape=input_shape, name='image') + EfficientNetb6 = EFFICIENTNET(image, scaling_coefficients) + model = EFFICIENTPOSE(image, num_classes, base_weights, head_weights, + FPN_num_filters, FPN_cell_repeats, box_class_repeats, + anchor_scale, fusion, return_base, model_name, + EfficientNetb6, subnet_iterations, subnet_repeats) return model diff --git a/examples/efficientpose/efficientpose_blocks.py b/examples/efficientpose/efficientpose_blocks.py index 5f4ddfbe7..62ab18b46 100644 --- a/examples/efficientpose/efficientpose_blocks.py +++ b/examples/efficientpose/efficientpose_blocks.py @@ -2,245 +2,228 @@ from tensorflow.keras.layers import (GroupNormalization, Concatenate, Add, Reshape) from paz.models.detection.efficientdet.efficientdet_blocks import ( - build_head_conv2D) + build_head_conv2D, build_head) def build_pose_estimator_head(middles, subnet_iterations, subnet_repeats, num_anchors, num_filters, num_dims): - """Builds EfficientPose pose estimator's head. - The built head includes RotationNet and TranslationNet - for estimating rotation and translation respectively. + """Builds EfficientPose pose estimator head + containing RotationNet and TranslationNet for + estimation of rotation and translation of the object respectively. # Arguments middles: List, BiFPN layer output. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. - num_anchors: List, number of combinations of - anchor box's scale and aspect ratios. + num_anchors: List, number of combinations of anchor box's scale + and aspect ratios. num_filters: Int, number of subnet filters. num_dims: Int, number of pose dimensions. # Returns - List: Containing estimated rotations and translations of shape - `(None, num_boxes, num_dims)` and - `(None, num_boxes, num_dims)` respectively. + Tensor: Concatenation of estimated rotations and translations + of shape `(None, num_boxes, num_dims + num_dims)` """ args = (middles, subnet_iterations, subnet_repeats, num_anchors) rotations = RotationNet(*args, num_filters, num_dims) - rotations = Concatenate(axis=1,)(rotations) + rotations = Concatenate(axis=1)(rotations) translations = TranslationNet(*args, num_filters) translations = Concatenate(axis=1)(translations) concatenate_transformation = Concatenate(axis=-1, name='transformation') - transformations = concatenate_transformation([rotations, translations]) - return transformations + return concatenate_transformation([rotations, translations]) -def RotationNet(middles, subnet_iterations, subnet_repeats, - num_anchors, num_filters, num_dims): +def RotationNet(middles, subnet_iterations, subnet_repeats, num_anchors, + num_filters, num_dims, survival_rate=None): """Initializes RotationNet. # Arguments middles: List, BiFPN layer output. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. num_anchors: List, number of combinations of anchor box's scale and aspect ratios. num_filters: Int, number of subnet filters. num_dims: Int, number of pose dimensions. + survival_rate: Float, used by drop connect. # Returns List: containing rotation estimates from every feature level. """ - num_filters = [num_filters, num_dims * num_anchors] bias_initializer = tf.zeros_initializer() + num_filters = [num_filters, num_dims * num_anchors] args = (subnet_repeats, num_filters, bias_initializer) - rotations = build_rotation_head(middles, *args) - return build_iterative_rotation_subnet(*rotations, subnet_iterations, - *args, num_dims) + initial_regressions = build_head(middles, *args, survival_rate, + normalization='group') + return refine_rotation_iteratively(*initial_regressions, subnet_iterations, + *args, num_dims) -def build_rotation_head(middles, subnet_repeats, num_filters, - bias_initializer, gn_groups=4, gn_axis=-1): - """Builds RotationNet head. +def refine_rotation_iteratively(rotation_features, initial_rotations, + subnet_iterations, subnet_repeats, + num_filters, bias_initializer, num_dims): + """Iteratively refines rotation. # Arguments - middles: List, BiFPN layer output. + rotation_features: List, containing features from rotation head. + initial_rotations: List, containing initial rotation values. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. num_filters: Int, number of subnet filters. bias_initializer: Callable, bias initializer. - gn_groups: Int, number of groups in group normalization. - gn_axis: Int, group normalization axis. + num_dims: Int, number of pose dimensions. # Returns - List: Containing rotation_features and initial_rotations. + rotations: List, containing final rotation values from every + feature level. """ - conv_blocks = build_head_conv2D(subnet_repeats, num_filters[0], - bias_initializer) - head_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - args = (conv_blocks, subnet_repeats, gn_groups, gn_axis) - rotation_features, initial_rotations = [], [] - for x in middles: - x = conv2D_norm_activation(x, *args) - initial_rotation = head_conv(x) - rotation_features.append(x) - initial_rotations.append(initial_rotation) - return [rotation_features, initial_rotations] + rotations = [] + iterator = zip(rotation_features, initial_rotations) + args = (subnet_repeats, num_filters, bias_initializer) + for rotation_feature, initial_rotation in iterator: + for _ in range(subnet_iterations): + x = Concatenate(axis=-1)([rotation_feature, initial_rotation]) + delta_rotation = refine_rotation(x, *args) + initial_rotation = Add()([initial_rotation, delta_rotation]) + rotation = Reshape((-1, num_dims))(initial_rotation) + rotations.append(rotation) + return rotations -def conv2D_norm_activation(x, conv_blocks, repeats, gn_groups, gn_axis): - """Builds group normalization blocks followed by activation. +def refine_rotation(x, repeats, num_filters, bias_initializer, + channels_per_group=16): + """Builds rotation refinement module. # Arguments x: Tensor, BiFPN layer output. - conv_blocks: List, containing convolutional blocks. repeats: Int, number of layers used in subnetworks. - gn_groups: Int, number of groups in group normalization. - gn_axis: Int, group normalization axis. + num_filters: Int, number of subnet filters. + bias_initializer: Callable, bias initializer. + channels_per_group: Int, number of channels per group + of Batchnormalization. # Returns - x: Tensor, after repeated convolution, + delta_rotation: Tensor, after repeated convolution, group normalization and activation. """ + conv_body = build_head_conv2D(repeats, num_filters[0], bias_initializer) + conv_head = build_head_conv2D(1, num_filters[1], bias_initializer)[0] + num_groups = int(num_filters[0] / channels_per_group) for block_arg in range(repeats): - x = conv_blocks[block_arg](x) - x = GroupNormalization(groups=gn_groups, axis=gn_axis)(x) + x = conv_body[block_arg](x) + x = GroupNormalization(groups=num_groups)(x) x = tf.nn.swish(x) - return x + return conv_head(x) -def build_iterative_rotation_subnet(rotation_features, initial_rotations, - subnet_iterations, subnet_repeats, - num_filters, bias_initializer, - num_dims, gn_groups=4, gn_axis=-1): - """Builds iterative rotation subnets. +def TranslationNet(middles, subnet_iterations, subnet_repeats, + num_anchors, num_filters): + """Initializes TranslationNet. # Arguments - rotation_features: List, containing features from rotation head. - initial_rotations: List, containing initial rotation values. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. + middles: List, BiFPN layer output. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. + num_anchors: List, number of combinations of anchor box's scale + and aspect ratios. num_filters: Int, number of subnet filters. - bias_initializer: Callable, bias initializer. - num_dims: Int, number of pose dimensions. - gn_groups: Int, number of groups in group normalization. - gn_axis: Int, group normalization axis. # Returns - rotations: List, containing final rotation values. + List: containing translation estimates from every feature level. """ - conv_blocks = build_head_conv2D(subnet_repeats - 1, num_filters[0], - bias_initializer) - head_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - args = (conv_blocks, subnet_repeats - 1, gn_groups, gn_axis) - rotations = [] - for x, initial_rotation in zip(rotation_features, initial_rotations): - for _ in range(subnet_iterations): - x = Concatenate(axis=-1)([x, initial_rotation]) - x = conv2D_norm_activation(x, *args) - delta_rotation = head_conv(x) - initial_rotation = Add()([initial_rotation, delta_rotation]) - rotation = Reshape((-1, num_dims))(initial_rotation) - rotations.append(rotation) - return rotations + bias_initializer = tf.zeros_initializer() + num_filters = [num_filters, num_anchors * 2, num_anchors] + args = (subnet_repeats, num_filters, bias_initializer) + initial_regressions = regress_initial_translations(middles, *args) + return refine_translation_iteratively(*initial_regressions, + *args, subnet_iterations) -def TranslationNet(middles, subnet_iterations, subnet_repeats, - num_anchors, num_filters): - """Initializes TranslationNet. +def regress_initial_translations(middles, subnet_repeats, num_filters, + bias_initializer): + """Builds TranslationNet head. # Arguments middles: List, BiFPN layer output. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. subnet_repeats: Int, number of layers used in subnetworks. - num_anchors: List, number of combinations of - anchor box's scale and aspect ratios. num_filters: Int, number of subnet filters. + bias_initializer: Callable, bias initializer. # Returns - List: containing translation estimates from every feature level. + List: Containing initial_features, initial_xy and initial_z. """ - num_filters = [num_filters, num_anchors * 2, num_anchors] - bias_initializer = tf.zeros_initializer() + initial_features, initial_xy, initial_z = [], [], [] args = (subnet_repeats, num_filters, bias_initializer) - translations = build_translation_head(middles, *args) - return build_iterative_translation_subnet(*translations, *args, - subnet_iterations) + for x in middles: + initial_translations = build_translation_subnets(x, *args) + x, initial_translation_xy, initial_translation_z = initial_translations + initial_features.append(x) + initial_xy.append(initial_translation_xy) + initial_z.append(initial_translation_z) + return [initial_features, initial_xy, initial_z] -def build_translation_head(middles, subnet_repeats, num_filters, - bias_initializer, gn_groups=4, gn_axis=-1): +def build_translation_subnets(x, repeats, num_filters, bias_initializer, + channels_per_group=16): """Builds TranslationNet head. # Arguments - middles: List, BiFPN layer output. - subnet_repeats: Int, number of layers used in subnetworks. + x: Tensor, BiFPN layer output. + repeats: Int, number of layers used in subnetworks. num_filters: Int, number of subnet filters. bias_initializer: Callable, bias initializer. - gn_groups: Int, number of groups in group normalization. - gn_axis: Int, group normalization axis. + channels_per_group: Int, number of channels per group + of Batchnormalization. # Returns - List: Containing translation_features, - translations_xy and translations_z. + List: Containing x, initial_xy and initial_z. """ - conv_blocks = build_head_conv2D(subnet_repeats, num_filters[0], - bias_initializer) - head_xy_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - head_z_conv = build_head_conv2D(1, num_filters[2], bias_initializer)[0] - args = (conv_blocks, subnet_repeats, gn_groups, gn_axis) - translation_features, translations_xy, translations_z = [], [], [] - for x in middles: - x = conv2D_norm_activation(x, *args) - translation_xy = head_xy_conv(x) - translation_z = head_z_conv(x) - translation_features.append(x) - translations_xy.append(translation_xy) - translations_z.append(translation_z) - return [translation_features, translations_xy, translations_z] - - -def build_iterative_translation_subnet(translation_features, translations_xy, - translations_z, subnet_repeats, - num_filters, bias_initializer, - subnet_iterations, gn_groups=4, - gn_axis=-1): - """Builds iterative translation subnets. + conv_body = build_head_conv2D(repeats, num_filters[0], bias_initializer) + conv_head_xy = build_head_conv2D(1, num_filters[1], bias_initializer)[0] + conv_head_z = build_head_conv2D(1, num_filters[2], bias_initializer)[0] + num_groups = int(num_filters[0] / channels_per_group) + for block_arg in range(repeats): + x = conv_body[block_arg](x) + x = GroupNormalization(groups=num_groups)(x) + x = tf.nn.swish(x) + return [x, conv_head_xy(x), conv_head_z(x)] + + +def refine_translation_iteratively(translation_features, translations_xy, + translations_z, subnet_repeats, num_filters, + bias_initializer, subnet_iterations): + """Refines translation iteratively. # Arguments - translation_features: List, containing - features from translation head. - translations_xy: List, containing translations - in XY directions from translation head. - translations_z: List, containing translations - in Z directions from translation head. + translation_features: List, containing features + from translation head. + translations_xy: List, containing translations in XY directions + from translation head. + translations_z: List, containing translations in Z directions + from translation head. subnet_repeats: Int, number of layers used in subnetworks. num_filters: Int, number of subnet filters. bias_initializer: Callable, bias initializer. - subnet_iterations: Int, number of iterative refinement - steps used in rotation and translation subnets. - gn_groups: Int, number of groups in group normalization. - gn_axis: Int, group normalization axis. + subnet_iterations: Int, number of iterative refinement steps + used in rotation and translation subnets. # Returns - translations: List, containing final translation values. + translations: List, containing final translation values + from every feature level. """ - conv_blocks = build_head_conv2D(subnet_repeats - 1, num_filters[0], - bias_initializer) - head_xy = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - head_z = build_head_conv2D(1, num_filters[2], bias_initializer)[0] - args = (conv_blocks, subnet_repeats - 1, gn_groups, gn_axis) translations = [] + args = (subnet_repeats, num_filters, bias_initializer) iterator = zip(translation_features, translations_xy, translations_z) - for x, translation_xy, translation_z in iterator: + for translation_feature, translation_xy, translation_z in iterator: for _ in range(subnet_iterations): - x = Concatenate(axis=-1)([x, translation_xy, translation_z]) - x = conv2D_norm_activation(x, *args) - delta_translation_xy = head_xy(x) - delta_translation_z = head_z(x) + x = Concatenate(axis=-1)([translation_feature, + translation_xy, translation_z]) + delta_translations = refine_translation(x, *args) + delta_translation_xy, delta_translation_z = delta_translations translation_xy = Add()([translation_xy, delta_translation_xy]) translation_z = Add()([translation_z, delta_translation_z]) translation_xy = Reshape((-1, 2))(translation_xy) @@ -248,3 +231,29 @@ def build_iterative_translation_subnet(translation_features, translations_xy, translation = Concatenate(axis=-1)([translation_xy, translation_z]) translations.append(translation) return translations + + +def refine_translation(x, repeats, num_filters, bias_initializer, + channels_per_group=16): + """Translation refinement module. + + # Arguments + x: Tensor, BiFPN layer output. + repeats: Int, number of layers used in subnetworks. + num_filters: Int, number of subnet filters. + bias_initializer: Callable, bias initializer. + channels_per_group: Int, number of channels per group + of Batchnormalization. + + # Returns + List: Containing delta_xy, and delta_z. + """ + conv_body = build_head_conv2D(repeats, num_filters[0], bias_initializer) + conv_head_xy = build_head_conv2D(1, num_filters[1], bias_initializer)[0] + conv_head_z = build_head_conv2D(1, num_filters[2], bias_initializer)[0] + num_groups = int(num_filters[0] / channels_per_group) + for block_arg in range(repeats): + x = conv_body[block_arg](x) + x = GroupNormalization(groups=num_groups)(x) + x = tf.nn.swish(x) + return [conv_head_xy(x), conv_head_z(x)] diff --git a/examples/efficientpose/evaluate_ADD.py b/examples/efficientpose/evaluate_ADD.py deleted file mode 100644 index e99eca5cf..000000000 --- a/examples/efficientpose/evaluate_ADD.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import numpy as np -from paz.backend.image import load_image -from scipy import spatial -from paz.backend.groups import quaternion_to_rotation_matrix - - -def transform_mesh_points(mesh_points, rotation, translation): - """Transforms the object points - - # Arguments - mesh_points: nx3 ndarray with 3D model points. - rotaion: Rotation matrix - translation: Translation vector - - # Returns - Transformed model - """ - assert (mesh_points.shape[1] == 3) - pts_t = rotation.dot(mesh_points.T) + translation.reshape((3, 1)) - return pts_t.T - - -def compute_ADD(true_pose, pred_pose, mesh_points): - """Calculate The ADD error. - - # Arguments - true_pose: Real pose - pred_pose: Predicted pose - mesh_pts: nx3 ndarray with 3D model points. - - # Returns - Return ADD error - """ - quaternion = pred_pose.quaternion - pred_translation = pred_pose.translation - pred_rotation = quaternion_to_rotation_matrix(quaternion) - pred_mesh = transform_mesh_points(mesh_points, pred_rotation, - pred_translation) - - true_rotation = true_pose[:3, :3] - true_translation = true_pose[:3, 3] - true_mesh = transform_mesh_points(mesh_points, true_rotation, - true_translation) - - error = np.linalg.norm(pred_mesh - true_mesh, axis=1).mean() - return error - - -def check_ADD(ADD_error, diameter, diameter_threshold=0.1): - """Check if ADD error is within the diameter's tolerance. - - # Arguments - ADD_error: Float, ADD error value. - diameter: Float, diameter of the object. - diameter_threshold: Float, thhreshold for diameter tolerance. - - # Returns - is_correct: Bool flag indicating if pose is correct. - """ - if ADD_error <= (diameter * diameter_threshold): - is_correct = True - else: - is_correct = False - return is_correct - - -def compute_ADI(true_pose, pred_pose, mesh_points): - """Calculate The ADI error. - - Calculate distances to the nearest neighbors from vertices in the - ground-truth pose to vertices in the estimated pose. - - # Arguments - true_pose: Real pose - pred_pose: Predicted pose - mesh_pts: nx3 ndarray with 3D model points. - - # Returns - Return ADI error - """ - - quaternion = pred_pose.quaternion - pred_translation = pred_pose.translation - pred_rotation = quaternion_to_rotation_matrix(quaternion) - - pred_mesh = transform_mesh_points(mesh_points, pred_rotation, - pred_translation) - - true_rotation = true_pose[:3, :3] - true_translation = true_pose[:3, 3] - true_mesh = transform_mesh_points(mesh_points, true_rotation, - true_translation) - nn_index = spatial.cKDTree(pred_mesh) - nn_dists, _ = nn_index.query(true_mesh, k=1) - - error = nn_dists.mean() - return error - - -class EvaluatePoseError: - """Callback for evaluating the pose error on ADD and ADI metric. - - # Arguments - experiment_path: String. Path in which the images will be saved. - images: List of numpy arrays of shape. - pipeline: Function that takes as input an element of ''images'' - and outputs a ''Dict'' with inferences. - mesh_points: nx3 ndarray with 3D model points. - topic: Key to the ''inferences'' dictionary containing as value the - drawn inferences. - verbose: Integer. If is bigger than 1 messages would be displayed. - """ - def __init__(self, experiment_path, evaluation_data_manager, pipeline, - mesh_points, object_diameter, topic='poses6D', verbose=1): - self.experiment_path = experiment_path - self.evaluation_data_manager = evaluation_data_manager - self.images = self._load_test_images() - self.gt_poses = self._load_gt_poses() - self.pipeline = pipeline - self.mesh_points = mesh_points - self.object_diameter = object_diameter - self.topic = topic - self.verbose = verbose - - def _load_test_images(self): - evaluation_data = self.evaluation_data_manager.load_data() - evaluation_images = [] - for evaluation_datum in evaluation_data: - evaluation_image = load_image(evaluation_datum['image']) - evaluation_images.append(evaluation_image) - return evaluation_images - - def _load_gt_poses(self): - evaluation_data = self.evaluation_data_manager.load_data() - gt_poses = [] - for evaluation_datum in evaluation_data: - rotation = evaluation_datum['rotation'] - rotation_matrix = rotation.reshape((3, 3)) - translation = evaluation_datum['translation_raw'] - gt_pose = np.concatenate((rotation_matrix, translation.T), axis=1) - gt_poses.append(gt_pose) - return gt_poses - - def on_epoch_end(self, epoch, logs=None): - sum_ADD = 0.0 - sum_ADI = 0.0 - sum_ADD_accuracy = 0 - valid_predictions = 0 - for image, gt_pose in zip(self.images, self.gt_poses): - inferences = self.pipeline(image.copy()) - pose6D = inferences[self.topic] - if pose6D: - add_error = compute_ADD(gt_pose, pose6D[0], self.mesh_points) - is_correct = check_ADD(add_error, self.object_diameter) - sum_ADD_accuracy = sum_ADD_accuracy + int(is_correct) - adi_error = compute_ADI(gt_pose, pose6D[0], self.mesh_points) - sum_ADD = sum_ADD + add_error - sum_ADI = sum_ADI + adi_error - valid_predictions = valid_predictions + 1 - - error_path = os.path.join(self.experiment_path, 'error.txt') - if valid_predictions > 0: - average_ADD = sum_ADD / valid_predictions - average_ADD_accuracy = sum_ADD_accuracy / len(self.gt_poses) - average_ADI = sum_ADI / valid_predictions - with open(error_path, 'a') as filer: - filer.write('epoch: %d\n' % epoch) - filer.write('Estimated ADD error: %f\n' % average_ADD) - filer.write( - 'Estimated ADD accuracy: %f\n\n' % average_ADD_accuracy) - filer.write('Estimated ADI error: %f\n\n' % average_ADI) - else: - average_ADD = None - average_ADI = None - average_ADD_accuracy = None - if self.verbose: - print('Estimated ADD error:', average_ADD) - print('Estimated ADD accuracy:', average_ADD_accuracy) - print('Estimated ADI error:', average_ADI) diff --git a/examples/efficientpose/linemod.py b/examples/efficientpose/linemod.py index aea31aab8..408f56463 100644 --- a/examples/efficientpose/linemod.py +++ b/examples/efficientpose/linemod.py @@ -15,10 +15,10 @@ class LINEMOD(Loader): e.g. `train`, `val` or `test` name: Str, or list indicating with dataset or datasets to load. e.g. ``VOC2007`` or ``[''VOC2007'', VOC2012]``. - evaluate: Bool, If ``True`` returned data will be loaded without - normalization for a direct evaluation. - image_size: Dict, containing keys 'width' and 'height' with - values equal to the input size of the model. + evaluate: Bool, If ``True`` returned data will be loaded + without normalization for a direct evaluation. + image_size: Dict, containing keys 'width' and 'height' + with values equal to the input size of the model. # Return data: List of dictionaries with keys corresponding to the image @@ -68,28 +68,28 @@ class LINEMODParser(object): """ Preprocess the LINEMOD yaml annotations data. # Arguments - object_id_to_class_arg: Dict, containing a mappning from objet - ID to class arg. + object_id_to_class_arg: Dict, containing a mapping + from object ID to class arg. dataset_name: Str, or list indicating with dataset or datasets to load. e.g. ``VOC2007`` or ``[''VOC2007'', VOC2012]``. split: Str, determining the data split to load. e.g. `train`, `val` or `test` dataset_path: Str, data path to LINEMOD annotations. - evaluate: Bool, If ``True`` returned data will be loaded without - normalization for a direct evaluation. + evaluate: Bool, If ``True`` returned data will be loaded + without normalization for a direct evaluation. object_id: Str, ID of the object to train. class_names: List of strings indicating class names. - image_size: Dict, containing keys 'width' and 'height' with - values equal to the input size of the model. - ground_truth_file: Str, name of the file containing - ground truths. + image_size: Dict, containing keys 'width' and 'height' + with values equal to the input size of the model. + ground_truth_file: Str, name of the file + containing ground truths. info_file: Str, name of the file containing info. data: Str, name of the directory containing object data. # Return data: Dict, with keys correspond to the image names and values - are numpy arrays for boxes, rotation, translation and - integer for class. + are numpy arrays for boxes, rotation, translation + and integer for class. """ def __init__(self, object_id_to_class_arg, dataset_name='LINEMOD', split='train', dataset_path='/Linemod_preprocessed/', @@ -176,6 +176,7 @@ def _preprocess_files(self): # Get mask path mask_path = (self.split_prefix + self.object_id + '/' + 'mask' + '/' + datum_file + '.png') + # Append class to box data box_data = np.concatenate( (box_data, np.array([[class_arg]])), axis=-1) diff --git a/examples/efficientpose/losses.py b/examples/efficientpose/losses.py index 2780d4667..35587421d 100644 --- a/examples/efficientpose/losses.py +++ b/examples/efficientpose/losses.py @@ -1,19 +1,42 @@ import tensorflow as tf -from tensorflow import keras import numpy as np from plyfile import PlyData -import math from pose import LINEMOD_CAMERA_MATRIX class MultiPoseLoss(object): + """Multi-pose loss for a single-shot 6D object pose estimation + architecture. + + # Arguments + object_id: Str, ID of object to train in LINEMOD dataset, + ex. powerdrill has an `object_id` of `08`. + translation_priors: Array of shape `(num_boxes, 3)`, + translation anchors. + data_path: Str, root directory of LINEMOD dataset. + target_num_points: Int,number of points of 3D model of object + to consider for loss calculation. + num_pose_dims: Int, number of pose dimensions. + model_path: Directory containing ply files of LINEMOD objects. + translation_scale_norm: Float, factor to change units. + EfficientPose internally works with meter and if the + dataset unit is mm for example, then this parameter + should be set to 1000. + + # References + - [EfficientPose: An efficient, accurate and scalable + end-to-end 6D multi object pose estimation approach]( + https://arxiv.org/abs/2011.04307) + - [EfficientPose](https://github.com/ybkscht/EfficientPose) + """ def __init__(self, object_id, translation_priors, data_path, target_num_points=500, num_pose_dims=3, model_path='models/', translation_scale_norm=1000): self.object_id = object_id self.translation_priors = translation_priors self.num_pose_dims = num_pose_dims - self.translation_scale_norm = translation_scale_norm + self.tz_scale = tf.convert_to_tensor(translation_scale_norm, + dtype=tf.float32) self.model_path = data_path + model_path + 'obj_' + object_id + '.ply' self.model_points = self._load_model_file() self.model_points = self._filter_model_points(target_num_points) @@ -22,8 +45,7 @@ def _load_model_file(self): model_data = PlyData.read(self.model_path) vertex = model_data['vertex'][:] vertices = [vertex['x'], vertex['y'], vertex['z']] - model_points = np.stack(vertices, axis=-1) - return model_points + return np.stack(vertices, axis=-1) def _filter_model_points(self, target_num_points): num_points = self.model_points.shape[0] @@ -42,89 +64,67 @@ def _filter_model_points(self, target_num_points): return tf.convert_to_tensor(points) def _compute_translation(self, translation_raw_pred, scale): + camera_matrix = tf.convert_to_tensor(LINEMOD_CAMERA_MATRIX) translation_pred = self._regress_translation(translation_raw_pred) - camera_parameter = self._compute_camera_parameter( - scale, LINEMOD_CAMERA_MATRIX) - translation_pred = self._compute_tx_ty(translation_pred, - camera_parameter) - return translation_pred + return self._compute_tx_ty_tz(translation_pred, camera_matrix, scale) def _regress_translation(self, translation_raw): stride = self.translation_priors[:, -1] x = self.translation_priors[:, 0] + (translation_raw[:, :, 0] * stride) y = self.translation_priors[:, 1] + (translation_raw[:, :, 1] * stride) + x, y = x[:, :, tf.newaxis], y[:, :, tf.newaxis] Tz = translation_raw[:, :, 2] - x = tf.expand_dims(x, axis=-1) - y = tf.expand_dims(y, axis=-1) - Tz = tf.expand_dims(Tz, axis=-1) - translations_predicted = tf.concat([x, y, Tz], axis=-1) - return translations_predicted - - def _compute_camera_parameter(self, image_scale, camera_matrix): - camera_parameter = tf.convert_to_tensor([camera_matrix[0, 0], - camera_matrix[1, 1], - camera_matrix[0, 2], - camera_matrix[1, 2], - self.translation_scale_norm, - image_scale]) - return camera_parameter - - def _compute_tx_ty(self, translation_xy_Tz, camera_parameter): - fx, fy = camera_parameter[0], camera_parameter[1], - px, py = camera_parameter[2], camera_parameter[3], - tz_scale, image_scale = camera_parameter[4], camera_parameter[5] - - x = translation_xy_Tz[:, :, 0] / image_scale - y = translation_xy_Tz[:, :, 1] / image_scale - tz = translation_xy_Tz[:, :, 2] * tz_scale + Tz = Tz[:, :, tf.newaxis] + return tf.concat([x, y, Tz], axis=-1) + def _compute_tx_ty_tz(self, translation_xy_Tz, camera_matrix, scale): + fx, fy = camera_matrix[0, 0], camera_matrix[1, 1] + px, py = camera_matrix[0, 2], camera_matrix[1, 2] + + x = translation_xy_Tz[:, :, 0] / scale + y = translation_xy_Tz[:, :, 1] / scale + tz = translation_xy_Tz[:, :, 2] * self.tz_scale x = x - px y = y - py tx = tf.math.multiply(x, tz) / fx ty = tf.math.multiply(y, tz) / fy - - tx = tf.expand_dims(tx, axis=-1) - ty = tf.expand_dims(ty, axis=-1) - tz = tf.expand_dims(tz, axis=-1) - - translations = tf.concat([tx, ty, tz], axis=-1) - return translations + tx, ty = tx[:, :, tf.newaxis], ty[:, :, tf.newaxis] + tz = tz[:, :, tf.newaxis] + return tf.concat([tx, ty, tz], axis=-1) def _separate_axis_from_angle(self, axis_angle): squared = tf.math.square(axis_angle) sum = tf.math.reduce_sum(squared, axis=-1) angle = tf.expand_dims(tf.math.sqrt(sum), axis=-1) axis = tf.math.divide_no_nan(axis_angle, angle) - return axis, angle + return [axis, angle] def _rotate(self, point, axis, angle): cos_angle = tf.cos(angle) axis_dot_point = self._dot(axis, point) - return (point * cos_angle + self._cross(axis, point) * - tf.sin(angle) + axis * axis_dot_point * (1.0 - cos_angle)) + return (point * cos_angle + self._cross(axis, point) * tf.sin(angle) + + axis * axis_dot_point * (1.0 - cos_angle)) def _dot(self, vector1, vector2, axis=-1, keepdims=True): return tf.reduce_sum(input_tensor=vector1 * vector2, axis=axis, keepdims=keepdims) def _cross(self, vector1, vector2): - vector1_x = vector1[:, :, 0] - vector1_y = vector1[:, :, 1] + vector1_x, vector1_y, = vector1[:, :, 0], vector1[:, :, 1] vector1_z = vector1[:, :, 2] - vector2_x = vector2[:, :, 0] - vector2_y = vector2[:, :, 1] + vector2_x, vector2_y = vector2[:, :, 0], vector2[:, :, 1] vector2_z = vector2[:, :, 2] n_x = vector1_y * vector2_z - vector1_z * vector2_y n_y = vector1_z * vector2_x - vector1_x * vector2_z n_z = vector1_x * vector2_y - vector1_y * vector2_x return tf.stack((n_x, n_y, n_z), axis=-1) - def _calc_sym_distances(self, sym_points_pred, sym_points_target): - sym_points_pred = tf.expand_dims(sym_points_pred, axis=2) - sym_points_target = tf.expand_dims(sym_points_target, axis=1) - distances = tf.reduce_min(tf.norm( - sym_points_pred - sym_points_target, axis=-1), axis=-1) + def _calc_sym_distances(self, sym_points_pred, sym_points_true): + sym_points_pred = sym_points_pred[:, :, tf.newaxis] + sym_points_true = sym_points_true[:, tf.newaxis] + distances = tf.reduce_min(tf.norm(sym_points_pred - sym_points_true, + axis=-1), axis=-1) return tf.reduce_mean(distances, axis=-1) def _calc_asym_distances(self, asym_points_pred, asym_points_target): @@ -132,71 +132,79 @@ def _calc_asym_distances(self, asym_points_pred, asym_points_target): return tf.reduce_mean(distances, axis=-1) def compute_loss(self, y_true, y_pred): + """Computes pose loss. + + # Arguments + y_true: Tensor of shape '[batch_size, num_boxes, 11]' + with correct labels. + y_pred: Tensor of shape '[batch_size, num_boxes, 6]' + with predicted inferences. + + # Returns + Tensor with loss per sample in batch. + """ + rotation_pred = y_pred[:, :, :self.num_pose_dims] rotation_true = y_true[:, :, :self.num_pose_dims] translation_true = y_true[:, :, 2 * self.num_pose_dims:2 * self.num_pose_dims + self.num_pose_dims] - - rotation_pred = y_pred[:, :, :self.num_pose_dims] - scale = y_true[0, 0, -1] translation_raw_pred = y_pred[:, :, self.num_pose_dims:] + scale = y_true[0, 0, -1] translation_pred = self._compute_translation(translation_raw_pred, scale) - is_symmetric = y_true[:, :, self.num_pose_dims] - class_indices = y_true[:, :, self.num_pose_dims + 1] anchor_flags = y_true[:, :, -2] anchor_state = tf.cast(tf.math.round(anchor_flags), tf.int32) - indices = tf.where(tf.equal(anchor_state, 1)) - rotation_pred = tf.gather_nd(rotation_pred, indices) * math.pi - translation_pred = tf.gather_nd(translation_pred, indices) - rotation_true = tf.gather_nd(rotation_true, indices) * math.pi + rotation_pred = tf.gather_nd(rotation_pred, indices) + rotation_pred = rotation_pred * np.pi + rotation_true = tf.gather_nd(rotation_true, indices) + rotation_true = rotation_true * np.pi + translation_pred = tf.gather_nd(translation_pred, indices) translation_true = tf.gather_nd(translation_true, indices) + is_symmetric = y_true[:, :, self.num_pose_dims] is_symmetric = tf.gather_nd(is_symmetric, indices) is_symmetric = tf.cast(tf.math.round(is_symmetric), tf.int32) + class_indices = y_true[:, :, self.num_pose_dims + 1] class_indices = tf.gather_nd(class_indices, indices) class_indices = tf.cast(tf.math.round(class_indices), tf.int32) axis_pred, angle_pred = self._separate_axis_from_angle(rotation_pred) - axis_target, angle_target = self._separate_axis_from_angle( - rotation_true) + axis_true, angle_true = self._separate_axis_from_angle(rotation_true) - axis_pred = tf.expand_dims(axis_pred, axis=1) - angle_pred = tf.expand_dims(angle_pred, axis=1) - axis_target = tf.expand_dims(axis_target, axis=1) - angle_target = tf.expand_dims(angle_target, axis=1) + axis_pred = axis_pred[:, tf.newaxis, :] + axis_true = axis_true[:, tf.newaxis, :] + angle_pred = angle_pred[:, tf.newaxis, :] + angle_true = angle_true[:, tf.newaxis, :] - translation_pred = tf.expand_dims(translation_pred, axis=1) - translation_true = tf.expand_dims(translation_true, axis=1) + translation_pred = translation_pred[:, tf.newaxis, :] + translation_true = translation_true[:, tf.newaxis, :] selected_model_points = tf.gather(self.model_points, class_indices, axis=0) transformed_points_pred = self._rotate( selected_model_points, axis_pred, angle_pred) + translation_pred - transformed_points_target = (self._rotate( - selected_model_points, axis_target, angle_target) + - translation_true) - - sym_indices = tf.where(keras.backend.equal(is_symmetric, 1)) - asym_indices = tf.where(keras.backend.not_equal(is_symmetric, 1)) + transformed_points_true = (self._rotate( + selected_model_points, axis_true, angle_true) + translation_true) num_points = selected_model_points.shape[1] + sym_indices = tf.where(tf.math.equal(is_symmetric, 1)) sym_points_pred = tf.reshape(tf.gather_nd( transformed_points_pred, sym_indices), (-1, num_points, 3)) + sym_points_true = tf.reshape(tf.gather_nd( + transformed_points_true, sym_indices), (-1, num_points, 3)) + + asym_indices = tf.where(tf.math.not_equal(is_symmetric, 1)) asym_points_pred = tf.reshape(tf.gather_nd( transformed_points_pred, asym_indices), (-1, num_points, 3)) + asym_points_true = tf.reshape(tf.gather_nd( + transformed_points_true, asym_indices), (-1, num_points, 3)) - sym_points_target = tf.reshape(tf.gather_nd( - transformed_points_target, sym_indices), (-1, num_points, 3)) - asym_points_target = tf.reshape(tf.gather_nd( - transformed_points_target, asym_indices), (-1, num_points, 3)) - - sym_distances = self._calc_sym_distances(sym_points_pred, - sym_points_target) - asym_distances = self._calc_asym_distances(asym_points_pred, - asym_points_target) + sym_distances = self._calc_sym_distances( + sym_points_pred, sym_points_true) + asym_distances = self._calc_asym_distances( + asym_points_pred, asym_points_true) distances = tf.concat([sym_distances, asym_distances], axis=0) loss = tf.math.reduce_mean(distances) diff --git a/examples/efficientpose/pose.py b/examples/efficientpose/pose.py index edfb556b5..19f4e6be2 100644 --- a/examples/efficientpose/pose.py +++ b/examples/efficientpose/pose.py @@ -6,11 +6,10 @@ from paz.pipelines.detection import PreprocessBoxes from efficientpose import EFFICIENTPOSEA from processors import (ComputeResizingShape, PadImage, ComputeCameraParameter, - RegressTranslation, ComputeTxTy, DrawPose6D, + RegressTranslation, ComputeTxTyTz, DrawPose6D, ComputeSelectedIndices, ScaleBoxes2D, ToPose6D, MatchPoses, TransformRotation, ConcatenatePoses, - ConcatenateScale, AugmentImageAndPose, - AugmentColorspace) + ConcatenateScale, Augment6DOF, AugmentColorspace) B_LINEMOD_MEAN, G_LINEMOD_MEAN, R_LINEMOD_MEAN = 103.53, 116.28, 123.675 @@ -65,6 +64,8 @@ class AugmentPose(SequentialProcessor): IOU: Float. Intersection over union used to match boxes. variances: List of two floats indicating variances to be encoded for encoding bounding boxes. + probability: Float indicating the probability + of data augmentation. num_pose_dims: Int, number of dimensions for pose. """ def __init__(self, model, split=pr.TRAIN, num_classes=8, size=512, @@ -73,10 +74,10 @@ def __init__(self, model, split=pr.TRAIN, num_classes=8, size=512, num_pose_dims=3): super(AugmentPose, self).__init__() self.augment_colorspace = AugmentColorspace() - self.augment_6DOF = AugmentImageAndPose( - probability=probability, input_size=size) - self.preprocess_image = EfficientPosePreprocess( - model, mean, camera_matrix) + self.augment_6DOF = Augment6DOF(probability=probability, + input_size=size) + self.preprocess_image = EfficientPosePreprocess(model, mean, + camera_matrix) # box processors self.scale_boxes = pr.ScaleBox() @@ -95,16 +96,16 @@ def __init__(self, model, split=pr.TRAIN, num_classes=8, size=512, self.add(pr.ControlMap(pr.LoadImage(), [5], [5])) if split == pr.TRAIN: self.add(pr.ControlMap(self.augment_colorspace, [0], [0])) - self.add(pr.ControlMap(self.augment_6DOF, - [0, 1, 2, 3, 5], [0, 1, 2, 3, 5])) + self.add(pr.ControlMap(self.augment_6DOF, [0, 1, 2, 3, 5], + [0, 1, 2, 3, 5])) self.add(pr.ControlMap(self.preprocess_image, [0], [0, 1, 2])) self.add(pr.ControlMap(self.scale_boxes, [3, 1], [3], keep={1: 1})) self.add(pr.ControlMap(self.preprocess_boxes, [4], [5], keep={4: 4})) self.add(pr.ControlMap(TransformRotation(num_pose_dims), [3], [3])) self.add(pr.ControlMap(self.match_poses, [4, 3], [3], keep={4: 4})) self.add(pr.ControlMap(self.match_poses, [4, 5], [7], keep={4: 4})) - self.add(pr.ControlMap(self.concatenate_poses, - [3, 8], [8], keep={3: 3})) + self.add(pr.ControlMap(self.concatenate_poses, [3, 8], [8], + keep={3: 3})) self.add(pr.ControlMap(self.concatenate_scale, [8, 1], [8])) self.add(pr.SequenceWrapper( {0: {'image': [size, size, 3]}}, @@ -127,10 +128,10 @@ class DetectAndEstimatePose(Processor): preprocess: Callable, preprocessing pipeline. postprocess: Callable, postprocessing pipeline. variances: List of float values. - show_boxes2D: Boolean. If ``True`` prediction - are drawn in the returned image. - show_poses6D: Boolean. If ``True`` estimated poses - are drawn in the returned image. + show_boxes2D: Boolean. If ``True`` prediction are drawn + in the returned image. + show_poses6D: Boolean. If ``True`` estimated poses are drawn + in the returned image. # Properties model: Keras model. @@ -196,7 +197,6 @@ def call(self, image): outputs, image_scale, camera_parameter) if self.show_boxes2D: image = self.draw_boxes2D(image, boxes2D) - if self.show_poses6D: self.draw_pose6D = self._build_draw_pose6D( self.class_to_sizes, self.camera_matrix) @@ -251,8 +251,6 @@ class EfficientPosePreprocess(Processor): # Arguments model: Keras model. mean: Tuple, containing mean per channel on ImageNet. - standard_deviation: Tuple, containing standard deviations - per channel on ImageNet. camera_matrix: Array of shape `(3, 3)` camera matrix. translation_scale_norm: Float, factor to change units. EfficientPose internally works with meter and if the @@ -313,7 +311,7 @@ def __init__(self, model, class_names, score_thresh, nms_thresh, self.to_boxes2D = pr.ToBoxes2D(class_names) self.round_boxes = pr.RoundBoxes2D() self.regress_translation = RegressTranslation(model.translation_priors) - self.compute_tx_ty = ComputeTxTy() + self.compute_tx_ty_tz = ComputeTxTyTz() self.compute_selections = ComputeSelectedIndices() self.squeeze = pr.Squeeze(axis=0) self.transform_rotations = pr.Scale(np.pi) @@ -327,7 +325,6 @@ def call(self, model_output, image_scale, camera_parameter): box_data = self.postprocess_2(box_data) boxes2D = self.to_boxes2D(box_data) boxes2D = self.round_boxes(boxes2D) - rotations = transformations[:, :, :self.num_pose_dims] translations = transformations[:, :, self.num_pose_dims:] poses6D = [] @@ -336,12 +333,10 @@ def call(self, model_output, image_scale, camera_parameter): rotations = self.squeeze(rotations) rotations = rotations[selected_indices] rotations = self.transform_rotations(rotations) - translation_xy_Tz = self.regress_translation(translations) - translation = self.compute_tx_ty(translation_xy_Tz, - camera_parameter) + translation = self.compute_tx_ty_tz(translation_xy_Tz, + camera_parameter) translations = translation[selected_indices] - poses6D = self.to_pose_6D(box_data, rotations, translations) return boxes2D, poses6D @@ -376,7 +371,7 @@ def __init__(self, model, class_names, score_thresh, nms_thresh, self.round_boxes = pr.RoundBoxes2D() self.denormalize = pr.DenormalizeBoxes2D() self.regress_translation = RegressTranslation(model.translation_priors) - self.compute_tx_ty = ComputeTxTy() + self.compute_tx_ty_tz = ComputeTxTyTz() self.compute_selections = ComputeSelectedIndices() self.squeeze = pr.Squeeze(axis=0) self.transform_rotations = pr.Scale(np.pi) @@ -391,7 +386,6 @@ def call(self, image, model_output, image_scale, camera_parameter): boxes2D = self.denormalize(image, boxes2D) boxes2D = self.scale_boxes2D(boxes2D, 1 / image_scale) boxes2D = self.round_boxes(boxes2D) - rotations = transformations[:, :, :self.num_pose_dims] translations = transformations[:, :, self.num_pose_dims:] poses6D = [] @@ -400,12 +394,10 @@ def call(self, image, model_output, image_scale, camera_parameter): rotations = self.squeeze(rotations) rotations = rotations[selected_indices] rotations = self.transform_rotations(rotations) - translation_xy_Tz = self.regress_translation(translations) - translation = self.compute_tx_ty(translation_xy_Tz, - camera_parameter) + translation = self.compute_tx_ty_tz(translation_xy_Tz, + camera_parameter) translations = translation[selected_indices] - poses6D = self.to_pose_6D(box_data, rotations, translations) return boxes2D, poses6D @@ -491,7 +483,6 @@ def call(self, image): preprocessed_image[0], outputs, image_scale, camera_parameter) if self.show_boxes2D: image = self.draw_boxes2D(image, boxes2D) - if self.show_poses6D: self.draw_pose6D = self._build_draw_pose6D( self.class_to_sizes, self.camera_matrix) @@ -511,16 +502,16 @@ class EFFICIENTPOSEALINEMOD(DetectAndEstimatePose): show_poses6D: Boolean. If ``True`` estimated poses are drawn in the returned image. - # References - [ybkscht repository implementation of EfficientPose]( - https://github.com/ybkscht/EfficientPose) + # References + [EfficientPose: An efficient, accurate and scalable end-to-end + 6D multi object pose estimation approach]( + https://arxiv.org/pdf/2011.04307.pdf) """ def __init__(self, score_thresh=0.60, nms_thresh=0.45, show_boxes2D=False, show_poses6D=True): names = get_class_names('LINEMOD_EFFICIENTPOSE') model = EFFICIENTPOSEA(num_classes=len(names), base_weights='COCO', - head_weights='LINEMOD_OCCLUDED', momentum=0.997, - epsilon=0.0001, activation='sigmoid') + head_weights='LINEMOD_OCCLUDED') super(EFFICIENTPOSEALINEMOD, self).__init__( model, names, score_thresh, nms_thresh, LINEMOD_CAMERA_MATRIX, LINEMOD_OBJECT_SIZES, @@ -538,16 +529,16 @@ class EFFICIENTPOSEALINEMODDRILLER(DetectAndEstimateEfficientPose): show_poses6D: Boolean. If ``True`` estimated poses are drawn in the returned image. - # References - [ybkscht repository implementation of EfficientPose]( - https://github.com/ybkscht/EfficientPose) + # References + [EfficientPose: An efficient, accurate and scalable end-to-end + 6D multi object pose estimation approach]( + https://arxiv.org/pdf/2011.04307.pdf) """ def __init__(self, score_thresh=0.60, nms_thresh=0.45, show_boxes2D=False, show_poses6D=True): names = get_class_names('LINEMOD_EFFICIENTPOSE_DRILLER') model = EFFICIENTPOSEA(num_classes=len(names), base_weights='COCO', - head_weights=None, momentum=0.99, - epsilon=0.001, activation='softmax') + head_weights=None) super(EFFICIENTPOSEALINEMODDRILLER, self).__init__( model, names, score_thresh, nms_thresh, LINEMOD_CAMERA_MATRIX, LINEMOD_OBJECT_SIZES, diff --git a/examples/efficientpose/pose_error.py b/examples/efficientpose/pose_error.py index 592a9a15b..b965fc932 100644 --- a/examples/efficientpose/pose_error.py +++ b/examples/efficientpose/pose_error.py @@ -7,7 +7,8 @@ def transform_mesh_points(mesh_points, rotation, translation): - """Transforms the object points + """Transforms object points + # Arguments mesh_points: nx3 ndarray with 3D model points. rotaion: Rotation matrix @@ -22,11 +23,13 @@ def transform_mesh_points(mesh_points, rotation, translation): def compute_ADD(true_pose, pred_pose, mesh_points): - """Calculate The ADD error. + """Calculates ADD error. + # Arguments true_pose: Real pose pred_pose: Predicted pose mesh_pts: nx3 ndarray with 3D model points. + # Returns Return ADD error """ @@ -35,7 +38,6 @@ def compute_ADD(true_pose, pred_pose, mesh_points): pred_rotation = quaternion_to_rotation_matrix(quaternion) pred_mesh = transform_mesh_points(mesh_points, pred_rotation, pred_translation) - true_rotation = true_pose[:3, :3] true_translation = true_pose[:3, 3] true_mesh = transform_mesh_points(mesh_points, true_rotation, @@ -68,17 +70,14 @@ def compute_ADI(true_pose, pred_pose, mesh_points): quaternion = pred_pose.quaternion pred_translation = pred_pose.translation pred_rotation = quaternion_to_rotation_matrix(quaternion) - pred_mesh = transform_mesh_points(mesh_points, pred_rotation, pred_translation) - true_rotation = true_pose[:3, :3] true_translation = true_pose[:3, 3] true_mesh = transform_mesh_points(mesh_points, true_rotation, true_translation) nn_index = spatial.cKDTree(pred_mesh) nn_dists, _ = nn_index.query(true_mesh, k=1) - error = nn_dists.mean() return error @@ -92,12 +91,14 @@ class EvaluatePoseError(Callback): pipeline: Function that takes as input an element of ''images'' and outputs a ''Dict'' with inferences. mesh_points: nx3 ndarray with 3D model points. - topic: Key to the ''inferences'' dictionary containing as value the - drawn inferences. - verbose: Integer. If is bigger than 1 messages would be displayed. + topic: Key to the ''inferences'' dictionary containing as value + the drawn inferences. + verbose: Integer. If is bigger than 1 + messages would be displayed. """ def __init__(self, experiment_path, evaluation_data_manager, pipeline, - mesh_points, object_diameter, topic='poses6D', verbose=1): + mesh_points, object_diameter, evaluation_period, + topic='poses6D', verbose=1): self.experiment_path = experiment_path self.evaluation_data_manager = evaluation_data_manager self.images = self._load_test_images() @@ -105,6 +106,7 @@ def __init__(self, experiment_path, evaluation_data_manager, pipeline, self.pipeline = pipeline self.mesh_points = mesh_points self.object_diameter = object_diameter + self.evaluation_period = evaluation_period self.topic = topic self.verbose = verbose @@ -128,38 +130,41 @@ def _load_gt_poses(self): return gt_poses def on_epoch_end(self, epoch, logs=None): - sum_ADD = 0.0 - sum_ADI = 0.0 - sum_ADD_accuracy = 0.0 - valid_predictions = 0 - for image, gt_pose in zip(self.images, self.gt_poses): - inferences = self.pipeline(image.copy()) - pose6D = inferences[self.topic] - if pose6D: - add_error = compute_ADD(gt_pose, pose6D[0], self.mesh_points) - is_correct = check_ADD(add_error, self.object_diameter) - sum_ADD_accuracy = sum_ADD_accuracy + float(is_correct) - adi_error = compute_ADI(gt_pose, pose6D[0], self.mesh_points) - sum_ADD = sum_ADD + add_error - sum_ADI = sum_ADI + adi_error - valid_predictions = valid_predictions + 1 - - error_path = os.path.join(self.experiment_path, 'error.txt') - if valid_predictions > 0: - average_ADD = sum_ADD / valid_predictions - average_ADD_accuracy = sum_ADD_accuracy / len(self.gt_poses) - average_ADI = sum_ADI / valid_predictions - with open(error_path, 'a') as filer: - filer.write('epoch: %d\n' % epoch) - filer.write('Estimated ADD error: %f\n' % average_ADD) - filer.write( - 'Estimated ADD accuracy: %f\n\n' % average_ADD_accuracy) - filer.write('Estimated ADI error: %f\n\n' % average_ADI) - else: - average_ADD = None - average_ADI = None - average_ADD_accuracy = None - if self.verbose: - print('Estimated ADD error:', average_ADD) - print('Estimated ADD accuracy:', average_ADD_accuracy) - print('Estimated ADI error:', average_ADI) + if (epoch + 1) % self.evaluation_period == 0: + sum_ADD = 0.0 + sum_ADI = 0.0 + sum_ADD_accuracy = 0.0 + valid_predictions = 0 + for image, gt_pose in zip(self.images, self.gt_poses): + inferences = self.pipeline(image.copy()) + pose6D = inferences[self.topic] + if pose6D: + add_error = compute_ADD(gt_pose, pose6D[0], + self.mesh_points) + is_correct = check_ADD(add_error, self.object_diameter) + sum_ADD_accuracy = sum_ADD_accuracy + float(is_correct) + adi_error = compute_ADI(gt_pose, pose6D[0], + self.mesh_points) + sum_ADD = sum_ADD + add_error + sum_ADI = sum_ADI + adi_error + valid_predictions = valid_predictions + 1 + + error_path = os.path.join(self.experiment_path, 'error.txt') + if valid_predictions > 0: + average_ADD = sum_ADD / valid_predictions + average_ADD_accuracy = sum_ADD_accuracy / len(self.gt_poses) + average_ADI = sum_ADI / valid_predictions + with open(error_path, 'a') as filer: + filer.write('epoch: %d\n' % epoch) + filer.write('Estimated ADD error: %f\n' % average_ADD) + filer.write(('Estimated ADD accuracy: %f\n\n' % + average_ADD_accuracy)) + filer.write('Estimated ADI error: %f\n\n' % average_ADI) + else: + average_ADD = None + average_ADI = None + average_ADD_accuracy = None + if self.verbose: + print('Estimated ADD error:', average_ADD) + print('Estimated ADD accuracy:', average_ADD_accuracy) + print('Estimated ADI error:', average_ADI) diff --git a/examples/efficientpose/processors.py b/examples/efficientpose/processors.py index 10938abca..4032b1ebc 100644 --- a/examples/efficientpose/processors.py +++ b/examples/efficientpose/processors.py @@ -15,7 +15,7 @@ class ComputeResizingShape(Processor): """Computes the final size of the image to be scaled by `size` - such that the maximum dimension of the image is equal to `size`. + such that the largest dimension of the image is equal to `size`. # Arguments size: Int, final size of maximum dimension of the image. @@ -29,12 +29,22 @@ def call(self, image): def compute_resizing_shape(image, size): + """Computes the final size of the image to be scaled by `size` + such that the largest dimension of the image is equal to `size`. + + # Arguments + image: Array, raw image to be scaled. + size: Int, final size of the image. + + # Returns + List: Containing final shape of image and scale. + """ H, W = image.shape[:2] image_scale = size / max(H, W) resizing_W = int(W * image_scale) resizing_H = int(H * image_scale) resizing_shape = (resizing_W, resizing_H) - return resizing_shape, np.array(image_scale) + return [resizing_shape, np.array(image_scale)] class PadImage(Processor): @@ -54,12 +64,21 @@ def call(self, image): def pad_image(image, size, mode): + """Pads the image to the final size `size`. + + # Arguments + image: Array, image to be padded. + size: Int, final size of the image. + mode: Str, specifying the type of padding. + + # Returns + Array: Padded image. + """ H, W = image.shape[:2] pad_H = size - H pad_W = size - W pad_shape = [(0, pad_H), (0, pad_W), (0, 0)] - image = np.pad(image, pad_shape, mode=mode) - return image + return np.pad(image, pad_shape, mode=mode) class ComputeCameraParameter(Processor): @@ -85,18 +104,28 @@ def call(self, image_scale): def compute_camera_parameter(image_scale, camera_matrix, translation_scale_norm): - camera_parameter = np.array([camera_matrix[0, 0], - camera_matrix[1, 1], - camera_matrix[0, 2], - camera_matrix[1, 2], - translation_scale_norm, - image_scale]) - return camera_parameter + """Computes camera parameter given camera matrix + and scale normalization factor of translation. + + # Arguments + image_scale: Array, scale of image. + camera_matrix: Array, Camera matrix. + translation_scale_norm: Float, factor to change units. + EfficientPose internally works with meter and if the + dataset unit is mm for example, then this parameter + should be set to 1000. + + # Returns + Array: of shape `(6,)` Camera parameter. + """ + return np.array([camera_matrix[0, 0], camera_matrix[1, 1], + camera_matrix[0, 2], camera_matrix[1, 2], + translation_scale_norm, image_scale]) class RegressTranslation(Processor): - """Applies regression offset values to translation - anchors to get the 2D translation center-point and Tz. + """Applies regression offset values to translation anchors + to get the 2D translation center-point and Tz. # Arguments translation_priors: Array of shape `(num_boxes, 3)`, @@ -111,26 +140,46 @@ def call(self, translation_raw): def regress_translation(translation_raw, translation_priors): + """Applies regression offset values to translation anchors + to get the 2D translation center-point and Tz. + + # Arguments + translation_raw: Array of shape `(1, num_boxes, 3)`, + translation_priors: Array of shape `(num_boxes, 3)`, + translation anchors. + + # Returns + Array: of shape `(num_boxes, 3)`. + """ stride = translation_priors[:, -1] x = translation_priors[:, 0] + (translation_raw[:, :, 0] * stride) y = translation_priors[:, 1] + (translation_raw[:, :, 1] * stride) Tz = translation_raw[:, :, 2] - translations_predicted = np.concatenate((x, y, Tz), axis=0) - return translations_predicted.T + return np.concatenate((x, y, Tz), axis=0).T -class ComputeTxTy(Processor): +class ComputeTxTyTz(Processor): """Computes the Tx and Ty components of the translation vector with a given 2D-point and the intrinsic camera parameters. """ def __init__(self): - super(ComputeTxTy, self).__init__() + super(ComputeTxTyTz, self).__init__() def call(self, translation_xy_Tz, camera_parameter): - return compute_tx_ty(translation_xy_Tz, camera_parameter) + return compute_tx_ty_tz(translation_xy_Tz, camera_parameter) -def compute_tx_ty(translation_xy_Tz, camera_parameter): +def compute_tx_ty_tz(translation_xy_Tz, camera_parameter): + """Computes Tx, Ty and Tz components of the translation vector + with a given 2D-point and the intrinsic camera parameters. + + # Arguments + translation_xy_Tz: Array of shape `(num_boxes, 3)`, + camera_parameter: Array: of shape `(6,)` camera parameter. + + # Returns + Array: of shape `(num_boxes, 3)`. + """ fx, fy = camera_parameter[0], camera_parameter[1], px, py = camera_parameter[2], camera_parameter[3], tz_scale, image_scale = camera_parameter[4], camera_parameter[5] @@ -141,14 +190,10 @@ def compute_tx_ty(translation_xy_Tz, camera_parameter): x = x - px y = y - py - tx = np.multiply(x, tz) / fx ty = np.multiply(y, tz) / fy - tx, ty, tz = tx[np.newaxis, :], ty[np.newaxis, :], tz[np.newaxis, :] - - translations = np.concatenate((tx, ty, tz), axis=0) - return translations.T + return np.concatenate((tx, ty, tz), axis=0).T class ComputeSelectedIndices(Processor): @@ -163,9 +208,18 @@ def call(self, box_data_raw, box_data): def compute_selected_indices(box_data_all, box_data): + """Computes row-wise intersection between two given + arrays and returns the indices of the intersections. + + # Arguments + box_data_all: Array of shape `(num_boxes, 5)`, + box_data: Array: of shape `(n, 5)` box data. + + # Returns + Array: of shape `(n, 3)`. + """ box_data_all_tuple = [tuple(row) for row in box_data_all[:, :4]] box_data_tuple = [tuple(row) for row in box_data[:, :4]] - location_indices = [] for tuple_element in box_data_tuple: location_index = box_data_all_tuple.index(tuple_element) @@ -178,8 +232,8 @@ class ToPose6D(Processor): translations into `Pose6D` messages. # Arguments - class_names: List of class names ordered with respect to the - class indices from the dataset ``boxes``. + class_names: List of class names ordered with respect + to the class indices from the dataset ``boxes``. one_hot_encoded: Bool, indicating if scores are one hot vectors. default_score: Float, score to set. default_class: Str, class to set. @@ -325,8 +379,8 @@ def draw_pose6D(image, pose6D, points3D, intrinsics, thickness, color): # Arguments image: Array (H, W). pose6D: paz.abstract.Pose6D instance. - intrinsics: Array (3, 3). Camera intrinsics for projecting - 3D rays into 2D image. + intrinsics: Array (3, 3). Camera intrinsics + for projecting 3D rays into 2D image. points3D: Array (num_points, 3). thickness: Positive integer indicating line thickness. color: List, the color to draw 3D bounding boxes. @@ -357,11 +411,21 @@ def __init__(self, prior_boxes, iou=.5): super(MatchPoses, self).__init__() def call(self, boxes, poses): - matched_poses = match_poses(boxes, poses, self.prior_boxes, self.iou) - return matched_poses + return match_poses(boxes, poses, self.prior_boxes, self.iou) def match_poses(boxes, poses, prior_boxes, iou_threshold): + """Match prior boxes with poses with ground truth boxes and poses. + + # Arguments + boxes: Array of shape `(n, 5)`. + poses: Array of shape `(n, 5)`. + prior_boxes: Array of shape `(num_boxes, 4)`. + iou_threshold: Floats, IOU threshold value. + + # Returns + matched_poses: Array of shape `(num_boxes, 6)`. + """ matched_poses = np.zeros((prior_boxes.shape[0], poses.shape[1] + 1)) ious = compute_ious(boxes, to_corner_form(np.float32(prior_boxes))) per_prior_which_box_iou = np.max(ious, axis=0) @@ -371,7 +435,6 @@ def match_poses(boxes, poses, prior_boxes, iou_threshold): for box_arg in range(len(per_box_which_prior_arg)): best_prior_box_arg = per_box_which_prior_arg[box_arg] per_prior_which_box_arg[best_prior_box_arg] = box_arg - matched_poses[:, :-1] = poses[per_prior_which_box_arg] matched_poses[per_prior_which_box_iou >= iou_threshold, -1] = 1 return matched_poses @@ -384,20 +447,27 @@ class TransformRotation(Processor): num_pose_dims: Int, number of dimensions of pose. # Returns: - transformed_rotations: Array of shape (5,) containing the - transformed rotation. + transformed_rotations: Array of shape (5,) + containing transformed rotation. """ def __init__(self, num_pose_dims): self.num_pose_dims = num_pose_dims super(TransformRotation, self).__init__() def call(self, rotations): - transformed_rotations = transform_rotation(rotations, - self.num_pose_dims) - return transformed_rotations + return transform_rotation(rotations, self.num_pose_dims) def transform_rotation(rotations, num_pose_dims): + """Computes axis angle rotation vector from a rotation matrix. + + # Arguments: + rotation: Array, of shape `(n, 9)`. + num_pose_dims: Int, number of pose dimensions. + + # Returns: + Array: of shape (n, 5) containing axis angle vector. + """ final_axis_angles = [] for rotation in rotations: final_axis_angle = np.zeros((num_pose_dims + 2)) @@ -407,26 +477,33 @@ def transform_rotation(rotations, num_pose_dims): final_axis_angle[:3] = axis_angle final_axis_angle = np.expand_dims(final_axis_angle, axis=0) final_axis_angles.append(final_axis_angle) - final_axis_angles = np.concatenate(final_axis_angles, axis=0) - return final_axis_angles + return np.concatenate(final_axis_angles, axis=0) class ConcatenatePoses(Processor): """Concatenates rotations and translations into a single array. # Returns: - poses_combined: Array of shape `(num_prior_boxes, 10)` + poses_combined: Array of shape `(num_boxes, 10)` containing the transformed rotation. """ def __init__(self): super(ConcatenatePoses, self).__init__() def call(self, rotations, translations): - poses_combined = concatenate_poses(rotations, translations) - return poses_combined + return concatenate_poses(rotations, translations) def concatenate_poses(rotations, translations): + """Concatenates rotations and translations into a single array. + + # Arguments: + rotations: Array, of shape `(num_boxes, 6)`. + translations: Array, of shape `(num_boxes, 4)`. + + # Returns: + Array: of shape (num_boxes, 10) + """ return np.concatenate((rotations, translations), axis=-1) @@ -441,15 +518,22 @@ def __init__(self): super(ConcatenateScale, self).__init__() def call(self, poses, scale): - poses_combined = concatenate_scale(poses, scale) - return poses_combined + return concatenate_scale(poses, scale) def concatenate_scale(poses, scale): + """Concatenates poses and scale into a single array. + + # Arguments: + poses: Array, of shape `(num_boxes, 10)`. + scale: Array, of shape `()`. + + # Returns: + Array: of shape (num_boxes, 11) + """ scale = np.repeat(scale, poses.shape[0]) scale = scale[np.newaxis, :] - poses = np.concatenate((poses, scale.T), axis=1) - return poses + return np.concatenate((poses, scale.T), axis=1) class ScaleBoxes2D(Processor): @@ -462,17 +546,25 @@ def __init__(self): super(ScaleBoxes2D, self).__init__() def call(self, boxes2D, scale): - boxes2D = scale_boxes2D(boxes2D, scale) - return boxes2D + return scale_boxes2D(boxes2D, scale) def scale_boxes2D(boxes2D, scale): + """Scales coordinates of Boxes2D. + + # Arguments: + boxes2D: List, of Box2D objects. + scale: Foat, scale value. + + # Returns: + boxes2D: List, of Box2D objects with scale coordinates. + """ for box2D in boxes2D: box2D.coordinates = tuple(np.array(box2D.coordinates) * scale) return boxes2D -class AugmentImageAndPose(Processor): +class Augment6DOF(Processor): """Augment images, boxes, rotation and translation vector for pose estimation. @@ -495,11 +587,11 @@ def __init__(self, scale_min=0.7, scale_max=1.3, angle_min=0, self.probability = probability self.mask_value = mask_value self.input_size = input_size - super(AugmentImageAndPose, self).__init__() + super(Augment6DOF, self).__init__() def call(self, image, boxes, rotation, translation_raw, mask): if np.random.rand() < self.probability: - augmented_data = augment_image_and_pose( + augmented_data = augment_6DOF( image, boxes, rotation, translation_raw, mask, self.scale_min, self.scale_max, self.angle_min, self.angle_max, self.mask_value, self.input_size) @@ -508,16 +600,35 @@ def call(self, image, boxes, rotation, translation_raw, mask): return augmented_data -def augment_image_and_pose(image, boxes, rotation, translation_raw, mask, - scale_min, scale_max, angle_min, angle_max, - mask_value, input_size): +def augment_6DOF(image, boxes, rotation, translation_raw, mask, + scale_min, scale_max, angle_min, angle_max, + mask_value, input_size): + """Performs 6 degree of freedom augmentation of image + and its corresponding poses. + + # Arguments + image: Array raw image. + boxes: Array of shape `(n, 5)` + rotation: Array of shape `(n, 9)` + translation_raw: Array of shape `(n, 3)` + mask: Array mask corresponding to raw image. + scale_min: Float, minimum value to scale image. + scale_max: Float, maximum value to scale image. + angle_min: Int, minimum degree to rotate image. + angle_max: Int, maximum degree to rotate image. + mask_value: Int, pixel gray value of foreground in mask image. + input_size: Int, input image size of the model. + + # Returns: + List: Containing augmented_image, augmented_boxes, + augmented_rotation, augmented_translation, augmented_mask + """ transformation, angle, scale = generate_random_transformation( scale_min, scale_max, angle_min, angle_max) augmented_image = apply_transformation( image, transformation, cv2.INTER_CUBIC) augmented_mask = apply_transformation( mask, transformation, cv2.INTER_NEAREST) - num_annotations = boxes.shape[0] augmented_boxes, is_valid = [], [] rotation_vector = compute_rotation_vector(angle) @@ -558,6 +669,17 @@ def augment_image_and_pose(image, boxes, rotation, translation_raw, mask, def generate_random_transformation(scale_min, scale_max, angle_min, angle_max): + """Generates random affine transformation matrix. + + # Arguments + scale_min: Float, minimum value to scale image. + scale_max: Float, maximum value to scale image. + angle_min: Int, minimum degree to rotate image. + angle_max: Int, maximum degree to rotate image. + + # Returns: + List: Containing transformation matrix, angle, scale + """ cx = LINEMOD_CAMERA_MATRIX[0, 2] cy = LINEMOD_CAMERA_MATRIX[1, 2] angle = np.random.uniform(angle_min, angle_max) @@ -566,37 +688,92 @@ def generate_random_transformation(scale_min, scale_max, def apply_transformation(image, transformation, interpolation): + """Applies random affine to raw image. + + # Arguments + image: Array raw image. + transformation: Array of shape `(2, 3)`. + interpolation: Int, type of pixel interpolation. + + # Returns: + Array: of affine transformed image. + """ H, W, _ = image.shape return cv2.warpAffine(image, transformation, (W, H), flags=interpolation) def compute_box_from_mask(mask, mask_value): - segmentation = np.where(mask == mask_value) - segmentation_x, segmentation_y = segmentation[1], segmentation[0] - if segmentation_x.size <= 0 or segmentation_y.size <= 0: + """Computes bounding box from mask image. + + # Arguments + mask: Array mask corresponding to raw image. + mask_value: Int, pixel gray value of foreground in mask image. + + # Returns: + box: List containing box coordinates. + """ + masked = np.where(mask == mask_value) + mask_x, mask_y = masked[1], masked[0] + if mask_x.size <= 0 or mask_y.size <= 0: box = [0, 0, 0, 0] else: - x_min, y_min = np.min(segmentation_x), np.min(segmentation_y) - x_max, y_max = np.max(segmentation_x), np.max(segmentation_y) + x_min, y_min = np.min(mask_x), np.min(mask_y) + x_max, y_max = np.max(mask_x), np.max(mask_y) box = [x_min, y_min, x_max, y_max] return box def compute_rotation_vector(angle): + """Computes rotation vector that results from rotation + by angle `angle` around Z axis. + + # Arguments + angle: Float, angle of rotation in degree. + + # Returns: + rotation_vector: Array of shape `(3, )` + """ rotation_vector = np.zeros((3, )) rotation_vector[2] = angle / 180 * np.pi return rotation_vector def transform_rotation_matrix(rotation_matrix, transformation): + """Computes augmented rotation matrix. + + # Arguments + rotation_matrix: Array, of shape `(3, 3)`. + transformation: Array, of shape `(3, 3)`. + + # Returns: + Array: of shape `(3, 3)` + """ return np.dot(transformation, rotation_matrix) def transform_translation_vector(translation, transformation): - return np.dot(translation, transformation.T) + """Computes augmented translation vector. + + # Arguments + translation: Array, of shape `(3, )`. + transformation: Array, of shape `(3, 3)`. + + # Returns: + Array: of shape `(3, )` + """ + return np.dot(transformation, translation.T) def scale_translation_vector(translation, scale): + """Scales translation vector. + + # Arguments + translation: Array, of shape `(3, )`. + scale: Float, scaling factor. + + # Returns: + Array: of shape `(3, )` + """ translation[2] = translation[2] / scale return translation @@ -620,6 +797,21 @@ def __init__(self): class AutoContrast(Processor): + """Performs autocontrast or automatic contrast enhancement in a + given image. This method achieves this by computing the image + histogram and removing a certain `cutoff` percent from the lighter + and darker part of the histogram and then stretching the histogram + such that the lightest pixel gray value becomes 255 and the darkest + ones become 0. + + # Arguments + probability: Float, probability of data transformation. + + # References: + [Python Pillow autocontrast]( + https://github.com/python-pillow/Pillow/blob/main' + '/src/PIL/ImageOps.py) + """ def __init__(self, probability=0.50): self.probability = probability super(AutoContrast, self).__init__() @@ -631,6 +823,24 @@ def call(self, image): def auto_contrast(image): + """Performs autocontrast or automatic contrast enhancement in a + given image. This method achieves this by computing the image + histogram and removing a certain `cutoff` percent from the lighter + and darker part of the histogram and then stretching the histogram + such that the lightest pixel gray value becomes 255 and the darkest + ones become 0. + + # Arguments + image: Array, raw image. + + # Returns: + contrasted: Array, contrast enhanced image. + + # References: + [Python Pillow autocontrast]( + https://github.com/python-pillow/Pillow/blob/main' + '/src/PIL/ImageOps.py) + """ contrasted = np.empty_like(image) num_channels = image.shape[2] @@ -668,9 +878,10 @@ def auto_contrast(image): class EqualizeHistogram(Processor): - """The paper uses Histogram euqlaization algorithm from PIL. - This version of Histogram equalization produces slightly different - results from that in the paper. + """The Efficientpose implementation uses Histogram equalization + algorithm from python Pillow library. This version of Histogram + equalization produces slightly different results from that used in + the paper. """ def __init__(self, probability=0.50): self.probability = probability @@ -683,9 +894,16 @@ def call(self, image): def equalize_histogram(image): + """Performs histogram equalization on a given image. + + # Arguments + image: Array, raw image. + + # Returns: + equalized: Array, histogram equalized image. + """ equalized = np.empty_like(image) num_channels = image.shape[2] - for channel_arg in range(num_channels): image_per_channel = image[:, :, channel_arg] equalized_per_channel = cv2.equalizeHist(image_per_channel) @@ -694,6 +912,11 @@ def equalize_histogram(image): class InvertColors(Processor): + """Performs color / gray value inversion on a given image. + + # Arguments + probability: Float, probability of data transformation. + """ def __init__(self, probability=0.50): self.probability = probability super(InvertColors, self).__init__() @@ -705,11 +928,25 @@ def call(self, image): def invert_colors(image): - image_inverted = 255 - image - return image_inverted + """Performs color / gray value inversion on a given image. + + # Arguments + image: Array, raw image. + + # Returns: + Array: Color inverted image. + """ + return 255 - image class Posterize(Processor): + """Performs posterization on a given image. This is achieved + by reducing the bit depth of the gray value. + + # Arguments + probability: Float, probability of data transformation. + num_bits: Int, final bit depth after posterization. + """ def __init__(self, probability=0.50, num_bits=4): self.probability = probability self.num_bits = num_bits @@ -722,18 +959,30 @@ def call(self, image): def posterize(image, num_bits): - posterized = np.empty_like(image) - num_channels = image.shape[2] - for channel_arg in range(num_channels): - image_per_channel = image[:, :, channel_arg] - scale_factor = 2 ** (8 - num_bits) - posterized_per_channel = np.round(image_per_channel / - scale_factor) * scale_factor - posterized[:, :, channel_arg] = posterized_per_channel.astype(np.uint8) - return posterized + """Performs posterization on a given image. This is achieved + by reducing the bit depth of the gray value. + + # Arguments + image: Array, raw image. + num_bits: Int, final bit depth after posterization. + + # Returns: + Array: Posterized image. + """ + scale_factor = 2 ** (8 - num_bits) + posterized = np.round(image / scale_factor) * scale_factor + return posterized.astype(np.uint8) class Solarize(Processor): + """Performs solarization on a given image. This is achieved + by inverting those pixels whose gray values lie above + a certain `threshold`. + + # Arguments + probability: Float, probability of data transformation. + threshold: Int, threshold value. + """ def __init__(self, probability=0.50, threshold=225): self.probability = probability self.threshold = threshold @@ -746,11 +995,27 @@ def call(self, image): def solarize(image, threshold): - solarized = np.where(image < threshold, image, 255 - image) - return solarized + """Performs solarization on a given image. This is achieved + by inverting those pixels whose gray values lie above + a certain `threshold`. + + # Arguments + probability: Float, probability of data transformation. + threshold: Int, threshold value. + + # Returns: + Array: Solarized image. + """ + return np.where(image < threshold, image, 255 - image) class SharpenImage(Processor): + """Performs image sharpening by applying a high pass filter. + + # Arguments + probability: Float, probability of data transformation. + kernel: Array, the high pass filter. + """ def __init__(self, probability=0.50): self.probability = probability self.kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) @@ -758,16 +1023,32 @@ def __init__(self, probability=0.50): def call(self, image): if self.probability > np.random.rand(): - image = sharpen_image(image, self.kernel) + image = convolve_image(image, self.kernel) return image -def sharpen_image(image, kernel): - sharpened = cv2.filter2D(image, -1, kernel) - return sharpened +def convolve_image(image, kernel): + """Convolves image by applying a `kernel`. + + # Arguments + image: Array, raw image. + kernel: Array, the convolution kernel. + + # Returns: + Array: Solarized image. + """ + return cv2.filter2D(image, -1, kernel) class Cutout(Processor): + """Cuts out a square of size `size` x `size` at a random location + in the image and fills it with `fill` value. + + # Arguments + probability: Float, probability of data transformation. + size: Int, size of cutout square. + fill: Int, value to fill cutout with. + """ def __init__(self, probability=0.50, size=16, fill=128): self.probability = probability self.size = size @@ -781,6 +1062,17 @@ def call(self, image): def cutout(image, size, fill): + """Cuts out a square of size `size` x `size` at a random location + in the `image` and fills it with `fill` value. + + # Arguments + image: Array, raw image. + size: Int, size of cutout square. + fill: Int, value to fill cutout with. + + # Returns: + image: Array, cutout image. + """ H, W, _ = image.shape y = np.random.randint(0, H - size) x = np.random.randint(0, W - size) @@ -789,10 +1081,18 @@ def cutout(image, size, fill): class AddGaussianNoise(Processor): - def __init__(self, probability=0.50, mean=0, scale=20): + """Adds Gaussian noise defined by `mean` and `scale` to the image. + + # Arguments + probability: Float, probability of data transformation. + mean: Int, mean of Gaussian noise. + scale: Int, percent of variance relative to 255 + (max gray value of 8 bit image). + """ + def __init__(self, probability=0.50, mean=0, scale=0.20): self.probability = probability self.mean = mean - self.variance = (scale / 100.0) * 255 + self.variance = scale * 255 self.sigma = self.variance ** 0.5 super(AddGaussianNoise, self).__init__() @@ -803,6 +1103,16 @@ def call(self, image): def add_gaussian_noise(image, mean, sigma): + """Adds Gaussian noise defined by `mean` and `scale` to the `image`. + + # Arguments + image: Array, raw image. + mean: Int, mean of Gaussian noise. + sigma: Float, standard deviation of Gaussian noise. + + # Returns: + Array: Image added with Gaussian noise. + """ H, W, num_channels = image.shape noise = np.random.normal(mean, sigma, (H, W, num_channels)) noisy_image = image + noise diff --git a/examples/efficientpose/train.py b/examples/efficientpose/train.py index 4dcfa2369..d0dc76511 100644 --- a/examples/efficientpose/train.py +++ b/examples/efficientpose/train.py @@ -10,7 +10,8 @@ from paz.optimization.callbacks import LearningRateScheduler from paz.processors import TRAIN, VAL from linemod import LINEMOD -from pose import AugmentPose, EFFICIENTPOSEA, EFFICIENTPOSEALINEMODDRILLER +from efficientpose import EFFICIENTPOSEA +from pose import AugmentPose, EFFICIENTPOSEALINEMODDRILLER from losses import MultiPoseLoss from pose_error import EvaluatePoseError @@ -20,8 +21,10 @@ description = 'Training script for single-shot object detection models' parser = argparse.ArgumentParser(description=description) -parser.add_argument('-bs', '--batch_size', default=1, type=int, +parser.add_argument('-bs', '--batch_size', default=16, type=int, help='Batch size for training') +parser.add_argument('-et', '--evaluation_period', default=10, type=int, + help='evaluation frequency') parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Initial learning rate for SGD') parser.add_argument('-m', '--momentum', default=0.9, type=float, @@ -127,7 +130,8 @@ # Pose accuracy calculation pipeline pose_error = EvaluatePoseError(args.save_path, evaluation_data_managers[0], - inference, mesh_points, object_diameter) + inference, mesh_points, object_diameter, + args.evaluation_period) # training model.fit( diff --git a/paz/models/detection/efficientdet/efficientdet_blocks.py b/paz/models/detection/efficientdet/efficientdet_blocks.py index ee8c372fe..eba79a03c 100644 --- a/paz/models/detection/efficientdet/efficientdet_blocks.py +++ b/paz/models/detection/efficientdet/efficientdet_blocks.py @@ -4,7 +4,7 @@ from tensorflow.keras.layers import Activation, Concatenate, Reshape from tensorflow.keras.layers import (BatchNormalization, Conv2D, Flatten, MaxPooling2D, SeparableConv2D, - UpSampling2D) + UpSampling2D, GroupNormalization) from .layers import FuseFeature, GetDropConnect @@ -32,8 +32,10 @@ def build_detector_head(middles, num_classes, num_dims, aspect_ratios, num_anchors = len(aspect_ratios) * num_scales args = (middles, num_anchors, FPN_num_filters, box_class_repeats, survival_rate) - class_outputs = ClassNet(*args, num_classes) - boxes_outputs = BoxesNet(*args, num_dims) + _, class_outputs = ClassNet(*args, num_classes) + class_outputs = [Flatten()(class_output) for class_output in class_outputs] + _, boxes_outputs = BoxesNet(*args, num_dims) + boxes_outputs = [Flatten()(boxes_output) for boxes_output in boxes_outputs] classes = Concatenate(axis=1)(class_outputs) regressions = Concatenate(axis=1)(boxes_outputs) num_boxes = K.int_shape(regressions)[-1] // num_dims @@ -61,8 +63,8 @@ def ClassNet(features, num_anchors=9, num_filters=32, num_blocks=4, """ bias_initializer = tf.constant_initializer(-np.log((1 - 0.01) / 0.01)) num_filters = [num_filters, num_classes * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer) + return build_head(features, num_blocks, num_filters, + bias_initializer, survival_rate) def BoxesNet(features, num_anchors=9, num_filters=32, num_blocks=4, @@ -82,20 +84,20 @@ def BoxesNet(features, num_anchors=9, num_filters=32, num_blocks=4, """ bias_initializer = tf.zeros_initializer() num_filters = [num_filters, num_dims * num_anchors] - return build_head(features, num_blocks, num_filters, survival_rate, - bias_initializer) + return build_head(features, num_blocks, num_filters, + bias_initializer, survival_rate) def build_head(middle_features, num_blocks, num_filters, - survival_rate, bias_initializer): + bias_initializer, survival_rate, normalization='batch'): """Builds ClassNet/BoxNet head. # Arguments middle_features: Tuple. input features. num_blocks: Int, number of intermediate layers. num_filters: Int, number of intermediate layer filters. - survival_rate: Float, used by drop connect. bias_initializer: Callable, bias initializer. + survival_rate: Float, used by drop connect. # Returns head_outputs: List, with head outputs. @@ -103,18 +105,27 @@ def build_head(middle_features, num_blocks, num_filters, conv_blocks = build_head_conv2D( num_blocks, num_filters[0], tf.zeros_initializer()) final_head_conv = build_head_conv2D(1, num_filters[1], bias_initializer)[0] - head_outputs = [] + pre_head_outputs, head_outputs = [], [] + + if normalization == 'batch': + normalizer = BatchNormalization + args = () + + elif normalization == 'group': + normalizer = GroupNormalization + args = (int(num_filters[0] / 16), ) + for x in middle_features: for block_arg in range(num_blocks): x = conv_blocks[block_arg](x) - x = BatchNormalization()(x) + x = normalizer(*args)(x) x = tf.nn.swish(x) if block_arg > 0 and survival_rate: x = x + GetDropConnect(survival_rate=survival_rate)(x) + pre_head_outputs.append(x) x = final_head_conv(x) - x = Flatten()(x) head_outputs.append(x) - return head_outputs + return [pre_head_outputs, head_outputs] def build_head_conv2D(num_blocks, num_filters, bias_initializer): diff --git a/tests/paz/models/detection/efficientdet/efficientdet_test.py b/tests/paz/models/detection/efficientdet/efficientdet_test.py index 214273bb5..ad59eac25 100644 --- a/tests/paz/models/detection/efficientdet/efficientdet_test.py +++ b/tests/paz/models/detection/efficientdet/efficientdet_test.py @@ -1,7 +1,7 @@ import pytest import numpy as np import tensorflow as tf -from tensorflow.keras.layers import Input +from tensorflow.keras.layers import Input, Flatten from tensorflow.keras.utils import get_file from paz.models.detection.efficientdet import ( EFFICIENTDETD0, EFFICIENTDETD1, EFFICIENTDETD2, EFFICIENTDETD3, @@ -354,7 +354,8 @@ def test_EfficientDet_ClassNet(input_shape, scaling_coefficients, num_anchors = len(aspect_ratios) * num_scales args = (middles, num_anchors, FPN_num_filters, box_class_repeats, survival_rate) - class_outputs = ClassNet(*args, num_classes) + _, class_outputs = ClassNet(*args, num_classes) + class_outputs = [Flatten()(class_output) for class_output in class_outputs] assert len(class_outputs) == 5, 'Class outputs length fail' for class_output, output_shape in zip(class_outputs, output_shapes): assert class_output.shape == (None, output_shape), ( @@ -400,7 +401,8 @@ def test_EfficientDet_BoxesNet(input_shape, scaling_coefficients, num_anchors = len(aspect_ratios) * num_scales args = (middles, num_anchors, FPN_num_filters, box_class_repeats, survival_rate) - boxes_outputs = BoxesNet(*args, num_dims) + _, boxes_outputs = BoxesNet(*args, num_dims) + boxes_outputs = [Flatten()(boxes_output) for boxes_output in boxes_outputs] assert len(boxes_outputs) == 5 for boxes_output, output_shape in zip(boxes_outputs, output_shapes): assert boxes_output.shape == (None, output_shape), (