5
5
#include " shaders/TensorToTextureCS.h"
6
6
#include " shaders/TextureToTensorCS.h"
7
7
#include " Logger.h"
8
- #include < onnxruntime/core/session/onnxruntime_session_options_config_keys.h>
9
8
#include < onnxruntime/core/providers/dml/dml_provider_factory.h>
10
9
#include " Win32Utils.h"
11
10
@@ -100,6 +99,7 @@ static winrt::com_ptr<IUnknown> AllocateD3D12Resource(const OrtDmlApi* ortDmlApi
100
99
101
100
bool DirectMLInferenceBackend::Initialize (
102
101
const wchar_t * modelPath,
102
+ uint32_t scale,
103
103
DeviceResources& deviceResources,
104
104
BackendDescriptorStore& /* descriptorStore*/ ,
105
105
ID3D11Texture2D* input,
@@ -109,13 +109,14 @@ bool DirectMLInferenceBackend::Initialize(
109
109
_d3d11DC = deviceResources.GetD3DDC ();
110
110
111
111
const SIZE inputSize = DirectXHelper::GetTextureSize (input);
112
+ const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
112
113
113
114
// 创建输出纹理
114
115
_outputTex = DirectXHelper::CreateTexture2D (
115
116
d3d11Device,
116
117
DXGI_FORMAT_R8G8B8A8_UNORM,
117
- inputSize .cx * 2 ,
118
- inputSize .cy * 2 ,
118
+ outputSize .cx ,
119
+ outputSize .cy ,
119
120
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS,
120
121
D3D11_USAGE_DEFAULT,
121
122
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
@@ -126,7 +127,8 @@ bool DirectMLInferenceBackend::Initialize(
126
127
}
127
128
*output = _outputTex.get ();
128
129
129
- const uint32_t elemCount = uint32_t (inputSize.cx * inputSize.cy * 3 );
130
+ const uint32_t inputElemCount = uint32_t (inputSize.cx * inputSize.cy * 3 );
131
+ const uint32_t outputElemCount = uint32_t (outputSize.cx * outputSize.cy * 3 );
130
132
131
133
winrt::com_ptr<ID3D12Device> d3d12Device = CreateD3D12Device (deviceResources.GetGraphicsAdapter ());
132
134
if (!d3d12Device) {
@@ -160,7 +162,6 @@ bool DirectMLInferenceBackend::Initialize(
160
162
sessionOptions.SetIntraOpNumThreads (1 );
161
163
sessionOptions.SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
162
164
sessionOptions.DisableMemPattern ();
163
- sessionOptions.AddConfigEntry (kOrtSessionOptionsDisableCPUEPFallback , " 1" );
164
165
165
166
Ort::ThrowOnError (ortApi.AddFreeDimensionOverride (sessionOptions, " DATA_BATCH" , 1 ));
166
167
@@ -187,7 +188,7 @@ bool DirectMLInferenceBackend::Initialize(
187
188
};
188
189
D3D12_RESOURCE_DESC resDesc{
189
190
.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER,
190
- .Width = elemCount * (isFP16Data ? 2 : 4 ),
191
+ .Width = inputElemCount * (isFP16Data ? 2 : 4 ),
191
192
.Height = 1 ,
192
193
.DepthOrArraySize = 1 ,
193
194
.MipLevels = 1 ,
@@ -209,7 +210,7 @@ bool DirectMLInferenceBackend::Initialize(
209
210
return false ;
210
211
}
211
212
212
- resDesc.Width *= 4 ;
213
+ resDesc.Width = UINT64 (outputElemCount * (isFP16Data ? 2 : 4 )) ;
213
214
hr = d3d12Device->CreateCommittedResource (
214
215
&heapDesc,
215
216
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
@@ -241,18 +242,18 @@ bool DirectMLInferenceBackend::Initialize(
241
242
_ioBinding.BindInput (" input" , Ort::Value::CreateTensor (
242
243
memoryInfo,
243
244
_allocatedInput.get (),
244
- size_t (elemCount * (isFP16Data ? 2 : 4 )),
245
+ size_t (inputElemCount * (isFP16Data ? 2 : 4 )),
245
246
inputShape,
246
247
std::size (inputShape),
247
248
dataType
248
249
));
249
250
250
- const int64_t outputShape[]{ 1 ,3 ,inputSize .cy * 2 ,inputSize .cx * 2 };
251
+ const int64_t outputShape[]{ 1 ,3 ,outputSize .cy ,outputSize .cx };
251
252
_allocatedOutput = AllocateD3D12Resource (ortDmlApi, _outputBuffer.get ());
252
253
_ioBinding.BindOutput (" output" , Ort::Value::CreateTensor (
253
254
memoryInfo,
254
255
_allocatedOutput.get (),
255
- size_t (elemCount * 4 * (isFP16Data ? 2 : 4 )),
256
+ size_t (outputElemCount * (isFP16Data ? 2 : 4 )),
256
257
outputShape,
257
258
std::size (outputShape),
258
259
dataType
@@ -276,7 +277,7 @@ bool DirectMLInferenceBackend::Initialize(
276
277
}
277
278
278
279
UINT descriptorSize;
279
- if (!_CreateCBVHeap (d3d12Device.get (), elemCount , isFP16Data, descriptorSize)) {
280
+ if (!_CreateCBVHeap (d3d12Device.get (), inputElemCount, outputElemCount , isFP16Data, descriptorSize)) {
280
281
Logger::Get ().Error (" _CreateCBVHeap 失败" );
281
282
return false ;
282
283
}
@@ -286,7 +287,7 @@ bool DirectMLInferenceBackend::Initialize(
286
287
return false ;
287
288
}
288
289
289
- if (!_CalcCommandLists (d3d12Device.get (), inputSize, descriptorSize)) {
290
+ if (!_CalcCommandLists (d3d12Device.get (), inputSize, outputSize, descriptorSize)) {
290
291
Logger::Get ().Error (" _CalcCommandLists 失败" );
291
292
return false ;
292
293
}
@@ -368,7 +369,8 @@ bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12De
368
369
369
370
bool DirectMLInferenceBackend::_CreateCBVHeap (
370
371
ID3D12Device* d3d12Device,
371
- uint32_t elemCount,
372
+ uint32_t inputElemCount,
373
+ uint32_t outputElemCount,
372
374
bool isFP16Data,
373
375
UINT& descriptorSize
374
376
) noexcept {
@@ -398,7 +400,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
398
400
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
399
401
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
400
402
.Buffer {
401
- .NumElements = elemCount
403
+ .NumElements = inputElemCount
402
404
}
403
405
};
404
406
d3d12Device->CreateUnorderedAccessView (_inputBuffer.get (), nullptr , &desc, cbvHandle);
@@ -411,7 +413,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
411
413
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
412
414
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
413
415
.Buffer {
414
- .NumElements = elemCount * 4
416
+ .NumElements = outputElemCount
415
417
}
416
418
};
417
419
d3d12Device->CreateShaderResourceView (_outputBuffer.get (), &desc, cbvHandle);
@@ -511,6 +513,7 @@ bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device)
511
513
bool DirectMLInferenceBackend::_CalcCommandLists (
512
514
ID3D12Device* d3d12Device,
513
515
SIZE inputSize,
516
+ SIZE outputSize,
514
517
UINT descriptorSize
515
518
) noexcept {
516
519
winrt::com_ptr<ID3D12CommandAllocator> d3d12CommandAllocator;
@@ -579,8 +582,8 @@ bool DirectMLInferenceBackend::_CalcCommandLists(
579
582
580
583
static constexpr std::pair<uint32_t , uint32_t > TENSOR_TO_TEX_BLOCK_SIZE{ 8 , 8 };
581
584
_tensor2TexCommandList->Dispatch (
582
- (inputSize .cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1 ) / TENSOR_TO_TEX_BLOCK_SIZE.first ,
583
- (inputSize .cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1 ) / TENSOR_TO_TEX_BLOCK_SIZE.second ,
585
+ (outputSize .cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1 ) / TENSOR_TO_TEX_BLOCK_SIZE.first ,
586
+ (outputSize .cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1 ) / TENSOR_TO_TEX_BLOCK_SIZE.second ,
584
587
1
585
588
);
586
589
hr = _tensor2TexCommandList->Close ();
0 commit comments