From 9c05dce41b932b319171e0f940cbf6280ba284bc Mon Sep 17 00:00:00 2001
From: Trevor Morris <trevmorr@amazon.com>
Date: Wed, 17 Jun 2020 09:38:24 -0700
Subject: [PATCH] Backwards compatibile with numpy < 1.16.1. Add trt doc to
 index (#198)

* Make backwards compatible with numpy versions < 1.16.1. Add trt doc to index

* Use dict to convert dtype to ctype

* Remove quotes around dtype in error message

* Move dtype to ctype map to DLRModel class

* Move _get_ctype_from_dtype to global scope
---
 doc/index.rst           |  1 +
 python/dlr/dlr_model.py | 32 +++++++++++++++++++++++++++++---
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/doc/index.rst b/doc/index.rst
index 3d52dc2c5..978d2dfee 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -13,4 +13,5 @@ Contents
   install
   python-api
   c-api
+  tensorrt
   Internal docs <http://neo-ai-dlr.readthedocs.io/en/latest/dev/>
diff --git a/python/dlr/dlr_model.py b/python/dlr/dlr_model.py
index 1b3c47af2..24b1cada3 100644
--- a/python/dlr/dlr_model.py
+++ b/python/dlr/dlr_model.py
@@ -8,6 +8,32 @@
 
 from .libpath import find_lib_path
 
+# Map from dtype string to ctype type.
+# Equivalent to np.ctypeslib.as_ctypes_type which requires numpy>=1.16.1
+DTYPE_TO_CTYPE = {
+    "float32": ctypes.c_float,
+    "float64": ctypes.c_double,
+    "uint8": ctypes.c_ubyte,
+    "uint32": ctypes.c_uint,
+    "uint64": ctypes.c_ulong,
+    "int8": ctypes.c_byte,
+    "int32": ctypes.c_int,
+    "int64": ctypes.c_long,
+}
+
+def _get_ctype_from_dtype(dtype):
+    """
+    Convert type string to ctype type.
+
+    Parameters
+    ----------
+    dtype: str
+        Type as a string, e.g. "float32".
+    """
+    if dtype not in DTYPE_TO_CTYPE:
+        raise ValueError("Model has input or output datatype {} which is not supported.".format(dtype))
+    return DTYPE_TO_CTYPE[dtype]
+
 class DLRError(Exception):
     """Error thrown by DLR"""
     pass
@@ -299,7 +325,7 @@ def _set_input(self, name, data):
             The data to be set.
         """
         input_dtype = self._get_input_or_weight_dtype_by_name(name)
-        input_ctype = np.ctypeslib.as_ctypes_type(input_dtype)
+        input_ctype = _get_ctype_from_dtype(input_dtype)
         # float32 inputs can accept any data (backward compatibility).
         if input_dtype == "float32":
             type_match = True
@@ -392,7 +418,7 @@ def _get_output(self, index):
             raise ValueError("index is expected between 0 and "
                              "len(output_shapes)-1, but got %d" % index)
         output_dtype = self.get_output_dtype(index)
-        output_ctype = np.ctypeslib.as_ctypes_type(output_dtype)
+        output_ctype = _get_ctype_from_dtype(output_dtype)
         output = np.zeros(self.output_size_dim[index][0], dtype=output_dtype)
         _check_call(_LIB.GetDLROutput(byref(self.handle), c_int(index),
                     output.ctypes.data_as(ctypes.POINTER(output_ctype))))
@@ -462,7 +488,7 @@ def get_input(self, name, shape=None):
                              'input {}, we cannot infer its shape. '.format(name) +
                              'Shape parameter should be explicitly specified')
         input_dtype = self._get_input_or_weight_dtype_by_name(name)
-        input_ctype = np.ctypeslib.as_ctypes_type(input_dtype)
+        input_ctype = _get_ctype_from_dtype(input_dtype)
         if shape is None:
             shape = self.input_shapes[name]
         shape = np.array(shape)