forked from oneapi-src/oneDNN
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ocl, softmax: Initial gen9 softmax implementation
- Loading branch information
Showing
5 changed files
with
314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/******************************************************************************* | ||
* Copyright 2020 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#include "gpu/ocl/ocl_types.h" | ||
|
||
#define OFF(dim, idx) \ | ||
(dim % CONCAT2(DATA_B, idx)) * CONCAT2(DATA_SB, idx) \ | ||
+ (dim / CONCAT2(DATA_B, idx)) * CONCAT2(DATA_S, idx) | ||
|
||
#if SOFTMAX_AXIS_IDX == 0 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(softmax_dim, 0) + OFF(dim0, 1) + OFF(dim1, 2) + OFF(dim2, 3) \ | ||
+ OFF(dim3, 4) + OFF(dim4, 5) | ||
#elif SOFTMAX_AXIS_IDX == 1 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(dim0, 0) + OFF(softmax_dim, 1) + OFF(dim1, 2) + OFF(dim2, 3) \ | ||
+ OFF(dim3, 4) + OFF(dim4, 5) | ||
#elif SOFTMAX_AXIS_IDX == 2 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(dim0, 0) + OFF(dim1, 1) + OFF(softmax_dim, 2) + OFF(dim2, 3) \ | ||
+ OFF(dim3, 4) + OFF(dim4, 5) | ||
#elif SOFTMAX_AXIS_IDX == 3 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(softmax_dim, 3) \ | ||
+ OFF(dim3, 4) + OFF(dim4, 5) | ||
#elif SOFTMAX_AXIS_IDX == 4 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(dim3, 3) \ | ||
+ OFF(softmax_dim, 4) + OFF(dim4, 5) | ||
#elif SOFTMAX_AXIS_IDX == 5 | ||
#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ | ||
OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(dim3, 3) + OFF(dim4, 4) \ | ||
+ OFF(softmax_dim, 5) | ||
#else | ||
#error unsupported softmax dimension | ||
#endif | ||
|
||
#define LOAD_DATA_8x16(ptr) \ | ||
CONVERT_FLOAT8_T( \ | ||
AS_DATA8_T(BLOCK_READ8((const __global BLOCK_DATA_T *)(ptr)))) | ||
|
||
#define STORE_DATA_8x16(ptr, val) \ | ||
BLOCK_WRITE8((__global BLOCK_DATA_T *)ptr, \ | ||
AS_BLOCK_DATA8_T(CONVERT_DATA8_T(val))) | ||
|
||
#define VECT_SIZE 8 | ||
#define NUM_BUF (SOFTMAX_AXIS_SIZE / SUB_GROUP_SIZE / VECT_SIZE) | ||
|
||
#if IS_FWD | ||
|
||
__attribute__((reqd_work_group_size(GROUP_SIZE, 1, 1))) | ||
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void | ||
gen9_softmax_fwd(__global DATA_T *src, __global DATA_T *dst) { | ||
|
||
const int dim[] = { | ||
(get_global_id(0) / GROUP_SIZE) % BLOCK_0, | ||
get_global_id(1) % BLOCK_1, | ||
get_global_id(2) % BLOCK_2, | ||
(get_global_id(0) / GROUP_SIZE) / BLOCK_0, | ||
get_global_id(1) / BLOCK_1, | ||
get_global_id(2) / BLOCK_2, | ||
}; | ||
|
||
float8 d[NUM_BUF]; | ||
|
||
int local_id = get_sub_group_local_id(); | ||
int begin = local_id * (SOFTMAX_AXIS_SIZE / VECT_SIZE); | ||
|
||
float max_ = -FLT_MAX; | ||
float denom_ = 0.f; | ||
|
||
size_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], begin); | ||
src += data_off; | ||
|
||
for (int k = 0; k < NUM_BUF; ++k) { | ||
d[k] = LOAD_DATA_8x16(&src[k * VECT_SIZE * SUB_GROUP_SIZE]); | ||
for (int i = 0; i < VECT_SIZE; ++i) { | ||
max_ = max(d[k][i], max_); | ||
} | ||
} | ||
|
||
max_ = sub_group_reduce_max(max_); | ||
|
||
for (int k = 0; k < NUM_BUF; ++k) { | ||
#if LOGSOFTMAX | ||
for (int i = 0; i < VECT_SIZE; ++i) | ||
denom_ += exp(d[k][i] - max_); | ||
#else | ||
d[k] = exp(d[k] - max_); | ||
for (int i = 0; i < VECT_SIZE; ++i) | ||
denom_ += d[k][i]; | ||
#endif | ||
} | ||
|
||
denom_ = sub_group_reduce_add(denom_); | ||
|
||
#if LOGSOFTMAX | ||
denom_ = log(denom_); | ||
#else | ||
denom_ = 1.0 / denom_; | ||
#endif | ||
|
||
dst += data_off; | ||
for (int k = 0; k < NUM_BUF; ++k) { | ||
#if LOGSOFTMAX | ||
d[k] = d[k] - max_ - denom_; | ||
#else | ||
d[k] = d[k] * denom_; | ||
#endif | ||
STORE_DATA_8x16(&dst[k * VECT_SIZE * SUB_GROUP_SIZE], d[k]); | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/******************************************************************************* | ||
* Copyright 2020 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#include "gpu/ocl/gen9_softmax.hpp" | ||
|
||
namespace dnnl { | ||
namespace impl { | ||
namespace gpu { | ||
namespace ocl { | ||
|
||
status_t gen9_softmax_fwd_t::execute_generic(const exec_ctx_t &ctx) const { | ||
if (memory_desc_wrapper(pd()->desc()->data_desc).has_zero_dim()) | ||
return status::success; | ||
|
||
auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); | ||
auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); | ||
|
||
compute::kernel_arg_list_t arg_list; | ||
arg_list.set(0, src); | ||
arg_list.set(1, dst); | ||
|
||
auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); | ||
|
||
status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); | ||
return status; | ||
} | ||
|
||
} // namespace ocl | ||
} // namespace gpu | ||
} // namespace impl | ||
} // namespace dnnl | ||
|
||
// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
/******************************************************************************* | ||
* Copyright 2020 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#ifndef GPU_OCL_GEN9_SOFTMAX_HPP | ||
#define GPU_OCL_GEN9_SOFTMAX_HPP | ||
|
||
#include "common/c_types_map.hpp" | ||
#include "common/nstl.hpp" | ||
#include "common/primitive.hpp" | ||
#include "gpu/compute/compute.hpp" | ||
#include "gpu/gpu_primitive.hpp" | ||
#include "gpu/gpu_resource.hpp" | ||
#include "gpu/gpu_softmax_pd.hpp" | ||
#include "gpu/ocl/ocl_stream.hpp" | ||
#include "gpu/ocl/ocl_utils.hpp" | ||
#include "gpu/primitive_conf.hpp" | ||
|
||
namespace dnnl { | ||
namespace impl { | ||
namespace gpu { | ||
namespace ocl { | ||
|
||
struct gen9_softmax_fwd_t : public gpu_primitive_t { | ||
struct pd_t : public gpu_softmax_fwd_pd_t { | ||
pd_t(const softmax_desc_t *adesc, const primitive_attr_t *attr, | ||
const softmax_fwd_pd_t *hint_fwd_pd) | ||
: gpu_softmax_fwd_pd_t(adesc, attr, hint_fwd_pd) {} | ||
|
||
DECLARE_COMMON_PD_T("ocl:gen9", gen9_softmax_fwd_t); | ||
|
||
status_t init(engine_t *engine) { | ||
auto *compute_engine | ||
= utils::downcast<compute::compute_engine_t *>(engine); | ||
|
||
const int nelems = desc()->data_desc.dims[desc()->softmax_axis]; | ||
const memory_desc_wrapper src_d(src_md()); | ||
bool ok = true && nelems % 128 == 0 | ||
&& desc()->softmax_axis == src_md()->ndims - 1 | ||
&& src_d.is_plain() | ||
&& utils::one_of(desc()->prop_kind, | ||
prop_kind::forward_inference, | ||
prop_kind::forward_training) | ||
&& utils::one_of(desc()->data_desc.data_type, | ||
data_type::f32, data_type::f16, data_type::bf16) | ||
&& IMPLICATION( | ||
desc()->data_desc.data_type == data_type::f16, | ||
compute_engine->mayiuse( | ||
compute::device_ext_t::khr_fp16)) | ||
&& attr()->has_default_values(); | ||
if (!ok) return status::unimplemented; | ||
|
||
gws[0] = gws[1] = gws[2] = 1; | ||
lws[0] = lws[1] = lws[2] = 1; | ||
block[0] = block[1] = block[2] = 1; | ||
|
||
for (int i = 0, j = 0; i < src_md()->ndims; ++i) { | ||
if (i != desc()->softmax_axis) { | ||
const auto dim = src_md()->dims[i]; | ||
gws[j % 3] *= dim; | ||
if (j < 3) block[j] = dim; | ||
++j; | ||
} | ||
} | ||
|
||
group_size = 16; | ||
|
||
lws[0] = group_size; | ||
gws[0] *= group_size; | ||
|
||
return status::success; | ||
} | ||
|
||
size_t gws[3] = {}; | ||
size_t lws[3] = {}; | ||
size_t block[3] = {}; | ||
size_t group_size = 0; | ||
}; | ||
|
||
gen9_softmax_fwd_t(const pd_t *apd) : gpu_primitive_t(apd) {} | ||
|
||
status_t init(engine_t *engine) override { | ||
if (memory_desc_wrapper(pd()->desc()->data_desc).has_zero_dim()) | ||
return status::success; | ||
|
||
compute::kernel_ctx_t kernel_ctx; | ||
|
||
const auto *desc = pd()->desc(); | ||
kernel_ctx.define_int("SOFTMAX_AXIS_IDX", desc->softmax_axis); | ||
kernel_ctx.define_int( | ||
"SOFTMAX_AXIS_SIZE", desc->data_desc.dims[desc->softmax_axis]); | ||
kernel_ctx.define_int("GROUP_SIZE", pd()->group_size); | ||
kernel_ctx.define_int("SUB_GROUP_SIZE", pd()->group_size); | ||
kernel_ctx.define_int("IS_FWD", 1); | ||
kernel_ctx.add_option("-cl-std=CL2.0"); | ||
kernel_ctx.define_int("LOGSOFTMAX", | ||
desc->primitive_kind == primitive_kind::logsoftmax ? 1 : 0); | ||
|
||
kernel_ctx.set_data_type(desc->data_desc.data_type); | ||
set_offsets(kernel_ctx, pd()->dst_md(), "DATA"); | ||
|
||
for (int i = 0; i < 3; ++i) | ||
kernel_ctx.define_int(utils::format("BLOCK_%d", i), pd()->block[i]); | ||
|
||
create_kernel(engine, &kernel_, "gen9_softmax_fwd", kernel_ctx); | ||
if (!kernel_) return status::runtime_error; | ||
|
||
return status::success; | ||
} | ||
|
||
status_t execute(const exec_ctx_t &ctx) const override { | ||
return execute_generic(ctx); | ||
} | ||
|
||
protected: | ||
status_t execute_generic(const exec_ctx_t &ctx) const; | ||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } | ||
compute::kernel_t kernel_; | ||
}; | ||
|
||
} // namespace ocl | ||
} // namespace gpu | ||
} // namespace impl | ||
} // namespace dnnl | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters