Skip to content

Commit

Permalink
version 0.2 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed committed Oct 8, 2022
1 parent 427f869 commit 6e7b613
Show file tree
Hide file tree
Showing 18 changed files with 962 additions and 368 deletions.
2 changes: 2 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
BasedOnStyle: Google
ColumnLimit: 100
DerivePointerAlignment: false
StatementMacros:
- _Pragma
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cub"]
path = third_party/cub
url = https://github.com/NVIDIA/cub.git
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ enable=indexing-exception,old-raise-syntax
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter,missing-return-type-doc,missing-type-doc,missing-param-doc


# Set the cache size for astng objects.
Expand Down
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ PYTHON_BIN_PATH = python

TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')

CFLAGS = ${TF_CFLAGS} -O3 -std=c++14
TF_VERSION := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(int(tf.__version__.split(".")[1]))')
ifeq ($(shell expr $(TF_VERSION) \>= 10), 1)
CPP_STD := 17
else
CPP_STD := 14
endif

CFLAGS = ${TF_CFLAGS} -O3 -std=c++${CPP_STD}
LDFLAGS = -shared ${TF_LFLAGS}

SRC = embedding_lookup_kernels
Expand All @@ -34,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so
all: $(TARGET_LIB)

%_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h
$(NVCC) -c -o $@ $< $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
$(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr

%_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h
$(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include
Expand Down
24 changes: 10 additions & 14 deletions distributed_embeddings/cc/kernels/embedding_lookup.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,32 @@

#include <string>

#include "tensorflow/core/framework/op_kernel.h"

namespace tensorflow {
enum class Combiner { Mean = 0, Sum = 1 };
inline Combiner StringToEnum(std::string combiner) {
return combiner == "mean" ? Combiner::Mean : Combiner::Sum;
}

template <typename Device, typename T, typename Tindices>
struct EmbeddingLookupConstantHotnessFunctor {
void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr,
Tindices nnz_per_row, Tindices num_rows, Tindices embedding_width,
Combiner combiner) const;
};

template <typename Device, typename T, typename Tindices>
struct EmbeddingLookupConstantHotnessGradFunctor {
void operator()(const Device& d, T* output_ptr, const T* grad_ptr, Tindices nnz_per_row,
Tindices num_rows, Tindices embedding_width, Combiner combiner) const;
template <typename Device, typename Tindices>
struct RowToSplitFunctor {
void operator()(const Device& d, Tindices* split_ptr, const Tindices* row_ptr, Tindices num_ids,
Tindices num_rows) const;
};

template <typename Device, typename T, typename Tindices>
struct EmbeddingLookupVariableHotnessFunctor {
void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr,
const Tindices* offsets_ptr, Tindices num_rows, Tindices embedding_width,
Combiner combiner) const;
Combiner combiner, Tindices ave_red_len) const;
};

template <typename Device, typename T, typename Tindices>
struct EmbeddingLookupVariableHotnessGradFunctor {
void operator()(const Device& d, T* output_ptr, const T* grad_ptr, const Tindices* offsets_ptr,
Tindices num_rows, Tindices embedding_width, Combiner combiner) const;
void operator()(OpKernelContext* context, const Tindices* ids_ptr, const Tindices* row_ptr,
const T* grad_ptr, int64_t num_ids, Tindices embedding_width, Tindices num_rows,
int64_t dense_shape_dim0, int64_t max_red_len, Combiner combiner) const;
};

} // namespace tensorflow
Expand Down
123 changes: 41 additions & 82 deletions distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/

#include "embedding_lookup.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"

Expand Down Expand Up @@ -45,62 +44,25 @@ class ReadVariableNoCopyOp : public OpKernel {
DataType dtype_;
};

template <typename Device, typename T, typename Tindices>
class EmbeddingLookupConstantHotnessOp : public OpKernel {
public:
explicit EmbeddingLookupConstantHotnessOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner));
}

void Compute(OpKernelContext* context) override {
const Tensor& params = context->input(0);
const Tensor& ids = context->input(1);

auto num_rows = ids.dim_size(0);
auto nnz_per_row = ids.dim_size(1);
auto embedding_width = params.dim_size(1);

TensorShape output_shape({num_rows, embedding_width});
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

EmbeddingLookupConstantHotnessFunctor<Device, T, Tindices>()(
context->eigen_device<Device>(), output->flat<T>().data(), params.flat<T>().data(),
ids.flat<Tindices>().data(), nnz_per_row, num_rows, embedding_width,
StringToEnum(_combiner));
}

private:
string _combiner;
};

template <typename Device, typename T, typename Tindices>
class EmbeddingLookupConstantHotnessGradOp : public OpKernel {
template <typename Device, typename Tindices>
class RowToSplitOp : public OpKernel {
public:
explicit EmbeddingLookupConstantHotnessGradOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner));
}
explicit RowToSplitOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
const Tensor& grad = context->input(0);
const Tensor& ids = context->input(1);
// [n, 2]
const Tensor& row = context->input(0);
auto num_ids = row.dim_size(0);
auto num_rows = context->input(1).scalar<int32>()();

auto num_rows = ids.dim_size(0);
auto nnz_per_row = ids.dim_size(1);
auto nnz = num_rows * nnz_per_row;
auto embedding_width = grad.dim_size(1);

TensorShape output_shape({nnz, embedding_width});
TensorShape output_shape({num_rows + 1});
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

EmbeddingLookupConstantHotnessGradFunctor<Device, T, Tindices>()(
context->eigen_device<Device>(), output->flat<T>().data(), grad.flat<T>().data(),
nnz_per_row, num_rows, embedding_width, StringToEnum(_combiner));
RowToSplitFunctor<Device, Tindices>()(context->eigen_device<Device>(),
output->flat<Tindices>().data(),
row.flat<Tindices>().data(), num_ids, num_rows);
}

private:
string _combiner;
};

template <typename Device, typename T, typename Tindices>
Expand All @@ -118,14 +80,17 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel {
auto num_rows = offsets.dim_size(0) - 1;
auto embedding_width = params.dim_size(1);

auto num_ids = ids.dim_size(0);
auto ave_red_len = num_ids / num_rows;

TensorShape output_shape({num_rows, embedding_width});
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

EmbeddingLookupVariableHotnessFunctor<Device, T, Tindices>()(
context->eigen_device<Device>(), output->flat<T>().data(), params.flat<T>().data(),
ids.flat<Tindices>().data(), offsets.flat<Tindices>().data(), num_rows, embedding_width,
StringToEnum(_combiner));
StringToEnum(_combiner), ave_red_len);
}

private:
Expand All @@ -140,21 +105,20 @@ class EmbeddingLookupVariableHotnessGradOp : public OpKernel {
}

void Compute(OpKernelContext* context) override {
const Tensor& grad = context->input(0);
const Tensor& ids = context->input(1);
const Tensor& offsets = context->input(2);

auto num_rows = offsets.dim_size(0) - 1;
const Tensor& ids = context->input(0);
const Tensor& offset_in = context->input(1);
const Tensor& grad = context->input(2);
const Tensor& param = context->input(3);
auto num_ids = ids.dim_size(0);
auto num_rows = offset_in.dim_size(0) - 1;
auto embedding_width = grad.dim_size(1);
auto nnz = ids.dim_size(0);

TensorShape output_shape({nnz, embedding_width});
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto max_red_len = grad.dim_size(0);
auto dense_shape_dim0 = param.dim_size(0);

EmbeddingLookupVariableHotnessGradFunctor<Device, T, Tindices>()(
context->eigen_device<Device>(), output->flat<T>().data(), grad.flat<T>().data(),
offsets.flat<Tindices>().data(), num_rows, embedding_width, StringToEnum(_combiner));
context, ids.flat<Tindices>().data(), offset_in.flat<Tindices>().data(),
grad.flat<T>().data(), num_ids, embedding_width, num_rows, dense_shape_dim0, max_red_len,
StringToEnum(_combiner));
}

private:
Expand All @@ -167,26 +131,21 @@ REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_DEFAULT).HostMe
REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableNoCopyOp);

#define REGISTER_GPU(T, Tindices) \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotness") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
EmbeddingLookupConstantHotnessOp<Eigen::GpuDevice, T, Tindices>); \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotnessGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
EmbeddingLookupConstantHotnessGradOp<Eigen::GpuDevice, T, Tindices>); \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
#define REGISTER_GPU(T, Tindices) \
REGISTER_KERNEL_BUILDER(Name("RowToSplit") \
.Device(DEVICE_GPU) \
.TypeConstraint<Tindices>("Tindices") \
.HostMemory("shape"), \
RowToSplitOp<Eigen::GpuDevice, Tindices>); \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
EmbeddingLookupVariableHotnessGradOp<Eigen::GpuDevice, T, Tindices>);

REGISTER_GPU(float, int64_t)
Expand Down
Loading

0 comments on commit 6e7b613

Please sign in to comment.