diff --git a/.clang-format b/.clang-format index b59d5fa..1fc33b8 100644 --- a/.clang-format +++ b/.clang-format @@ -2,3 +2,5 @@ BasedOnStyle: Google ColumnLimit: 100 DerivePointerAlignment: false +StatementMacros: + - _Pragma diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..17ccc9e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/cub"] + path = third_party/cub + url = https://github.com/NVIDIA/cub.git diff --git a/.pylintrc b/.pylintrc index 1d69d7f..33ddab8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -38,7 +38,7 @@ enable=indexing-exception,old-raise-syntax # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter,missing-return-type-doc,missing-type-doc,missing-param-doc # Set the cache size for astng objects. diff --git a/Makefile b/Makefile index a83070e..4b8e1f9 100644 --- a/Makefile +++ b/Makefile @@ -20,8 +20,14 @@ PYTHON_BIN_PATH = python TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') - -CFLAGS = ${TF_CFLAGS} -O3 -std=c++14 +TF_VERSION := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(int(tf.__version__.split(".")[1]))') +ifeq ($(shell expr $(TF_VERSION) \>= 10), 1) + CPP_STD := 17 + else + CPP_STD := 14 + endif + +CFLAGS = ${TF_CFLAGS} -O3 -std=c++${CPP_STD} LDFLAGS = -shared ${TF_LFLAGS} SRC = embedding_lookup_kernels @@ -34,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so all: $(TARGET_LIB) %_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h - $(NVCC) -c -o $@ $< $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr + $(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr %_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h $(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include diff --git a/distributed_embeddings/cc/kernels/embedding_lookup.h b/distributed_embeddings/cc/kernels/embedding_lookup.h index 4e08444..d3d699d 100644 --- a/distributed_embeddings/cc/kernels/embedding_lookup.h +++ b/distributed_embeddings/cc/kernels/embedding_lookup.h @@ -20,36 +20,32 @@ #include +#include "tensorflow/core/framework/op_kernel.h" + namespace tensorflow { enum class Combiner { Mean = 0, Sum = 1 }; inline Combiner StringToEnum(std::string combiner) { return combiner == "mean" ? Combiner::Mean : Combiner::Sum; } -template -struct EmbeddingLookupConstantHotnessFunctor { - void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr, - Tindices nnz_per_row, Tindices num_rows, Tindices embedding_width, - Combiner combiner) const; -}; - -template -struct EmbeddingLookupConstantHotnessGradFunctor { - void operator()(const Device& d, T* output_ptr, const T* grad_ptr, Tindices nnz_per_row, - Tindices num_rows, Tindices embedding_width, Combiner combiner) const; +template +struct RowToSplitFunctor { + void operator()(const Device& d, Tindices* split_ptr, const Tindices* row_ptr, Tindices num_ids, + Tindices num_rows) const; }; template struct EmbeddingLookupVariableHotnessFunctor { void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr, const Tindices* offsets_ptr, Tindices num_rows, Tindices embedding_width, - Combiner combiner) const; + Combiner combiner, Tindices ave_red_len) const; }; template struct EmbeddingLookupVariableHotnessGradFunctor { - void operator()(const Device& d, T* output_ptr, const T* grad_ptr, const Tindices* offsets_ptr, - Tindices num_rows, Tindices embedding_width, Combiner combiner) const; + void operator()(OpKernelContext* context, const Tindices* ids_ptr, const Tindices* row_ptr, + const T* grad_ptr, int64_t num_ids, Tindices embedding_width, Tindices num_rows, + int64_t dense_shape_dim0, int64_t max_red_len, Combiner combiner) const; }; } // namespace tensorflow diff --git a/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc b/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc index 319e6bd..e0e0350 100644 --- a/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc +++ b/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc @@ -16,7 +16,6 @@ */ #include "embedding_lookup.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_var.h" @@ -45,62 +44,25 @@ class ReadVariableNoCopyOp : public OpKernel { DataType dtype_; }; -template -class EmbeddingLookupConstantHotnessOp : public OpKernel { - public: - explicit EmbeddingLookupConstantHotnessOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& params = context->input(0); - const Tensor& ids = context->input(1); - - auto num_rows = ids.dim_size(0); - auto nnz_per_row = ids.dim_size(1); - auto embedding_width = params.dim_size(1); - - TensorShape output_shape({num_rows, embedding_width}); - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - - EmbeddingLookupConstantHotnessFunctor()( - context->eigen_device(), output->flat().data(), params.flat().data(), - ids.flat().data(), nnz_per_row, num_rows, embedding_width, - StringToEnum(_combiner)); - } - - private: - string _combiner; -}; - -template -class EmbeddingLookupConstantHotnessGradOp : public OpKernel { +template +class RowToSplitOp : public OpKernel { public: - explicit EmbeddingLookupConstantHotnessGradOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner)); - } + explicit RowToSplitOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - const Tensor& grad = context->input(0); - const Tensor& ids = context->input(1); + // [n, 2] + const Tensor& row = context->input(0); + auto num_ids = row.dim_size(0); + auto num_rows = context->input(1).scalar()(); - auto num_rows = ids.dim_size(0); - auto nnz_per_row = ids.dim_size(1); - auto nnz = num_rows * nnz_per_row; - auto embedding_width = grad.dim_size(1); - - TensorShape output_shape({nnz, embedding_width}); + TensorShape output_shape({num_rows + 1}); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - EmbeddingLookupConstantHotnessGradFunctor()( - context->eigen_device(), output->flat().data(), grad.flat().data(), - nnz_per_row, num_rows, embedding_width, StringToEnum(_combiner)); + RowToSplitFunctor()(context->eigen_device(), + output->flat().data(), + row.flat().data(), num_ids, num_rows); } - - private: - string _combiner; }; template @@ -118,6 +80,9 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel { auto num_rows = offsets.dim_size(0) - 1; auto embedding_width = params.dim_size(1); + auto num_ids = ids.dim_size(0); + auto ave_red_len = num_ids / num_rows; + TensorShape output_shape({num_rows, embedding_width}); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); @@ -125,7 +90,7 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel { EmbeddingLookupVariableHotnessFunctor()( context->eigen_device(), output->flat().data(), params.flat().data(), ids.flat().data(), offsets.flat().data(), num_rows, embedding_width, - StringToEnum(_combiner)); + StringToEnum(_combiner), ave_red_len); } private: @@ -140,21 +105,20 @@ class EmbeddingLookupVariableHotnessGradOp : public OpKernel { } void Compute(OpKernelContext* context) override { - const Tensor& grad = context->input(0); - const Tensor& ids = context->input(1); - const Tensor& offsets = context->input(2); - - auto num_rows = offsets.dim_size(0) - 1; + const Tensor& ids = context->input(0); + const Tensor& offset_in = context->input(1); + const Tensor& grad = context->input(2); + const Tensor& param = context->input(3); + auto num_ids = ids.dim_size(0); + auto num_rows = offset_in.dim_size(0) - 1; auto embedding_width = grad.dim_size(1); - auto nnz = ids.dim_size(0); - - TensorShape output_shape({nnz, embedding_width}); - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto max_red_len = grad.dim_size(0); + auto dense_shape_dim0 = param.dim_size(0); EmbeddingLookupVariableHotnessGradFunctor()( - context->eigen_device(), output->flat().data(), grad.flat().data(), - offsets.flat().data(), num_rows, embedding_width, StringToEnum(_combiner)); + context, ids.flat().data(), offset_in.flat().data(), + grad.flat().data(), num_ids, embedding_width, num_rows, dense_shape_dim0, max_red_len, + StringToEnum(_combiner)); } private: @@ -167,26 +131,21 @@ REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_DEFAULT).HostMe REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_GPU).HostMemory("resource"), ReadVariableNoCopyOp); -#define REGISTER_GPU(T, Tindices) \ - REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotness") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - EmbeddingLookupConstantHotnessOp); \ - REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotnessGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - EmbeddingLookupConstantHotnessGradOp); \ - REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - EmbeddingLookupVariableHotnessOp); \ - REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ +#define REGISTER_GPU(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("RowToSplit") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("Tindices") \ + .HostMemory("shape"), \ + RowToSplitOp); \ + REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + EmbeddingLookupVariableHotnessOp); \ + REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ EmbeddingLookupVariableHotnessGradOp); REGISTER_GPU(float, int64_t) diff --git a/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu b/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu index 7183c9b..991f311 100644 --- a/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu +++ b/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu @@ -14,171 +14,625 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #define EIGEN_USE_GPU +#define THRUST_IGNORE_CUB_VERSION_CHECK +#define _CG_ABI_EXPERIMENTAL // enable experimental API +#define ILP 4 + +#include +#include "cub/cub.cuh" #include "embedding_lookup.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +namespace cg = cooperative_groups; + namespace tensorflow { -template -__device__ void EmbeddingReduceByIndices(T* out, const T* params, Tindices embedding_width, - Tindices query_nnz, const Tindices* indices, - Tindices* shmem_indices, Combiner combiner) { - int tid = threadIdx.x; - T result = 0; +template +__device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, const T* params, + int embedding_width, int query_nnz, const TIndex* indices, + TIndex* shmem_indices, Combiner combiner, + const T* weights) { + T weight = 1; + int tid = g.thread_rank(); + int row_off = tid / row * row; + int row_tid = tid % row; + T result[(row + tile - 1) / tile] = {0}; // Remainder is handled first - int remainder = query_nnz % blockDim.x; + int remainder = query_nnz % tile; // First stage, each CTA load one segment of indices in the sample into shared memory + g.sync(); if (tid < remainder) { shmem_indices[tid] = indices[tid]; } - __syncthreads(); + g.sync(); // Second stage - // A CTA first reads indices from shared memory and finds the corresponding entry in the embedding - // table. Then the CTA reads the embedding vector and accumulates into register file. Each thread - // in the CTA reads one element of the embedding vector -#pragma unroll 4 - for (int i = 0; i < remainder; ++i) { - result += params[shmem_indices[i] * static_cast(embedding_width) + tid]; + // A CTA first reads indices from shared memory and finds the corresponding entry in the + // embedding table. Then the CTA reads the embedding vector and accumulates into register file. + // Each thread in the CTA reads one element of the embedding vector + _Pragma("unroll 4") + for (int i = tid / row; i < remainder; i += (tile + row - 1) / row) { + if (weights != nullptr) weight = weights[shmem_indices[i]]; + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (j * tile + row_tid < embedding_width) { + result[j] += + weight * + params[shmem_indices[i] * static_cast(embedding_width) + j * tile + row_tid]; + } + } + } + + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + out[j] += result[j]; + result[j] = 0; } + g.sync(); // Repeat stages above and handle one block size of indices at a time - for (int processed = remainder; processed < query_nnz; processed += blockDim.x) { + for (int processed = remainder; processed < query_nnz; processed += tile) { shmem_indices[tid] = indices[processed + tid]; - __syncthreads(); - -#pragma unroll 4 - for (int i = 0; i < blockDim.x; ++i) { - result += params[shmem_indices[i] * static_cast(embedding_width) + tid]; + g.sync(); + _Pragma("unroll 4") + for (int i = 0; i < row && i < tile; ++i) { + if (weights != nullptr) weight = weights[shmem_indices[i + row_off]]; + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (j * tile + row_tid < embedding_width) { + result[j] += + weight * params[shmem_indices[i + row_off] * static_cast(embedding_width) + + j * tile + row_tid]; + } + } } + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + out[j] += result[j]; + result[j] = 0; + } + g.sync(); } - if (combiner == Combiner::Mean) { - result /= query_nnz; + // reduce down to row elements, only first row have correct result + for (int i = tile / 2; i >= row; i /= 2) { + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + out[j] += g.shfl_down(out[j], i); + } } - - out[tid] = result; } -template -__device__ void EmbeddingExpandGrad(T* value, const T* grad, Tindices embedding_width, - Tindices query_nnz, Combiner combiner) { - int tid = threadIdx.x; +template +__device__ void EmbeddingReduceByIndicesWide(cg::thread_block_tile g, T* out, const T* params, + int embedding_width, int query_nnz, + const TIndex* indices, TIndex* shmem_indices, + Combiner combiner, const T* weights, int rem_width) { + T weight = 1; + int tid = g.thread_rank(); + T result[(row + tile - 1) / tile] = {0}; + + // Remainder is handled first + int remainder = query_nnz % tile; + // First stage, each CTA load one segment of indices in the sample into shared memory + g.sync(); + if (tid < remainder) { + shmem_indices[tid] = indices[tid]; + } + g.sync(); + // Second stage + // A CTA first reads indices from shared memory and finds the corresponding entry in the + // embedding table. Then the CTA reads the embedding vector and accumulates into register file. + // Each thread in the CTA reads one element of the embedding vector + _Pragma("unroll 4") + for (int i = 0; i < remainder; ++i) { + if (weights != nullptr) weight = weights[shmem_indices[i]]; + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (j * tile + tid < rem_width) { + result[j] += + weight * + params[shmem_indices[i] * static_cast(embedding_width) + j * tile + tid]; + } + } + } - T g = grad[tid]; - if (combiner == Combiner::Mean) { - g /= query_nnz; + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + out[j] += result[j]; + result[j] = 0; } -#pragma unroll 4 - for (int i = 0; i < query_nnz; ++i) { - value[i * static_cast(embedding_width) + tid] = g; + g.sync(); + // Repeat stages above and handle one block size of indices at a time + for (int processed = remainder; processed < query_nnz; processed += tile) { + shmem_indices[tid] = indices[processed + tid]; + g.sync(); + _Pragma("unroll 4") + for (int i = 0; i < tile; ++i) { + if (weights != nullptr) weight = weights[shmem_indices[i]]; + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (j * tile + tid < rem_width) { + result[j] += + weight * + params[shmem_indices[i] * static_cast(embedding_width) + j * tile + tid]; + } + } + } + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + out[j] += result[j]; + result[j] = 0; + } + g.sync(); } } -template -__global__ void EmbeddingLookUpConstantHot(const T* params, Tindices embedding_width, - Tindices query_nnz, const Tindices* indices, T* out, - Combiner combiner) { - int64_t block_ind_offset = blockIdx.x * query_nnz; - int64_t block_out_offset = blockIdx.x * embedding_width; +template +__global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, + const TIndex* indptr, const TIndex* indices, T* out, + Combiner combiner, TIndex num_rows, const T* weights) { + auto row_group = cg::tiled_partition(cg::this_thread_block()); + // smem same size as block size. extern __shared__ char shmem[]; - Tindices* shmem_indices = reinterpret_cast(shmem); - // each block handle one query(output) of embedding - EmbeddingReduceByIndices(out + block_out_offset, params, embedding_width, query_nnz, - indices + block_ind_offset, shmem_indices, combiner); -} + TIndex* shmem_indices = reinterpret_cast(shmem); + T* shmem_values = reinterpret_cast(shmem); + shmem_indices += threadIdx.y * blockDim.x; -template -__global__ void EmbeddingLookUpGradConstantHot(const T* grad, Tindices embedding_width, - Tindices query_nnz, T* value, Combiner combiner) { - int64_t block_value_offset = blockIdx.x * query_nnz * embedding_width; - int64_t block_grad_offset = blockIdx.x * embedding_width; - EmbeddingExpandGrad(value + block_value_offset, grad + block_grad_offset, embedding_width, - query_nnz, combiner); + int num_step = num_rows / gridDim.x; + indptr += blockIdx.x; + out += blockIdx.x * embedding_width + threadIdx.x; + if (blockIdx.x < (num_rows % gridDim.x)) num_step += 1; + int step_counter = threadIdx.y; + for (int step = 0; step < num_step; step++) { + int64_t block_ind_offset = indptr[0]; + int query_nnz = indptr[1] - block_ind_offset; + // we only want break down skewed long reductions, i.e, power law input backward. + // These reduction length correlate strongly to batchsize. Let's say we care about perf + // beyond 1k batchsize in general, then we probably need this threshold <512 to be able + // to breakdown long reduction in these cases. + // 128 is chosen so each warp have a full read into indptr when there are 4 of them. + // it seems works fine, but we can make it a function of launch config if needed + if (query_nnz > 128 && blockDim.y > 1) { + T result[(row + tile - 1) / tile] = {0}; + int prev_row_extra = + (query_nnz % blockDim.y) > threadIdx.y ? threadIdx.y : query_nnz % blockDim.y; + int row_extra = (query_nnz % blockDim.y) > threadIdx.y ? 1 : 0; + int row_offset = (query_nnz / blockDim.y) * threadIdx.y + prev_row_extra; + int row_nnz = (query_nnz / blockDim.y) + row_extra; + EmbeddingReduceByIndices( + row_group, result, params, embedding_width, row_nnz, + indices + block_ind_offset + row_offset, shmem_indices, combiner, weights); + __syncthreads(); + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + shmem_values[threadIdx.y * blockDim.x + threadIdx.x] = result[j]; + __syncthreads(); + if (threadIdx.y == 0) { + for (int i = 1; i < blockDim.y; i++) { + result[j] += shmem_values[i * blockDim.x + threadIdx.x]; + } + if (combiner == Combiner::Mean) { + result[j] /= query_nnz; + } + if (j * tile + threadIdx.x < embedding_width) out[j * tile] = result[j]; + } + __syncthreads(); + } + } else { + // only one row of threads handle one query(output) of embedding + // the rest of thread can proceed without stucking here + if (!step_counter) { + step_counter = blockDim.y; + T result[(row + tile - 1) / tile] = {0}; + EmbeddingReduceByIndices(row_group, result, params, embedding_width, + query_nnz, indices + block_ind_offset, + shmem_indices, combiner, weights); + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (combiner == Combiner::Mean) { + result[j] /= query_nnz; + } + if (j * tile + threadIdx.x < embedding_width) out[j * tile] = result[j]; + } + } + step_counter -= 1; + } + indptr += gridDim.x; + out += gridDim.x * embedding_width; + } } -template -__global__ void EmbeddingLookUpVariableHot(const T* params, Tindices embedding_width, - const Tindices* indptr, const Tindices* indices, T* out, - Combiner combiner) { - Tindices block_ind_offset = indptr[blockIdx.x]; - Tindices query_nnz = indptr[blockIdx.x + 1] - block_ind_offset; - int64_t block_out_offset = blockIdx.x * embedding_width; +// version for tile size greater than 32, differences are: +// each tile not within warp so no reduction with shfldown +// need experimental to support tile > 32 +// have an outer loop to handle arbitrary embedding_width +template +__global__ void EmbeddingLookUpVariableHotWide(const T* params, int embedding_width, + const TIndex* indptr, const TIndex* indices, T* out, + Combiner combiner, TIndex num_rows, + const T* weights) { + // reserve shared memory for thread_block_tile usage. + __shared__ cg::experimental::block_tile_memory shared_for_cg; + cg::thread_block thb = cg::experimental::this_thread_block(shared_for_cg); + auto row_group = cg::experimental::tiled_partition(thb); + // smem same size as block size. extern __shared__ char shmem[]; - Tindices* shmem_indices = reinterpret_cast(shmem); - // each block handle one query(output) of embedding - EmbeddingReduceByIndices(out + block_out_offset, params, embedding_width, query_nnz, - indices + block_ind_offset, shmem_indices, combiner); -} + TIndex* shmem_indices = reinterpret_cast(shmem); + T* shmem_values = reinterpret_cast(shmem); + shmem_indices += threadIdx.y * blockDim.x; + + int rem_width = embedding_width; + for (int out_i = 0; out_i < (embedding_width + row - 1) / row; ++out_i) { + int cur_id = blockIdx.x; + while (cur_id < num_rows) { + TIndex block_ind_offset = indptr[cur_id]; + int query_nnz = indptr[cur_id + 1] - block_ind_offset; + int64_t block_out_offset = cur_id * embedding_width; + if (query_nnz > 128 && blockDim.y > 1) { + T result[(row + tile - 1) / tile] = {0}; + int prev_row_extra = + (query_nnz % blockDim.y) > threadIdx.y ? threadIdx.y : query_nnz % blockDim.y; + int row_extra = (query_nnz % blockDim.y) > threadIdx.y ? 1 : 0; + int row_offset = (query_nnz / blockDim.y) * threadIdx.y + prev_row_extra; + int row_nnz = (query_nnz / blockDim.y) + row_extra; + EmbeddingReduceByIndicesWide( + row_group, result, params, embedding_width, row_nnz, + indices + block_ind_offset + row_offset, shmem_indices, combiner, weights, rem_width); + __syncthreads(); -template -__global__ void EmbeddingLookUpGradVariableHot(const T* grad, Tindices embedding_width, - const Tindices* indptr, T* value, - Combiner combiner) { - Tindices block_ind_offset = indptr[blockIdx.x]; - Tindices query_nnz = indptr[blockIdx.x + 1] - block_ind_offset; - int64_t block_value_offset = block_ind_offset * embedding_width; - int64_t block_grad_offset = blockIdx.x * embedding_width; - EmbeddingExpandGrad(value + block_value_offset, grad + block_grad_offset, embedding_width, - query_nnz, combiner); + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + shmem_values[threadIdx.y * blockDim.x + threadIdx.x] = result[j]; + __syncthreads(); + if (threadIdx.y == 0) { + for (int i = 1; i < blockDim.y; i++) { + result[j] += shmem_values[i * blockDim.x + threadIdx.x]; + } + if (combiner == Combiner::Mean) { + result[j] /= query_nnz; + } + if (j * tile + threadIdx.x < rem_width) + out[block_out_offset + j * tile + threadIdx.x] = result[j]; + } + __syncthreads(); + } + } else { + // only one row of threads handle one query(output) of embedding + // the rest of thread can proceed without stucking here + if ((cur_id / gridDim.x) % blockDim.y == threadIdx.y) { + T result[(row + tile - 1) / tile] = {0}; + EmbeddingReduceByIndicesWide( + row_group, result, params, embedding_width, query_nnz, indices + block_ind_offset, + shmem_indices, combiner, weights, rem_width); + _Pragma("unroll") + for (int j = 0; j < (row + tile - 1) / tile; ++j) { + if (combiner == Combiner::Mean) { + result[j] /= query_nnz; + } + if (j * tile + threadIdx.x < rem_width) + out[block_out_offset + j * tile + threadIdx.x] = result[j]; + } + } + } + cur_id += gridDim.x; + } + params += row; + out += row; + rem_width -= row; + } +} +template +__global__ void RowToSplit(TIndex* split_ptr, const TIndex* row_ptr, TIndex num_ids, + TIndex num_rows) { + // effectively parallel binary search + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid == num_rows) split_ptr[tid] = num_ids; + if (tid >= num_rows) return; + TIndex res, begin = 0, end = num_ids - 1; + while (begin < end) { + res = (begin + end) / 2; + if (row_ptr[res * 2] < tid) { + begin = res + 1; + } else if (row_ptr[res * 2] > tid) { + end = res - 1; + } else { + end = res; + } + } + split_ptr[tid] = end; } -template -struct EmbeddingLookupConstantHotnessFunctor { - void operator()(const Eigen::GpuDevice& d, T* output_ptr, const T* param_ptr, - const Tindices* ids_ptr, Tindices nnz_per_row, Tindices num_rows, - Tindices embedding_width, Combiner combiner) const { - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpConstantHot, num_rows, embedding_width, - embedding_width * sizeof(Tindices), d.stream(), param_ptr, - embedding_width, nnz_per_row, ids_ptr, output_ptr, combiner)); +template +__global__ void OffsetToWeightsAndRowId(const TIndex* indptr, int32_t* out, T* weights) { + TIndex block_start_offset = indptr[blockIdx.x]; + TIndex block_end_offset = indptr[blockIdx.x + 1]; + for (TIndex i = block_start_offset + threadIdx.x; i < block_end_offset; i += blockDim.x) { + out[i] = blockIdx.x; } -}; + if (threadIdx.x == 0 && weights) + weights[blockIdx.x] = static_cast(1) / static_cast(block_end_offset - block_start_offset); +} -template -struct EmbeddingLookupConstantHotnessGradFunctor { - void operator()(const Eigen::GpuDevice& d, T* output_ptr, const T* grad_ptr, Tindices nnz_per_row, - Tindices num_rows, Tindices embedding_width, Combiner combiner) const { - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpGradConstantHot, num_rows, - embedding_width, 0, d.stream(), grad_ptr, embedding_width, - nnz_per_row, output_ptr, combiner)); +template +struct RowToSplitFunctor { + void operator()(const Eigen::GpuDevice& d, TIndex* split_ptr, const TIndex* row_ptr, + TIndex num_ids, TIndex num_rows) const { + TF_CHECK_OK(GpuLaunchKernel(RowToSplit, num_rows / 512 + 1, 512, 0, d.stream(), + split_ptr, row_ptr, num_ids, num_rows)); } }; -template -struct EmbeddingLookupVariableHotnessFunctor { +template +struct EmbeddingLookupVariableHotnessFunctor { void operator()(const Eigen::GpuDevice& d, T* output_ptr, const T* param_ptr, - const Tindices* ids_ptr, const Tindices* offsets_ptr, Tindices num_rows, - Tindices embedding_width, Combiner combiner) const { - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, num_rows, embedding_width, - embedding_width * sizeof(Tindices), d.stream(), param_ptr, - embedding_width, offsets_ptr, ids_ptr, output_ptr, combiner)); + const TIndex* ids_ptr, const TIndex* offsets_ptr, TIndex num_rows, + TIndex embedding_width, Combiner combiner, TIndex ave_red_len) const { + int next_power_of_two = 1 << Log2Ceiling64(embedding_width); + + // decide number of parallel tile base on reduction length + int parallel_tile = 1; + if (ave_red_len >= 256) parallel_tile = 2; + if (ave_red_len >= 1024) parallel_tile = 4; + + // decide number of threads per tile and adjust number of tile with CUDA limits + int blockX = next_power_of_two / ILP; + if (blockX < 32) blockX = 32; + if (blockX > 256) blockX = 256; + if (parallel_tile * blockX > 1024) parallel_tile = 1024 / blockX; + + // decide grid dimension and dynamic shared memory size + dim3 blockDim = dim3(blockX, parallel_tile); + int smem_size = sizeof(TIndex) > sizeof(T) ? sizeof(TIndex) : sizeof(T); + smem_size = blockX * parallel_tile * smem_size; + int gridDim = 32768 / (blockX / 32 * parallel_tile); + if (gridDim > num_rows) gridDim = num_rows; + + switch (next_power_of_two) { + case 1: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 2: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 4: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 8: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 16: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 32: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 64: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 128: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 256: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + case 512: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + default: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + break; + } } }; -template -struct EmbeddingLookupVariableHotnessGradFunctor { - void operator()(const Eigen::GpuDevice& d, T* output_ptr, const T* grad_ptr, - const Tindices* offsets_ptr, Tindices num_rows, Tindices embedding_width, - Combiner combiner) const { - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpGradVariableHot, num_rows, - embedding_width, 0, d.stream(), grad_ptr, embedding_width, - offsets_ptr, output_ptr, combiner)); +template +struct EmbeddingLookupVariableHotnessGradFunctor { + void operator()(OpKernelContext* context, const TIndex* ids_ptr, const TIndex* offset_in_ptr, + const T* grad_ptr, int64_t num_ids, TIndex embedding_width, TIndex num_rows, + int64_t dense_shape_dim0, int64_t max_red_len, Combiner combiner) const { + const auto& cu_stream = GetGpuStream(context); + cub::CountingInputIterator itr(0); + + // allocate intermediate results buffer + Tensor tmp_unique_ids; + Tensor offsets; + Tensor num_unique_ids; + Tensor sorted_ids; + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &tmp_unique_ids); + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &offsets); + context->allocate_temp(DataTypeToEnum::value, TensorShape({1}), &num_unique_ids); + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &sorted_ids); + auto tmp_unique_ids_ptr = tmp_unique_ids.flat().data(); + auto offsets_ptr = offsets.flat().data(); + auto num_unique_ids_ptr = num_unique_ids.flat().data(); + auto sorted_ids_ptr = sorted_ids.flat().data(); + + Tensor row; + Tensor sorted_row; + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &row); + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &sorted_row); + auto row_ptr = row.flat().data(); + auto sorted_row_ptr = sorted_row.flat().data(); + + T* weights_ptr = nullptr; + Tensor weights; + if (combiner == Combiner::Mean) { + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_rows}), &weights); + weights_ptr = weights.flat().data(); + } + + TF_CHECK_OK(GpuLaunchKernel(OffsetToWeightsAndRowId, num_rows, 32, 0, cu_stream, + offset_in_ptr, row_ptr, weights_ptr)); + + // Determine temporary device storage requirements + size_t temp_sort = 0; + size_t temp_unique = 0; + cub::DeviceRadixSort::SortPairs(nullptr, temp_sort, ids_ptr, sorted_ids_ptr, row_ptr, + sorted_row_ptr, num_ids, 0, Log2Ceiling64(dense_shape_dim0), + cu_stream); + cub::DeviceSelect::UniqueByKey(nullptr, temp_unique, sorted_ids_ptr, itr, tmp_unique_ids_ptr, + offsets_ptr, num_unique_ids_ptr, num_ids, cu_stream); + Tensor temp_storage; + size_t temp_storage_bytes = temp_sort > temp_unique ? temp_sort : temp_unique; + context->allocate_temp(DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage); + + auto temp_storage_ptr = temp_storage.flat().data(); + cub::DeviceRadixSort::SortPairs(temp_storage_ptr, temp_sort, ids_ptr, sorted_ids_ptr, row_ptr, + sorted_row_ptr, num_ids, 0, Log2Ceiling64(dense_shape_dim0), + cu_stream); + cub::DeviceSelect::UniqueByKey(temp_storage_ptr, temp_unique, sorted_ids_ptr, itr, + tmp_unique_ids_ptr, offsets_ptr, num_unique_ids_ptr, num_ids, + cu_stream); + + // copy this back to host. should be ok to sync since there is not much to do in between + // TF way of doing it seems to be event query base + TIndex num_unique_ids_host = 0; + cudaMemcpyAsync(&num_unique_ids_host, num_unique_ids_ptr, sizeof(TIndex), + cudaMemcpyDeviceToHost, cu_stream); + + cudaMemcpyAsync(offsets_ptr + num_unique_ids_host, &num_ids, sizeof(int32_t), + cudaMemcpyHostToDevice, cu_stream); + // allocate output + Tensor* unique_ids = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_unique_ids_host}), &unique_ids)); + auto unique_ids_ptr = unique_ids->flat().data(); + cudaMemcpyAsync(unique_ids_ptr, tmp_unique_ids_ptr, num_unique_ids_host * sizeof(TIndex), + cudaMemcpyDeviceToDevice, cu_stream); + + Tensor* unique_grad = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, TensorShape({num_unique_ids_host, embedding_width}), + &unique_grad)); + auto unique_grad_ptr = unique_grad->flat().data(); + + int next_power_of_two = 1 << Log2Ceiling64(embedding_width); + + // decide number of parallel tile base on reduction length + int parallel_tile = 1; + if (max_red_len > 512) parallel_tile = 2; + if (max_red_len > 4096) parallel_tile = 4; + if (max_red_len > 65536) parallel_tile = 6; + + // decide number of threads per tile and adjust number of tile with CUDA limits + int blockX = next_power_of_two / ILP; + if (blockX < 32) blockX = 32; + if (blockX > 256) blockX = 256; + if (parallel_tile * blockX > 1024) parallel_tile = 1024 / blockX; + + // decide grid dimension and dynamic shared memory size + dim3 blockDim = dim3(blockX, parallel_tile); + int smem_size = sizeof(TIndex) > sizeof(T) ? sizeof(TIndex) : sizeof(T); + smem_size = blockX * parallel_tile * smem_size; + int gridDim = 32768 / (blockX / 32 * parallel_tile); + if (gridDim > num_unique_ids_host) gridDim = num_unique_ids_host; + + switch (next_power_of_two) { + case 1: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 2: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 4: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 8: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 16: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 32: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 64: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 128: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 256: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + case 512: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + default: + TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); + break; + } } }; -template struct EmbeddingLookupConstantHotnessFunctor; -template struct EmbeddingLookupConstantHotnessGradFunctor; +template struct RowToSplitFunctor; +template struct RowToSplitFunctor; template struct EmbeddingLookupVariableHotnessFunctor; -template struct EmbeddingLookupVariableHotnessGradFunctor; -template struct EmbeddingLookupConstantHotnessFunctor; -template struct EmbeddingLookupConstantHotnessGradFunctor; template struct EmbeddingLookupVariableHotnessFunctor; +template struct EmbeddingLookupVariableHotnessGradFunctor; template struct EmbeddingLookupVariableHotnessGradFunctor; } // namespace tensorflow diff --git a/distributed_embeddings/cc/ops/embedding_lookup_ops.cc b/distributed_embeddings/cc/ops/embedding_lookup_ops.cc index c139e92..047bcc6 100644 --- a/distributed_embeddings/cc/ops/embedding_lookup_ops.cc +++ b/distributed_embeddings/cc/ops/embedding_lookup_ops.cc @@ -20,6 +20,7 @@ #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { + REGISTER_OP("ReadVariableNoCopy") .Input("resource: resource") .Output("value: dtype") @@ -31,39 +32,13 @@ REGISTER_OP("ReadVariableNoCopy") return Status::OK(); }); -REGISTER_OP("EmbeddingLookupConstantHotness") - .Attr("T: {float}") - .Attr("Tindices: {int32, int64}") - .Input("param: T") - .Input("ids: Tindices") - .Output("output_params: T") - .Attr("combiner: {'sum', 'mean'}") - .SetShapeFn([](shape_inference::InferenceContext* c) { - // param: [N,p], ids:[m, n], output: [m, p] - shape_inference::ShapeHandle params_shape; - shape_inference::ShapeHandle ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, ¶ms_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &ids_shape)); - c->set_output(0, c->Matrix(c->Dim(ids_shape, 0), c->Dim(params_shape, 1))); - return Status::OK(); - }); - -REGISTER_OP("EmbeddingLookupConstantHotnessGrad") - .Attr("T: {float}") +REGISTER_OP("RowToSplit") .Attr("Tindices: {int32, int64}") - .Input("grad: T") - .Input("ids: Tindices") - .Output("grad_param_value: T") - .Attr("combiner: {'sum', 'mean'}") + .Input("row_ids: Tindices") + .Input("shape: int32") + .Output("row_split: Tindices") .SetShapeFn([](shape_inference::InferenceContext* c) { - // param: [N,p], ids:[m,n], grad: [m,p], grad_param_value: [m*n, p] - // we used ids as input here just to do shape inference - shape_inference::ShapeHandle grad_shape; - shape_inference::ShapeHandle ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &grad_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &ids_shape)); - auto outdim_0 = c->Value(c->Dim(ids_shape, 0)) * c->Value(c->Dim(ids_shape, 1)); - c->set_output(0, c->Matrix(outdim_0, c->Dim(grad_shape, 1))); + // TODO return Status::OK(); }); @@ -96,23 +71,19 @@ REGISTER_OP("EmbeddingLookupVariableHotness") REGISTER_OP("EmbeddingLookupVariableHotnessGrad") .Attr("T: {float}") .Attr("Tindices: {int32, int64}") - .Input("grad: T") .Input("ids: Tindices") - .Input("offsets: Tindices") - .Output("grad_param_value: T") + .Input("offset: Tindices") + .Input("grad: T") + .Input("param: T") + .Output("unique_ids: Tindices") + .Output("unique_grad: T") .Attr("combiner: {'sum', 'mean'}") .SetShapeFn([](shape_inference::InferenceContext* c) { - // vitual input: [m,n], param: [N,p], ids:[nnz], offsets:[m+1] - // grad: [m, p], grad_param_value: [nnz, p] - // we used ids as input here just to do shape inference - shape_inference::ShapeHandle grad_shape; - shape_inference::ShapeHandle ids_shape; - shape_inference::ShapeHandle offsets_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &grad_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &ids_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &offsets_shape)); - c->set_output(0, c->Matrix(c->Dim(ids_shape, 0), c->Dim(grad_shape, 1))); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &grad_shape)); + c->set_output(0, c->Vector(shape_inference::InferenceContext::kUnknownDim)); + c->set_output( + 1, c->Matrix(shape_inference::InferenceContext::kUnknownDim, c->Dim(grad_shape, 1))); return Status::OK(); }); diff --git a/distributed_embeddings/python/layers/dist_model_parallel.py b/distributed_embeddings/python/layers/dist_model_parallel.py index 7fe1135..641a00c 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel.py +++ b/distributed_embeddings/python/layers/dist_model_parallel.py @@ -299,7 +299,9 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- # cast before alltoall according to dtype policy mp_outs = tf.cast(mp_outs, self.compute_dtype) dp_outs = hvd.alltoall(mp_outs, name='out_mp_to_dp') - local_bs = inputs[0].shape[0] // self.world_size + batch_size = tf.shape( + inputs[0], out_type=tf.int32)[0] if inputs[0].shape[0] is None else inputs[0].shape[0] + local_bs = batch_size // self.world_size num_elements = [local_bs * item for item in self.strategy.widths_list_flat] split_outs = tf.split(dp_outs, num_elements) worker_order_res = [tf.reshape(split_out, [local_bs, -1]) for split_out in split_outs] @@ -309,8 +311,7 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- return result def _concat_column_slice_outputs(self, outs): - """Concat sliced outputs result from column slicing back together - """ + """Concat sliced outputs result from column slicing back together""" for start, end in self.strategy.sliced_out_ranges: outs[start:end] = [tf.concat(outs[start:end], axis=-1)] return outs @@ -414,6 +415,8 @@ def get_weights(self, all_ranks=False): Args: all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0. Default False. + Returns: + result (list): List of weight tensors. """ # avoid copy-on-read on dense access local_weights = [read_var_no_copy(w) for w in self.weights] diff --git a/distributed_embeddings/python/layers/dist_model_parallel_test.py b/distributed_embeddings/python/layers/dist_model_parallel_test.py index 9660484..bea39f0 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel_test.py +++ b/distributed_embeddings/python/layers/dist_model_parallel_test.py @@ -15,7 +15,6 @@ """Test of distributed model parallel""" import random import time -import numpy as np import tensorflow as tf from tensorflow.python.platform import test from tensorflow.python.keras import keras_parameterized @@ -53,7 +52,7 @@ def __init__(self, self.input_table_map = input_table_map def call(self, inputs): - if self.dist_embeddings: + if self.dist_embeddings is not None: outs = self.dist_embeddings(inputs) elif self.input_table_map: outs = [self.embeddings[j](i) for i, j in zip(inputs, self.input_table_map)] @@ -69,6 +68,14 @@ def get_config(self): """ return None + def train_step(self, data): + with tf.GradientTape() as tape: + out = tf.reduce_sum(self(data[0])) + tape = dmp.DistributedGradientTape(tape) + gradients = tape.gradient(out, self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + return {"out": out} + class DistributedEmbeddingTest(keras_parameterized.TestCase): @@ -127,8 +134,7 @@ def gen_inputs(self, table_sizes, input_to_table_map=None, mp_input_ids=None): return dp_inputs, mp_inputs def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs): - tf.random.set_seed(int(time.time()) + self.hvd_rank) - np.random.seed(int(time.time()) + self.hvd_rank) + tf.keras.utils.set_random_seed(int(time.time()) + self.hvd_rank) # run a batch to initialize weight tensors _ = ref_model(ref_inputs) _ = test_model(test_inputs) @@ -165,8 +171,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs): self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w)) def test_broadcast(self): - tf.random.set_seed(int(time.time()) + self.hvd_rank) - np.random.seed(int(time.time()) + self.hvd_rank) + tf.keras.utils.set_random_seed(int(time.time()) + self.hvd_rank) num_tables = 7 table_sizes = [[11, 7], [5, 8], [3, 8], [5, 8], [12, 25], [3, 12], [7, 13]] @@ -281,6 +286,40 @@ def test_column_slice_dup_worker(self): dp_inputs, mp_inputs = self.gen_inputs(table_sizes, mp_input_ids=mp_input_ids) self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs) + def test_model_fit_basic(self): + table_sizes = self.gen_table_sizes() + + ref_model = EmbeddingListModel(table_sizes, distribute=False) + test_model = EmbeddingListModel(table_sizes, distribute=True, strategy='basic') + optimizer = tf.keras.optimizers.SGD(learning_rate=1.5, momentum=0) + test_model.compile(optimizer=optimizer) + + dp_inputs, _ = self.gen_inputs(table_sizes) + ref_model(dp_inputs) + test_model(dp_inputs) + # broadcast ref model weights and set test model weights + hvd.broadcast_variables(ref_model.variables, root_rank=0) + ref_weights = ref_model.get_weights() + num_tables = len(ref_model.embeddings) + test_model.dist_embeddings.set_weights(ref_weights[:num_tables]) + test_model.dense.set_weights(ref_weights[num_tables:]) + + with tf.GradientTape() as tape: + ref_out = tf.reduce_sum(ref_model(dp_inputs)) + tape = hvd.DistributedGradientTape(tape) + ref_grads = tape.gradient(ref_out, ref_model.variables) + + optimizer.apply_gradients(zip(ref_grads, ref_model.variables)) + ref_weights = ref_model.get_weights() + + test_history = test_model.fit(dp_inputs, epochs=1, steps_per_epoch=1) + test_weights = test_model.dist_embeddings.get_weights(True) + test_model.dense.get_weights() + + self.assertAllClose(ref_out, test_history.history['out'][0]) + for ref_w, test_w in zip(ref_weights, test_weights): + # assert close here since order of accumulations(inputs and batch dim) might have changed + self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w)) + if __name__ == "__main__": test.main() diff --git a/distributed_embeddings/python/layers/embedding.py b/distributed_embeddings/python/layers/embedding.py index dccef3f..5378b4d 100644 --- a/distributed_embeddings/python/layers/embedding.py +++ b/distributed_embeddings/python/layers/embedding.py @@ -52,12 +52,11 @@ class Embedding(tf.keras.layers.Layer): the `embeddings` matrix (see `keras.constraints`). combiner (str): Reduction method, ['sum', 'mean'] or None. Default None. - When combiner is not None, only dense inputs with rank >=2 and ragged inputs with - rank==2 are supported. Embedding picked from last input dimension will be reduced. - - In other word, support one of the following input/output shape combination: - ND dense input: `(d1,...,dn)`, ND output: `(d1,...,dn-1,output_dim)`, N>=2 - 2D ragged input: `(batch_size, ragged_dim)`, 2D output: `(batch_size, output_dim)` + When combiner is not None, supported input and their respectively output shape are: + N-D `Tensor`: `(d1,...,dn)`, output shape: `(d1,...,dn-1,output_dim)`, N >= 2 + 2-D `RaggedTensor`: `(batch_size, ragged_dim)`, output shape: `(batch_size, output_dim)` + 2-D `SparseTensor`: `(batch_size, max_hotness)`, output shape: `(batch_size, output_dim)` + Embedding picked from last input dimension will be reduced with given combiner. """ def __init__(self, diff --git a/distributed_embeddings/python/layers/embedding_test.py b/distributed_embeddings/python/layers/embedding_test.py index 0677a2e..681a478 100644 --- a/distributed_embeddings/python/layers/embedding_test.py +++ b/distributed_embeddings/python/layers/embedding_test.py @@ -112,6 +112,24 @@ def test_ragged_input_with_mean_combiner(self): outputs = model.predict(ragged_factory_ops.constant([[1, 2, 2], [0], [1, 2]], ragged_rank=1)) self.assertAllEqual(outputs, [[5., 3.], [0., 3.], [4., 3.5]]) + @keras_parameterized.run_all_keras_modes + def test_sparse_input_with_mean_combiner(self): + layer = embedding.Embedding(input_dim=3, + output_dim=2, + combiner='mean', + weights=[np.array([[0., 3.], [1., 5.], [7., 2.]])]) + inputs = tf.keras.layers.Input(shape=(None,), dtype=tf.int64, sparse=True) + outputs = layer(inputs) + + model = tf.keras.Model(inputs, outputs) + model.run_eagerly = testing_utils.should_run_eagerly() + + outputs = model.predict( + tf.sparse.SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [2, 0], [2, 1]], + values=[1, 2, 2, 0, 1, 2], + dense_shape=[3, 4])) + self.assertAllEqual(outputs, [[5., 3.], [0., 3.], [4., 3.5]]) + @combinations.generate(combinations.combine(mode=['eager'])) def test_2d_input_with_sum_combiner_with_grad(self): layer = embedding.Embedding(output_dim=2, input_dim=3, combiner='sum') diff --git a/distributed_embeddings/python/ops/embedding_lookup_ops.py b/distributed_embeddings/python/ops/embedding_lookup_ops.py index 4f797f6..88b829d 100644 --- a/distributed_embeddings/python/ops/embedding_lookup_ops.py +++ b/distributed_embeddings/python/ops/embedding_lookup_ops.py @@ -39,8 +39,8 @@ def embedding_lookup(param, ids, combiner=None): Args: param (Tensor): A single tensor representing the complete embedding tensor. - ids (Tensor or RaggedTensor): A 2D `int32` or `int64` `Tensor` containing - the ids to be looked up in `param`. + ids (Tensor): A 2D `int32` or `int64` `Tensor` containing the ids to be looked up + in `param`. Also support `RaggedTensor` and `SparseTensor`. combiner (string or None): Reduction method, ['sum', 'mean'] or None. Default None. Returns: @@ -66,29 +66,40 @@ def embedding_lookup(param, ids, combiner=None): if combiner is None: return tf.nn.embedding_lookup(param, ids) if isinstance(ids, ragged_tensor.RaggedTensor): + # assuming no empty sample. tf.shape may fail on earlier tf version with ragged input + try: + dim_0 = tf.shape(ids, out_type=tf.int32)[0] if ids.shape[0] is None else ids.shape[0] + except: # pylint: disable=bare-except + dim_0 = tf.shape(ids.row_splits, + out_type=tf.int32)[0] - 1 if ids.shape[0] is None else ids.shape[0] + num_input = tf.shape( + ids.values, out_type=tf.int32)[0] if ids.values.shape[0] is None else ids.values.shape[0] + if dim_0 == num_input: + return tf.nn.embedding_lookup(param, ids.values) return ops.embedding_lookup_variable_hotness(read_var_no_copy(param), ids.values, ids.row_splits, combiner) - return ops.embedding_lookup_constant_hotness(read_var_no_copy(param), ids, combiner) - - -@tf.RegisterGradient("EmbeddingLookupConstantHotness") -def _embedding_lookup_constant_hotness_grad(op, grad): - """The gradients for `embedding_lookup_constant_hotness`. - Args: - op (object): The `embedding_lookup_constant_hotness` `Operation` that we are differentiating, - which we can use to find the inputs and outputs of the original op. - grad (Tensor): Gradient with respect to the output of `embedding_lookup_constant_hotness`. - Returns: - IndexedSlices: A `IndexedSlices` contain sparse gradients with respect to - the embedding parameter of `embedding_lookup_constant_hotness`. - """ - param_shape = tf.shape(op.inputs[0]) - ids = op.inputs[1] - grad_param_value = ops.embedding_lookup_constant_hotness_grad(grad, - ids, - combiner=op.get_attr('combiner')) - - return (tf.IndexedSlices(grad_param_value, tf.reshape(ids, [-1]), param_shape), None) + if isinstance(ids, tf.SparseTensor): + # sparse is ordered but may not be right-ragged. so we generate offset here + # avoid d2h copy in eager mode by using sparsetensor's shape directly + dim_0 = tf.shape(ids, out_type=tf.int32)[0] if ids.shape[0] is None else ids.shape[0] + num_input = tf.shape( + ids.values, out_type=tf.int32)[0] if ids.values.shape[0] is None else ids.values.shape[0] + if dim_0 == num_input: + return tf.nn.embedding_lookup(param, ids.values) + # use custom op to avoid bad XLA bahavior and d2h copy caused by searchsorted + row_splits = ops.row_to_split(ids.indices, dim_0) + # we really want ids.values and row_splits to be same dtype to simplify things + # since max(row_splits) here is likely ~total hotness, int32 should be ok + # TODO(Deyu): fuse this cast into above row_to_split function and make always int32 + return ops.embedding_lookup_variable_hotness(read_var_no_copy(param), ids.values, + tf.cast(row_splits, dtype=ids.values.dtype), + combiner) + dim1 = tf.shape(ids, out_type=tf.int32)[1] if ids.shape[1] is None else ids.shape[1] + if dim1 == 1: + return tf.nn.embedding_lookup(param, tf.squeeze(ids, [1])) + if combiner == 'sum': + return tf.reduce_sum(tf.nn.embedding_lookup(param, ids), axis=1) + return tf.reduce_mean(tf.nn.embedding_lookup(param, ids), axis=1) @tf.RegisterGradient("EmbeddingLookupVariableHotness") @@ -102,12 +113,10 @@ def _embedding_lookup_variable_hotness_grad(op, grad): IndexedSlices: A `IndexedSlices` contain sparse gradients with respect to the embedding parameter of `embedding_lookup_variable_hotness`. """ - ids = op.inputs[1] + param_shape = tf.shape(op.inputs[0]) + flat_ids = tf.reshape(op.inputs[1], [-1]) offsets = op.inputs[2] - grad_param_value = ops.embedding_lookup_variable_hotness_grad(grad, - ids, - offsets, - combiner=op.get_attr('combiner')) + unique_ids, unique_grad = ops.embedding_lookup_variable_hotness_grad( + flat_ids, offsets, grad, op.inputs[0], combiner=op.get_attr('combiner')) - param_shape = tf.cast(op.inputs[0].shape, dtype=tf.int64) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter - return (tf.IndexedSlices(grad_param_value, ids, param_shape), None, None) + return (tf.IndexedSlices(unique_grad, unique_ids, param_shape), None, None) diff --git a/distributed_embeddings/python/ops/embedding_lookup_ops_test.py b/distributed_embeddings/python/ops/embedding_lookup_ops_test.py index f506645..58a5f8c 100644 --- a/distributed_embeddings/python/ops/embedding_lookup_ops_test.py +++ b/distributed_embeddings/python/ops/embedding_lookup_ops_test.py @@ -80,6 +80,40 @@ def test_constant_hotness(self): self.assertAllEqual(ref_ret, ret) self.assertAllEqual(ref_g_dense, g_dense) + def test_sparse_tensor_input(self): + voc, emb, batch, max_hotness = 69, 64, 15, 207 + # create dense representation of index matrix + data_a = tf.random.uniform(shape=[batch, max_hotness], minval=1, maxval=max_hotness + 1) + data_b = tf.random.uniform(shape=[batch], minval=1, maxval=max_hotness + 1) + # make sure there is no empty row + data_c = tf.reshape(tf.eye(max_hotness, batch_shape=[batch // max_hotness + 1]), + [-1, max_hotness])[:batch] + + data_0 = tf.cast((data_a / tf.reshape(data_b, [-1, 1]) + data_c) > 1, tf.int64) + data_1 = tf.random.uniform(shape=[batch, max_hotness], minval=0, maxval=voc, dtype=tf.int64) + data = data_0 * data_1 + + # COO format for tf native API + ref_ids = tf.sparse.from_dense(data) + test_ids = tf.sparse.from_dense(data) + + initial_weight = tf.random.uniform([voc, emb], dtype=tf.float32) + param = tf.Variable(initial_weight) + + for red in ['sum', 'mean']: + with tf.GradientTape(persistent=True) as tape: + tape.watch(param) + ref_ret = tf.nn.embedding_lookup_sparse(param, ref_ids, sp_weights=None, combiner=red) + ret = embedding_lookup(param, test_ids, combiner=red) + ref_g = tape.gradient(ref_ret, param) + g = tape.gradient(ret, param) + + ref_g_dense = tf.convert_to_tensor(ref_g) + g_dense = tf.convert_to_tensor(g) + # Seems some ops in sparse lookup is running on CPU and rounding differently + self.assertAllClose(ref_ret, ret) + self.assertAllClose(ref_g_dense, g_dense) + if __name__ == '__main__': tf.test.main() diff --git a/examples/benchmarks/synthetic_models/main.py b/examples/benchmarks/synthetic_models/main.py index 8940cdf..cfc9223 100644 --- a/examples/benchmarks/synthetic_models/main.py +++ b/examples/benchmarks/synthetic_models/main.py @@ -23,10 +23,10 @@ import tensorflow as tf from tensorflow import keras -import horovod.tensorflow as hvd +import horovod.tensorflow.keras as hvd from config_v3 import synthetic_models_v3 -from synthetic_models import SyntheticModel, InputGenerator +from synthetic_models import SyntheticModelTFDE, SyntheticModelNative, InputGenerator from distributed_embeddings.python.layers import dist_model_parallel as dmp @@ -35,13 +35,16 @@ # pylint: disable=line-too-long # yapf: disable flags.DEFINE_integer("batch_size", 4, help="Global batch size") -flags.DEFINE_integer("num_data_batches", 10, help="Number of batches of synthetic data to generate") +flags.DEFINE_integer("num_data_batches", 1, help="Number of batches of synthetic data to generate") flags.DEFINE_float("alpha", 1.05, help="Exponent to generate power law distributed data") flags.DEFINE_integer("num_steps", 100, help="Number of steps to benchmark") flags.DEFINE_bool("dp_input", False, help="Use data parallel input") flags.DEFINE_string("model", "tiny", help="Choose model size to run benchmark") flags.DEFINE_enum("optimizer", "sgd", ["sgd", "adagrad", "adam"], help="Optimizer") flags.DEFINE_integer("column_slice_threshold", None, help="Upper bound of elements count in each column slice") +flags.DEFINE_bool("use_model_fit", False, help="Use Keras model.fit") +flags.DEFINE_string("embedding_device", "/GPU:0", help="device to place embedding. inputs are placed on same device") +flags.DEFINE_enum("embedding_api", "tfde", ["native", "tfde"], help="embedding to use.") # yapf: enable # pylint: enable=line-too-long @@ -60,16 +63,30 @@ def main(_): if FLAGS.batch_size % hvd_size != 0: raise ValueError(F"Batch size ({FLAGS.batch_size}) is not divisible by world size ({hvd_size})") + model_config = synthetic_models_v3[FLAGS.model] - model = SyntheticModel(model_config, - column_slice_threshold=FLAGS.column_slice_threshold, - dp_input=FLAGS.dp_input) + if FLAGS.embedding_api == "tfde": + if FLAGS.embedding_device != "/GPU:0": + raise ValueError( + F"distributed-embeddings api is not supported on device {FLAGS.embedding_device}.") + model = SyntheticModelTFDE(model_config, + column_slice_threshold=FLAGS.column_slice_threshold, + dp_input=FLAGS.dp_input) + elif FLAGS.embedding_api == "native": + if FLAGS.dp_input is False or FLAGS.column_slice_threshold is not None: + raise ValueError( + "Model parallel inputs and column slicing are not supported with native embedding api.") + model = SyntheticModelNative(model_config, embedding_device=FLAGS.embedding_device) + else: + raise ValueError(F"Unknown embedding api {FLAGS.embedding_api}.") + + mp_input_ids = None if FLAGS.dp_input else model.embeddings.strategy.input_ids_list[hvd_rank] input_gen = InputGenerator(model_config, FLAGS.batch_size, alpha=FLAGS.alpha, - input_ids_list=model.embeddings.strategy.input_ids_list, + mp_input_ids=mp_input_ids, num_batches=FLAGS.num_data_batches, - dp_input=FLAGS.dp_input) + embedding_device=FLAGS.embedding_device) if FLAGS.optimizer == "sgd": optimizer = tf.keras.optimizers.SGD(learning_rate=0.03, momentum=0) @@ -80,41 +97,58 @@ def main(_): bce = keras.losses.BinaryCrossentropy(reduction=keras.losses.Reduction.NONE, from_logits=True) - @tf.function - def train_step(numerical_features, categorical_features, labels): - with tf.GradientTape() as tape: - predictions = model((numerical_features, categorical_features)) - loss = tf.math.reduce_mean(bce(labels, predictions)) - tape = dmp.DistributedGradientTape(tape) - gradients = tape.gradient(loss, model.trainable_variables) - optimizer.apply_gradients(zip(gradients, model.trainable_variables)) - return loss - - # Run one step to warm up - numerical_features, cat_features, labels = input_gen[-1] - loss = train_step(numerical_features, cat_features, labels) + # Run one step to init and broadcast weights + (numerical_features, cat_features), labels = input_gen[-1] + model((numerical_features, cat_features)) dmp.broadcast_variables(model.variables, root_rank=0) - _ = hvd.allreduce(loss, name="mean_loss", op=hvd.Average) - - start = time() - # Input data consumes a lot of memory. Instead of generating num_steps batch of synthetic data, - # We generate smaller amount of data and loop over them - for step in range(FLAGS.num_steps): - inputs = input_gen[step % FLAGS.num_data_batches] - numerical_features, cat_features, labels = inputs - loss = train_step(numerical_features, cat_features, labels) - if step == 0: - dmp.broadcast_variables(model.variables, root_rank=0) + + if not FLAGS.use_model_fit: + + @tf.function + def train_step(numerical_features, categorical_features, labels): + with tf.GradientTape() as tape: + predictions = model((numerical_features, categorical_features)) + loss = tf.math.reduce_mean(bce(labels, predictions)) + tape = dmp.DistributedGradientTape(tape) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + return loss + + # Run 5 steps to compile and warm up + (numerical_features, cat_features), labels = input_gen[-1] + for _ in range(5): + loss = train_step(numerical_features, cat_features, labels) + loss = hvd.allreduce(loss, name="mean_loss", op=hvd.Average) + # printing initial loss here to force sync before we start timer + print(F"Initial loss: {loss:.3f}") + + start = time() + # Input data consumes a lot of memory. Instead of generating num_steps batch of synthetic data, + # We generate smaller amount of data and loop over them + for step in range(FLAGS.num_steps): + inputs = input_gen[step % FLAGS.num_data_batches] + (numerical_features, cat_features), labels = inputs + loss = train_step(numerical_features, cat_features, labels) + loss = hvd.allreduce(loss, name="mean_loss", op=hvd.Average) + if step % 50 == 0: + loss = hvd.allreduce(loss, name="mean_loss", op=hvd.Average) + if hvd_rank == 0: + print(F"Benchmark step [{step}/{FLAGS.num_steps}]") + loss = hvd.allreduce(loss, name="mean_loss", op=hvd.Average) - if step % 50 == 0 and hvd_rank == 0: - print(F"Benchmark step [{step}/{FLAGS.num_steps}]") - - if hvd_rank == 0: - # printing GPU tensor forces a sync. loss was allreduced, printing on one GPU is enough - # for computing time so we don't print noisy messages from all ranks - print(F"loss: {loss:.3f}") - stop = time() - print(F"Iteration time: {(stop - start) * 1000 / FLAGS.num_steps:.3f} ms") + if hvd_rank == 0: + # printing GPU tensor forces a sync. loss was allreduced, printing on one GPU is enough + # for computing time so we don't print noisy messages from all ranks + print(F"loss: {loss:.3f}") + stop = time() + print(F"Iteration time: {(stop - start) * 1000 / FLAGS.num_steps:.3f} ms") + else: + model.compile(optimizer=optimizer, loss=bce) + + epochs = FLAGS.num_steps // FLAGS.num_data_batches + # A broadcast variable callback should be registered once Horovod supports broadcast only data + # parallel variables + model.fit(input_gen, epochs=epochs, batch_size=FLAGS.batch_size) if __name__ == '__main__': diff --git a/examples/benchmarks/synthetic_models/synthetic_models.py b/examples/benchmarks/synthetic_models/synthetic_models.py index af11831..ad3b353 100644 --- a/examples/benchmarks/synthetic_models/synthetic_models.py +++ b/examples/benchmarks/synthetic_models/synthetic_models.py @@ -48,29 +48,28 @@ def gen_power_law_data(batch_size, hotness, num_rows, alpha): # pylint: enable=missing-type-doc -class InputGenerator(): +class InputGenerator(keras.utils.Sequence): """Synthetic input generator Args: model_config (ModelConfig): A named tuple describes the synthetic model global_batch_size (int): Batch size. alpha (float): exponent to generate power law distributed input. 0 means uniform, default 0 - input_ids_list (list of int): Nested list containing input indices by rank. + mp_input_ids (list of int): List containing model parallel input indices. num_batches (int): Number of batches to generate. Default 100. - dp_input (bool): If True, generate data parallel input. Default False. + embedding_device (string): device to put embedding and inputs on """ def __init__(self, model_config, global_batch_size, alpha=0, - input_ids_list=None, + mp_input_ids=None, num_batches=10, - dp_input=False): - self.global_batch_size = global_batch_size + embedding_device='/GPU:0'): self.dp_batch_size = global_batch_size // hvd.size() + self.cat_batch_size = global_batch_size if mp_input_ids is not None else self.dp_batch_size self.num_batches = num_batches - self.dp_input = dp_input input_count = 0 embed_count = 0 @@ -87,22 +86,25 @@ def __init__(self, self.input_pool = [] for _ in range(num_batches): cat_features = [] - for input_id in input_ids_list[hvd.rank()]: + input_ids = mp_input_ids if mp_input_ids is not None else list(range(input_count)) + for input_id in input_ids: hotness, num_rows = global_input_shapes[input_id] - if alpha == 0: - cat_features.append( - tf.random.uniform(shape=[global_batch_size, hotness], maxval=num_rows, - dtype=tf.int64)) - else: - cat_features.append(gen_power_law_data(global_batch_size, hotness, num_rows, alpha)) - - numerical_features = tf.random.uniform( - shape=[self.dp_batch_size, model_config.num_numerical_features], - maxval=100, - dtype=tf.float32) - labels = tf.random.uniform(shape=[self.dp_batch_size, 1], maxval=2, dtype=tf.int32) - - self.input_pool.append((numerical_features, cat_features, labels)) + with tf.device(embedding_device): + if alpha == 0: + cat_features.append( + tf.random.uniform(shape=[self.cat_batch_size, hotness], + maxval=num_rows, + dtype=tf.int64)) + else: + cat_features.append(gen_power_law_data(self.cat_batch_size, hotness, num_rows, alpha)) + + numerical_features = tf.random.uniform( + shape=[self.dp_batch_size, model_config.num_numerical_features], + maxval=100, + dtype=tf.float32) + labels = tf.random.uniform(shape=[self.dp_batch_size, 1], maxval=2, dtype=tf.int32) + + self.input_pool.append(((numerical_features, cat_features), labels)) def __len__(self): return self.num_batches @@ -111,7 +113,7 @@ def __getitem__(self, idx): return self.input_pool[idx] -class SyntheticModel(keras.Model): # pylint: disable=abstract-method +class SyntheticModelTFDE(keras.Model): # pylint: disable=abstract-method """Main synthetic model class Args: @@ -122,9 +124,6 @@ class SyntheticModel(keras.Model): # pylint: disable=abstract-method """ def __init__(self, model_config, column_slice_threshold=None, dp_input=False): - if dp_input: - raise NotImplementedError - super().__init__() self.num_numerical_features = model_config.num_numerical_features @@ -168,6 +167,63 @@ def call(self, inputs): numerical_features, cat_features = inputs x = self.embeddings(cat_features) + if self.interact is not None: + x = [tf.squeeze(self.interact(tf.expand_dims(tf.concat(x, 1), axis=0)))] + x = tf.concat(x + [numerical_features], 1) + + x = self.mlp(x) + return x + + +class SyntheticModelNative(keras.Model): # pylint: disable=abstract-method + """Main synthetic model class + + Args: + model_config (ModelConfig): A named tuple describes the synthetic model + column_slice_threshold (int or None): upper bound of elements count in each slice + dp_input (bool): If True, use data parallel input. Otherwise model parallel input. + Default False. + """ + + def __init__(self, model_config, embedding_device='/GPU:0'): + super().__init__() + self.num_numerical_features = model_config.num_numerical_features + self.embedding_device = embedding_device + # Expand embedding configs and create embeddings + self.embeddings = [] + self.input_table_map = [] + embed_count = 0 + for config in model_config.embedding_configs: + if len(config.nnz) > 1 and not config.shared: + raise NotImplementedError("Nonshared multihot embedding is not implemented yet") + + for _ in range(config.num_tables): + self.embeddings.append(tf.keras.layers.Embedding(config.num_rows, config.width)) + for _ in range(len(config.nnz)): + self.input_table_map.append(embed_count) + embed_count += 1 + logging.info("%d embedding tables created.", embed_count) + + # Use a memory bandwidth limited pooling layer to emulate interaction (aka FM, pool, etc.) + if model_config.interact_stride is not None: + self.interact = keras.layers.AveragePooling1D(model_config.interact_stride, + padding='same', + data_format='channels_first') + else: + self.interact = None # use concatenation + + # Create MLP + self.mlp = keras.Sequential() + for size in model_config.mlp_sizes: + self.mlp.add(keras.layers.Dense(size, activation="relu")) + self.mlp.add(keras.layers.Dense(1, activation=None)) + + def call(self, inputs): + numerical_features, cat_features = inputs + + with tf.device(self.embedding_device): + x = [self.embeddings[ind](inp) for ind, inp in zip(self.input_table_map, cat_features)] + x = [tf.reduce_sum(t, 1) for t in x] if self.interact is not None: x = [tf.squeeze(self.interact(tf.expand_dims(tf.concat(x, 1), axis=0)))] @@ -175,3 +231,13 @@ def call(self, inputs): x = self.mlp(x) return x + + def train_step(self, data): + x, y = data + with tf.GradientTape() as tape: + predictions = self(x) + loss = tf.math.reduce_mean(self.compiled_loss(y, predictions)) + tape = dmp.DistributedGradientTape(tape) + gradients = tape.gradient(loss, self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + return {"loss": loss} diff --git a/third_party/cub b/third_party/cub new file mode 160000 index 0000000..cdaa955 --- /dev/null +++ b/third_party/cub @@ -0,0 +1 @@ +Subproject commit cdaa9558a85e45d849016e5fe7b6e4ee79113f95 diff --git a/version.txt b/version.txt index 6e8bf73..0ea3a94 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.0 +0.2.0