From 2a0360040616a5aceecd2456c23186fe2dea3b1f Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 30 Apr 2024 11:34:53 +0000 Subject: [PATCH] 2024-04-30 nightly release (38fad23b19bacb84f58a828e9d581db65c2f82f5) --- .github/workflows/doc-build.yml | 4 +- .github/workflows/pull.yml | 64 +++ CMakeLists.txt | 2 +- .../graph/ops/glsl/copy_channel_offset.glsl | 78 +++ .../graph/ops/glsl/copy_channel_offset.yaml | 10 + .../runtime/graph/ops/glsl/copy_offset.glsl | 16 +- .../vulkan/runtime/graph/ops/impl/Cat.cpp | 95 ++++ .../vulkan/runtime/graph/ops/impl/Copy.cpp | 197 +++++++- backends/vulkan/runtime/graph/ops/impl/Copy.h | 28 ++ .../runtime/graph/ops/impl/utils/DimUtils.h | 46 ++ backends/vulkan/test/op_tests/cases.py | 47 ++ .../vulkan/test/op_tests/generate_op_tests.py | 4 + .../vulkan/test/op_tests/utils/codegen.py | 82 ++- .../test/op_tests/utils/codegen_base.py | 34 +- backends/vulkan/test/utils/test_utils.cpp | 6 +- backends/vulkan/test/utils/test_utils.h | 22 +- .../vulkan/test/vulkan_compute_api_test.cpp | 269 +++++++++- backends/xnnpack/test/ops/add.py | 3 +- backends/xnnpack/test/tester/tester.py | 2 +- docs/source/debug-backend-delegate.md | 6 +- docs/source/llm/getting-started.md | 6 +- .../sdk-integration-tutorial.py | 22 +- examples/models/llama2/README.md | 2 +- examples/models/llama2/builder.py | 6 +- exir/backend/test/test_utils.py | 4 +- exir/backend/utils.py | 11 +- kernels/README.md | 473 +++++++++++++++++ kernels/portable/README.md | 474 +----------------- kernels/portable/cpu/op_convolution.cpp | 13 +- kernels/portable/cpu/op_copy.cpp | 8 +- kernels/portable/cpu/op_index_put.cpp | 4 +- kernels/portable/cpu/op_slice_scatter.cpp | 41 +- profiler/parse_profiler_results.py | 8 +- 33 files changed, 1494 insertions(+), 593 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Cat.cpp create mode 100644 kernels/README.md diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml index ccc852c24f..7a3b862b21 100644 --- a/.github/workflows/doc-build.yml +++ b/.github/workflows/doc-build.yml @@ -94,11 +94,11 @@ jobs: # Get github.ref for the output doc folder. By default "main" # If matches a tag like refs/tags/v1.12.0-rc3 or # refs/tags/v1.12.0 convert to 1.12 - GITHUB_REF=${{ github.ref }} + export GITHUB_REF=${{ github.ref }} # Convert refs/tags/v1.12.0rc3 into 1.12. # Adopted from https://github.com/pytorch/pytorch/blob/main/.github/workflows/_docs.yml#L150C11-L155C13 - if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\\.[0-9]+)\\. ]]; then + if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+) ]]; then TARGET_FOLDER="${BASH_REMATCH[1]}" else TARGET_FOLDER="main" diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index f650fc7920..e2cf7e6121 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -239,6 +239,70 @@ jobs: # see if we can import the module successfully python -c "from executorch.extension.pybindings import portable_lib; print('success!')" + test-binary-size-linux-gcc: + name: test-binary-size-linux-gcc + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + strategy: + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-gcc9 + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # build module for executorch.extension.pybindings.portable_lib + bash test/build_size_test.sh + strip cmake-out/test/size_test + output=$(ls -la cmake-out/test/size_test) + arr=($output) + size=${arr[4]} + # threshold=48120 on devserver with gcc11.4 + # todo(lfq): update once binary size is below 50kb. + threshold="51504" + if [[ "$size" -le "$threshold" ]]; then + echo "Success $size <= $threshold" + else + echo "Fail $size > $threshold" + exit 1 + fi + + test-binary-size-linux: + name: test-binary-size-linux + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + strategy: + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-clang12 + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # build module for executorch.extension.pybindings.portable_lib + bash test/build_size_test.sh + strip cmake-out/test/size_test + output=$(ls -la cmake-out/test/size_test) + arr=($output) + size=${arr[4]} + # threshold=48120 on devserver with gcc11.4 + # todo(lfq): update once binary size is below 50kb. + threshold="51768" + if [[ "$size" -le "$threshold" ]]; then + echo "Success $size <= $threshold" + else + echo "Fail $size > $threshold" + exit 1 + fi + unittest: uses: ./.github/workflows/_unittest.yml with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0610462aed..f5e0937757 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,7 +120,7 @@ endif() # disables exceptions and runtime type. set(CMAKE_CXX_FLAGS_RELEASE "-ffunction-sections -fdata-sections -fno-exceptions -fno-rtti") -if(NOT APPLE) +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -s") endif() diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl new file mode 100644 index 0000000000..78e698fa7e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D existing_out; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 3) uniform PRECISION restrict CopyArgs { + ivec4 out_sizes; + ivec4 in_sizes; + // Analogus to range variable in copy. It defines the # of channel being + // copied. + int channel_range; + int src_channel_offset; + int dst_channel_offset; + int unused; + // Operates on (x, y, z) extents. + ivec3 range; + int unused1; + ivec3 dst_offset; + int unused2; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + // Note: Unlike other shaders, the range is often not equal to the destination + // texture extent. + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(pos, range))) { + return; + } + + const ivec3 out_pos = pos + dst_offset; + + const ivec4 out_whcn = to_tensor_idx(out_pos, out_sizes, packed_dim); + + // First read the existing values to make sure the boundary values stay. + VEC4_T v = VEC4_T(texelFetch(existing_out, out_pos, 0)); + + for (int i=0; i<4; i++) { + ivec4 in_whcn = out_whcn; + + in_whcn.z = out_whcn.z - dst_channel_offset + i; + + // Handle the partial update for begining of channel in an existing tensor. + // If the source channel index is below zero or exceeds the range, we skip + // updating the element to avoid overwriting existing data. + if ((in_whcn.z < 0) || (in_whcn.z >= channel_range)) { + continue; + } + + // Readjust for the source offset. + in_whcn.z = in_whcn.z + src_channel_offset; + + ivec4 in_elem_pos = to_texture_elem_pos(in_whcn, in_sizes, packed_dim); + v[i] = VEC4_T(texelFetch(image_in, in_elem_pos.xyz, 0))[in_elem_pos.w]; + } + + imageStore(image_out, out_pos, v); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml new file mode 100644 index 0000000000..3887647ff8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -0,0 +1,10 @@ +copy_channel_offset: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl index 17b3e06e61..0d1d3420a5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl @@ -10,26 +10,12 @@ #define PRECISION ${PRECISION} -#define VEC4_T ${texel_type(DTYPE)} - layout(std430) buffer; -#include "indexing_utils.h" - layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 3) uniform PRECISION restrict InLimits { - ivec3 in_limits; -}; - - - -layout(set = 0, binding = 4) uniform PRECISION restrict CopyArgs { +layout(set = 0, binding = 2) uniform PRECISION restrict CopyArgs { ivec3 range; int unused0; ivec3 src_offset; diff --git a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp new file mode 100644 index 0000000000..08363fa71e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +void add_cat_default_node( + ComputeGraph& graph, + ValueRef in_list_ref, + ValueRef dim_ref, + ValueRef out) { + ValueListPtr input_list = graph.get_value_list(in_list_ref); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); + } + + int64_t dim = graph.extract_scalar(dim_ref); + vTensorPtr t_out = graph.get_tensor(out); + + NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim); + + // TODO: Find ways to factor out the similar code for width, height, and batch + if (nchw_dim == DimWidth) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[0] += range.data[0]; + } + + } else if (nchw_dim == DimHeight) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[1] += range.data[1]; + } + } else if (nchw_dim == DimBatch) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[2] += range.data[2]; + } + } else if (nchw_dim == DimChannel) { + int32_t src_offset = 0; + int32_t dst_offset = 0; + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + int32_t range = dim_at(t_in->sizes()); + add_copy_channel_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset += range; + } + } else { + VK_THROW("Unexpected value of nchw_dim=", nchw_dim); + } +} + +void cat_default(ComputeGraph& graph, const std::vector& args) { + add_cat_default_node(graph, args[0], args[1], args[2]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.cat.default, cat_default); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index 0a5e20e4f7..5ca4973e56 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -8,38 +8,39 @@ #include +#include #include #include #include namespace vkcompute { +using api::utils::ivec3; +using api::utils::uvec3; + void add_copy_offset_node( ComputeGraph& graph, const ValueRef in, - const api::utils::ivec3& range, - const api::utils::ivec3& src_offset, - const api::utils::ivec3& dst_offset, + const ivec3& range, + const ivec3& src_offset, + const ivec3& dst_offset, const ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); - VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked)); - std::string kernel_name = "copy_offset"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - api::utils::uvec3 global_size = api::utils::make_uvec3(range); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + uvec3 global_size = api::utils::make_uvec3(range); + uvec3 local_size = adaptive_work_group_size(global_size); const struct Block final { - api::utils::ivec3 range; + ivec3 range; int32_t unused0; - api::utils::ivec3 src_offset; + ivec3 src_offset; int32_t unused1; - api::utils::ivec3 dst_offset; + ivec3 dst_offset; int32_t unused2; } offset_params{ range, @@ -58,13 +59,179 @@ void add_copy_offset_node( global_size, local_size, // Inputs and Outputs - {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + { + {out, api::MemoryAccessType::WRITE}, + {in, api::MemoryAccessType::READ}, + }, // Parameter buffers - {t_out->texture_limits_ubo(), - t_in->texture_limits_ubo(), - graph.create_params_buffer(offset_params)}, + {graph.create_params_buffer(offset_params)}, // Specialization Constants {})); } +void add_copy_channel_offset_node( + ComputeGraph& graph, + const ValueRef in, + int32_t channel_range, + int32_t src_channel_offset, + int32_t dst_channel_offset, + const ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + // Likely need to prepad these numbers. + std::vector in_sizes = t_in->sizes(); + std::vector out_sizes = t_out->sizes(); + + VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked)); + + // NOTE: This function should be able to support 1d and 2d tensors when + // range=1, src_offset=dst_offset=1. + VK_CHECK_COND(t_in->dim() >= 3, "Src dim should be at least 3"); + VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3"); + + VK_CHECK_COND( + dim_at(in_sizes) >= src_channel_offset + channel_range, + "Src channel (", + src_channel_offset, + ") and range (", + channel_range, + ") should be less than or equal to input tensor's channel size (", + dim_at(in_sizes), + ")"); + + VK_CHECK_COND( + dim_at(out_sizes) >= dst_channel_offset + channel_range, + "Dst channel (", + dst_channel_offset, + ") and range (", + channel_range, + ") should be less than or equal to input tensor's channel size (", + dim_at(out_sizes), + ")"); + + VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative"); + VK_CHECK_COND( + src_channel_offset >= 0, "Src channel offset must be non-negative"); + VK_CHECK_COND( + dst_channel_offset >= 0, "Dst channel offset must be non-negative"); + + std::string kernel_name = "copy_channel_offset"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + int32_t out_channels = dim_at(out_sizes); + + // Copy one batch at a time. + for (int batch_idx = 0; batch_idx < dim_at(in_sizes); + batch_idx++) { + // Mapping the tensor NCHW coordinates into texture XYZ coordinates + int32_t dst_first_z = dst_channel_offset / 4; + int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4; + + // We copy the entire width and height dimension. For the channel dimension, + // we use the z-dimension of the global_size to specify the texture range. + // The shader combines the global invocation id and the dst_offset to get + // the actual coordinate. + + ivec3 dst_offset{ + 0, 0, dst_first_z + batch_idx * api::utils::div_up(out_channels, 4)}; + + uvec3 global_size{ + dim_at(in_sizes), + dim_at(in_sizes), + api::utils::safe_downcast(dst_last_z - dst_first_z + 1)}; + + uvec3 local_size = adaptive_work_group_size(global_size); + + const struct Block final { + api::utils::ivec4 out_sizes; + api::utils::ivec4 in_sizes; + int32_t channel_range; + int32_t src_channel_offset; + int32_t dst_channel_offset; + int32_t unused; + ivec3 range; + int32_t unused1; + ivec3 dst_offset; + int32_t unused2; + + } channel_offset_params{ + api::utils::make_whcn_ivec4(out_sizes), + api::utils::make_whcn_ivec4(in_sizes), + channel_range, + src_channel_offset, + dst_channel_offset, + 0, + api::utils::make_ivec3(global_size), + 0, + dst_offset, + 0, + }; + + auto shader = VK_KERNEL_FROM_STR(kernel_name); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + { + {out, api::MemoryAccessType::WRITE}, + {out, api::MemoryAccessType::READ}, + {in, api::MemoryAccessType::READ}, + }, + // Parameter buffers + {graph.create_params_buffer(channel_offset_params)}, + // Specialization Constants + {})); + } +} + +void add_copy_offset_node( + ComputeGraph& graph, + ValueRef in, + ValueRef range_ref, + ValueRef src_offset_ref, + ValueRef dst_offset_ref, + ValueRef out) { + ivec3 range = api::utils::make_ivec3(*graph.get_int_list(range_ref)); + ivec3 src_offset = + api::utils::make_ivec3(*graph.get_int_list(src_offset_ref)); + ivec3 dst_offset = + api::utils::make_ivec3(*graph.get_int_list(dst_offset_ref)); + + add_copy_offset_node(graph, in, range, src_offset, dst_offset, out); +} + +void copy_offset(ComputeGraph& graph, const std::vector& args) { + add_copy_offset_node(graph, args[0], args[1], args[2], args[3], args[4]); +} + +void copy_channel_offset( + ComputeGraph& graph, + const std::vector& args) { + ValueRef in = args[0]; + ValueRef channel_range_ref = args[1]; + ValueRef src_channel_offset_ref = args[2]; + ValueRef dst_channel_offset_ref = args[3]; + ValueRef out = args[4]; + + auto channel_range = graph.extract_scalar(channel_range_ref); + auto src_channel_offset = + graph.extract_scalar(src_channel_offset_ref); + auto dst_channel_offset = + graph.extract_scalar(dst_channel_offset_ref); + + add_copy_channel_offset_node( + graph, in, channel_range, src_channel_offset, dst_channel_offset, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(etvk.copy_offset, copy_offset); + VK_REGISTER_OP(etvk.copy_channel_offset, copy_channel_offset); +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.h b/backends/vulkan/runtime/graph/ops/impl/Copy.h index 6e0deb6b74..60a58b2fa8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.h +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.h @@ -14,6 +14,13 @@ namespace vkcompute { +// add_copy_offset_node resumes the vkCmdCopyImage command. It copies the +// texture extents specified by the range, src_offset, and dst_offset (all are +// in texture coordinate (x, y, z) from the input image to the output image. +// +// It is possible to have input and output to point to the same image +// object. But when the source range and destination range overlap, the behavior +// is undefined. void add_copy_offset_node( ComputeGraph& graph, const ValueRef in, @@ -22,4 +29,25 @@ void add_copy_offset_node( const api::utils::ivec3& dst_offset, const ValueRef out); +// add_copy_channel_offset_node behaves similar to add_copy_node, except that it +// works on the channel dimensions of the tensor (up to 4 dimensions in NCHW). +// The range and offset arguments are in the tensor coordinate. It assumes the +// underlying texture is channel-packed. +// +// This function is specialized implementation for copying +// channel packed values. The complication comes from when reading / writing the +// channel dimension on indices that are not aligned to packing, we will need +// be careful about the boundaries. +// +// It achieves the following: +// out[:, dst_channel_offset:dst_channel_offset + channel_range, :, :] = +// in [:, src_channel_offset:src_channel_offset + channel_range, :, :] +void add_copy_channel_offset_node( + ComputeGraph& graph, + const ValueRef in, + int32_t channel_range, + int32_t src_channel_offset, + int32_t dst_channel_offset, + const ValueRef out); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h index 92eba407d8..e7b9a614e2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h @@ -70,4 +70,50 @@ uint32_t dim_at(const vTensor& v_in) { return dim_at(v_in.sizes()); } +// A canonical way to represent dimensions as enum. Intended to use the same +// value as Dim4D for potential future refactoring. + +enum NchwDim { + DimWidth = 1, + DimHeight = 2, + DimChannel = 3, + DimBatch = 4, +}; + +/* This function return a NchwDim + * given a Tensor and a user provided dim. The reason for this normalization is + * that in the user tensor coordinate, it is using a "big-endian" mechanism when + * referring to a nchw dimension, in that dim=0 refers to the batch dimension in + * a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common + * texture representation of channel packing, a 2d tensor has exactly the same + * layout as a 4d with the batch and channel size equals to 1. This function + * returns a canonical dimension to simplify dimension reasoning in the code. + * + */ + +inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) { + return static_cast(v_in.dim() - dim); +} + +inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) { + switch (nchw_dim) { + case DimWidth: + os << "DimWidth"; + break; + case DimHeight: + os << "DimHeight"; + break; + case DimChannel: + os << "DimChannel"; + break; + case DimBatch: + os << "DimBatch"; + break; + default: + os << "DimUnknown"; + break; + } + return os; +} + } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 2a100b92e3..f0659ad823 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -428,6 +428,52 @@ def get_repeat_inputs(): return test_suite +def get_cat_inputs(): + # TensorList must be specified as list of tuples + test_suite = VkTestSuite( + [ + # Cat on Height + ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), + ([(S1, 3, 5), (S1, 4, 5)], 1), + ([(3, 5), (4, 5)], 0), + ([(3, 5), (4, 5), (1, 5)], 0), + ( + [ + (3, 5), + ], + 0, + ), + # Cat on Width + ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), + ([(S1, 5, 3), (S1, 5, 4)], 2), + ([(5, 3), (5, 4)], 1), + ([(5, 3), (5, 4), (5, 1)], 1), + ( + [ + (5, 4), + ], + 1, + ), + ([(5,), (6,)], 0), + # Cat on Batch + ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), + ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), + ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), + # Cat on Channel + ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), + ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), + ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), + ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), + ] + ) + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + test_suite.dtypes = ["at::kFloat"] + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -447,4 +493,5 @@ def get_repeat_inputs(): "aten.unsqueeze_copy.default": get_unsqueeze_inputs(), "aten.clone.default": get_clone_inputs(), "aten.repeat.default": get_repeat_inputs(), + "aten.cat.default": get_cat_inputs(), } diff --git a/backends/vulkan/test/op_tests/generate_op_tests.py b/backends/vulkan/test/op_tests/generate_op_tests.py index ef4dc0af91..71047ac6f4 100644 --- a/backends/vulkan/test/op_tests/generate_op_tests.py +++ b/backends/vulkan/test/op_tests/generate_op_tests.py @@ -16,6 +16,7 @@ TestSuite, TestSuiteGen, ) +from torchgen import local from torchgen.gen import parse_native_yaml, ParsedYaml from torchgen.model import DispatchKey, NativeFunction @@ -45,6 +46,9 @@ def process_test_suites( cpp_generator.add_suite(registry_name, f, op_test_suite) +@local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False +) def generate_cpp( native_functions_yaml_path: str, tags_path: str, output_dir: str ) -> None: diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index f0e5547b4f..ac5e25fa59 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -12,6 +12,7 @@ AT_INT_ARRAY_REF, AT_SCALAR, AT_TENSOR, + AT_TENSOR_LIST, BOOL, CppTestFileGen, DOUBLE, @@ -28,6 +29,7 @@ THREE_TENSOR_TUPLE, TWO_TENSOR_TUPLE, ) + from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup @@ -75,6 +77,8 @@ class ValueRef: ValueRefList = Union[ValueRef, List[ValueRef]] +InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST]) + class ComputeGraphGen: def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): @@ -114,7 +118,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, - is_in=(cpp_type == AT_TENSOR), + is_in=(cpp_type in InableCppType), requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) @@ -244,6 +248,25 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 ret_str += f"{self.graph}{self.dot}add_scalar" ret_str += f"({ref.src_cpp_name}.value());\n" return ret_str + elif ref.src_cpp_type == AT_TENSOR_LIST: + assert ref.is_in, "AT_TENSOR_LIST must be an input" + # This logic is a bit convoluted. We need to create a IOValueRef for + # each tensor, to facilate staging. On the other hand, we will + # use the .value tensor to create a ValueList, which will be passed + # to the corresponding ops. + ret_str = f"std::vector {ref.name}_io_value_refs;\n" + ret_str += f"std::vector {ref.name}_value_refs;\n" + ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n" + ret_str += f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" + ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n" + ret_str += ( + f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n" + ) + ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n" + ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n" + ret_str += "}\n" + ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" + return ret_str ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" if ref.src_cpp_type == AT_TENSOR and not prepack: @@ -288,11 +311,16 @@ def create_op_call(self) -> str: for aten_arg in self.args: ref = self.refs[aten_arg.name] - op_create_code += ( - f"{ref.name}.value, " - if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out - else f"{ref.name}, " - ) + if ref.src_cpp_type == AT_TENSOR_LIST: + # Special case. Underlying tensors are input tensors, but the + # container itself is just a normal value. + op_create_code += f"{ref.name}, " + else: + op_create_code += ( + f"{ref.name}.value, " + if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out + else f"{ref.name}, " + ) op_create_code += "out_ref});\n" return op_create_code @@ -311,22 +339,46 @@ def set_output(self, ref: ValueRefList) -> str: def virtual_resize(self, ref: ValueRefList) -> str: assert isinstance(ref, ValueRef) - assert ref.src_cpp_type == AT_TENSOR and ref.is_in + assert ref.src_cpp_type in InableCppType and ref.is_in if self.prepack_ref(ref): return "" - ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" - ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + + if ref.src_cpp_type == AT_TENSOR: + ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" + ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + elif ref.src_cpp_type == AT_TENSOR_LIST: + ret_str = "" + ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)" + ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n" + ret_str += "}\n" + else: + raise AssertionError(f"{ref.src_cpp_type} not expected") + return ret_str def copy_into_staging(self, ref: ValueRefList) -> str: assert isinstance(ref, ValueRef) - assert ref.src_cpp_type == AT_TENSOR and ref.is_in + assert ref.src_cpp_type in InableCppType and ref.is_in + if self.prepack_ref(ref): return "" - ret_str = f"{self.graph}{self.dot}copy_into_staging(" - ret_str += f"{ref.name}.staging, " - ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " - ret_str += f"{ref.src_cpp_name}.numel());\n" + + if ref.src_cpp_type == AT_TENSOR: + ret_str = f"{self.graph}{self.dot}copy_into_staging(" + ret_str += f"{ref.name}.staging, " + ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}.numel());\n" + elif ref.src_cpp_type == AT_TENSOR_LIST: + ret_str = "" + ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += f" {self.graph}{self.dot}copy_into_staging(" + ret_str += f"{ref.name}_io_value_refs[i].staging, " + ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}[i].numel());\n" + ret_str += "}\n" + else: + raise AssertionError(f"{ref.src_cpp_type} not expected") return ret_str def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: @@ -547,8 +599,10 @@ def gen_parameterization(self) -> str: if (!is_close && t1.numel() < 500) { std::cout << "reference: " << std::endl; print(t1, 150); + std::cout << std::endl; std::cout << "vulkan: " << std::endl; print(t2, 150); + std::cout << std::endl; } return is_close; } diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index d5feada1df..e9fbe0b1b2 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List +import re +from typing import Any, List, Tuple from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup @@ -17,6 +18,7 @@ AT_INT_ARRAY_REF = "at::IntArrayRef" AT_SCALAR = "at::Scalar" AT_TENSOR = "at::Tensor" +AT_TENSOR_LIST = "at::TensorList" BOOL = "bool" DOUBLE = "double" INT = "int64_t" @@ -57,8 +59,8 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ - {create_ref_data} - {create_and_check_out} +{create_ref_data} +{create_and_check_out} }} """ @@ -97,6 +99,9 @@ def __init__(self, f: NativeFunction, test_suite: TestSuite): self.f, method=False, fallback_binding=self.f.manual_cpp_binding ).most_faithful_signature() + def gen_case_name_tuple(self, t: Tuple) -> str: + return "x".join([str(e) for e in t]) + def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: name_str = self.op_name if prepack: @@ -104,13 +109,15 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: for arg_sizes_or_val in inputs: name_str += "_" if isinstance(arg_sizes_or_val, tuple): - for size in arg_sizes_or_val: - name_str += str(size) + "x" - name_str = name_str[:-1] + name_str += self.gen_case_name_tuple(arg_sizes_or_val) elif isinstance(arg_sizes_or_val, list): + lst = [] for size in arg_sizes_or_val: - name_str += str(size) + "c" - name_str = name_str[:-1] + if isinstance(size, tuple): + lst.append(self.gen_case_name_tuple(size)) + else: + lst.append(str(size)) + name_str += "c".join(lst) else: name_str += str(arg_sizes_or_val).replace(".", "p") @@ -122,6 +129,15 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ctype = cpp.argumenttype_type(arg.type, mutable=arg.is_write, binds=arg.name) cpp_type = ctype.cpp_type(strip_ref=True) + # Short cut exit for TENSORLIST, because it needs multiple lines of + # construction, deviates from the rest. + if cpp_type == AT_TENSOR_LIST: + ret_str = f"std::vector<{AT_TENSOR}> tensor_vec;\n" + for elem in data: + ret_str += f"tensor_vec.emplace_back({self.suite_def.data_gen}({init_list_str(elem)}, test_dtype));\n" + ret_str += f"{cpp_type} {arg.name} = tensor_vec;\n" + return ret_str + "\n" + if cpp_type == AT_INT_ARRAY_REF: ret_str = f"std::vector {arg.name} = " else: @@ -169,6 +185,7 @@ def gen_create_ref_data(self, inputs: List[Any]) -> str: arg_data = get_or_return_default(arg, inputs, i) ref_code += self.create_input_data(arg, arg_data) + ref_code = re.sub(r"^", " ", ref_code, flags=re.M) return ref_code def gen_create_and_check_out(self, prepack=False) -> str: @@ -179,6 +196,7 @@ def gen_create_and_check_out(self, prepack=False) -> str: arg = binding.argument test_str += f"{arg.name}, " test_str = test_str[:-2] + ");" + test_str = re.sub(r"^", " ", test_str, flags=re.M) return test_str def gen_parameterization(self) -> str: diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index db966b6a7c..37ced363b6 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -64,7 +64,6 @@ void record_conv2d_prepack_weights_op( api::VulkanBuffer& src_buffer, vTensor& v_dst, const std::vector& original_sizes, - const std::vector& padded_sizes, const bool transposed) { api::PipelineBarrier pipeline_barrier{}; @@ -80,8 +79,6 @@ void record_conv2d_prepack_weights_op( api::UniformParamsBuffer original_sizes_ubo( context, api::utils::make_ivec4(original_sizes, /*reverse = */ true)); - api::UniformParamsBuffer padded_sizes_ubo( - context, api::utils::make_ivec2(padded_sizes, /*reverse = */ true)); api::SpecVarList specialization_constants = {}; context->submit_compute_job( @@ -97,8 +94,7 @@ void record_conv2d_prepack_weights_op( api::MemoryAccessType::WRITE), src_buffer, v_dst.sizes_ubo(), - original_sizes_ubo.buffer(), - padded_sizes_ubo.buffer()); + original_sizes_ubo.buffer()); } void record_binary_op( diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index a1f3b93dc3..8f23b6c407 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -12,6 +12,7 @@ #include +#include #include #include @@ -86,7 +87,6 @@ void record_conv2d_prepack_weights_op( api::VulkanBuffer& src_buffer, vTensor& v_dst, const std::vector& original_sizes, - const std::vector& padded_sizes, const bool transposed); void record_binary_op( @@ -153,6 +153,26 @@ check_staging_buffer(api::StorageBuffer& staging, float val, int numel = -1) { } } +inline int64_t get_buf_idx( + ComputeGraph& graph, + IOValueRef ref, + const std::vector& tensor_coor) { + vTensorPtr vten_ptr = graph.get_tensor(ref.value); + + const std::vector& sizes = vten_ptr->sizes(); + + int64_t c = dim_at(sizes); + int64_t h = dim_at(sizes); + int64_t w = dim_at(sizes); + + int64_t ni = dim_at(tensor_coor); + int64_t ci = dim_at(tensor_coor); + int64_t hi = dim_at(tensor_coor); + int64_t wi = dim_at(tensor_coor); + + return (ni * c * h * w + ci * h * w + hi * w + wi); +} + // // Context Management // diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 4955d0537e..f9a149a9df 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -876,6 +876,274 @@ TEST(VulkanComputeGraphTest, test_large_graph) { } } +TEST(VulkanComputeGraphTest, test_etvk_copy_offset_node) { + GraphConfig config; + ComputeGraph graph(config); + + int64_t n = 6; + int64_t c = 12; + int64_t h = 4; + int64_t w = 8; + api::GPUMemoryLayout memory_layout = + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; + + std::vector size = {n, c, h, w}; + + IOValueRef a = graph.add_input_tensor(size, api::kFloat, memory_layout); + + IOValueRef out = {}; + out.value = graph.add_tensor(size, api::kFloat, memory_layout); + + // Notice that copy_node operates on in texture's x, y, z dimension. In the + // comment, we provide the cooresponding coordinate in nchw. + + // src_offset is (n=0, c=4, h=1, w=1) + ValueRef src_offset_ref = graph.add_scalar_list({1, 1, 1}); + + // dst_offset is (n=1, c=8, h=2, w=0) in nchw coordinate + // Argument is {x, y, z}. + // x = 0 since w = 0 + // y = 2 since h = 2 + // z = c / 4 + 2 since + // 1. there c/4 planes per batch, n=1 means we are on the first batch; + // 2. +2 because c = 8, with channel packing it means two texels. + ValueRef dst_offset_ref = graph.add_scalar_list({0, 2, c / 4 + 2}); + + // range is (n=1, c=8, h=2, w=4) + // Argument is {x, y, z}. + // x = 4 since w = 4 + // y = 2 since h = 2 + // z = 2 since we are only copying 8 channels, hence 2 texel. n = 1 can be a + // bit misleading here, since it gives the impression that we are copying the + // entire channel. However, remember when we copy, we are trying to + // dst[dst_offset:dst_offset + range] = src[src_offset:src_offset + range], + // range must be non zero. + ValueRef range_ref = graph.add_scalar_list({4, 2, 2}); + + auto copyFn = VK_GET_OP_FN("etvk.copy_offset"); + copyFn( + graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + fill_vtensor(graph, a, 0.0f, /*iota = */ true); + + graph.execute(); + + EXTRACT_TENSOR(out); + EXTRACT_TENSOR(a); + + // We will examine the results in the dst_range + // The value in the cooresponding coordinate should match between the source + // and destination tensor. We loop thru the range, calculate both the src and + // dst index using the offsets, and compare the values in the extracted + // vector. They should match. + int n_idx = 0; + // at each nested loop, index range from dst_offset to dst_offset + range + + for (int c_idx = 0; c_idx < 8; c_idx++) { + for (int h_idx = 0; h_idx < 2; h_idx++) { + for (int w_idx = 0; w_idx < 4; w_idx++) { + auto dst_idx = + get_buf_idx(graph, out, {n_idx + 1, c_idx + 8, h_idx + 2, w_idx}); + auto src_idx = + get_buf_idx(graph, a, {n_idx, c_idx + 4, h_idx + 1, w_idx + 1}); + + EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); + } + } + } +} + +TEST(VulkanComputeGraphTest, test_etvk_copy_channel_offset_node) { + GraphConfig config; + ComputeGraph graph(config); + + int64_t n = 2; + int64_t c = 12; + int64_t h = 4; + int64_t w = 8; + api::GPUMemoryLayout memory_layout = + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; + + std::vector size = {n, c, h, w}; + + IOValueRef a = graph.add_input_tensor(size, api::kFloat, memory_layout); + + IOValueRef out = {}; + out.value = graph.add_tensor(size, api::kFloat, memory_layout); + + int64_t src_offset = 2; + int64_t dst_offset = 3; + int64_t range = 7; + + ValueRef src_offset_ref = graph.add_scalar(src_offset); + ValueRef dst_offset_ref = graph.add_scalar(dst_offset); + ValueRef range_ref = graph.add_scalar(range); + + auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); + copyFn( + graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + fill_vtensor(graph, a, 0.0f, true); + + graph.execute(); + + EXTRACT_TENSOR(out); + EXTRACT_TENSOR(a); + + for (int n_idx = 0; n_idx < n; n_idx++) { + for (int c_idx = 0; c_idx < range; c_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int w_idx = 0; w_idx < w; w_idx++) { + auto src_idx = + get_buf_idx(graph, a, {n_idx, c_idx + src_offset, h_idx, w_idx}); + auto dst_idx = get_buf_idx( + graph, out, {n_idx, c_idx + dst_offset, h_idx, w_idx}); + EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); + } + } + } + } +} + +TEST( + VulkanComputeGraphTest, + test_etvk_copy_channel_offset_node_clean_boundary) { + // Tricky part for channel copy is handling the boundary across multiple copy. + // For example, when we concat two [3, 1, 1] nchw-tensors along the channel + // dimension, due to channel packing, elements from different source texel + // will be packed into same destination texel at the boundaries. + GraphConfig config; + ComputeGraph graph(config); + + int64_t n = 2; + int64_t c = 12; + int64_t h = 4; + int64_t w = 8; + api::GPUMemoryLayout memory_layout = + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; + + std::vector size = {n, c, h, w}; + + IOValueRef zero = graph.add_input_tensor(size, api::kFloat, memory_layout); + IOValueRef a = graph.add_input_tensor(size, api::kFloat, memory_layout); + IOValueRef b = graph.add_input_tensor(size, api::kFloat, memory_layout); + + IOValueRef out = {}; + out.value = graph.add_tensor(size, api::kFloat, memory_layout); + + auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); + + // Make sure entire out tensor is zeroed. The zero tensor will be filled with + // zero later. + copyFn( + graph, + {zero.value, + graph.add_scalar(c), + graph.add_scalar(0), + graph.add_scalar(0), + out.value}); + + int64_t a_src_offset = 0; + int64_t a_dst_offset = 2; + int64_t a_range = 5; + // a will write to channge [2, 7) + copyFn( + graph, + {a.value, + graph.add_scalar(a_range), + graph.add_scalar(a_src_offset), + graph.add_scalar(a_dst_offset), + out.value}); + + // b will write to channel [6, 11) + // Intentional for b to override channel=6 + int64_t b_src_offset = 0; + int64_t b_dst_offset = 6; + int64_t b_range = 5; + + copyFn( + graph, + {b.value, + graph.add_scalar(b_range), + graph.add_scalar(b_src_offset), + graph.add_scalar(b_dst_offset), + out.value}); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + float a_value = 1.0f; + float b_value = 2.0f; + float zero_value = 0.0f; + fill_vtensor(graph, a, a_value); + fill_vtensor(graph, b, b_value); + fill_vtensor(graph, zero, zero_value); + + graph.execute(); + + EXTRACT_TENSOR(out); + + for (int n_idx = 0; n_idx < n; n_idx++) { + // c_idx only up to a_range-1 because the expected overwrite by b + for (int c_idx = a_dst_offset; c_idx < a_dst_offset + a_range - 1; + c_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int w_idx = 0; w_idx < w; w_idx++) { + auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); + EXPECT_TRUE(data_out[dst_idx] == a_value); + } + } + } + } + + for (int n_idx = 0; n_idx < n; n_idx++) { + for (int c_idx = b_dst_offset; c_idx < b_dst_offset + b_range; c_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int w_idx = 0; w_idx < w; w_idx++) { + auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); + EXPECT_TRUE(data_out[dst_idx] == b_value); + } + } + } + } + + // Also verify that data before a_dst_offset and after b_dst_offset + b_range + // are untouched. + for (int n_idx = 0; n_idx < n; n_idx++) { + for (int c_idx = 0; c_idx < a_dst_offset; c_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int w_idx = 0; w_idx < w; w_idx++) { + auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); + EXPECT_TRUE(data_out[dst_idx] == zero_value); + } + } + } + } + + for (int n_idx = 0; n_idx < n; n_idx++) { + for (int c_idx = b_dst_offset + b_range; c_idx < c; c_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int w_idx = 0; w_idx < w; w_idx++) { + auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); + EXPECT_TRUE(data_out[dst_idx] == zero_value); + } + } + } + } +} + class VulkanToFromGPUShaderTest : public ::testing::Test { public: void SetUp() override { @@ -1365,7 +1633,6 @@ void test_conv2d( staging_buffer_in.buffer(), vten, original_sizes, - padded_sizes, transposed); record_image_to_nchw_op(api::context(), vten, staging_buffer_out.buffer()); diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index 8b0d0c6234..8d75729f72 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -95,8 +95,7 @@ def test_qs8_add_constant(self): .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_compare_outputs() ) def test_qs8_add(self): diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 8812d5e501..f4891a0018 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -189,7 +189,7 @@ def run( ) -> None: self.exported_program = export( artifact, inputs, dynamic_shapes=self.dynamic_shapes - ) + ).run_decompositions() @property def artifact(self) -> ExportedProgram: diff --git a/docs/source/debug-backend-delegate.md b/docs/source/debug-backend-delegate.md index ebcf94136c..17e4afe82a 100644 --- a/docs/source/debug-backend-delegate.md +++ b/docs/source/debug-backend-delegate.md @@ -39,12 +39,12 @@ Number of non-delegated nodes: 430 From the table, the operator `aten_view_copy_default` appears 170 times in delegate graphs and 48 times in non-delegated graphs. Users can use information like this to debug. ## Visualize delegated graph -To see a more detailed view, use the `print_delegated_graph()` method to display a printout of the whole graph: +To see a more detailed view, use the `format_delegated_graph()` method to get a str of printout of the whole graph or use `print_delegated_graph()` to print directly: ```python -from executorch.exir.backend.utils import print_delegated_graph +from executorch.exir.backend.utils import format_delegated_graph graph_module = edge_manager.exported_program().graph_module -print(print_delegated_graph(graph_module)) +print(format_delegated_graph(graph_module)) # or call print_delegated_graph(graph_module) ``` It will print the whole model as well as the subgraph consumed by the backend. The generic debug function provided by fx like `print_tabular()` or `print_readable()` will only show `call_delegate` but hide the the subgraph consumes by the backend, while this function exposes the contents inside the subgraph. diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index ae743e8e6d..bdbdbddab4 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -721,12 +721,12 @@ Number of non-delegated nodes: 430 | 26 | Total | 473 | 430 | From the table, the operator `aten_view_copy_default` appears 170 times in delegate graphs and 48 times in non-delegated graphs. -To see a more detailed view, use the `print_delegated_graph()` method to display a printout of the whole graph. +To see a more detailed view, use the `format_delegated_graph()` method to get a formatted str of printout of the whole graph or use `print_delegated_graph()` to print directly: ```python -from executorch.exir.backend.utils import print_delegated_graph +from executorch.exir.backend.utils import format_delegated_graph graph_module = edge_manager.exported_program().graph_module -print(print_delegated_graph(graph_module)) +print(format_delegated_graph(graph_module)) ``` This may generate a large amount of output for large models. Consider using "Control+F" or "Command+F" to locate the operator you’re interested in (e.g. “aten_view_copy_default”). Observe which instances are not under lowered graphs. diff --git a/docs/source/tutorials_source/sdk-integration-tutorial.py b/docs/source/tutorials_source/sdk-integration-tutorial.py index 27474c2251..cd45f806fb 100644 --- a/docs/source/tutorials_source/sdk-integration-tutorial.py +++ b/docs/source/tutorials_source/sdk-integration-tutorial.py @@ -172,10 +172,24 @@ def forward(self, x): # Use CMake (follow `these instructions <../runtime-build-and-cross-compilation.html#configure-the-cmake-build>`__ to set up cmake) to execute the Bundled Program to generate the ``ETDump``:: # # cd executorch -# rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake -DEXECUTORCH_BUILD_SDK=1 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=1 .. -# cd .. -# cmake --build cmake-out -j8 -t sdk_example_runner -# ./cmake-out/examples/sdk/sdk_example_runner --bundled_program_path +# rm -rf cmake-out +# cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ +# -DCMAKE_BUILD_TYPE=Release \ +# -DEXECUTORCH_BUILD_SDK=ON \ +# -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ +# -Bcmake-out . +# cmake --build cmake-out -j9 --target install --config Release +# +# local example_dir=examples/sdk +# local build_dir=cmake-out/${example_dir} +# CMAKE_PREFIX_PATH="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags" +# rm -rf ${build_dir} +# cmake -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \ +# -DCMAKE_BUILD_TYPE=Release \ +# -B${build_dir} \ +# ${example_dir} +# cmake --build ${build_dir} -j9 --config Release +# ${build_dir}/sdk_example_runner --bundled_program_path="bundled_program.bp" ###################################################################### # Creating an Inspector diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index f3c3951b4c..19f386eb31 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -37,7 +37,7 @@ Note that groupsize less than 128 was not enabled, since such model were still t We have verified running Llama 2 7B [mobile applications](#step-6-build-mobile-apps) efficiently on select devices including the iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S22 and S24, and OnePlus 12. -For Llama 3 8B, we have verified so far on iPhone 15 Pro Max and OnePlus 12 (with 16GB RAM). +For Llama 3 8B, we have verified so far on iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM). ## Performance diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index b05dc19bfc..4e2ec1922b 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -21,7 +21,7 @@ from executorch.exir import EdgeProgramManager from executorch.exir.backend.partitioner import Partitioner -from executorch.exir.backend.utils import print_delegated_graph +from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass @@ -283,7 +283,7 @@ def export_to_edge( dynamic_shapes=dynamic_shape, edge_constant_methods=metadata, edge_compile_config=edge_config, - verbose=True, + verbose=self.verbose, ) return self @@ -308,7 +308,7 @@ def to_backend( self.edge_manager = self.edge_manager.to_backend(partitioner) if self.verbose: logging.info( - print_delegated_graph( + format_delegated_graph( self.edge_manager.exported_program().graph_module ) ) diff --git a/exir/backend/test/test_utils.py b/exir/backend/test/test_utils.py index c27db23626..2f24b734b7 100644 --- a/exir/backend/test/test_utils.py +++ b/exir/backend/test/test_utils.py @@ -16,11 +16,11 @@ from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo from executorch.exir.backend.utils import ( DelegationBreakdown, + format_delegated_graph, get_delegates, get_delegation_info, get_non_lowered_nodes, is_identical_graph, - print_delegated_graph, ) from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops @@ -266,7 +266,7 @@ def forward(self, a, x, b): edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo()) - graph_str = print_delegated_graph(edge.exported_program().graph_module) + graph_str = format_delegated_graph(edge.exported_program().graph_module) self.assertIn( "BackendWithCompilerDemo", graph_str, diff --git a/exir/backend/utils.py b/exir/backend/utils.py index b299ba4be8..6b85e4b603 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -448,9 +448,16 @@ def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: ) -def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str: +def print_delegated_graph(graph_module: torch.fx.GraphModule) -> None: """ - Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output: + Print the formatted graph string. + """ + print(format_delegated_graph(graph_module)) + + +def format_delegated_graph(graph_module: torch.fx.GraphModule) -> str: + """ + Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output: graph(): %arg0_1 : [num_users=2] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] diff --git a/kernels/README.md b/kernels/README.md new file mode 100644 index 0000000000..bf1d1e0f8e --- /dev/null +++ b/kernels/README.md @@ -0,0 +1,473 @@ +This subtree contains operator implementations that ExecuTorch clients can +use and contribute to. + +## Layout + +- `kernels`: Contains implementations and tests for the operators defined + in the YAML files. + - `kernels/portable/cpu`: Pure C++ implementations of the operators defined in the + YAML files. + - `kernels/optimized/cpu`: Optimized C++ implementations of the operators defined in the + YAML files, for specific hardware platforms. + - `kernels/aten`: A thin wrapper layer to hookup ATen library into ExecuTorch. + - `kernels/test`: Tests for all operator implementations. Since all + implementations should behave identically, the same tests should pass for + all target types. + +## Help & Improvements + +If you have problems or questions, or have suggestions for ways to make +implementation and testing better, please contact [Dave +Bort](https://fb.workplace.com/profile.php?id=100042415022179), [Mengwei +Liu](https://fb.workplace.com/profile.php?id=100024007250862), or [Martin + Yuan](https://fb.workplace.com/profile.php?id=100020734910364) on the PyTorch +Edge team. + +## Contributing + +Please follow these steps and guidelines when adding a new operator +implementation to this library. The goals of these guidelines are to: +- Make it straightforward to add new operator implementations. +- Ensure that the operator implementations are of high quality, and are easy to + maintain. +- Make it easy for users to find available operator implementations, and to + trust in their quality and behavioral stability. + +### Your code must be compatible with ExecuTorch types + +ExecuTorch does not use `at::Tensor`, `at::ScalarType`, `c10::Scalar`, or any of +the types defined by PyTorch core in the `at` or `c10` namespaces. To retain +tigher control over CPU and memory runtime behavior, ExecuTorch reimplements +compatible but restricted subsets of those types. + +[`//runtime/core/exec_aten/exec_aten.h`](https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/exec_aten.h) +contains the mapping between ATen/c10 types and the ExecuTorch types. The +ExecuTorch types are defined in other headers in that same directory, +[`//runtime/core/portable_type/`](https://github.com/pytorch/executorch/tree/main/runtime/core/portable_type). + +The ExecuTorch types are source-compatible with the ATen/c10 types; if you write +code that works with the ExecuTorch types, then that same code should work when +built against ATen/c10. But, there are features of `at::Tensor` and other +ATen/c10 types that may not be present. In many cases this is intentional, but +in other cases we can consider adding the missing features. + +### Do your initial work in fbcode (skip this if in OSS) + +Although ExecuTorch is mapped into both `xplat` and `fbcode`, we recommend +setting up the initial targets while working from `fbcode`. Once everything's in +place, you should be able to build from either spot. + +The most important thing is to consistently work out of one root or the other. +And, if you're getting weird build failures, `hg commit` your edited files +locally to make sure that both `xplat` and `fbcode` are in sync with each other. + +### Declare the operator in a YAML file + +We use yaml files to declare the ATen operators or custom operators being implemented by this kernel library. + +Before implementing, the operator must be declared in exactly one of the +operator YAML files: +- [`//kernels/portable/functions.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/portable/functions.yaml) + - Add your entry here if your operator overload (e.g., `op: add.out`) + appears in the core pytorch file + [`pytorch/aten/src/ATen/native/native_functions.yaml`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml). + - Also add your entry to [`//kernels/aten/functions.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/aten/functions.yaml) for test coverage. +- [`//kernels/portable/custom_ops.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/portable/custom_ops.yaml) + - Add your entry here if your operator overload does *not* appear in the core pytorch `native_functions.yaml`. + +The next sections describe how to add a yaml entry. + +#### YAML Schema + +This YAML file schema is a DSL to decribe the operators and the kernels that implement them. This YAML file is a contract between AOT model export and runtime execution, that if followed correctly, can make sure ExecuTorch runtime be able to link the C++ implementation of an operator to the exported model artifact. Here are some rules of writing up your own YAML files. + +**Out variants only** + +ExecuTorch only supports out-style operators, where: +- The caller provides the output Tensor or Tensor list in the final position + with the name `out`. +- The C++ function modifies and returns the same `out` argument. + - If the return type in the YAML file is `()` (which maps to void), the C++ + function should still modify `out` but does not need to return anything. +- The `out` argument must be keyword-only, which means it needs to follow an + argument named `*` like in the `add.out` example below. +- Conventionally, these out operators are named using the pattern `.out` + or `._out`. + +Since all output values are returned via an `out` parameter, ExecuTorch ignores +the actual C++ function return value. But, to be consistent, functions should +always return `out` when the return type is non-`void`. + +**Can only return `Tensor` or `()`** + +ExecuTorch only supports operators that return a single `Tensor`, or the unit +type `()` (which maps to `void`). It does not support returning any other types, +including lists, optionals, tuples, or scalars like `bool`. + +**Supported argument types** + +ExecuTorch does not support all of the argument types that core PyTorch +supports. See [this +spreadsheet](https://docs.google.com/spreadsheets/d/1uArc0r1Yq1QSeyRJZKzZ8Wkz0eS9TsM39ghmMAZCXDA/edit#gid=0) +for the list of supported and unsupported types. + + +**Functions only, no methods** + +ExecuTorch does not support Tensor methods, and assumes `variants: function` for +all operators. Entries like `variants: method` or `variants: function, method` +will be ignored. + +#### Add your operator entry + +Some examples of operator entry: + +ATen operator with a default kernel +``` +- op: add.out + kernels: + - arg_meta: null + kernel_name: torch::executor::add_out +``` + +ATen operator with a dtype/dim order specialized kernel (works for `Double` dtype and dim order needs to be (0, 1, 2, 3)) +``` +- op: add.out + type_alias: + T0: [Double] + dim_order_alias: + D0: [[0, 1, 2, 3]] + kernels: + - arg_meta: + self: [T0, D0] + other: [T0 , D0] + out: [T0, D0] + kernel_name: torch::executor::add_out +``` + +Custom operator with a default kernel +``` +- func: allclose.out(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False, bool dummy_param=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::allclose_out +``` + +Top level attributes: +* `op` (if the operator appears in `native_functions.yaml`) or `func` for custom operator. The value for this key needs to be the full operator name (including overload name) for `op` key, or a full operator schema (namespace, operator name, operator overload name and schema string). For schema syntax please refer to this [instruction](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md). + +* `kernels`: this entry is used to define the information of kernels. It consists of `arg_meta` and `kernel_name`, they are bound together to describe "for input tensors with these metadata, use this kernel". +* `type_alias`(optional): we are giving aliases to possible dtype options. `T0: [Double, Float]` means `T0` can be one of `Double` or `Float`. +* `dim_order_alias`(optional): similar to `type_alias`, we are giving names to possible dim order options. + +Attributes under `kernels`: +* `arg_meta`: a list of "tensor arg name" entries. The value for these keys are dtypes and dim orders alias, that are implemented by the corresponding `kernel_name`. This being `null` means the kernel will be used for all types of input. +* `kernel_name`: the expected name of the +C++ function that will implement this operator. You can put whatever you want to +here, but you should follow the convention of replacing the `.` in the overload +name with an underscore, and lowercasing all characters. In this example, +`add.out` uses the C++ function named `add_out`. `add.Scalar_out` would become `add_scalar_out`, with a lowercase `S`. We support namespace for kernels, but note that we will be inserting a `native::` to the last level of namespace. So `custom::add_out` in the `kernel_name` will point to `custom::native::add_out`. + +### Find operator base name + +The base name is the part of the operator name before the `.`, excluding any +trailing underscores. The rest of this document refer to this as ``. + +E.g., these operator overloads all have a base name of `add`: +- `add.Scalar` +- `add.Tensor` +- `add.out` +- `add_.Tensor` + +So, if you were implementing `add.out` then your operator base name would be +`add`, and you would replace `` with `add` everywhere below. + +### Selective build + +When using macros that require a `NAME` argument, eg. `#define ET_SWITCH_REAL_TYPES_AND(ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...)`, make sure to pass in the same operator name defined in `functions.yaml`. This is the base name + variant, eg. `add.out`, `add.Scalar_out`. The function name is required for dtype selective build, which matches against the operator names and dtypes present in a model. + +### Overview of files and targets + +For the operator base name ``, you should work with these files and Buck +targets. Sections below give more details about what they should contain. + +- `./kernels/portable/cpu/op_.cpp`: The implementations of operator overloads + with base name ``. This is the file that clients will link into their + runtimes. +- `//executorch/kernels/portable/cpu:op_`: The build target for + `op_.cpp`, defined in `targets.bzl` in the same directory. +- `./kernels/test/op__test.cpp`: Unit tests for the operator overloads + with base name ``. + - Note that tests under this directory are for portable kernel specific. To + share tests between multiple kernels, we can put tests in ../test. + - Note that the tests do not live under `cpu`; tests should be + implementation-agnostic. This will let us run the same tests against all + implementations of a given operator, which should behave identically. +- `//executorch/kernels/portable/test:op__test`: The test target for + `op__test.cpp`, defined in `targets.bzl` in the same directory. + +For an example, see the `add` operator (note that these are slightly different +from the `add` examples in this doc): +- [`//executorch/kernels/portable/cpu/op_add.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/op_add.cpp): + Implementations. +- [`//executorch/kernels/portable/cpu/targets.bzl`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/targets.bzl): + Definition of the `:op_add` target. +- [`//executorch/kernels/portable/test/op_add_test.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/op_add_test.cpp): + Unit tests. +- [`//executorch/kernels/portable/test/targets.bzl`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/targets.bzl): + Definition of the `:op_add_test` target. + +### Define the build target for the operator implementation + +Define a build target by adding an entry to +`//executorch/kernels/portable/cpu/targets.bzl`, inside +`define_common_targets()`, in sorted order with other `_op_target` entries: +``` +_op_target(name = "op_") +``` + +If your operator overload group is ATen-compatible, its `_op_target` entry +belongs in the `_ATEN_OPS` list, otherwise it belongs in the `_CUSTOM_OPS` list. +Note that this means that a given `op_` cannot implement both +ATen-compatible and non-ATen-compatible (i.e., custom) operators. We suggest +adding the suffix `_custom` if necessary: e.g., `op_add` for ATen-compatible +overloads of the `add` operator, and `op_add_custom` for non-ATen-compatible +overloads. + +By default, this target will depend on the core ExecuTorch types, but you can +add additional deps if you want to. + +NOTE: An `op_` target may not depend on another `op_` target. If two +`op_` targets need to share code, define a separate `runtime.cxx_library` target +under `//executorch/kernels/portable/cpu/lib` that they both depend on. This +keeps the dependencies more managable, especially for selective builds where +only a subset of operators are used. + +NOTE: An `op_` target may not depend on targets outside of `//executorch`. +This library is intended to be portable, open-sourceable, and self-contained. + +### Create a skeleton .cpp file for the operator implementation + +If not already present, create the file +`//executorch/kernels/portable/cpu/op_.cpp`, which should follow the +pattern: +``` +// Copyright (c) Meta Platforms, Inc. and affiliates. +#include + +namespace torch { +namespace executor { +namespace native { + +namespace { + // +} // namespace + +// + +} // namespace native +} // namespace executor +} // namespace torch +``` + +With the target and cpp file in place, you should be able to build it: +``` +cd ${HOME}/fbsource/fbcode/executorch +buck build fbcode//executorch/kernels/portable/cpu:op_ +``` + +### Find the function signature for the operator overload + +When you add an entry to the YAML file, the codegen tools will generate an +expected function signature for you to implement in a file called +`NativeFunctions.h`. + +To build and find that generated header, run the script +`fbsource/fbcode/executorch/kernels/portable/find_op_header.sh`. It will print +output like: +``` +===== Generating header files ===== +File changed: fbcode//executorch/kernels/portable/functions.yaml +Buck UI: https://www.internalfb.com/buck2/e5a6f22a-5b6e-4931-9a7f-df18bdf97ab6 +RE Session: reSessionID-4b735cfa-e66f-43d8-a73b-94f22d5936c5 +Jobs completed: 3. Time elapsed: 0.2s. Cache hits: 100%. Commands: 1 (cached: 1, remote: 0, local: 0) +BUILD SUCCEEDED + +Header file: /data/users/USER/fbsource/buck-out/v2/gen/fbcode/d839c731f5505c62/executorch/codegen/__generated_lib_generate__/out/NativeFunctions.h +``` +The path will be different in your environment, so be sure to use the output +from the script instead of copy-pasting this path. And, since this header is +generated from the YAML files, re-run the script if you have modified your +operator's entry in those files. + +Open that file and look for the function with the same name that you earlier +added under `CPU: dispatch:` in the YAML file. For `add_out`, this might look +like +``` +TORCH_API torch::executor::Tensor & add_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +``` + +This is the function signature that you will need to implement. + +### Add a stub implementation + +Now that you have your function signature, add a stub to the `op_.cpp` +file that just returns the `out` argument. For example: +``` +Tensor& add_out( + const Tensor& self, + const Tensor& other, + Tensor& out) { + return out; +} +``` + +Note that you should drop the `TORCH_API` attribute, and should drop `at::`. + +Try building again with +``` +cd ${HOME}/fbsource/fbcode/executorch +buck build fbcode//executorch/kernels/portable/cpu:op_ +``` + +### Create a test build target + +Define a test build target by adding an entry to +`//executorch/kernels/portable/test/targets.bzl`, inside +`define_common_targets()`, in sorted order with other `_op_test` entries: +``` +_op_target(name = "op__test") +``` + +By default, this target will depend on +`//executorch/kernels/portable/cpu:op_`, the core Executor types, and +some helper test utilities ([see +headers](https://www.internalfb.com/code/fbsource/fbcode/executorch/runtime/core/exec_aten/testing_util/)), +but you can add additional deps if you want to. + +### Create a skeleton test .cpp file + +If not already present, create the file +`//executorch/kernels/portable/test/op__test.cpp`. Here's a suggested +starting point: +``` +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include // Declares the operator +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::native::; +using torch::executor::testing::IsCloseTo; +using torch::executor::testing::TensorFactory; + +TEST(OpTest, SmokeTest) { + TensorFactory tf; + + Tensor a = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 1, 1, 1}): + Tensor b = tf.ones(/*sizes=*/{2, 2}): + Tensor z = tf.zeros(/*sizes=*/{2, 2}): + + EXPECT_EQ(a, b); // Exact equality + EXPECT_THAT(a, IsCloseTo(b)); // For floating-point tensors + + EXPECT_NE(a, z); + EXPECT_THAT(a, Not(IsCloseTo(z))); +} +``` + +Try running the test: +``` +cd ${HOME}/fbsource/fbcode/executorch +buck test fbcode//executorch/kernels/test:op__test +``` + +### Implement and test the operator + +You should now be able to implement and test your operator. It's helpful to see +how other operators do it, so take a look at `op_add`: +- [`//executorch/kernels/portable/cpu/op_add.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/op_add.cpp) +- [`//executorch/kernels/portable/test/op_add_test.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/op_add_test.cpp) + +Check out how it uses helper macros like `ET_CHECK_SAME_SHAPE_AND_DTYPE` and +`ET_FORALL_REAL_TYPES` when implementing the operator, and test helpers like +`TensorFactory` and `IsCloseTo()` when testing. + +#### Implementation restrictions + +To reduce dependencies and size, to ensure portability, and to conform to the +restrictions of embedded environments, your operator implementations: + +- Must not include C++ stdlib headers, or use C++ stdlib types. For example, + `string`/`basic_string`, `vector`, `unordered_map`, `cout`, `unique_pointer` + must not be used. +- Must not dynamically allocate memory, or cause memory to be dynamically + allocated. All non-stack memory must be provided as a function parameter by + the caller, typically via an `out` parameter or another tensor parameter to be + used as scratch space. + - This includes direct calls to `new`, `malloc`, `realloc`, etc., as well as + operations that allocate under the hood like `make_unique`, or the creation + of `vector` or `string`, for example. +- Must be stateless. +- Must be thread-safe. Note that the ExecuTorch environment does not provide + a locking construct, so this means that operator implementations must not + modify global memory. +- Must work in an environment without threads. This, along with the stateless + requirement, means that thread local storage must not be used. +- Must not use `stdout`, `stderr`, or other file/stream IO via `printf`/`cout` + etc.; instead, use `ET_LOG` from `executorch/runtime/platform/log.h`. +- Must not use `assert()`. Instead use `ET_CHECK` and other macros from + `executorch/runtime/platform/assert.h`. +- Must not raise exceptions. Instead use `ET_CHECK` and other macros from + `executorch/runtime/platform/assert.h`. + +Note that not all of these apply to *every* ExecuTorch-compatible operator +implementation, only those included in this portable library. + +For example, a target-specfic custom operator that initiates a DMA copy would be +stateful, and would probaby modify global memory, but it would need to use +target-specific APIs to do so. But, since this library is only for portable +operator implementations, the operators it contains can't depend on +target-specific APIs like that. + +### Shared kernel tests (//executorch/kernels/test) +The portable kernel impelemntation and its corresponding tests can be used as a +reference for other kernels. We can also share the test cases in +`//executorch/kernels/test`, which contains common resources for kernel testing. + +*util.bzl* contains common BUCK targets for other test libs to include: +- op_test(): Defines a cxx_test() for an "op_*_test.cpp" file +- define_supported_features_lib(): Defines the corresponding supported features library + +*targets.bzl* has targets shared by other kernels tests: +- supported_features_header: header file for SupportedFeatures +- supported_features_header_aten: ATen implementation of SupportedFeatures +- _codegen_function_header_wrapper: a wrapper to include the right Functions.h header +- _common_op_test: generates _ + +*_codegen_function_header_wrapper* generates a header FunctionHeaderWrapper.h, which simply +includes the corresponding Functions.h file for the specified kernel: +`#include `. With that, the test sources don't need to know +about which kernel we are testing and which Functions.h we should use. + +With *_common_op_test* we use a single test source file (op__test.cpp) at this directory. +We automatically find the corresponding registered dispatch function through Funcitons.h, so +it can be used to test multiple kernels. + +In /test/ we can put kernel-specific test cases. + +*SupportedFeatures* is used to distinguish between different kernel features. For example, +ATen supports mixing input and output dtype while portable doesn't. When we expect death in +portable testing in such case, we can check the supported features by the running kernel and +bypass if it's supported. +- The default value of supported features is in test/supported_features.yaml +- Each kernel needs to override its supported features in /test/supported_features_def.yaml. + See example in supported_features_def_example.yaml. +- This ensures that all kernels can share the same c++ test case source diff --git a/kernels/portable/README.md b/kernels/portable/README.md index bf1d1e0f8e..955738ecab 100644 --- a/kernels/portable/README.md +++ b/kernels/portable/README.md @@ -1,473 +1 @@ -This subtree contains operator implementations that ExecuTorch clients can -use and contribute to. - -## Layout - -- `kernels`: Contains implementations and tests for the operators defined - in the YAML files. - - `kernels/portable/cpu`: Pure C++ implementations of the operators defined in the - YAML files. - - `kernels/optimized/cpu`: Optimized C++ implementations of the operators defined in the - YAML files, for specific hardware platforms. - - `kernels/aten`: A thin wrapper layer to hookup ATen library into ExecuTorch. - - `kernels/test`: Tests for all operator implementations. Since all - implementations should behave identically, the same tests should pass for - all target types. - -## Help & Improvements - -If you have problems or questions, or have suggestions for ways to make -implementation and testing better, please contact [Dave -Bort](https://fb.workplace.com/profile.php?id=100042415022179), [Mengwei -Liu](https://fb.workplace.com/profile.php?id=100024007250862), or [Martin - Yuan](https://fb.workplace.com/profile.php?id=100020734910364) on the PyTorch -Edge team. - -## Contributing - -Please follow these steps and guidelines when adding a new operator -implementation to this library. The goals of these guidelines are to: -- Make it straightforward to add new operator implementations. -- Ensure that the operator implementations are of high quality, and are easy to - maintain. -- Make it easy for users to find available operator implementations, and to - trust in their quality and behavioral stability. - -### Your code must be compatible with ExecuTorch types - -ExecuTorch does not use `at::Tensor`, `at::ScalarType`, `c10::Scalar`, or any of -the types defined by PyTorch core in the `at` or `c10` namespaces. To retain -tigher control over CPU and memory runtime behavior, ExecuTorch reimplements -compatible but restricted subsets of those types. - -[`//runtime/core/exec_aten/exec_aten.h`](https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/exec_aten.h) -contains the mapping between ATen/c10 types and the ExecuTorch types. The -ExecuTorch types are defined in other headers in that same directory, -[`//runtime/core/portable_type/`](https://github.com/pytorch/executorch/tree/main/runtime/core/portable_type). - -The ExecuTorch types are source-compatible with the ATen/c10 types; if you write -code that works with the ExecuTorch types, then that same code should work when -built against ATen/c10. But, there are features of `at::Tensor` and other -ATen/c10 types that may not be present. In many cases this is intentional, but -in other cases we can consider adding the missing features. - -### Do your initial work in fbcode (skip this if in OSS) - -Although ExecuTorch is mapped into both `xplat` and `fbcode`, we recommend -setting up the initial targets while working from `fbcode`. Once everything's in -place, you should be able to build from either spot. - -The most important thing is to consistently work out of one root or the other. -And, if you're getting weird build failures, `hg commit` your edited files -locally to make sure that both `xplat` and `fbcode` are in sync with each other. - -### Declare the operator in a YAML file - -We use yaml files to declare the ATen operators or custom operators being implemented by this kernel library. - -Before implementing, the operator must be declared in exactly one of the -operator YAML files: -- [`//kernels/portable/functions.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/portable/functions.yaml) - - Add your entry here if your operator overload (e.g., `op: add.out`) - appears in the core pytorch file - [`pytorch/aten/src/ATen/native/native_functions.yaml`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml). - - Also add your entry to [`//kernels/aten/functions.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/aten/functions.yaml) for test coverage. -- [`//kernels/portable/custom_ops.yaml`](https://github.com/pytorch/executorch/blob/main/kernels/portable/custom_ops.yaml) - - Add your entry here if your operator overload does *not* appear in the core pytorch `native_functions.yaml`. - -The next sections describe how to add a yaml entry. - -#### YAML Schema - -This YAML file schema is a DSL to decribe the operators and the kernels that implement them. This YAML file is a contract between AOT model export and runtime execution, that if followed correctly, can make sure ExecuTorch runtime be able to link the C++ implementation of an operator to the exported model artifact. Here are some rules of writing up your own YAML files. - -**Out variants only** - -ExecuTorch only supports out-style operators, where: -- The caller provides the output Tensor or Tensor list in the final position - with the name `out`. -- The C++ function modifies and returns the same `out` argument. - - If the return type in the YAML file is `()` (which maps to void), the C++ - function should still modify `out` but does not need to return anything. -- The `out` argument must be keyword-only, which means it needs to follow an - argument named `*` like in the `add.out` example below. -- Conventionally, these out operators are named using the pattern `.out` - or `._out`. - -Since all output values are returned via an `out` parameter, ExecuTorch ignores -the actual C++ function return value. But, to be consistent, functions should -always return `out` when the return type is non-`void`. - -**Can only return `Tensor` or `()`** - -ExecuTorch only supports operators that return a single `Tensor`, or the unit -type `()` (which maps to `void`). It does not support returning any other types, -including lists, optionals, tuples, or scalars like `bool`. - -**Supported argument types** - -ExecuTorch does not support all of the argument types that core PyTorch -supports. See [this -spreadsheet](https://docs.google.com/spreadsheets/d/1uArc0r1Yq1QSeyRJZKzZ8Wkz0eS9TsM39ghmMAZCXDA/edit#gid=0) -for the list of supported and unsupported types. - - -**Functions only, no methods** - -ExecuTorch does not support Tensor methods, and assumes `variants: function` for -all operators. Entries like `variants: method` or `variants: function, method` -will be ignored. - -#### Add your operator entry - -Some examples of operator entry: - -ATen operator with a default kernel -``` -- op: add.out - kernels: - - arg_meta: null - kernel_name: torch::executor::add_out -``` - -ATen operator with a dtype/dim order specialized kernel (works for `Double` dtype and dim order needs to be (0, 1, 2, 3)) -``` -- op: add.out - type_alias: - T0: [Double] - dim_order_alias: - D0: [[0, 1, 2, 3]] - kernels: - - arg_meta: - self: [T0, D0] - other: [T0 , D0] - out: [T0, D0] - kernel_name: torch::executor::add_out -``` - -Custom operator with a default kernel -``` -- func: allclose.out(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False, bool dummy_param=False, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: torch::executor::allclose_out -``` - -Top level attributes: -* `op` (if the operator appears in `native_functions.yaml`) or `func` for custom operator. The value for this key needs to be the full operator name (including overload name) for `op` key, or a full operator schema (namespace, operator name, operator overload name and schema string). For schema syntax please refer to this [instruction](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md). - -* `kernels`: this entry is used to define the information of kernels. It consists of `arg_meta` and `kernel_name`, they are bound together to describe "for input tensors with these metadata, use this kernel". -* `type_alias`(optional): we are giving aliases to possible dtype options. `T0: [Double, Float]` means `T0` can be one of `Double` or `Float`. -* `dim_order_alias`(optional): similar to `type_alias`, we are giving names to possible dim order options. - -Attributes under `kernels`: -* `arg_meta`: a list of "tensor arg name" entries. The value for these keys are dtypes and dim orders alias, that are implemented by the corresponding `kernel_name`. This being `null` means the kernel will be used for all types of input. -* `kernel_name`: the expected name of the -C++ function that will implement this operator. You can put whatever you want to -here, but you should follow the convention of replacing the `.` in the overload -name with an underscore, and lowercasing all characters. In this example, -`add.out` uses the C++ function named `add_out`. `add.Scalar_out` would become `add_scalar_out`, with a lowercase `S`. We support namespace for kernels, but note that we will be inserting a `native::` to the last level of namespace. So `custom::add_out` in the `kernel_name` will point to `custom::native::add_out`. - -### Find operator base name - -The base name is the part of the operator name before the `.`, excluding any -trailing underscores. The rest of this document refer to this as ``. - -E.g., these operator overloads all have a base name of `add`: -- `add.Scalar` -- `add.Tensor` -- `add.out` -- `add_.Tensor` - -So, if you were implementing `add.out` then your operator base name would be -`add`, and you would replace `` with `add` everywhere below. - -### Selective build - -When using macros that require a `NAME` argument, eg. `#define ET_SWITCH_REAL_TYPES_AND(ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...)`, make sure to pass in the same operator name defined in `functions.yaml`. This is the base name + variant, eg. `add.out`, `add.Scalar_out`. The function name is required for dtype selective build, which matches against the operator names and dtypes present in a model. - -### Overview of files and targets - -For the operator base name ``, you should work with these files and Buck -targets. Sections below give more details about what they should contain. - -- `./kernels/portable/cpu/op_.cpp`: The implementations of operator overloads - with base name ``. This is the file that clients will link into their - runtimes. -- `//executorch/kernels/portable/cpu:op_`: The build target for - `op_.cpp`, defined in `targets.bzl` in the same directory. -- `./kernels/test/op__test.cpp`: Unit tests for the operator overloads - with base name ``. - - Note that tests under this directory are for portable kernel specific. To - share tests between multiple kernels, we can put tests in ../test. - - Note that the tests do not live under `cpu`; tests should be - implementation-agnostic. This will let us run the same tests against all - implementations of a given operator, which should behave identically. -- `//executorch/kernels/portable/test:op__test`: The test target for - `op__test.cpp`, defined in `targets.bzl` in the same directory. - -For an example, see the `add` operator (note that these are slightly different -from the `add` examples in this doc): -- [`//executorch/kernels/portable/cpu/op_add.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/op_add.cpp): - Implementations. -- [`//executorch/kernels/portable/cpu/targets.bzl`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/targets.bzl): - Definition of the `:op_add` target. -- [`//executorch/kernels/portable/test/op_add_test.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/op_add_test.cpp): - Unit tests. -- [`//executorch/kernels/portable/test/targets.bzl`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/targets.bzl): - Definition of the `:op_add_test` target. - -### Define the build target for the operator implementation - -Define a build target by adding an entry to -`//executorch/kernels/portable/cpu/targets.bzl`, inside -`define_common_targets()`, in sorted order with other `_op_target` entries: -``` -_op_target(name = "op_") -``` - -If your operator overload group is ATen-compatible, its `_op_target` entry -belongs in the `_ATEN_OPS` list, otherwise it belongs in the `_CUSTOM_OPS` list. -Note that this means that a given `op_` cannot implement both -ATen-compatible and non-ATen-compatible (i.e., custom) operators. We suggest -adding the suffix `_custom` if necessary: e.g., `op_add` for ATen-compatible -overloads of the `add` operator, and `op_add_custom` for non-ATen-compatible -overloads. - -By default, this target will depend on the core ExecuTorch types, but you can -add additional deps if you want to. - -NOTE: An `op_` target may not depend on another `op_` target. If two -`op_` targets need to share code, define a separate `runtime.cxx_library` target -under `//executorch/kernels/portable/cpu/lib` that they both depend on. This -keeps the dependencies more managable, especially for selective builds where -only a subset of operators are used. - -NOTE: An `op_` target may not depend on targets outside of `//executorch`. -This library is intended to be portable, open-sourceable, and self-contained. - -### Create a skeleton .cpp file for the operator implementation - -If not already present, create the file -`//executorch/kernels/portable/cpu/op_.cpp`, which should follow the -pattern: -``` -// Copyright (c) Meta Platforms, Inc. and affiliates. -#include - -namespace torch { -namespace executor { -namespace native { - -namespace { - // -} // namespace - -// - -} // namespace native -} // namespace executor -} // namespace torch -``` - -With the target and cpp file in place, you should be able to build it: -``` -cd ${HOME}/fbsource/fbcode/executorch -buck build fbcode//executorch/kernels/portable/cpu:op_ -``` - -### Find the function signature for the operator overload - -When you add an entry to the YAML file, the codegen tools will generate an -expected function signature for you to implement in a file called -`NativeFunctions.h`. - -To build and find that generated header, run the script -`fbsource/fbcode/executorch/kernels/portable/find_op_header.sh`. It will print -output like: -``` -===== Generating header files ===== -File changed: fbcode//executorch/kernels/portable/functions.yaml -Buck UI: https://www.internalfb.com/buck2/e5a6f22a-5b6e-4931-9a7f-df18bdf97ab6 -RE Session: reSessionID-4b735cfa-e66f-43d8-a73b-94f22d5936c5 -Jobs completed: 3. Time elapsed: 0.2s. Cache hits: 100%. Commands: 1 (cached: 1, remote: 0, local: 0) -BUILD SUCCEEDED - -Header file: /data/users/USER/fbsource/buck-out/v2/gen/fbcode/d839c731f5505c62/executorch/codegen/__generated_lib_generate__/out/NativeFunctions.h -``` -The path will be different in your environment, so be sure to use the output -from the script instead of copy-pasting this path. And, since this header is -generated from the YAML files, re-run the script if you have modified your -operator's entry in those files. - -Open that file and look for the function with the same name that you earlier -added under `CPU: dispatch:` in the YAML file. For `add_out`, this might look -like -``` -TORCH_API torch::executor::Tensor & add_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); -``` - -This is the function signature that you will need to implement. - -### Add a stub implementation - -Now that you have your function signature, add a stub to the `op_.cpp` -file that just returns the `out` argument. For example: -``` -Tensor& add_out( - const Tensor& self, - const Tensor& other, - Tensor& out) { - return out; -} -``` - -Note that you should drop the `TORCH_API` attribute, and should drop `at::`. - -Try building again with -``` -cd ${HOME}/fbsource/fbcode/executorch -buck build fbcode//executorch/kernels/portable/cpu:op_ -``` - -### Create a test build target - -Define a test build target by adding an entry to -`//executorch/kernels/portable/test/targets.bzl`, inside -`define_common_targets()`, in sorted order with other `_op_test` entries: -``` -_op_target(name = "op__test") -``` - -By default, this target will depend on -`//executorch/kernels/portable/cpu:op_`, the core Executor types, and -some helper test utilities ([see -headers](https://www.internalfb.com/code/fbsource/fbcode/executorch/runtime/core/exec_aten/testing_util/)), -but you can add additional deps if you want to. - -### Create a skeleton test .cpp file - -If not already present, create the file -`//executorch/kernels/portable/test/op__test.cpp`. Here's a suggested -starting point: -``` -// Copyright (c) Meta Platforms, Inc. and affiliates. - -#include // Declares the operator -#include -#include -#include - -#include - -using namespace ::testing; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using torch::executor::native::; -using torch::executor::testing::IsCloseTo; -using torch::executor::testing::TensorFactory; - -TEST(OpTest, SmokeTest) { - TensorFactory tf; - - Tensor a = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 1, 1, 1}): - Tensor b = tf.ones(/*sizes=*/{2, 2}): - Tensor z = tf.zeros(/*sizes=*/{2, 2}): - - EXPECT_EQ(a, b); // Exact equality - EXPECT_THAT(a, IsCloseTo(b)); // For floating-point tensors - - EXPECT_NE(a, z); - EXPECT_THAT(a, Not(IsCloseTo(z))); -} -``` - -Try running the test: -``` -cd ${HOME}/fbsource/fbcode/executorch -buck test fbcode//executorch/kernels/test:op__test -``` - -### Implement and test the operator - -You should now be able to implement and test your operator. It's helpful to see -how other operators do it, so take a look at `op_add`: -- [`//executorch/kernels/portable/cpu/op_add.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/cpu/op_add.cpp) -- [`//executorch/kernels/portable/test/op_add_test.cpp`](https://www.internalfb.com/code/fbsource/fbcode/executorch/kernels/portable/test/op_add_test.cpp) - -Check out how it uses helper macros like `ET_CHECK_SAME_SHAPE_AND_DTYPE` and -`ET_FORALL_REAL_TYPES` when implementing the operator, and test helpers like -`TensorFactory` and `IsCloseTo()` when testing. - -#### Implementation restrictions - -To reduce dependencies and size, to ensure portability, and to conform to the -restrictions of embedded environments, your operator implementations: - -- Must not include C++ stdlib headers, or use C++ stdlib types. For example, - `string`/`basic_string`, `vector`, `unordered_map`, `cout`, `unique_pointer` - must not be used. -- Must not dynamically allocate memory, or cause memory to be dynamically - allocated. All non-stack memory must be provided as a function parameter by - the caller, typically via an `out` parameter or another tensor parameter to be - used as scratch space. - - This includes direct calls to `new`, `malloc`, `realloc`, etc., as well as - operations that allocate under the hood like `make_unique`, or the creation - of `vector` or `string`, for example. -- Must be stateless. -- Must be thread-safe. Note that the ExecuTorch environment does not provide - a locking construct, so this means that operator implementations must not - modify global memory. -- Must work in an environment without threads. This, along with the stateless - requirement, means that thread local storage must not be used. -- Must not use `stdout`, `stderr`, or other file/stream IO via `printf`/`cout` - etc.; instead, use `ET_LOG` from `executorch/runtime/platform/log.h`. -- Must not use `assert()`. Instead use `ET_CHECK` and other macros from - `executorch/runtime/platform/assert.h`. -- Must not raise exceptions. Instead use `ET_CHECK` and other macros from - `executorch/runtime/platform/assert.h`. - -Note that not all of these apply to *every* ExecuTorch-compatible operator -implementation, only those included in this portable library. - -For example, a target-specfic custom operator that initiates a DMA copy would be -stateful, and would probaby modify global memory, but it would need to use -target-specific APIs to do so. But, since this library is only for portable -operator implementations, the operators it contains can't depend on -target-specific APIs like that. - -### Shared kernel tests (//executorch/kernels/test) -The portable kernel impelemntation and its corresponding tests can be used as a -reference for other kernels. We can also share the test cases in -`//executorch/kernels/test`, which contains common resources for kernel testing. - -*util.bzl* contains common BUCK targets for other test libs to include: -- op_test(): Defines a cxx_test() for an "op_*_test.cpp" file -- define_supported_features_lib(): Defines the corresponding supported features library - -*targets.bzl* has targets shared by other kernels tests: -- supported_features_header: header file for SupportedFeatures -- supported_features_header_aten: ATen implementation of SupportedFeatures -- _codegen_function_header_wrapper: a wrapper to include the right Functions.h header -- _common_op_test: generates _ - -*_codegen_function_header_wrapper* generates a header FunctionHeaderWrapper.h, which simply -includes the corresponding Functions.h file for the specified kernel: -`#include `. With that, the test sources don't need to know -about which kernel we are testing and which Functions.h we should use. - -With *_common_op_test* we use a single test source file (op__test.cpp) at this directory. -We automatically find the corresponding registered dispatch function through Funcitons.h, so -it can be used to test multiple kernels. - -In /test/ we can put kernel-specific test cases. - -*SupportedFeatures* is used to distinguish between different kernel features. For example, -ATen supports mixing input and output dtype while portable doesn't. When we expect death in -portable testing in such case, we can check the supported features by the running kernel and -bypass if it's supported. -- The default value of supported features is in test/supported_features.yaml -- Each kernel needs to override its supported features in /test/supported_features_def.yaml. - See example in supported_features_def_example.yaml. -- This ensures that all kernels can share the same c++ test case source +See README.md in the parent directory. diff --git a/kernels/portable/cpu/op_convolution.cpp b/kernels/portable/cpu/op_convolution.cpp index 1aa38948ba..91fea25cac 100644 --- a/kernels/portable/cpu/op_convolution.cpp +++ b/kernels/portable/cpu/op_convolution.cpp @@ -72,6 +72,13 @@ void conv2d_impl( exec_aten::SizesType w_coord[kTensorDimensionLimit]; w_coord[0] = out_c; + const int64_t stride_y = val_at(stride, 0); + const int64_t padding_y = val_at(padding, 0, /*default_value=*/0); + const int64_t dilation_y = val_at(dilation, 0); + const int64_t stride_x = val_at(stride, 1); + const int64_t padding_x = val_at(padding, 1, /*default_value=*/0); + const int64_t dilation_x = val_at(dilation, 1); + // Compute 2D output region for (size_t out_y = 0; out_y < out_H; ++out_y) { out_coord[2] = out_y; @@ -87,9 +94,6 @@ void conv2d_impl( for (size_t w_y = 0; w_y < w_H; ++w_y) { w_coord[2] = w_y; - int64_t stride_y = val_at(stride, 0); - int64_t padding_y = val_at(padding, 0, /*default_value=*/0); - int64_t dilation_y = val_at(dilation, 0); size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y; in_coord[2] = in_y; // Only proceed if input y coordinate is within bounds @@ -97,9 +101,6 @@ void conv2d_impl( for (size_t w_x = 0; w_x < w_W; ++w_x) { w_coord[3] = w_x; - int64_t stride_x = val_at(stride, 1); - int64_t padding_x = val_at(padding, 1, /*default_value=*/0); - int64_t dilation_x = val_at(dilation, 1); size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x; in_coord[3] = in_x; diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 8abf6f9722..900b6e39d3 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -42,8 +42,8 @@ Tensor& copy_out( ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy.out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy.out", CTYPE_SRC, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { + ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); @@ -69,8 +69,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) { ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy_", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy_", CTYPE_SRC, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { + ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); diff --git a/kernels/portable/cpu/op_index_put.cpp b/kernels/portable/cpu/op_index_put.cpp index 59a258eb00..88cffe1bce 100644 --- a/kernels/portable/cpu/op_index_put.cpp +++ b/kernels/portable/cpu/op_index_put.cpp @@ -48,7 +48,7 @@ Tensor& index_put_out( ET_KERNEL_CHECK( ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "index_put.out", CTYPE, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { apply_binary_elementwise_fn( [accumulate](const CTYPE val_in, const CTYPE val) { return accumulate ? val_in + val : val; @@ -115,7 +115,7 @@ Tensor& index_put_out( x_numel *= x_sizes[i]; } - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "index_put.out", CTYPE, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { const CTYPE* const values_data = values.const_data_ptr(); CTYPE* const out_data = out.mutable_data_ptr(); diff --git a/kernels/portable/cpu/op_slice_scatter.cpp b/kernels/portable/cpu/op_slice_scatter.cpp index dfa7dfb690..367b626696 100644 --- a/kernels/portable/cpu/op_slice_scatter.cpp +++ b/kernels/portable/cpu/op_slice_scatter.cpp @@ -74,28 +74,27 @@ Tensor& slice_scatter_out( ScalarType in_type = input.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REAL_TYPES_AND( - Bool, in_type, ctx, "slice_scatter.out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, src_type, ctx, "slice_scatter.out", CTYPE_SRC, [&]() { - CTYPE* out_data = out.mutable_data_ptr(); - const CTYPE_SRC* src_data = src.const_data_ptr(); - - size_t src_offset = 0; - - for (int i = 0; i < leading_dims; i++) { - size_t out_offset = (i * dim_length + start) * trailing_dims; - for (int j = 0; j < num_values; j++) { - for (size_t k = 0; k < trailing_dims; ++k) { - out_data[out_offset + k] = - convert(src_data[src_offset + k]); - } - src_offset += trailing_dims; - out_offset += step * trailing_dims; - } + ET_SWITCH_REALHB_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() { + ET_SWITCH_REALHB_TYPES( + src_type, ctx, "slice_scatter.out", CTYPE_SRC, [&]() { + CTYPE* out_data = out.mutable_data_ptr(); + const CTYPE_SRC* src_data = src.const_data_ptr(); + + size_t src_offset = 0; + + for (int i = 0; i < leading_dims; i++) { + size_t out_offset = (i * dim_length + start) * trailing_dims; + for (int j = 0; j < num_values; j++) { + for (size_t k = 0; k < trailing_dims; ++k) { + out_data[out_offset + k] = + convert(src_data[src_offset + k]); } - }); - }); + src_offset += trailing_dims; + out_offset += step * trailing_dims; + } + } + }); + }); return out; } diff --git a/profiler/parse_profiler_results.py b/profiler/parse_profiler_results.py index 88598b9dc3..3fc1a69176 100644 --- a/profiler/parse_profiler_results.py +++ b/profiler/parse_profiler_results.py @@ -434,19 +434,19 @@ def profile_framework_tax_table( def deserialize_profile_results_files( profile_results_path: str, - model_ff_path: str, + bundled_program_ff_path: str, time_scale: TimeScale = TimeScale.TIME_IN_NS, ): with open(profile_results_path, "rb") as prof_res_file, open( - model_ff_path, "rb" + bundled_program_ff_path, "rb" ) as model_ff_file: prof_res_buf = prof_res_file.read() - model_ff_buf = model_ff_file.read() + bundled_program_ff_buf = model_ff_file.read() prof_data, mem_allocations = deserialize_profile_results(prof_res_buf, time_scale) framework_tax_data = profile_aggregate_framework_tax(prof_data) - prof_tables = profile_table(prof_data, model_ff_buf) + prof_tables = profile_table(prof_data, bundled_program_ff_buf) for table in prof_tables: print(table)