forked from LiuLeif/hello_tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorrt_onnx.diff
111 lines (108 loc) · 4.88 KB
/
tensorrt_onnx.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
diff --git a/ModelImporter.cpp b/ModelImporter.cpp
index 01bff22..54e88bb 100644
--- a/ModelImporter.cpp
+++ b/ModelImporter.cpp
@@ -142,8 +142,19 @@ Status parseGraph(IImporterContext* ctx, const ::ONNX_NAMESPACE::GraphProto& gra
// Dispatch to appropriate converter.
const NodeImporter* importFunc{nullptr};
- if (opImporters.count(node.op_type()))
+ if (node.op_type() == "Softmax")
{
+ LOG_INFO("Custom op: " << node.op_type() << ". Attempting to import Softmax as plugin.");
+ importFunc = &opImporters.at("HxdSoftmaxPlugin");
+ }
+ /*else if (node.op_type() == "Add")
+ {
+ LOG_INFO("Custom op: " << node.op_type() << ". Attempting to import Add as plugin.");
+ importFunc = &opImporters.at("HxdAddPlugin");
+ }*/
+ else if (opImporters.count(node.op_type()))
+ {
+ LOG_INFO("node.op_type: " << node.op_type());
importFunc = &opImporters.at(node.op_type());
}
else
diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp
index 6977804..591c2d7 100644
--- a/builtin_op_importers.cpp
+++ b/builtin_op_importers.cpp
@@ -5237,6 +5237,81 @@ DEFINE_BUILTIN_OP_IMPORTER(TRT_AveragePool)
return importAveragePool(ctx, node, inputs);
}
+DEFINE_BUILTIN_OP_IMPORTER(HxdSoftmaxPlugin)
+{
+ OnnxAttrs attrs(node, ctx);
+ std::string pluginName{node.op_type()};
+ std::transform(pluginName.begin(), pluginName.end(), pluginName.begin(), ::toupper);
+ const std::string pluginVersion{attrs.get<std::string>("plugin_version", "1")};
+ const std::string pluginNamespace{attrs.get<std::string>("plugin_namespace", "")};
+
+ LOG_INFO("Searching for plugin: " << pluginName << ", plugin_version: " << pluginVersion << ", plugin_namespace: " << pluginNamespace);
+ nvinfer1::IPluginCreator* creator = importPluginCreator(pluginName, pluginVersion, pluginNamespace);
+ ASSERT(creator && "Plugin not found, are the plugin name, version, and namespace correct?", ErrorCode::kUNSUPPORTED_NODE);
+
+
+ std::vector<nvinfer1::PluginField> fields;
+ const int defaultAxis = (ctx->getOpsetVersion() >= 13) ? -1 : 1;
+ int axis = attrs.get("axis", defaultAxis);
+ fields.emplace_back("axis", &axis, nvinfer1::PluginFieldType::kINT32, 1);
+
+ const auto plugin = createPlugin(getNodeName(node), creator, fields);
+ ASSERT(plugin && "Could not create plugin", ErrorCode::kUNSUPPORTED_NODE);
+
+ std::vector<nvinfer1::ITensor*> pluginInputs{};
+ for (auto& input : inputs)
+ {
+ pluginInputs.emplace_back(&convertToTensor(input, ctx));
+ }
+ LOG_INFO("Successfully created plugin: " << pluginName);
+ auto* layer = ctx->network()->addPluginV2(pluginInputs.data(), pluginInputs.size(), *plugin);
+ ctx->registerLayer(layer, getNodeName(node));
+ RETURN_ALL_OUTPUTS(layer);
+}
+
+// DEFINE_BUILTIN_OP_IMPORTER(HxdAddPlugin)
+// {
+// OnnxAttrs attrs(node, ctx);
+// std::string pluginName{"ELTWISE"};
+// std::transform(pluginName.begin(), pluginName.end(), pluginName.begin(), ::toupper);
+// const std::string pluginVersion{attrs.get<std::string>("plugin_version", "1")};
+// const std::string pluginNamespace{attrs.get<std::string>("plugin_namespace", "")};
+
+// LOG_INFO("Searching for plugin: " << pluginName << ", plugin_version: " << pluginVersion << ", plugin_namespace: " << pluginNamespace);
+// nvinfer1::IPluginCreator* creator = importPluginCreator(pluginName, pluginVersion, pluginNamespace);
+// ASSERT(creator && "Plugin not found, are the plugin name, version, and namespace correct?", ErrorCode::kUNSUPPORTED_NODE);
+
+
+// std::vector<nvinfer1::PluginField> fields;
+// int operation = 3;
+// fields.emplace_back("operation", &operation, nvinfer1::PluginFieldType::kINT32, 1);
+
+// const auto plugin = createPlugin(getNodeName(node), creator, fields);
+// ASSERT(plugin && "Could not create plugin", ErrorCode::kUNSUPPORTED_NODE);
+
+// std::vector<nvinfer1::ITensor*> inputTensors;
+// int maxNbDims = -1;
+// for (auto input : inputs)
+// {
+// maxNbDims = std::max(maxNbDims, input.shape().nbDims);
+// }
+
+// for (auto input : inputs)
+// {
+// auto* tensor_ptr = &convertToTensor(input, ctx);
+
+// // Broadcast all input tensors to size of maxNbDims
+// broadcastTensor(ctx, tensor_ptr, maxNbDims);
+// ASSERT(tensor_ptr->getDimensions().nbDims == maxNbDims && "Failed to broadcast tensors elementwise!",
+// ErrorCode::kUNSUPPORTED_NODE);
+// inputTensors.emplace_back(tensor_ptr);
+// }
+// LOG_INFO("Successfully created plugin: " << pluginName);
+// auto* layer = ctx->network()->addPluginV2(inputTensors.data(), inputTensors.size(), *plugin);
+// ctx->registerLayer(layer, getNodeName(node));
+// RETURN_FIRST_OUTPUT(layer);
+// }
+
} // namespace
} // namespace onnx2trt