diff --git a/kernels/portable/cpu/op_argmax.cpp b/kernels/portable/cpu/op_argmax.cpp index 5eb656d5b7..7e95c305e8 100644 --- a/kernels/portable/cpu/op_argmax.cpp +++ b/kernels/portable/cpu/op_argmax.cpp @@ -49,7 +49,10 @@ Tensor& argmax_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) { + // the below condition as written is equivalent to + // !isnan(accval) && (isnan(v) || v > acc_val). See + // argument in op_argmin.cpp. + if (!std::isnan(acc_val) && !(v <= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index 1c4a2572ea..0223a643e9 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -49,7 +49,17 @@ Tensor& argmin_out( for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { - if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) { + // the below condition as written is equivalent to !isnan(accval) && (isnan(v) || v < acc_val). + // cases: + // - if neither acc_val nor v is NaN, !(v >= acc_val) is + // trivially equivalent to v < acc_val. + // - if acc_val is NaN, the whole thing is trivially false. + // - if acc_val is not NaN and v is NaN, then v >= acc_val + // - is false because all comparisons involving NaN are + // - false, so the result is true. The result is trivially + // - true for the above condition that uses isnan(v) as + // - well. + if (!std::isnan(acc_val) && !(v >= acc_val)) { acc_val = v; acc_ix = ix; } diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 66c79cefff..4d68dfe88b 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgmaxTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmax_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index 250fe4f7e1..a0b2699a28 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) { EXPECT_TENSOR_EQ(out, expected); // clang-format on } + +TEST_F(OpArgminTest, FirstNaNWins) { + TensorFactory tf_float; + Tensor in = tf_float.make({4}, {1, NAN, -4, NAN}); + + TensorFactory tf_long; + Tensor out = tf_long.zeros({}); + Tensor expected = tf_long.make({}, {1}); + + Tensor ret = op_argmin_out(in, {}, false, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); +}