Skip to content

Commit

Permalink
#16066: Add seed param to uniform and bernoulli ops (#16179)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue: #16066 

### Problem description
The uniform and bernoulli operation unit-tests behavior are
un-consistent because a random seed is created inside the program
factory.

### What's changed

- Add seed param to uniform and bernoulli operation, which allows
passing a seed to these ops in python side. If seed = 0 or not passed, a
random seed is generated in the program factory.
- Add a custom compute program hash function for these ops so that seed
param is not included in the the program hash.
- Update pytest for these operations

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12411880762
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

---------

Co-authored-by: Michael Chiou <156848643+ttmchiou@users.noreply.github.com>
  • Loading branch information
BuiChiTrung and ttmchiou authored Dec 27, 2024
1 parent 9879b37 commit 2a86ff7
Showing 14 changed files with 123 additions and 83 deletions.
100 changes: 47 additions & 53 deletions tests/ttnn/unit_tests/operations/test_bernoulli.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@
from loguru import logger


def run_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc=False, compute_kernel_options=None):
# Due to the issue with tensix instruction to generated pseudo-random numbers: #13904, the seed is temporarily fixed to make the test result consistent.
def run_bernoulli(shape, in_dtype, out_dtype, device, seed=0, is_out_alloc=False, compute_kernel_options=None):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
cpu_input = torch.rand(shape, dtype=get_lib_dtype(torch, in_dtype))
npu_input = ttnn.from_torch(cpu_input, device=device, dtype=get_lib_dtype(ttnn, in_dtype), layout=ttnn.TILE_LAYOUT)
@@ -30,62 +31,50 @@ def run_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc=False, comput
)

one_probs = []
for _ in range(10):
if is_out_alloc:
ttnn.bernoulli(
npu_input,
output=npu_output,
dtype=get_lib_dtype(ttnn, out_dtype),
compute_kernel_config=compute_kernel_config,
)
else:
npu_output = ttnn.bernoulli(
npu_input,
dtype=get_lib_dtype(ttnn, out_dtype),
compute_kernel_config=compute_kernel_config,
)

tt_output = ttnn.to_torch(npu_output).reshape(shape)
tt_output_list = tt_output.flatten().tolist()

c = Counter(tt_output_list)
one_probs.append(c[1] / len(tt_output_list))

if is_out_alloc:
ttnn.bernoulli(
npu_input,
seed,
output=npu_output,
dtype=get_lib_dtype(ttnn, out_dtype),
compute_kernel_config=compute_kernel_config,
)
else:
npu_output = ttnn.bernoulli(
npu_input,
seed,
dtype=get_lib_dtype(ttnn, out_dtype),
compute_kernel_config=compute_kernel_config,
)

tt_output = ttnn.to_torch(npu_output).reshape(shape)
tt_output_list = tt_output.flatten().tolist()

c = Counter(tt_output_list)
one_probs.append(c[1] / len(tt_output_list))
logger.info(f"one_probs={one_probs}")

expected_one_prob = 0.5
assert np.allclose(expected_one_prob, np.mean(one_probs), rtol=0.05)


# fmt: off
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize("shape",
@pytest.mark.parametrize(
"shape",
[
[2003],
[500, 500],
[1, 512, 2, 256],
],
)
@pytest.mark.parametrize("in_dtype",
[
"bfloat16",
"float32"
]
)
@pytest.mark.parametrize("out_dtype",
[
"bfloat16",
"float32"
]
)
@pytest.mark.parametrize("is_out_alloc",
[
True,
False
]
)
# fmt: on
def test_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc):
torch.manual_seed(0)
run_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc)
@pytest.mark.parametrize("seed", [6296, 3501, 1712])
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("out_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("is_out_alloc", [True, False])
def test_bernoulli(shape, seed, in_dtype, out_dtype, device, is_out_alloc):
torch.manual_seed(seed)
run_bernoulli(shape, in_dtype, out_dtype, device, seed=seed, is_out_alloc=is_out_alloc)


@skip_for_grayskull("Requires wormhole_b0 to run")
@@ -95,17 +84,21 @@ def test_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc):
[1, 21, 123, 24],
],
)
@pytest.mark.parametrize("seed", [1408])
@pytest.mark.parametrize("in_dtype", ["float32"])
@pytest.mark.parametrize("out_dtype", ["float32"])
@pytest.mark.parametrize("is_out_alloc", [True, False])
def test_bernoulli_callback(shape, in_dtype, out_dtype, device, is_out_alloc, use_program_cache):
torch.manual_seed(0)
def test_bernoulli_callback(shape, seed, in_dtype, out_dtype, device, is_out_alloc, use_program_cache):
torch.manual_seed(seed)
num_program_cache_entries_list = []
for i in range(2):
run_bernoulli(shape, in_dtype, out_dtype, device, is_out_alloc)
for _ in range(2):
run_bernoulli(shape, in_dtype, out_dtype, device, seed=seed, is_out_alloc=is_out_alloc)
# Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr
tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
# Cache must hit when we change seed and seed runtime arg is overrode
seed = seed + 1

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
@@ -114,11 +107,12 @@ def test_bernoulli_callback(shape, in_dtype, out_dtype, device, is_out_alloc, us
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"shape",
[[512, 512], [5, 4, 70, 40]],
[[512, 512], [5, 8, 70, 40]],
)
@pytest.mark.parametrize("in_dtype", ["float32"])
@pytest.mark.parametrize("out_dtype", ["float32"])
@pytest.mark.parametrize("seed", [1408])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_uniform_with_compute_kernel_options(shape, in_dtype, out_dtype, device, compute_kernel_options):
torch.manual_seed(0)
run_bernoulli(shape, in_dtype, out_dtype, device, compute_kernel_options)
def test_bernoulli_with_compute_kernel_options(shape, seed, in_dtype, out_dtype, device, compute_kernel_options):
torch.manual_seed(seed)
run_bernoulli(shape, in_dtype, out_dtype, device, seed=seed, compute_kernel_options=compute_kernel_options)
36 changes: 22 additions & 14 deletions tests/ttnn/unit_tests/operations/test_uniform.py
Original file line number Diff line number Diff line change
@@ -52,8 +52,8 @@ def benchmark_uniform(cpu_input, npu_input, rand_from, rand_to):
logger.info(f"NPU avg time: {npu_total_time / iter_num}ns")


def validate_uniform(npu_input, shape, rand_from, rand_to, dtype, compute_kernel_config):
ttnn.uniform(npu_input, rand_from, rand_to, compute_kernel_config=compute_kernel_config)
def validate_uniform(npu_input, shape, rand_from, rand_to, seed, dtype, compute_kernel_config):
ttnn.uniform(npu_input, rand_from, rand_to, seed, compute_kernel_config=compute_kernel_config)
tt_input = ttnn.to_torch(npu_input).reshape(shape)
elem_cnt = Counter(tt_input.flatten().tolist())

@@ -75,7 +75,8 @@ def validate_uniform(npu_input, shape, rand_from, rand_to, dtype, compute_kernel
assert np.allclose(npu_var, expected_var, rtol=0.5)


def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, mode=TestMode.VALIDATE):
# Due to the issue with tensix instruction to generated pseudo-random numbers: #13904, the seed is temporarily fixed to make the test result consistent.
def run_uniform(shape, rand_range, dtype, device, seed=0, compute_kernel_options=None, mode=TestMode.VALIDATE):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
rand_from, rand_to = rand_range[0], rand_range[1]
cpu_input = torch.ones(shape, dtype=get_lib_dtype(torch, dtype))
@@ -89,12 +90,12 @@ def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, m
shape=shape,
rand_from=rand_from,
rand_to=rand_to,
seed=seed,
dtype=dtype,
compute_kernel_config=compute_kernel_config,
)


@pytest.mark.skip("#16066: Undefined behaviour. It will fail on some runs and pass on others since it's stochastic.")
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"shape",
@@ -108,9 +109,10 @@ def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, m
)
@pytest.mark.parametrize("rand_range", [[0, 1], [2.1, 9], [-5.1, 1.2]])
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
def test_uniform(shape, rand_range, dtype, device):
torch.manual_seed(0)
run_uniform(shape, rand_range, dtype, device)
@pytest.mark.parametrize("seed", [2024, 19, 522021])
def test_uniform(shape, rand_range, dtype, seed, device):
torch.manual_seed(seed)
run_uniform(shape, rand_range, dtype, device, seed=seed)


@skip_for_grayskull("Requires wormhole_b0 to run")
@@ -120,14 +122,19 @@ def test_uniform(shape, rand_range, dtype, device):
)
@pytest.mark.parametrize("rand_range", [[-3, 4]])
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
def test_uniform_callback(shape, rand_range, dtype, device, use_program_cache):
torch.manual_seed(0)
@pytest.mark.parametrize("seed", [0])
def test_uniform_callback(shape, rand_range, dtype, seed, device, use_program_cache):
torch.manual_seed(seed)
num_program_cache_entries_list = []
for i in range(2):
run_uniform(shape, rand_range, dtype, device)
for _ in range(2):
run_uniform(shape, rand_range, dtype, device, seed=seed)
# Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr
tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device)
num_program_cache_entries_list.append(device.num_program_cache_entries())

# Cache must hit when we change seed and seed runtime arg is overrode
seed = seed + 1

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
@@ -138,9 +145,10 @@ def test_uniform_callback(shape, rand_range, dtype, device, use_program_cache):
"shape",
[[512, 512], [5, 2, 4, 70, 40]],
)
@pytest.mark.parametrize("seed", [1408])
@pytest.mark.parametrize("rand_range", [[0, 1]])
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_uniform_with_compute_kernel_options(shape, rand_range, dtype, device, compute_kernel_options):
torch.manual_seed(0)
run_uniform(shape, rand_range, dtype, device, compute_kernel_options)
def test_uniform_with_compute_kernel_options(shape, seed, rand_range, dtype, device, compute_kernel_options):
torch.manual_seed(seed)
run_uniform(shape, rand_range, dtype, device, seed=seed, compute_kernel_options=compute_kernel_options)
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/bernoulli/bernoulli.cpp
Original file line number Diff line number Diff line change
@@ -10,10 +10,11 @@
namespace ttnn::operations::bernoulli {
Tensor Bernoulli::invoke(
const Tensor& input,
const uint32_t seed,
const std::optional<Tensor>& output,
const std::optional<DataType>& dtype,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
return ttnn::prim::bernoulli(input, output, dtype, memory_config, compute_kernel_config);
return ttnn::prim::bernoulli(input, seed, output, dtype, memory_config, compute_kernel_config);
}
} // namespace ttnn::operations::bernoulli
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/bernoulli/bernoulli.hpp
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ namespace ttnn::operations::bernoulli {
struct Bernoulli {
static Tensor invoke(
const Tensor& input,
const uint32_t seed,
const std::optional<Tensor>& output,
const std::optional<DataType>& dtype,
const std::optional<MemoryConfig>& memory_config,
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/bernoulli/bernoulli_pybind.cpp
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ void bind_bernoulli_operation(py::module& module) {
doc,
ttnn::pybind_arguments_t{
py::arg("input"),
py::arg("seed") = 0,
py::kw_only(),
py::arg("output") = std::nullopt,
py::arg("dtype") = std::nullopt,
Original file line number Diff line number Diff line change
@@ -66,15 +66,23 @@ BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
}

tt::stl::hash::hash_t BernoulliDeviceOperation::compute_program_hash(const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto cached_operation_attributes = operation_attributes;
cached_operation_attributes.seed = 0;
return tt::stl::hash::hash_objects_with_default_seed(cached_operation_attributes, tensor_args);
}

std::tuple<BernoulliDeviceOperation::operation_attributes_t, BernoulliDeviceOperation::tensor_args_t>
BernoulliDeviceOperation::invoke(
const Tensor& input,
const uint32_t seed,
const std::optional<Tensor>& output,
const std::optional<DataType>& dtype,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
return {
operation_attributes_t{
seed,
dtype.value_or(DataType::FLOAT32),
memory_config.value_or(input.memory_config()),
init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)},
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ namespace ttnn::operations::bernoulli {

struct BernoulliDeviceOperation {
struct operation_attributes_t {
uint32_t seed;
const DataType dtype;
const MemoryConfig memory_config;
const DeviceComputeKernelConfig compute_kernel_config;
@@ -57,10 +58,13 @@ struct BernoulliDeviceOperation {

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& input,
const uint32_t seed,
const std::optional<Tensor>& output,
const std::optional<DataType>& dtype,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config);

static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&);
};

} // namespace ttnn::operations::bernoulli
Original file line number Diff line number Diff line change
@@ -103,7 +103,8 @@ BernoulliDeviceOperation::ProgramFactory::cached_program_t BernoulliDeviceOperat
});

uint32_t tile_offset = 0;
for (const auto& core : cores) {
for (int i = 0; i < cores.size(); ++i) {
const auto& core = cores[i];
uint32_t units_per_core;
if (core_group_1.contains(core)) {
units_per_core = units_per_core_group_1;
@@ -116,7 +117,10 @@ BernoulliDeviceOperation::ProgramFactory::cached_program_t BernoulliDeviceOperat
std::vector<uint32_t> reader_runtime_args = {input.buffer()->address(), tile_offset, units_per_core};
SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);

std::vector<uint32_t> compute_runtime_args = {get_random_seed(), tile_offset, units_per_core};
// Each core has its own seed to increase the number of generated random numbers
uint32_t seed = operation_attributes.seed != 0 ? operation_attributes.seed + i : get_random_seed();

std::vector<uint32_t> compute_runtime_args = {seed, tile_offset, units_per_core};
SetRuntimeArgs(program, compute_kernel_id, core, compute_runtime_args);

std::vector<uint32_t> writer_runtime_args = {output.buffer()->address(), tile_offset, units_per_core};
@@ -147,17 +151,17 @@ void BernoulliDeviceOperation::ProgramFactory::override_runtime_arguments(
const uint32_t input_addr = tensor_args.input.buffer()->address();
const uint32_t output_addr = output.buffer()->address();

for (const auto& core : cores) {
for (int i = 0; i < cores.size(); ++i) {
{
auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, cores[i]);
runtime_args[0] = input_addr;
}
{
auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = get_random_seed();
auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, cores[i]);
runtime_args[0] = operation_attributes.seed != 0 ? operation_attributes.seed + i : get_random_seed();
}
{
auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, cores[i]);
runtime_args[0] = output_addr;
}
}
Original file line number Diff line number Diff line change
@@ -43,17 +43,25 @@ UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_out
return tensor_args.input;
}

tt::stl::hash::hash_t UniformDeviceOperation::compute_program_hash(const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
auto cached_operation_attributes = operation_attributes;
cached_operation_attributes.seed = 0;
return tt::stl::hash::hash_objects_with_default_seed(cached_operation_attributes, tensor_args);
}

std::tuple<UniformDeviceOperation::operation_attributes_t, UniformDeviceOperation::tensor_args_t>
UniformDeviceOperation::invoke(
const Tensor& input,
const float from,
const float to,
const uint32_t seed,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
return {
operation_attributes_t{
from,
to,
seed,
memory_config.value_or(input.memory_config()),
init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)},
tensor_args_t{input}};
Loading

0 comments on commit 2a86ff7

Please sign in to comment.