From 59a2d5e522783fd0d0c2a07ad4af9d2203d4935c Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Fri, 25 Oct 2024 11:14:19 -0700 Subject: [PATCH] #14186: Fixed moreh_adam and moreh_adamw (#14243) * #14186: Fixed moreh_adam * #0: fixed adam too --- .../device/moreh_adam_device_operation.cpp | 23 ++++--------------- .../device/moreh_adamw_device_operation.cpp | 23 ++++--------------- 2 files changed, 8 insertions(+), 38 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp index cf0faa72b4d..3cc32ff7ed1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp @@ -160,25 +160,10 @@ std::tuple tt::stl::hash::hash_t { - return operation::hash_operation( - operation_attributes.beta1, - operation_attributes.beta2, - operation_attributes.eps, - operation_attributes.amsgrad, - operation_attributes.weight_decay, - operation_attributes.memory_config, - operation_attributes.compute_kernel_config, - tensor_args.param_in.memory_config(), - tensor_args.param_in.dtype(), - tensor_args.grad.memory_config(), - tensor_args.grad.dtype(), - tensor_args.exp_avg_in.memory_config(), - tensor_args.exp_avg_in.dtype(), - tensor_args.exp_avg_sq_in.memory_config(), - tensor_args.exp_avg_sq_in.dtype(), - tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().memory_config() - : MemoryConfig{}, - tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().dtype() : DataType::INVALID); + auto operation_attributes_without_step_and_lr = operation_attributes; + operation_attributes_without_step_and_lr.step = 0; + operation_attributes_without_step_and_lr.lr = 0.0f; + return tt::stl::hash::hash_objects_with_default_seed(operation_attributes_without_step_and_lr, tensor_args); } } // namespace ttnn::operations::moreh::moreh_adam diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp index 774d5d63885..1084b7de99e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp @@ -152,24 +152,9 @@ MorehAdamWDeviceOperation::invoke( tt::stl::hash::hash_t MorehAdamWDeviceOperation::compute_program_hash( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return operation::hash_operation( - operation_attributes.beta1, - operation_attributes.beta2, - operation_attributes.eps, - operation_attributes.amsgrad, - operation_attributes.weight_decay, - operation_attributes.memory_config, - operation_attributes.compute_kernel_config, - tensor_args.param_in.memory_config(), - tensor_args.param_in.dtype(), - tensor_args.grad.memory_config(), - tensor_args.grad.dtype(), - tensor_args.exp_avg_in.memory_config(), - tensor_args.exp_avg_in.dtype(), - tensor_args.exp_avg_sq_in.memory_config(), - tensor_args.exp_avg_sq_in.dtype(), - tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().memory_config() - : MemoryConfig{}, - tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().dtype() : DataType::INVALID); + auto operation_attributes_without_step_and_lr = operation_attributes; + operation_attributes_without_step_and_lr.step = 0; + operation_attributes_without_step_and_lr.lr = 0.0f; + return tt::stl::hash::hash_objects_with_default_seed(operation_attributes_without_step_and_lr, tensor_args); } } // namespace ttnn::operations::moreh::moreh_adamw