Skip to content

Commit

Permalink
#14186: Fixed moreh_adam and moreh_adamw (#14243)
Browse files Browse the repository at this point in the history
* #14186: Fixed moreh_adam

* #0: fixed adam too
  • Loading branch information
dmakoviichuk-tt authored Oct 25, 2024
1 parent 047da7a commit 59a2d5e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,25 +160,10 @@ std::tuple<MorehAdamOperation::operation_attributes_t, MorehAdamOperation::tenso
auto MorehAdamOperation::compute_program_hash(
const MorehAdamOperation::operation_attributes_t& operation_attributes,
const MorehAdamOperation::tensor_args_t& tensor_args) -> tt::stl::hash::hash_t {
return operation::hash_operation<MorehAdamOperation>(
operation_attributes.beta1,
operation_attributes.beta2,
operation_attributes.eps,
operation_attributes.amsgrad,
operation_attributes.weight_decay,
operation_attributes.memory_config,
operation_attributes.compute_kernel_config,
tensor_args.param_in.memory_config(),
tensor_args.param_in.dtype(),
tensor_args.grad.memory_config(),
tensor_args.grad.dtype(),
tensor_args.exp_avg_in.memory_config(),
tensor_args.exp_avg_in.dtype(),
tensor_args.exp_avg_sq_in.memory_config(),
tensor_args.exp_avg_sq_in.dtype(),
tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().memory_config()
: MemoryConfig{},
tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().dtype() : DataType::INVALID);
auto operation_attributes_without_step_and_lr = operation_attributes;
operation_attributes_without_step_and_lr.step = 0;
operation_attributes_without_step_and_lr.lr = 0.0f;
return tt::stl::hash::hash_objects_with_default_seed(operation_attributes_without_step_and_lr, tensor_args);
}

} // namespace ttnn::operations::moreh::moreh_adam
Original file line number Diff line number Diff line change
Expand Up @@ -152,24 +152,9 @@ MorehAdamWDeviceOperation::invoke(

tt::stl::hash::hash_t MorehAdamWDeviceOperation::compute_program_hash(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return operation::hash_operation<MorehAdamWDeviceOperation>(
operation_attributes.beta1,
operation_attributes.beta2,
operation_attributes.eps,
operation_attributes.amsgrad,
operation_attributes.weight_decay,
operation_attributes.memory_config,
operation_attributes.compute_kernel_config,
tensor_args.param_in.memory_config(),
tensor_args.param_in.dtype(),
tensor_args.grad.memory_config(),
tensor_args.grad.dtype(),
tensor_args.exp_avg_in.memory_config(),
tensor_args.exp_avg_in.dtype(),
tensor_args.exp_avg_sq_in.memory_config(),
tensor_args.exp_avg_sq_in.dtype(),
tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().memory_config()
: MemoryConfig{},
tensor_args.max_exp_avg_sq_in.has_value() ? tensor_args.max_exp_avg_sq_in.value().dtype() : DataType::INVALID);
auto operation_attributes_without_step_and_lr = operation_attributes;
operation_attributes_without_step_and_lr.step = 0;
operation_attributes_without_step_and_lr.lr = 0.0f;
return tt::stl::hash::hash_objects_with_default_seed(operation_attributes_without_step_and_lr, tensor_args);
}
} // namespace ttnn::operations::moreh::moreh_adamw

0 comments on commit 59a2d5e

Please sign in to comment.