From 9c05dce41b932b319171e0f940cbf6280ba284bc Mon Sep 17 00:00:00 2001 From: Trevor Morris <trevmorr@amazon.com> Date: Wed, 17 Jun 2020 09:38:24 -0700 Subject: [PATCH] Backwards compatibile with numpy < 1.16.1. Add trt doc to index (#198) * Make backwards compatible with numpy versions < 1.16.1. Add trt doc to index * Use dict to convert dtype to ctype * Remove quotes around dtype in error message * Move dtype to ctype map to DLRModel class * Move _get_ctype_from_dtype to global scope --- doc/index.rst | 1 + python/dlr/dlr_model.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 3d52dc2c5..978d2dfee 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -13,4 +13,5 @@ Contents install python-api c-api + tensorrt Internal docs <http://neo-ai-dlr.readthedocs.io/en/latest/dev/> diff --git a/python/dlr/dlr_model.py b/python/dlr/dlr_model.py index 1b3c47af2..24b1cada3 100644 --- a/python/dlr/dlr_model.py +++ b/python/dlr/dlr_model.py @@ -8,6 +8,32 @@ from .libpath import find_lib_path +# Map from dtype string to ctype type. +# Equivalent to np.ctypeslib.as_ctypes_type which requires numpy>=1.16.1 +DTYPE_TO_CTYPE = { + "float32": ctypes.c_float, + "float64": ctypes.c_double, + "uint8": ctypes.c_ubyte, + "uint32": ctypes.c_uint, + "uint64": ctypes.c_ulong, + "int8": ctypes.c_byte, + "int32": ctypes.c_int, + "int64": ctypes.c_long, +} + +def _get_ctype_from_dtype(dtype): + """ + Convert type string to ctype type. + + Parameters + ---------- + dtype: str + Type as a string, e.g. "float32". + """ + if dtype not in DTYPE_TO_CTYPE: + raise ValueError("Model has input or output datatype {} which is not supported.".format(dtype)) + return DTYPE_TO_CTYPE[dtype] + class DLRError(Exception): """Error thrown by DLR""" pass @@ -299,7 +325,7 @@ def _set_input(self, name, data): The data to be set. """ input_dtype = self._get_input_or_weight_dtype_by_name(name) - input_ctype = np.ctypeslib.as_ctypes_type(input_dtype) + input_ctype = _get_ctype_from_dtype(input_dtype) # float32 inputs can accept any data (backward compatibility). if input_dtype == "float32": type_match = True @@ -392,7 +418,7 @@ def _get_output(self, index): raise ValueError("index is expected between 0 and " "len(output_shapes)-1, but got %d" % index) output_dtype = self.get_output_dtype(index) - output_ctype = np.ctypeslib.as_ctypes_type(output_dtype) + output_ctype = _get_ctype_from_dtype(output_dtype) output = np.zeros(self.output_size_dim[index][0], dtype=output_dtype) _check_call(_LIB.GetDLROutput(byref(self.handle), c_int(index), output.ctypes.data_as(ctypes.POINTER(output_ctype)))) @@ -462,7 +488,7 @@ def get_input(self, name, shape=None): 'input {}, we cannot infer its shape. '.format(name) + 'Shape parameter should be explicitly specified') input_dtype = self._get_input_or_weight_dtype_by_name(name) - input_ctype = np.ctypeslib.as_ctypes_type(input_dtype) + input_ctype = _get_ctype_from_dtype(input_dtype) if shape is None: shape = self.input_shapes[name] shape = np.array(shape)