Skip to content

Commit

Permalink
#9749: move repeat (#11215)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar authored Aug 11, 2024

Verified

This commit was signed with the committer’s verified signature.
mattlord Matt Lord
1 parent fcf94ac commit 5161b53
Showing 38 changed files with 298 additions and 249 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
@@ -405,5 +405,3 @@ Other Operations
.. autofunction:: tt_lib.tensor.mean_hw

.. autofunction:: tt_lib.tensor.lamb_optimizer

.. autofunction:: tt_lib.tensor.repeat
6 changes: 3 additions & 3 deletions models/demos/falcon7b_common/tt/falcon_model.py
Original file line number Diff line number Diff line change
@@ -159,10 +159,10 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
)
# Repeat attn masks for all heads
for i in range(self.num_devices):
tt_attention_mask[i] = ttnn.experimental.tensor.repeat(
tt_attention_mask[i] = ttnn.repeat(
tt_attention_mask[i],
[1, self.config.num_attention_heads, 1, 1],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
ttnn.Shape([1, self.config.num_attention_heads, 1, 1]),
memory_config=self.model_config["ATTN_MASK_MEMCFG"],
)
# Tilize attn masks
for i in range(self.num_devices):
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask):
attn = ttnn.multiply(attn, scale)

## Need to figure out how to broadcast in t dim
# attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = tt2torch_tensor(attn_mask)
# attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1)
# attn_mask = torch2tt_tensor(attn_mask, self.device)
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tests/test_llama_attention.py
Original file line number Diff line number Diff line change
@@ -217,8 +217,8 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
attn_masks = ttnn.to_device(attn_masks, llama_attention_model.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=llama_attention_model.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_attention_model.model_config["DRAM_MEMCFG"]
)
return (
xs,
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
@@ -219,8 +219,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos):
attn_masks = ttnn.to_device(attn_masks, llama_decoder_model.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=llama_decoder_model.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_decoder_model.model_config["DRAM_MEMCFG"]
)
return (
xs,
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask):
attn = ttnn.multiply(attn, scale)

## Need to figure out how to broadcast in t dim
# attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = tt2torch_tensor(attn_mask)
# attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1)
# attn_mask = torch2tt_tensor(attn_mask, self.device)
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
@@ -138,8 +138,8 @@ def prepare_inputs(self, x, start_pos):
repeat_shape = (attn_batch, 1, 1, 1)

for i in range(self.num_devices):
attn_masks[i] = tt_lib.tensor.repeat(
attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks[i] = ttnn.repeat(
attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
@@ -213,8 +213,8 @@ def prepare_inputs(self, x, start_pos):
repeat_shape = (1, self.n_local_heads, 1, 1)

for i in range(self.num_devices):
attn_masks[i] = tt_lib.tensor.repeat(
attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks[i] = ttnn.repeat(
attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
@@ -278,8 +278,8 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None):
attn_masks = ttnn.to_device(attn_masks, self.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
6 changes: 3 additions & 3 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
@@ -1115,9 +1115,9 @@ def repeat(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.repeat(t0, repeat, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)
t1 = ttnn.repeat(t0, ttnn.Shape(repeat), memory_config=output_mem_config)
output_tensor = ttnn.to_torch(t1)
return output_tensor


@setup_host_and_device
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ def run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, ou

tt_cpu = input.repeat(torch.Size(repeats))

tt = ttl.tensor.repeat(tt_input, ttl.tensor.Shape(repeats), output_mem_config)
tt = ttnn.repeat(tt_input, ttnn.Shape(repeats), memory_config=output_mem_config)

tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

4 changes: 2 additions & 2 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
@@ -1520,7 +1520,7 @@ def swiglu_2(x):


def repeat(x):
tt_lib.tensor.repeat(x, (1, 1, 1, 4))
ttnn.repeat(x, ttnn.Shape((1, 1, 1, 4)))


def repeat_interleave_0(x):
@@ -2252,7 +2252,7 @@ def clone(x):
},
{
"op": repeat,
"name": "tt_lib.tensor.repeat",
"name": "ttnn.repeat",
},
{
"op": repeat_interleave_0,
2 changes: 1 addition & 1 deletion tests/ttnn/profiling/reference.txt
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ tt_lib.tensor.real,200,0.027,0.029,0.06,0.012
ttnn.real_bw,200,0.821,0.827,0.847
ttnn.reglu_dim_2,200,0.102,0.107,0.245,0.045
ttnn.reglu_dim_3,200,0.105,0.111,0.244,0.045
tt_lib.tensor.repeat,200,0.025,0.027,0.368,0.009
ttnn.repeat,200,0.025,0.027,0.368,0.009
ttnn.repeat_interleave_dim_0,200,0.039,0.043,0.375,0.01
ttnn.repeat_interleave_dim_1,80,0.42,0.429,323.298,0.219
ttnn.repeat_interleave_dim_2,80,0.152,0.154,150.628,0.076
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ def run_repeat_tests(

x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0])

tt_result = ttnn.repeat(x, ttnn.Shape(shape))
tt_result = ttnn.repeat(x, shape)

tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config)

4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -141,6 +141,10 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/split.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/device/non_zero_indices_op.cpp
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -120,8 +120,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/multi_core/repeat_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/repeat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_tms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_falcon7b.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_decode.cpp
90 changes: 0 additions & 90 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
#include "ttnn/deprecated/tt_dnn/op_library/move/move_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reshape/reshape_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/fold/fold_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp"
@@ -63,20 +62,6 @@ namespace tt::tt_metal::detail{
)doc"
);

m_tensor.def("repeat", &tt::tt_metal::repeat,
py::arg("input"), py::arg("size"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``size``. The rank of ``size`` should be less than or equal to the rank of tensor ``input_a``.
Output tensor will have same data type as input.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Input tensor for which repetition is computed", "Tensor", "Tensor of any shape", "Yes"
"size", "The number of times to repeat this tensor along each dimension", "List[Int]", "Positive repetition values", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("assign",
[](const Tensor& input_a, const Tensor& input_b, uint8_t queue_id){
return assign(queue_id, input_a, input_b); },
37 changes: 0 additions & 37 deletions ttnn/cpp/ttnn/operations/data_movement.hpp

This file was deleted.

34 changes: 2 additions & 32 deletions ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/operations/data_movement/concat/concat_pybind.hpp"
#include "ttnn/operations/data_movement/pad/pad_pybind.hpp"
#include "ttnn/operations/data_movement/permute/permute_pybind.hpp"
@@ -24,6 +23,7 @@
#include "ttnn/operations/data_movement/untilize_with_halo_v2/untilize_with_halo_v2_pybind.hpp"
#include "ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.hpp"
#include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp"
#include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp"


namespace py = pybind11;
@@ -32,43 +32,12 @@ namespace ttnn {
namespace operations {
namespace data_movement {

void bind_repeat(py::module& module) {
auto doc = R"doc(
repeat(input_tensor: ttnn.Tensor, shape : ttnn.Shape) -> ttnn.Tensor
Returns a new tensor filled with repetition of input :attr:`input_tensor` according to number of times specified in :attr:`shape`.
Args:
* :attr:`input_tensor`: the input_tensor to apply the repeate operation.
* :attr:`shape`: The number of repetitions for each element.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example:
>>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device)
>>> print(tensor)
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
)doc";

ttnn::bind_registered_operation(
module,
ttnn::repeat,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg("shape"), py::kw_only(), py::arg("memory_config") = std::nullopt});
}

void py_module(py::module& module) {
detail::bind_permute(module);
detail::bind_concat(module);
detail::bind_pad(module);
detail::bind_slice(module);
bind_repeat(module);
detail::bind_repeat_interleave(module);
detail::bind_tilize(module);
detail::bind_tilize_with_val_padding(module);
@@ -80,6 +49,7 @@ void py_module(py::module& module) {
detail::bind_untilize_with_halo_v2(module);
bind_non_zero_indices(module);
bind_fill_rm(module);
py_bind_repeat(module);
}

} // namespace data_movement
Loading

0 comments on commit 5161b53

Please sign in to comment.