Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 951f264

Browse files
Formatting applied
1 parent 65ed25c commit 951f264

13 files changed

+771
-357
lines changed

apply_py_formatting.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
yapf --in-place --recursive --parallel --exclude 'tftrt/blog_posts/**/*' .

perflab_tftrt.sh

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
export TF_TRT_SHOW_DETAILED_REPORT=1
2+
3+
LOGDIR="/workspace/bench_logs"
4+
mkdir -p ${LOGDIR}/
5+
6+
EXEC_MODES=(
7+
"TF_NATIVE"
8+
"TFTRT_FP32"
9+
"TFTRT_FP16"
10+
"TFTRT_INT8"
11+
)
12+
13+
IMAGE_CLASSIFICATION_MODELS=(
14+
# # Historical Models
15+
# "inception_v3"
16+
# "inception_v4"
17+
# "mobilenet_v1"
18+
# "mobilenet_v2"
19+
# "nasnet_large"
20+
# "nasnet_mobile"
21+
# "resnet_v1.5_50_tfv2"
22+
# "resnet_v1_50"
23+
# "resnet_v2_50"
24+
# "vgg_16"
25+
# "vgg_19"
26+
# # JOC Model
27+
# "resnet50-v1.5_tf1_ngc"
28+
# # Waymo Models
29+
"resnet50_v2_backbone"
30+
# "resnet50_v2_sparse_backbone"
31+
)
32+
33+
OBJECT_DETECTION_MODELS=(
34+
# "faster_rcnn_resnet50_coco"
35+
# "ssd_mobilenet_v1_fpn_coco"
36+
"ssd_resnet_50_fpn_coco"
37+
# "ssd_inception_v2_coco"
38+
# "ssd_mobilenet_v1_coco"
39+
# "ssd_mobilenet_v2_coco"
40+
# "ssdlite_mobilenet_v2_coco"
41+
)
42+
43+
TRANSFORMER_MODELS=(
44+
# "bart_base"
45+
# "bert_base_cased"
46+
"bert_base_uncased"
47+
)
48+
49+
COMMON_BENCH_FLAGS="--debug --use_synthetic_data --num_iterations=800"
50+
51+
for EXEC_MODE in "${EXEC_MODES[@]}"; do
52+
53+
if [[ ${exec_mode} == "TF_NATIVE" ]]; then
54+
ADDITIONAL_ARGUMENTS=""
55+
JOBNAME="tf_native"
56+
else
57+
TFTRT_PRECISION=${EXEC_MODE#*_}
58+
ADDITIONAL_ARGUMENTS="--use_tftrt --precision=${TFTRT_PRECISION}"
59+
60+
if [[ ${EXEC_MODE} != "TFTRT_INT8" ]]; then
61+
ADDITIONAL_ARGUMENTS="${ADDITIONAL_ARGUMENTS} --use_dynamic_shape"
62+
fi
63+
JOBNAME="tftrt_${TFTRT_PRECISION}"
64+
fi
65+
66+
# ========================= IMAGE CLASSIFICATION ========================= #
67+
68+
cd /workspace/tftrt/examples/image_classification/
69+
70+
RUN_ARGS="--data_dir=/data/imagenet --input_saved_model_dir=/models/image_classification"
71+
RUN_ARGS="${RUN_ARGS} ${COMMON_BENCH_FLAGS} --batch_size=128"
72+
73+
if [[ ${EXEC_MODE} == "TFTRT_INT8" ]]; then
74+
RUN_ARGS="${RUN_ARGS} --num_calib_inputs=1280"
75+
fi
76+
77+
for model_name in "${IMAGE_CLASSIFICATION_MODELS[@]}"; do
78+
script -q -c "./scripts/${model_name}.sh ${RUN_ARGS} ${ADDITIONAL_ARGUMENTS}" /dev/null | tee "${LOGDIR}/${JOBNAME}_${model_name}.log"
79+
done
80+
81+
# ========================= OBJECT DETECTION ========================= #
82+
83+
cd /workspace/tftrt/examples/object_detection/
84+
85+
RUN_ARGS="--data_dir=/data/coco2017 --input_saved_model_dir=/models/object_detection"
86+
RUN_ARGS="${RUN_ARGS} ${COMMON_BENCH_FLAGS} --batch_size=8"
87+
88+
if [[ ${EXEC_MODE} == "TFTRT_INT8" ]]; then
89+
RUN_ARGS="${RUN_ARGS} --num_calib_inputs=80"
90+
fi
91+
92+
for model_name in "${OBJECT_DETECTION_MODELS[@]}"; do
93+
script -q -c "./scripts/${model_name}.sh ${RUN_ARGS} ${ADDITIONAL_ARGUMENTS}" /dev/null | tee "${LOGDIR}/${JOBNAME}_${model_name}.log"
94+
done
95+
96+
# ========================= TRANSFORMERS ========================= #
97+
98+
if [[ ${EXEC_MODE} != "TFTRT_INT8" ]]; then
99+
100+
cd /workspace/tftrt/examples/transformers/
101+
102+
RUN_ARGS="--input_saved_model_dir=/models/transformers"
103+
RUN_ARGS="${RUN_ARGS} ${COMMON_BENCH_FLAGS} --batch_size=32"
104+
105+
if [[ "$EXEC_MODE" =~ ^TFTRT_.* ]]; then
106+
RUN_ARGS="${RUN_ARGS} --minimum_segment_size=20"
107+
fi
108+
109+
for model_name in "${TRANSFORMER_MODELS[@]}"; do
110+
script -q -c "./scripts/${model_name}.sh ${RUN_ARGS} ${ADDITIONAL_ARGUMENTS}" /dev/null | tee "${LOGDIR}/${JOBNAME}_${model_name}.log"
111+
done
112+
113+
fi
114+
115+
done

setup.cfg

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[tool:pytest]
2+
minversion = 6.0
3+
addopts = -ra -q -s
4+
testpaths =
5+
tests
6+
7+
[yapf]
8+
based_on_style = google
9+
10+
# The number of columns to use for indentation.
11+
indent_width = 4
12+
13+
# The column limit.
14+
column_limit = 80
15+
16+
# Place each dictionary entry onto its own line.
17+
each_dict_entry_on_separate_line = True
18+
19+
# Put closing brackets on a separate line, dedented, if the bracketed
20+
# expression can't fit in a single line. Applies to all kinds of brackets,
21+
# including function definitions and calls. For example:
22+
#
23+
# config = {
24+
# 'key1': 'value1',
25+
# 'key2': 'value2',
26+
# } # <--- this bracket is dedented and on a separate line
27+
#
28+
# time_series = self.remote_client.query_entity_counters(
29+
# entity='dev3246.region1',
30+
# key='dns.query_latency_tcp',
31+
# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
32+
# start_ts=now()-timedelta(days=3),
33+
# end_ts=now(),
34+
# ) # <--- this bracket is dedented and on a separate line
35+
dedent_closing_brackets=True

tests/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-

tests/test_yapf_format.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import sys
5+
import unittest
6+
7+
import pygments
8+
from pygments import console
9+
10+
from tests.utils import list_all_py_files
11+
from tests.utils import CustomTestCase
12+
13+
from yapf.yapflib.yapf_api import FormatCode
14+
15+
16+
def _read_utf_8_file(filename):
17+
if sys.version_info.major == 2: ## Python 2 specific
18+
with open(filename, 'rb') as f:
19+
return unicode(f.read(), 'utf-8')
20+
else:
21+
with open(filename, encoding='utf-8') as f:
22+
return f.read()
23+
24+
25+
def print_color(msg, color):
26+
print(pygments.console.colorize(color, msg))
27+
28+
29+
class YAPF_Style_Test(unittest.TestCase):
30+
31+
@classmethod
32+
def setUpClass(cls):
33+
34+
cls.badly_formatted_files = list()
35+
cls.files_2_test = list_all_py_files()
36+
37+
def test_files_format(self):
38+
39+
total_analyzed_files = 0
40+
for file in list_all_py_files():
41+
42+
total_analyzed_files += 1
43+
44+
try:
45+
46+
print(f"Testing: {file:100s}", end="")
47+
code = _read_utf_8_file(file)
48+
49+
# https://pypi.python.org/pypi/yapf/0.20.2#example-as-a-module
50+
diff, changed = FormatCode(
51+
code,
52+
filename=file,
53+
style_config='setup.cfg',
54+
print_diff=True
55+
)
56+
57+
if changed:
58+
print_color("FAILURE", "red")
59+
self.badly_formatted_files.append(file)
60+
else:
61+
print_color("SUCCESS", "green")
62+
63+
except Exception as e:
64+
print_color("FAILURE", "red")("FAILURE")
65+
print(
66+
"Error while processing file: `%s`\n"
67+
"Error: %s" % (file, str(e))
68+
)
69+
70+
str_err = ""
71+
72+
if self.badly_formatted_files:
73+
for filename in self.badly_formatted_files:
74+
str_err += f"yapf -i --style=setup.cfg {filename}\n"
75+
76+
str_err = "\n======================================================================================\n" \
77+
f"Bad Coding Style: {len(self.badly_formatted_files)} file(s) need to be formatted, run the following commands to fix: \n" \
78+
f"{str_err}" \
79+
"======================================================================================"
80+
81+
passing_files = total_analyzed_files - len(self.badly_formatted_files)
82+
print_color(
83+
f"\nPASSING: {passing_files} / {total_analyzed_files}",
84+
"green" if str_err == "" else "red"
85+
)
86+
87+
if str_err != "":
88+
print_color(str_err, "red")
89+
90+
self.assertEqual(str_err, "")
91+
92+
93+
if __name__ == '__main__':
94+
unittest.main()

tests/utils.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
import unittest
6+
7+
from contextlib import contextmanager
8+
from glob import glob, iglob
9+
10+
__all__ = [
11+
'CustomTestCase',
12+
'list_all_py_files',
13+
]
14+
15+
16+
class CustomTestCase(unittest.TestCase):
17+
18+
@contextmanager
19+
def assertNotRaises(self, exc_type):
20+
try:
21+
yield None
22+
except exc_type:
23+
raise self.failureException('{} raised'.format(exc_type.__name__))
24+
25+
26+
_excludes_paths = ["tftrt/blog_posts/", "tftrt/examples/third_party"]
27+
28+
29+
def list_all_py_files():
30+
for _dir in ['tests', 'tftrt']:
31+
for _file in iglob(f"{_dir}/**/*.py", recursive=True):
32+
if any([path in _file for path in _excludes_paths]):
33+
continue
34+
yield _file

0 commit comments

Comments
 (0)