Skip to content

Commit a5726c7

Browse files
committed
feat: 支持任意缩放倍率
1 parent 69416af commit a5726c7

6 files changed

+59
-33
lines changed

src/Magpie.Core/CudaInferenceBackend.cpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "BackendDescriptorStore.h"
88
#include "Logger.h"
99
#include "DirectXHelper.h"
10-
#include <onnxruntime/core/session/onnxruntime_session_options_config_keys.h>
1110
#include "Utils.h"
1211

1312
#pragma comment(lib, "cudart.lib")
@@ -29,6 +28,7 @@ CudaInferenceBackend::~CudaInferenceBackend() {
2928

3029
bool CudaInferenceBackend::Initialize(
3130
const wchar_t* modelPath,
31+
uint32_t scale,
3232
DeviceResources& deviceResources,
3333
BackendDescriptorStore& descriptorStore,
3434
ID3D11Texture2D* input,
@@ -59,7 +59,6 @@ bool CudaInferenceBackend::Initialize(
5959

6060
Ort::SessionOptions sessionOptions;
6161
sessionOptions.SetIntraOpNumThreads(1);
62-
sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1");
6362

6463
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
6564

@@ -83,13 +82,14 @@ bool CudaInferenceBackend::Initialize(
8382
_d3dDC = deviceResources.GetD3DDC();
8483

8584
_inputSize = DirectXHelper::GetTextureSize(input);
85+
_outputSize = SIZE{ _inputSize.cx * (LONG)scale, _inputSize.cy * (LONG)scale };
8686

8787
// 创建输出纹理
8888
winrt::com_ptr<ID3D11Texture2D> outputTex = DirectXHelper::CreateTexture2D(
8989
d3dDevice,
9090
DXGI_FORMAT_R8G8B8A8_UNORM,
91-
_inputSize.cx * 2,
92-
_inputSize.cy * 2,
91+
_outputSize.cx,
92+
_outputSize.cy,
9393
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS
9494
);
9595
if (!outputTex) {
@@ -98,13 +98,14 @@ bool CudaInferenceBackend::Initialize(
9898
}
9999
*output = outputTex.get();
100100

101-
const uint32_t elemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3);
101+
const uint32_t inputElemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3);
102+
const uint32_t outputElemCount = uint32_t(_outputSize.cx * _outputSize.cy * 3);
102103

103104
winrt::com_ptr<ID3D11Buffer> inputBuffer;
104105
winrt::com_ptr<ID3D11Buffer> outputBuffer;
105106
{
106107
D3D11_BUFFER_DESC desc{
107-
.ByteWidth = _isFP16Data ? ((elemCount + 1) / 2 * 4) : (elemCount * 4),
108+
.ByteWidth = _isFP16Data ? ((inputElemCount + 1) / 2 * 4) : (inputElemCount * 4),
108109
.BindFlags = D3D11_BIND_UNORDERED_ACCESS
109110
};
110111
HRESULT hr = d3dDevice->CreateBuffer(&desc, nullptr, inputBuffer.put());
@@ -113,7 +114,7 @@ bool CudaInferenceBackend::Initialize(
113114
return false;
114115
}
115116

116-
desc.ByteWidth = elemCount * 4 * (_isFP16Data ? 2 : 4);
117+
desc.ByteWidth = _isFP16Data ? ((outputElemCount + 1) / 2 * 4) : (outputElemCount * 4);
117118
desc.BindFlags = D3D11_BIND_SHADER_RESOURCE;
118119
hr = d3dDevice->CreateBuffer(&desc, nullptr, outputBuffer.put());
119120
if (FAILED(hr)) {
@@ -140,7 +141,7 @@ bool CudaInferenceBackend::Initialize(
140141
.Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
141142
.ViewDimension = D3D11_UAV_DIMENSION_BUFFER,
142143
.Buffer{
143-
.NumElements = elemCount
144+
.NumElements = inputElemCount
144145
}
145146
};
146147

@@ -157,7 +158,7 @@ bool CudaInferenceBackend::Initialize(
157158
.Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
158159
.ViewDimension = D3D11_SRV_DIMENSION_BUFFER,
159160
.Buffer{
160-
.NumElements = elemCount * 4
161+
.NumElements = outputElemCount
161162
}
162163
};
163164

@@ -202,8 +203,8 @@ bool CudaInferenceBackend::Initialize(
202203
(_inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second
203204
};
204205
_tensorToTexDispatchCount = {
205-
(_inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
206-
(_inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
206+
(_outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
207+
(_outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
207208
};
208209

209210
cudaResult = cudaGraphicsD3D11RegisterResource(
@@ -275,7 +276,7 @@ void CudaInferenceBackend::Evaluate() noexcept {
275276
std::size(inputShape),
276277
_isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
277278
);
278-
const int64_t outputShape[]{ 1,3,_inputSize.cy * 2,_inputSize.cx * 2 };
279+
const int64_t outputShape[]{ 1,3,_outputSize.cy,_outputSize.cx };
279280
Ort::Value outputValue = Ort::Value::CreateTensor(
280281
_cudaMemInfo,
281282
outputMem,

src/Magpie.Core/CudaInferenceBackend.h

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class CudaInferenceBackend : public InferenceBackendBase {
1515

1616
bool Initialize(
1717
const wchar_t* modelPath,
18+
uint32_t scale,
1819
DeviceResources& deviceResources,
1920
BackendDescriptorStore& descriptorStore,
2021
ID3D11Texture2D* input,
@@ -56,6 +57,7 @@ class CudaInferenceBackend : public InferenceBackendBase {
5657
Ort::MemoryInfo _cudaMemInfo{ nullptr };
5758

5859
SIZE _inputSize{};
60+
SIZE _outputSize{};
5961

6062
const char* _inputName = nullptr;
6163
const char* _outputName = nullptr;

src/Magpie.Core/DirectMLInferenceBackend.cpp

+20-17
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "shaders/TensorToTextureCS.h"
66
#include "shaders/TextureToTensorCS.h"
77
#include "Logger.h"
8-
#include <onnxruntime/core/session/onnxruntime_session_options_config_keys.h>
98
#include <onnxruntime/core/providers/dml/dml_provider_factory.h>
109
#include "Win32Utils.h"
1110

@@ -100,6 +99,7 @@ static winrt::com_ptr<IUnknown> AllocateD3D12Resource(const OrtDmlApi* ortDmlApi
10099

101100
bool DirectMLInferenceBackend::Initialize(
102101
const wchar_t* modelPath,
102+
uint32_t scale,
103103
DeviceResources& deviceResources,
104104
BackendDescriptorStore& /*descriptorStore*/,
105105
ID3D11Texture2D* input,
@@ -109,13 +109,14 @@ bool DirectMLInferenceBackend::Initialize(
109109
_d3d11DC = deviceResources.GetD3DDC();
110110

111111
const SIZE inputSize = DirectXHelper::GetTextureSize(input);
112+
const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
112113

113114
// 创建输出纹理
114115
_outputTex = DirectXHelper::CreateTexture2D(
115116
d3d11Device,
116117
DXGI_FORMAT_R8G8B8A8_UNORM,
117-
inputSize.cx * 2,
118-
inputSize.cy * 2,
118+
outputSize.cx,
119+
outputSize.cy,
119120
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS,
120121
D3D11_USAGE_DEFAULT,
121122
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
@@ -126,7 +127,8 @@ bool DirectMLInferenceBackend::Initialize(
126127
}
127128
*output = _outputTex.get();
128129

129-
const uint32_t elemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
130+
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
131+
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);
130132

131133
winrt::com_ptr<ID3D12Device> d3d12Device = CreateD3D12Device(deviceResources.GetGraphicsAdapter());
132134
if (!d3d12Device) {
@@ -160,7 +162,6 @@ bool DirectMLInferenceBackend::Initialize(
160162
sessionOptions.SetIntraOpNumThreads(1);
161163
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
162164
sessionOptions.DisableMemPattern();
163-
sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1");
164165

165166
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
166167

@@ -187,7 +188,7 @@ bool DirectMLInferenceBackend::Initialize(
187188
};
188189
D3D12_RESOURCE_DESC resDesc{
189190
.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER,
190-
.Width = elemCount * (isFP16Data ? 2 : 4),
191+
.Width = inputElemCount * (isFP16Data ? 2 : 4),
191192
.Height = 1,
192193
.DepthOrArraySize = 1,
193194
.MipLevels = 1,
@@ -209,7 +210,7 @@ bool DirectMLInferenceBackend::Initialize(
209210
return false;
210211
}
211212

212-
resDesc.Width *= 4;
213+
resDesc.Width = UINT64(outputElemCount * (isFP16Data ? 2 : 4));
213214
hr = d3d12Device->CreateCommittedResource(
214215
&heapDesc,
215216
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
@@ -241,18 +242,18 @@ bool DirectMLInferenceBackend::Initialize(
241242
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
242243
memoryInfo,
243244
_allocatedInput.get(),
244-
size_t(elemCount * (isFP16Data ? 2 : 4)),
245+
size_t(inputElemCount * (isFP16Data ? 2 : 4)),
245246
inputShape,
246247
std::size(inputShape),
247248
dataType
248249
));
249250

250-
const int64_t outputShape[]{ 1,3,inputSize.cy * 2,inputSize.cx * 2 };
251+
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
251252
_allocatedOutput = AllocateD3D12Resource(ortDmlApi, _outputBuffer.get());
252253
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
253254
memoryInfo,
254255
_allocatedOutput.get(),
255-
size_t(elemCount * 4 * (isFP16Data ? 2 : 4)),
256+
size_t(outputElemCount * (isFP16Data ? 2 : 4)),
256257
outputShape,
257258
std::size(outputShape),
258259
dataType
@@ -276,7 +277,7 @@ bool DirectMLInferenceBackend::Initialize(
276277
}
277278

278279
UINT descriptorSize;
279-
if (!_CreateCBVHeap(d3d12Device.get(), elemCount, isFP16Data, descriptorSize)) {
280+
if (!_CreateCBVHeap(d3d12Device.get(), inputElemCount, outputElemCount, isFP16Data, descriptorSize)) {
280281
Logger::Get().Error("_CreateCBVHeap 失败");
281282
return false;
282283
}
@@ -286,7 +287,7 @@ bool DirectMLInferenceBackend::Initialize(
286287
return false;
287288
}
288289

289-
if (!_CalcCommandLists(d3d12Device.get(), inputSize, descriptorSize)) {
290+
if (!_CalcCommandLists(d3d12Device.get(), inputSize, outputSize, descriptorSize)) {
290291
Logger::Get().Error("_CalcCommandLists 失败");
291292
return false;
292293
}
@@ -368,7 +369,8 @@ bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12De
368369

369370
bool DirectMLInferenceBackend::_CreateCBVHeap(
370371
ID3D12Device* d3d12Device,
371-
uint32_t elemCount,
372+
uint32_t inputElemCount,
373+
uint32_t outputElemCount,
372374
bool isFP16Data,
373375
UINT& descriptorSize
374376
) noexcept {
@@ -398,7 +400,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
398400
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
399401
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
400402
.Buffer{
401-
.NumElements = elemCount
403+
.NumElements = inputElemCount
402404
}
403405
};
404406
d3d12Device->CreateUnorderedAccessView(_inputBuffer.get(), nullptr, &desc, cbvHandle);
@@ -411,7 +413,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
411413
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
412414
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
413415
.Buffer{
414-
.NumElements = elemCount * 4
416+
.NumElements = outputElemCount
415417
}
416418
};
417419
d3d12Device->CreateShaderResourceView(_outputBuffer.get(), &desc, cbvHandle);
@@ -511,6 +513,7 @@ bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device)
511513
bool DirectMLInferenceBackend::_CalcCommandLists(
512514
ID3D12Device* d3d12Device,
513515
SIZE inputSize,
516+
SIZE outputSize,
514517
UINT descriptorSize
515518
) noexcept {
516519
winrt::com_ptr<ID3D12CommandAllocator> d3d12CommandAllocator;
@@ -579,8 +582,8 @@ bool DirectMLInferenceBackend::_CalcCommandLists(
579582

580583
static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
581584
_tensor2TexCommandList->Dispatch(
582-
(inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
583-
(inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
585+
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
586+
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
584587
1
585588
);
586589
hr = _tensor2TexCommandList->Close();

src/Magpie.Core/DirectMLInferenceBackend.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase {
1414

1515
bool Initialize(
1616
const wchar_t* modelPath,
17+
uint32_t scale,
1718
DeviceResources& deviceResources,
1819
BackendDescriptorStore& descriptorStore,
1920
ID3D11Texture2D* input,
@@ -27,7 +28,8 @@ class DirectMLInferenceBackend : public InferenceBackendBase {
2728

2829
bool _CreateCBVHeap(
2930
ID3D12Device* d3d12Device,
30-
uint32_t elemCount,
31+
uint32_t inputElemCount,
32+
uint32_t outputElemCount,
3133
bool isFP16Data,
3234
UINT& descriptorSize
3335
) noexcept;
@@ -37,6 +39,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase {
3739
bool _CalcCommandLists(
3840
ID3D12Device* d3d12Device,
3941
SIZE inputSize,
42+
SIZE outputSize,
4043
UINT descriptorSize
4144
) noexcept;
4245

src/Magpie.Core/InferenceBackendBase.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class InferenceBackendBase {
1616

1717
virtual bool Initialize(
1818
const wchar_t* modelPath,
19+
uint32_t scale,
1920
DeviceResources& deviceResources,
2021
BackendDescriptorStore& descriptorStore,
2122
ID3D11Texture2D* input,

src/Magpie.Core/OnnxEffectDrawer.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ OnnxEffectDrawer::OnnxEffectDrawer() {}
1414

1515
OnnxEffectDrawer::~OnnxEffectDrawer() {}
1616

17-
static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std::string& backend) noexcept {
17+
static bool ReadJson(
18+
const rapidjson::Document& doc,
19+
std::string& modelPath,
20+
uint32_t& scale,
21+
std::string& backend
22+
) noexcept {
1823
if (!doc.IsObject()) {
1924
Logger::Get().Error("根元素不是 Object");
2025
return false;
@@ -32,6 +37,16 @@ static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std
3237
modelPath = node->value.GetString();
3338
}
3439

40+
{
41+
auto node = root.FindMember("scale");
42+
if (node == root.MemberEnd() || !node->value.IsUint()) {
43+
Logger::Get().Error("解析 scale 失败");
44+
return false;
45+
}
46+
47+
scale = node->value.GetUint();
48+
}
49+
3550
{
3651
auto node = root.FindMember("backend");
3752
if (node == root.MemberEnd() || !node->value.IsString()) {
@@ -62,6 +77,7 @@ bool OnnxEffectDrawer::Initialize(
6277
}
6378

6479
std::string modelPath;
80+
uint32_t scale = 1;
6581
std::string backend;
6682
{
6783
rapidjson::Document doc;
@@ -71,7 +87,7 @@ bool OnnxEffectDrawer::Initialize(
7187
return false;
7288
}
7389

74-
if (!ReadJson(doc, modelPath, backend)) {
90+
if (!ReadJson(doc, modelPath, scale, backend)) {
7591
Logger::Get().Error("ReadJson 失败");
7692
return false;
7793
}
@@ -90,7 +106,7 @@ bool OnnxEffectDrawer::Initialize(
90106
}
91107

92108
std::wstring modelPathW = StrUtils::UTF8ToUTF16(modelPath);
93-
if (!_inferenceBackend->Initialize(modelPathW.c_str(), deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
109+
if (!_inferenceBackend->Initialize(modelPathW.c_str(), scale, deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
94110
return false;
95111
}
96112

0 commit comments

Comments
 (0)