Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 6cf16c4

Browse files
author
DEKHTIARJonathan
committed
[Benchmarking-Py] Adding TF Model - SpineNet49 Mobile
1 parent a337598 commit 6cf16c4

File tree

7 files changed

+439
-1
lines changed

7 files changed

+439
-1
lines changed

tftrt/benchmarking-python/tf_hub/albert/base_run_inference.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/#!/usr/bin/env bash
1+
#!/usr/bin/env bash
22

33
nvidia-smi
44

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# --experiment_type=retinanet_mobile_coco
2+
runtime:
3+
distribution_strategy: 'tpu'
4+
mixed_precision_dtype: 'bfloat16'
5+
task:
6+
losses:
7+
l2_weight_decay: 3.0e-05
8+
model:
9+
anchor:
10+
anchor_size: 3
11+
aspect_ratios: [0.5, 1.0, 2.0]
12+
num_scales: 3
13+
backbone:
14+
spinenet_mobile:
15+
stochastic_depth_drop_rate: 0.2
16+
model_id: '49'
17+
se_ratio: 0.2
18+
type: 'spinenet_mobile'
19+
decoder:
20+
type: 'identity'
21+
head:
22+
num_convs: 4
23+
num_filters: 48
24+
use_separable_conv: true
25+
input_size: [384, 384, 3]
26+
max_level: 7
27+
min_level: 3
28+
norm_activation:
29+
activation: 'swish'
30+
norm_epsilon: 0.001
31+
norm_momentum: 0.99
32+
use_sync_bn: true
33+
train_data:
34+
dtype: 'bfloat16'
35+
global_batch_size: 256
36+
is_training: true
37+
parser:
38+
aug_rand_hflip: true
39+
aug_scale_max: 2.0
40+
aug_scale_min: 0.5
41+
validation_data:
42+
dtype: 'bfloat16'
43+
global_batch_size: 8
44+
is_training: false
45+
trainer:
46+
checkpoint_interval: 462
47+
optimizer_config:
48+
learning_rate:
49+
stepwise:
50+
boundaries: [263340, 272580]
51+
values: [0.32, 0.032, 0.0032]
52+
type: 'stepwise'
53+
warmup:
54+
linear:
55+
warmup_learning_rate: 0.0067
56+
warmup_steps: 2000
57+
steps_per_loop: 462
58+
train_steps: 277200
59+
validation_interval: 462
60+
validation_steps: 625
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# --experiment_type=retinanet_mobile_coco
2+
runtime:
3+
distribution_strategy: 'tpu'
4+
mixed_precision_dtype: 'float32'
5+
task:
6+
losses:
7+
l2_weight_decay: 3.0e-05
8+
model:
9+
anchor:
10+
anchor_size: 3
11+
aspect_ratios: [0.5, 1.0, 2.0]
12+
num_scales: 3
13+
backbone:
14+
spinenet_mobile:
15+
stochastic_depth_drop_rate: 0.2
16+
model_id: '49'
17+
se_ratio: 0.2
18+
type: 'spinenet_mobile'
19+
decoder:
20+
type: 'identity'
21+
head:
22+
num_convs: 4
23+
num_filters: 48
24+
use_separable_conv: true
25+
input_size: [384, 384, 3]
26+
max_level: 7
27+
min_level: 3
28+
norm_activation:
29+
activation: 'swish'
30+
norm_epsilon: 0.001
31+
norm_momentum: 0.99
32+
use_sync_bn: true
33+
train_data:
34+
dtype: 'float32'
35+
global_batch_size: 256
36+
is_training: true
37+
parser:
38+
aug_rand_hflip: true
39+
aug_scale_max: 2.0
40+
aug_scale_min: 0.5
41+
validation_data:
42+
dtype: 'float32'
43+
global_batch_size: 8
44+
is_training: false
45+
trainer:
46+
checkpoint_interval: 462
47+
optimizer_config:
48+
learning_rate:
49+
stepwise:
50+
boundaries: [263340, 272580]
51+
values: [0.32, 0.032, 0.0032]
52+
type: 'stepwise'
53+
warmup:
54+
linear:
55+
warmup_learning_rate: 0.0067
56+
warmup_steps: 2000
57+
steps_per_loop: 462
58+
train_steps: 277200
59+
validation_interval: 462
60+
validation_steps: 625
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# =============================================================================
17+
18+
import math
19+
import os
20+
import sys
21+
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
# Allow import of top level python files
26+
import inspect
27+
28+
currentdir = os.path.dirname(
29+
os.path.abspath(inspect.getfile(inspect.currentframe()))
30+
)
31+
parentdir = os.path.dirname(currentdir)
32+
parentdir = os.path.dirname(parentdir)
33+
34+
sys.path.insert(0, parentdir)
35+
36+
from benchmark_args import BaseCommandLineAPI
37+
from benchmark_runner import BaseBenchmarkRunner
38+
39+
40+
class CommandLineAPI(BaseCommandLineAPI):
41+
42+
def __init__(self):
43+
super(CommandLineAPI, self).__init__()
44+
45+
self._parser.add_argument(
46+
'--input_size',
47+
type=int,
48+
default=384,
49+
help='Size of input images expected by the model'
50+
)
51+
52+
def _validate_args(self, args):
53+
super(CommandLineAPI, self)._validate_args(args)
54+
55+
# TODO: Remove when proper dataloading is implemented
56+
if not args.use_synthetic_data:
57+
raise ValueError(
58+
"This benchmark does not currently support non-synthetic data "
59+
"--use_synthetic_data"
60+
)
61+
# This model requires that the batch size is 1
62+
if args.batch_size != 1:
63+
raise ValueError(
64+
"This benchmark does not currently support "
65+
"--batch_size != 1"
66+
)
67+
68+
69+
class BenchmarkRunner(BaseBenchmarkRunner):
70+
71+
def get_dataset_batches(self):
72+
"""Returns a list of batches of input samples.
73+
74+
Each batch should be in the form [x, y], where
75+
x is a numpy array of the input samples for the batch, and
76+
y is a numpy array of the expected model outputs for the batch
77+
78+
Returns:
79+
- dataset: a TF Dataset object
80+
- bypass_data_to_eval: any object type that will be passed unmodified to
81+
`evaluate_result()`. If not necessary: `None`
82+
83+
Note: script arguments can be accessed using `self._args.attr`
84+
"""
85+
86+
tf.random.set_seed(10)
87+
88+
inputs = tf.random.uniform(
89+
shape=(1, self._args.input_size, self._args.input_size, 3),
90+
maxval=255,
91+
dtype=tf.int32
92+
)
93+
94+
dataset = tf.data.Dataset.from_tensor_slices(inputs)
95+
96+
dataset = dataset.map(
97+
lambda x: {"inputs": tf.cast(x, tf.uint8)}, num_parallel_calls=tf.data.AUTOTUNE
98+
)
99+
100+
dataset = dataset.repeat()
101+
dataset = dataset.batch(self._args.batch_size)
102+
103+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
104+
return dataset, None
105+
106+
def preprocess_model_inputs(self, data_batch):
107+
"""This function prepare the `data_batch` generated from the dataset.
108+
Returns:
109+
x: input of the model
110+
y: data to be used for model evaluation
111+
112+
Note: script arguments can be accessed using `self._args.attr` """
113+
114+
return data_batch, None
115+
116+
def postprocess_model_outputs(self, predictions, expected):
117+
"""Post process if needed the predictions and expected tensors. At the
118+
minimum, this function transforms all TF Tensors into a numpy arrays.
119+
Most models will not need to modify this function.
120+
121+
Note: script arguments can be accessed using `self._args.attr`
122+
"""
123+
124+
# NOTE : DO NOT MODIFY FOR NOW => We do not measure accuracy right now
125+
126+
return predictions.numpy(), expected.numpy()
127+
128+
def evaluate_model(self, predictions, expected, bypass_data_to_eval):
129+
"""Evaluate result predictions for entire dataset.
130+
131+
This computes overall accuracy, mAP, etc. Returns the
132+
metric value and a metric_units string naming the metric.
133+
134+
Note: script arguments can be accessed using `self._args.attr`
135+
"""
136+
return None, "Raw Pitch Accuracy"
137+
138+
139+
if __name__ == '__main__':
140+
141+
cmdline_api = CommandLineAPI()
142+
args = cmdline_api.parse_args()
143+
144+
runner = BenchmarkRunner(args)
145+
runner.execute_benchmark()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env bash
2+
3+
pip install tf-models-official==2.9.2
4+
5+
wget https://raw.githubusercontent.com/tensorflow/models/v2.9.2/official/vision/configs/experiments/retinanet/coco_spinenet49_mobile_tpu.yaml \
6+
-O coco_spinenet49_mobile_tpu_fp16.yaml
7+
8+
sed 's/bfloat16/float32/g' coco_spinenet49_mobile_tpu_fp16.yaml > coco_spinenet49_mobile_tpu_fp32.yaml
9+
10+
BATCH_SIZES=(
11+
"1"
12+
"8"
13+
"16"
14+
"32"
15+
"64"
16+
"128"
17+
)
18+
19+
MODEL_DIR="/models/tf_models/spinetnet49_mobile"
20+
21+
for batch_size in "${BATCH_SIZES[@]}"; do
22+
23+
python -m official.vision.serving.export_saved_model \
24+
--experiment="retinanet_mobile_coco" \
25+
--checkpoint_path="${MODEL_DIR}/checkpoint/" \
26+
--config_file="coco_spinenet49_mobile_tpu_fp32.yaml" \
27+
--export_dir="${MODEL_DIR}/" \
28+
--export_saved_model_subdir="saved_model_bs${batch_size}" \
29+
--input_image_size=384,384 \
30+
--batch_size="${batch_size}"
31+
32+
saved_model_cli show --dir "${MODEL_DIR}/saved_model_bs${batch_size}/" --all 2>&1 \
33+
| tee "${MODEL_DIR}/saved_model_bs${batch_size}/analysis.txt"
34+
35+
done
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=""
4+
5+
EXPERIMENT_NAME="spinetnet49_mobile"
6+
7+
BASE_BENCHMARK_DATA_EXPORT_DIR="/workspace/benchmark_data/${EXPERIMENT_NAME}"
8+
rm -rf ${BASE_BENCHMARK_DATA_EXPORT_DIR}
9+
mkdir -p ${BASE_BENCHMARK_DATA_EXPORT_DIR}
10+
11+
# EXPERIMENT_FLAG="--experiment_name=${EXPERIMENT_NAME} --upload_metrics_endpoint=http://10.31.241.12:5000/record_metrics/"
12+
EXPERIMENT_FLAG=""
13+
14+
#########################
15+
16+
BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
17+
18+
BENCHMARK_DATA_EXPORT_DIR="${BASE_BENCHMARK_DATA_EXPORT_DIR}/tf_models/"
19+
mkdir -p ${BENCHMARK_DATA_EXPORT_DIR}
20+
21+
model_name="spinetnet49_mobile"
22+
23+
RUN_ARGS="${EXPERIMENT_FLAG} --data_dir=/tmp --input_saved_model_dir=/models/tf_models/${model_name}/saved_model_bs1/ "
24+
RUN_ARGS="${RUN_ARGS} --debug --batch_size=1 --display_every=5 --use_synthetic_data --num_warmup_iterations=200 --num_iterations=500"
25+
TF_TRT_ARGS="--use_tftrt --use_dynamic_shape --num_calib_batches=10"
26+
TF_XLA_ARGS="--use_xla_auto_jit"
27+
28+
export TF_TRT_SHOW_DETAILED_REPORT=1
29+
# export TF_TRT_BENCHMARK_EARLY_QUIT=1
30+
31+
MODEL_DATA_EXPORT_DIR="${BENCHMARK_DATA_EXPORT_DIR}/${model_name}"
32+
mkdir -p ${MODEL_DATA_EXPORT_DIR}
33+
34+
SCRIPT_PATH="${BASE_DIR}/run_inference.sh"
35+
METRICS_JSON_FLAG="--export_metrics_json_path=${MODEL_DATA_EXPORT_DIR}"
36+
37+
# TF Native
38+
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tf_fp32.log
39+
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tf_fp16.log
40+
41+
# TF-XLA manual
42+
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} ${TF_XLA_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tfxla_fp32.log
43+
script -q -c "${SCRIPT_PATH} ${RUN_ARGS} ${TF_XLA_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tfxla_fp16.log
44+
45+
# TF-TRT
46+
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_fp32.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=FP32" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_fp32.log
47+
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_fp16.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=FP16" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_fp16.log
48+
script -q -c "TF_TRT_EXPORT_GRAPH_VIZ_PATH=${MODEL_DATA_EXPORT_DIR}/tftrt_int8.dot ${SCRIPT_PATH} ${RUN_ARGS} ${TF_TRT_ARGS} --precision=INT8" /dev/null | tee ${MODEL_DATA_EXPORT_DIR}/inference_tftrt_int8.log

0 commit comments

Comments
 (0)