Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
swolchok committed Feb 27, 2025
1 parent afcec1d commit c9bd251
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
5 changes: 4 additions & 1 deletion kernels/portable/cpu/op_argmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ Tensor& argmax_out(
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) {
// the below condition as written is equivalent to
// !isnan(accval) && (isnan(v) || v > acc_val). See
// argument in op_argmin.cpp.
if (!std::isnan(acc_val) && !(v <= acc_val)) {
acc_val = v;
acc_ix = ix;
}
Expand Down
12 changes: 11 additions & 1 deletion kernels/portable/cpu/op_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ Tensor& argmin_out(
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
// the below condition as written is equivalent to !isnan(accval) && (isnan(v) || v < acc_val).
// cases:
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
// trivially equivalent to v < acc_val.
// - if acc_val is NaN, the whole thing is trivially false.
// - if acc_val is not NaN and v is NaN, then v >= acc_val
// - is false because all comparisons involving NaN are
// - false, so the result is true. The result is trivially
// - true for the above condition that uses isnan(v) as
// - well.
if (!std::isnan(acc_val) && !(v >= acc_val)) {
acc_val = v;
acc_ix = ix;
}
Expand Down
13 changes: 13 additions & 0 deletions kernels/test/op_argmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) {
EXPECT_TENSOR_EQ(out, expected);
// clang-format on
}

TEST_F(OpArgmaxTest, FirstNaNWins) {
TensorFactory<ScalarType::Float> tf_float;
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});

TensorFactory<ScalarType::Long> tf_long;
Tensor out = tf_long.zeros({});
Tensor expected = tf_long.make({}, {1});

Tensor ret = op_argmax_out(in, {}, false, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, expected);
}
13 changes: 13 additions & 0 deletions kernels/test/op_argmin_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) {
EXPECT_TENSOR_EQ(out, expected);
// clang-format on
}

TEST_F(OpArgminTest, FirstNaNWins) {
TensorFactory<ScalarType::Float> tf_float;
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});

TensorFactory<ScalarType::Long> tf_long;
Tensor out = tf_long.zeros({});
Tensor expected = tf_long.make({}, {1});

Tensor ret = op_argmin_out(in, {}, false, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, expected);
}

0 comments on commit c9bd251

Please sign in to comment.