Skip to content

Commit

Permalink
#13741: Develop Full Operation (#13743)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngohoang34 authored Oct 21, 2024
1 parent 571ba98 commit d180db1
Show file tree
Hide file tree
Showing 11 changed files with 555 additions and 1 deletion.
104 changes: 104 additions & 0 deletions tests/ttnn/unit_tests/operations/test_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch
import torch.nn as nn
import ttnn
from models.utility_functions import comp_allclose
from loguru import logger

from tests.ttnn.utils_for_testing import assert_equal


@pytest.mark.parametrize(
"input_shape",
[
[1, 3], # single tile
[32, 32], # single tile
[5, 17, 31], # multiple tiles
],
)
@pytest.mark.parametrize(
"fill_value",
[3, -1],
)
def test_full_int(device, input_shape, fill_value):
torch_any = torch.randint(0, 100, (input_shape), dtype=torch.int32)
torch_output_tensor = torch.full(input_shape, fill_value)

any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.moreh_full(input_shape, fill_value, any)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[1, 3], # single tile
[32, 32], # single tile
[5, 96, 64], # multiple tiles
[3, 91, 67, 77], # not multiple of 32
],
)
@pytest.mark.parametrize(
"fill_value",
[0.15, -1.2],
)
@pytest.mark.parametrize(
"dtype",
[
torch.bfloat16,
torch.float32,
],
)
def test_full_float(device, input_shape, fill_value, dtype):
torch_any = torch.rand((input_shape), dtype=dtype)

torch_output_tensor = torch.full(input_shape, fill_value)
any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.moreh_full(input_shape, fill_value, any)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32], # single tile
],
)
@pytest.mark.parametrize(
"fill_value",
[3],
)
@pytest.mark.parametrize(
"layout",
[
ttnn.TILE_LAYOUT, # Currently only support tile layout
],
)
def test_full_callback(device, input_shape, fill_value, layout, use_program_cache):
for i in range(2):
torch_any = torch.randint(0, 100, (input_shape), dtype=torch.int32)
torch_output_tensor = torch.full(input_shape, fill_value)

any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.moreh_full(input_shape, fill_value, any)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
torch_dummy = torch.randn([32, 32])
ttnn_dummy = ttnn.from_torch(torch_dummy, device=device)
if i == 0:
num_program_cache_entries = device.num_program_cache_entries()
assert num_program_cache_entries > 0
else:
assert device.num_program_cache_entries() == num_program_cache_entries
assert_equal(torch_output_tensor, output_tensor)
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/plusone/plusone_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/plusone/device/plusone_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/device/full_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/device/full_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Expand Down
5 changes: 4 additions & 1 deletion ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "ttnn/operations/embedding_backward/embedding_backward_pybind.hpp"
#include "ttnn/operations/examples/examples_pybind.hpp"
#include "ttnn/operations/experimental/experimental_pybind.hpp"
#include "ttnn/operations/full/full_pybind.hpp"
#include "ttnn/operations/kv_cache/kv_cache_pybind.hpp"
#include "ttnn/operations/loss/loss_pybind.hpp"
#include "ttnn/operations/matmul/matmul_pybind.hpp"
Expand All @@ -54,7 +55,6 @@ void py_module(py::module& module) {
auto m_examples = module.def_submodule("examples", "examples of operations");
examples::py_module(m_examples);


// Eltwise operations: unary, binary, ternary, backward, complex
auto m_unary = module.def_submodule("unary", "unary operations");
unary::py_module(m_unary);
Expand Down Expand Up @@ -96,6 +96,9 @@ void py_module(py::module& module) {
auto m_embedding_backward = module.def_submodule("embedding_backward", "embedding backward operations");
embedding_backward::py_bind_embedding_backward(m_embedding_backward);

auto m_full = module.def_submodule("full", "full operation");
full::bind_full_operation(m_full);

auto m_loss = module.def_submodule("loss", "loss operations");
loss::py_bind_loss_functions(m_loss);

Expand Down
89 changes: 89 additions & 0 deletions ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "full_device_operation.hpp"

#include "ttnn/tensor/tensor.hpp"

namespace ttnn::operations::full {
void FullOperation::validate_inputs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto any = tensor_args.any;
TT_FATAL(any.storage_type() == StorageType::DEVICE, "Full operation error: Any tensor must be on device");
TT_FATAL(
operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"Full operation error: Not currently supporting sharding");
TT_FATAL(
operation_attributes.layout == Layout::TILE, "Full operation error: Not currently supporting row major layout");

const auto shape = operation_attributes.shape;

TT_FATAL(
shape.size() > 1,
"Full operation error: Shape size must be greater than 1, but got shape size = {}",
shape.size());

for (int i = 0; i < shape.size(); i++) {
TT_FATAL(
shape[i] > 0,
"Full operation error: Invalid shape at index {}. Each dimension of the shape must be greater than 0, but"
"got shape[{}] = {}",
i,
i,
shape[i]);
}
}

FullOperation::program_factory_t FullOperation::select_program_factory(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return ProgramFactory{};
}

void FullOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate_inputs(operation_attributes, tensor_args);
};

void FullOperation::validate_on_program_cache_hit(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate_inputs(operation_attributes, tensor_args);
};

FullOperation::shape_return_value_t FullOperation::compute_output_shapes(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return SimpleShape(operation_attributes.shape);
};

FullOperation::tensor_return_value_t FullOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
tensor_args.any.device(),
operation_attributes.memory_config);
}

std::tuple<FullOperation::operation_attributes_t, FullOperation::tensor_args_t> FullOperation::invoke(
const std::vector<uint32_t> shape,
const std::variant<float, int> fill_value,
const Tensor& any,
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<MemoryConfig>& memory_config) {
return {
operation_attributes_t{
shape,
fill_value,
dtype.value_or(any.get_dtype()),
layout.value_or(any.get_layout()),
memory_config.value_or(any.memory_config()),
},
tensor_args_t{
any,
},
};
}
} // namespace ttnn::operations::full
71 changes: 71 additions & 0 deletions ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <variant>

#include "ttnn/decorators.hpp"
#include "ttnn/tensor/types.hpp"

namespace ttnn::operations::full {

struct FullOperation {
struct operation_attributes_t {
const std::vector<uint32_t> shape;
const std::variant<float, int> fill_value;
const DataType dtype;
const Layout layout;
const MemoryConfig memory_config;
};

struct tensor_args_t {
const Tensor& any;
};

using shape_return_value_t = SimpleShape;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
struct shared_variables_t {
KernelHandle writer_id;
std::size_t num_cores;
std::size_t core_h;
};

using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;

static cached_program_t create(
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& output);

static void override_runtime_arguments(
cached_program_t& cached_program,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& output);
};

using program_factory_t = std::variant<ProgramFactory>;

static void validate_inputs(const operation_attributes_t&, const tensor_args_t&);
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const std::vector<uint32_t> shape,
const std::variant<float, int> fill_value,
const Tensor& any,
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<MemoryConfig>& memory_config);
};

} // namespace ttnn::operations::full

namespace ttnn::prim {
constexpr auto full = ttnn::register_operation<"ttnn::prim::full", ttnn::operations::full::FullOperation>();
} // namespace ttnn::prim
Loading

0 comments on commit d180db1

Please sign in to comment.