-
Notifications
You must be signed in to change notification settings - Fork 458
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge ArithmeticNode into ExecuteNode (#2247)
Summary: bypass-github-export-checks Pull Request resolved: #2247 This diff moves the logic of `ArithmeticNode` into its corresponding OpFunction `add_arithmetic_node()` and the `ExecuteNode` class. Our aim is to remove all derived classes of `ExecuteNode`, i.e., to make `ExecuteNode` a final class. All operator-specific logic will be handled in the OpFunction. Note the next change will move `StagingNode` into its OpFunction + this new ExecuteNode implementation. Until then, we can't tidy up the `ExecuteNode` class fully. Finally, we leave a few task TODOs. ghstack-source-id: 217439330 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D53982441 fbshipit-source-id: b8a51eee538b679e4168864a4870f3921c9ba333
- Loading branch information
1 parent
fae9ef0
commit 862f755
Showing
9 changed files
with
212 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h> | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace vulkan { | ||
|
||
void ExecuteNode::encode(ComputeGraph* graph) { | ||
api::Context* const context = graph->context(); | ||
api::PipelineBarrier pipeline_barrier{}; | ||
|
||
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock(); | ||
|
||
api::DescriptorSet descriptor_set = | ||
context->get_descriptor_set(shader_, local_workgroup_size_); | ||
|
||
uint32_t idx = 0; | ||
idx = bind_values_to_descriptor_set( | ||
graph, | ||
outputs_, | ||
pipeline_barrier, | ||
api::MemoryAccessType::WRITE, | ||
descriptor_set, | ||
idx); | ||
idx = bind_values_to_descriptor_set( | ||
graph, | ||
inputs_, | ||
pipeline_barrier, | ||
api::MemoryAccessType::READ, | ||
descriptor_set, | ||
idx); | ||
descriptor_set.bind(idx, params_.buffer()); | ||
|
||
context->register_shader_dispatch( | ||
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); | ||
} | ||
|
||
} // namespace vulkan | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace vulkan { | ||
|
||
api::utils::ivec4 get_size_as_ivec4(const vTensor& t) { | ||
return api::utils::make_ivec4( | ||
{dim_at<Dim4D::Width>(t), | ||
dim_at<Dim4D::Height>(t), | ||
dim_at<Dim4D::Channel>(t), | ||
dim_at<Dim4D::Batch>(t)}); | ||
} | ||
|
||
void bind_tensor_to_descriptor_set( | ||
vTensor& tensor, | ||
api::PipelineBarrier& pipeline_barrier, | ||
const api::MemoryAccessType accessType, | ||
api::DescriptorSet& descriptor_set, | ||
const uint32_t idx) { | ||
if (tensor.buffer()) { | ||
api::VulkanBuffer& buffer = tensor.buffer( | ||
pipeline_barrier, api::PipelineStage::COMPUTE, accessType); | ||
descriptor_set.bind(idx, buffer); | ||
} else { | ||
api::VulkanImage& image = | ||
tensor.image(pipeline_barrier, api::PipelineStage::COMPUTE, accessType); | ||
descriptor_set.bind(idx, image); | ||
} | ||
} | ||
|
||
uint32_t bind_values_to_descriptor_set( | ||
ComputeGraph* graph, | ||
const std::vector<ValueRef>& args, | ||
api::PipelineBarrier& pipeline_barrier, | ||
const api::MemoryAccessType accessType, | ||
api::DescriptorSet& descriptor_set, | ||
const uint32_t base_idx) { | ||
uint32_t idx = base_idx; | ||
for (auto& arg : args) { | ||
Value& val = graph->get_val(arg); | ||
if (val.isTensor()) { | ||
vTensor& tensor = val.toTensor(); | ||
bind_tensor_to_descriptor_set( | ||
tensor, pipeline_barrier, accessType, descriptor_set, idx++); | ||
} else { | ||
VK_THROW("Unsupported type: ", val.type()); | ||
} | ||
} | ||
return idx; | ||
} | ||
|
||
} // namespace vulkan | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.