Skip to content

Commit

Permalink
Backwards compatibile with numpy < 1.16.1. Add trt doc to index (#198)
Browse files Browse the repository at this point in the history
* Make backwards compatible with numpy versions < 1.16.1. Add trt doc to index

* Use dict to convert dtype to ctype

* Remove quotes around dtype in error message

* Move dtype to ctype map to DLRModel class

* Move _get_ctype_from_dtype to global scope
  • Loading branch information
Trevor Morris authored Jun 17, 2020
1 parent 06b1cc3 commit 9c05dce
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ Contents
install
python-api
c-api
tensorrt
Internal docs <http://neo-ai-dlr.readthedocs.io/en/latest/dev/>
32 changes: 29 additions & 3 deletions python/dlr/dlr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@

from .libpath import find_lib_path

# Map from dtype string to ctype type.
# Equivalent to np.ctypeslib.as_ctypes_type which requires numpy>=1.16.1
DTYPE_TO_CTYPE = {
"float32": ctypes.c_float,
"float64": ctypes.c_double,
"uint8": ctypes.c_ubyte,
"uint32": ctypes.c_uint,
"uint64": ctypes.c_ulong,
"int8": ctypes.c_byte,
"int32": ctypes.c_int,
"int64": ctypes.c_long,
}

def _get_ctype_from_dtype(dtype):
"""
Convert type string to ctype type.
Parameters
----------
dtype: str
Type as a string, e.g. "float32".
"""
if dtype not in DTYPE_TO_CTYPE:
raise ValueError("Model has input or output datatype {} which is not supported.".format(dtype))
return DTYPE_TO_CTYPE[dtype]

class DLRError(Exception):
"""Error thrown by DLR"""
pass
Expand Down Expand Up @@ -299,7 +325,7 @@ def _set_input(self, name, data):
The data to be set.
"""
input_dtype = self._get_input_or_weight_dtype_by_name(name)
input_ctype = np.ctypeslib.as_ctypes_type(input_dtype)
input_ctype = _get_ctype_from_dtype(input_dtype)
# float32 inputs can accept any data (backward compatibility).
if input_dtype == "float32":
type_match = True
Expand Down Expand Up @@ -392,7 +418,7 @@ def _get_output(self, index):
raise ValueError("index is expected between 0 and "
"len(output_shapes)-1, but got %d" % index)
output_dtype = self.get_output_dtype(index)
output_ctype = np.ctypeslib.as_ctypes_type(output_dtype)
output_ctype = _get_ctype_from_dtype(output_dtype)
output = np.zeros(self.output_size_dim[index][0], dtype=output_dtype)
_check_call(_LIB.GetDLROutput(byref(self.handle), c_int(index),
output.ctypes.data_as(ctypes.POINTER(output_ctype))))
Expand Down Expand Up @@ -462,7 +488,7 @@ def get_input(self, name, shape=None):
'input {}, we cannot infer its shape. '.format(name) +
'Shape parameter should be explicitly specified')
input_dtype = self._get_input_or_weight_dtype_by_name(name)
input_ctype = np.ctypeslib.as_ctypes_type(input_dtype)
input_ctype = _get_ctype_from_dtype(input_dtype)
if shape is None:
shape = self.input_shapes[name]
shape = np.array(shape)
Expand Down

0 comments on commit 9c05dce

Please sign in to comment.