-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
571ba98
commit d180db1
Showing
11 changed files
with
555 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import torch | ||
import torch.nn as nn | ||
import ttnn | ||
from models.utility_functions import comp_allclose | ||
from loguru import logger | ||
|
||
from tests.ttnn.utils_for_testing import assert_equal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[1, 3], # single tile | ||
[32, 32], # single tile | ||
[5, 17, 31], # multiple tiles | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[3, -1], | ||
) | ||
def test_full_int(device, input_shape, fill_value): | ||
torch_any = torch.randint(0, 100, (input_shape), dtype=torch.int32) | ||
torch_output_tensor = torch.full(input_shape, fill_value) | ||
|
||
any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT) | ||
output_tensor = ttnn.moreh_full(input_shape, fill_value, any) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
assert_equal(torch_output_tensor, output_tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[1, 3], # single tile | ||
[32, 32], # single tile | ||
[5, 96, 64], # multiple tiles | ||
[3, 91, 67, 77], # not multiple of 32 | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[0.15, -1.2], | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.bfloat16, | ||
torch.float32, | ||
], | ||
) | ||
def test_full_float(device, input_shape, fill_value, dtype): | ||
torch_any = torch.rand((input_shape), dtype=dtype) | ||
|
||
torch_output_tensor = torch.full(input_shape, fill_value) | ||
any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT) | ||
output_tensor = ttnn.moreh_full(input_shape, fill_value, any) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
assert_equal(torch_output_tensor, output_tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[32, 32], # single tile | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[3], | ||
) | ||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.TILE_LAYOUT, # Currently only support tile layout | ||
], | ||
) | ||
def test_full_callback(device, input_shape, fill_value, layout, use_program_cache): | ||
for i in range(2): | ||
torch_any = torch.randint(0, 100, (input_shape), dtype=torch.int32) | ||
torch_output_tensor = torch.full(input_shape, fill_value) | ||
|
||
any = ttnn.from_torch(torch_any, device=device, layout=ttnn.TILE_LAYOUT) | ||
output_tensor = ttnn.moreh_full(input_shape, fill_value, any) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
torch_dummy = torch.randn([32, 32]) | ||
ttnn_dummy = ttnn.from_torch(torch_dummy, device=device) | ||
if i == 0: | ||
num_program_cache_entries = device.num_program_cache_entries() | ||
assert num_program_cache_entries > 0 | ||
else: | ||
assert device.num_program_cache_entries() == num_program_cache_entries | ||
assert_equal(torch_output_tensor, output_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "full_device_operation.hpp" | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
|
||
namespace ttnn::operations::full { | ||
void FullOperation::validate_inputs( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
auto any = tensor_args.any; | ||
TT_FATAL(any.storage_type() == StorageType::DEVICE, "Full operation error: Any tensor must be on device"); | ||
TT_FATAL( | ||
operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, | ||
"Full operation error: Not currently supporting sharding"); | ||
TT_FATAL( | ||
operation_attributes.layout == Layout::TILE, "Full operation error: Not currently supporting row major layout"); | ||
|
||
const auto shape = operation_attributes.shape; | ||
|
||
TT_FATAL( | ||
shape.size() > 1, | ||
"Full operation error: Shape size must be greater than 1, but got shape size = {}", | ||
shape.size()); | ||
|
||
for (int i = 0; i < shape.size(); i++) { | ||
TT_FATAL( | ||
shape[i] > 0, | ||
"Full operation error: Invalid shape at index {}. Each dimension of the shape must be greater than 0, but" | ||
"got shape[{}] = {}", | ||
i, | ||
i, | ||
shape[i]); | ||
} | ||
} | ||
|
||
FullOperation::program_factory_t FullOperation::select_program_factory( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
return ProgramFactory{}; | ||
} | ||
|
||
void FullOperation::validate_on_program_cache_miss( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
validate_inputs(operation_attributes, tensor_args); | ||
}; | ||
|
||
void FullOperation::validate_on_program_cache_hit( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
validate_inputs(operation_attributes, tensor_args); | ||
}; | ||
|
||
FullOperation::shape_return_value_t FullOperation::compute_output_shapes( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
return SimpleShape(operation_attributes.shape); | ||
}; | ||
|
||
FullOperation::tensor_return_value_t FullOperation::create_output_tensors( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
auto output_shape = compute_output_shapes(operation_attributes, tensor_args); | ||
return create_device_tensor( | ||
output_shape, | ||
operation_attributes.dtype, | ||
operation_attributes.layout, | ||
tensor_args.any.device(), | ||
operation_attributes.memory_config); | ||
} | ||
|
||
std::tuple<FullOperation::operation_attributes_t, FullOperation::tensor_args_t> FullOperation::invoke( | ||
const std::vector<uint32_t> shape, | ||
const std::variant<float, int> fill_value, | ||
const Tensor& any, | ||
const std::optional<DataType>& dtype, | ||
const std::optional<Layout>& layout, | ||
const std::optional<MemoryConfig>& memory_config) { | ||
return { | ||
operation_attributes_t{ | ||
shape, | ||
fill_value, | ||
dtype.value_or(any.get_dtype()), | ||
layout.value_or(any.get_layout()), | ||
memory_config.value_or(any.memory_config()), | ||
}, | ||
tensor_args_t{ | ||
any, | ||
}, | ||
}; | ||
} | ||
} // namespace ttnn::operations::full |
71 changes: 71 additions & 0 deletions
71
ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <variant> | ||
|
||
#include "ttnn/decorators.hpp" | ||
#include "ttnn/tensor/types.hpp" | ||
|
||
namespace ttnn::operations::full { | ||
|
||
struct FullOperation { | ||
struct operation_attributes_t { | ||
const std::vector<uint32_t> shape; | ||
const std::variant<float, int> fill_value; | ||
const DataType dtype; | ||
const Layout layout; | ||
const MemoryConfig memory_config; | ||
}; | ||
|
||
struct tensor_args_t { | ||
const Tensor& any; | ||
}; | ||
|
||
using shape_return_value_t = SimpleShape; | ||
using tensor_return_value_t = Tensor; | ||
|
||
struct ProgramFactory { | ||
struct shared_variables_t { | ||
KernelHandle writer_id; | ||
std::size_t num_cores; | ||
std::size_t core_h; | ||
}; | ||
|
||
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>; | ||
|
||
static cached_program_t create( | ||
const operation_attributes_t& operation_attributes, | ||
const tensor_args_t& tensor_args, | ||
tensor_return_value_t& output); | ||
|
||
static void override_runtime_arguments( | ||
cached_program_t& cached_program, | ||
const operation_attributes_t& operation_attributes, | ||
const tensor_args_t& tensor_args, | ||
tensor_return_value_t& output); | ||
}; | ||
|
||
using program_factory_t = std::variant<ProgramFactory>; | ||
|
||
static void validate_inputs(const operation_attributes_t&, const tensor_args_t&); | ||
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); | ||
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); | ||
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); | ||
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); | ||
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); | ||
|
||
static std::tuple<operation_attributes_t, tensor_args_t> invoke( | ||
const std::vector<uint32_t> shape, | ||
const std::variant<float, int> fill_value, | ||
const Tensor& any, | ||
const std::optional<DataType>& dtype, | ||
const std::optional<Layout>& layout, | ||
const std::optional<MemoryConfig>& memory_config); | ||
}; | ||
|
||
} // namespace ttnn::operations::full | ||
|
||
namespace ttnn::prim { | ||
constexpr auto full = ttnn::register_operation<"ttnn::prim::full", ttnn::operations::full::FullOperation>(); | ||
} // namespace ttnn::prim |
Oops, something went wrong.