diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 665f932b78e537..f9528266486a4a 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -1797,7 +1797,8 @@ TEST_P(XlaBuilderUnboundedUnaryOpTest, UnboundedUnaryOpTest) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); GetParam().unary_op(Parameter(&b, 0, operand, "operand")); - TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); EXPECT_THAT(GetRoot(*module), GmockMatch(m::Op().WithShapeEqualTo(&expected))); } @@ -1811,9 +1812,9 @@ TEST_P(XlaBuilderUnboundedBinaryOpTest, UnboundedBinaryOpTest) { GetParam().binary_op(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), GetParam().broadcast_dimensions); - if (auto result = BuildHloModule(b); result.ok()) { - const std::unique_ptr module = std::move(*result); - EXPECT_THAT(GetRoot(*module), + if (const auto result = BuildHloModule(b); result.ok()) { + ASSERT_NE(*result, nullptr); + EXPECT_THAT(GetRoot(**result), GmockMatch(m::Op().WithShapeEqualTo(&expected))); } else { ASSERT_TRUE(GetParam().error_message.has_value()); @@ -1856,6 +1857,42 @@ TEST(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { StatusIs(_, HasSubstr(kBroadcastDimensionMismatch))); } +TEST(XlaBuilderTest, UnboundedAllGather) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + AllGather(Parameter(&b, 0, operand, "operand"), /*all_gather_dimension=*/0, + /*shard_count=*/2, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAllReduce) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + XlaComputation computation; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(computation, sub_builder->Build()); + } + + AllReduce(Parameter(&b, 0, operand, "operand"), computation, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedAnd) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, @@ -1865,7 +1902,7 @@ TEST(XlaBuilderTest, UnboundedAnd) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); And(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), - /*broadcast_dimensions=*/absl::Span{}); + /*broadcast_dimensions=*/empty_array); TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); EXPECT_THAT(GetRoot(*module), GmockMatch(m::Op().WithShapeEqualTo(&expected))); @@ -1978,6 +2015,17 @@ TEST(XlaBuilderTest, UnboundedBroadcastInDimUnsupported) { "static or bounded dynamic"))); } +TEST(XlaBuilderTest, UnboundedCholesky) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Cholesky(Parameter(&b, 0, a, "a"), /*lower=*/true); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedClamp) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, @@ -2059,6 +2107,17 @@ TEST(XlaBuilderTest, StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); } +TEST(XlaBuilderTest, UnboundedCollectivePermute) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + CollectivePermute(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*source_target_pairs=*/{std::make_pair(0, 1)}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedCompare) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, @@ -2160,6 +2219,40 @@ TEST(XlaBuilderTest, UnboundedDotGeneral) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedDynamicSlice) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]")); + DynamicSlice(Parameter(&b, 0, operand, "operand"), + /*start_indices=*/ + { + Parameter(&b, 1, start_indices, "start_indices0"), + Parameter(&b, 2, start_indices, "start_indices1"), + }, + /*slice_sizes=*/{2, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedDynamicUpdateSlice) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape update, ParseShape("f32[?, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + DynamicUpdateSlice(Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, update, "update"), + /*start_indices=*/ + {Parameter(&b, 2, start_indices, "start_indices0"), + Parameter(&b, 3, start_indices, "start_indices1")}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedGather) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]")); @@ -2183,6 +2276,79 @@ TEST(XlaBuilderTest, UnboundedGather) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedGetTupleElement) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + GetTupleElement(Tuple(&b, {Parameter(&b, 0, operand, "operand")}), 0); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedInfeed) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Infeed(/*builder=*/&b, /*shape=*/shape, /*config=*/""); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedInfeedWithToken) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("(f32[?, 10], token[])")); + InfeedWithToken(/*token=*/CreateToken(&b), /*shape=*/shape, /*config=*/""); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedMap) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand0, ParseShape("f32[2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, ?, ?]")); + + XlaComputation computation; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(computation, sub_builder->Build()); + } + + Map(&b, /*operands=*/ + {Parameter(&b, 0, operand0, "operand0"), + Parameter(&b, 1, operand1, "operand1")}, + computation, /*dimensions=*/{0, 1, 2}, + /*static_operands=*/{}); + + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedOptimizationBarrier) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + OptimizationBarrier(Parameter(&b, 0, operand, "operand")); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedOr) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, @@ -2193,11 +2359,34 @@ TEST(XlaBuilderTest, UnboundedOr) { ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); Or(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), /*broadcast_dimensions=*/empty_array); - TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); EXPECT_THAT(GetRoot(*module), GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedOutfeed) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout, + ParseShape("f32[?, 10]")); + Outfeed(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*shape_with_layout=*/shape_with_layout, /*outfeed_config=*/""); + EXPECT_OK(BuildHloModule(b)); +} + +TEST(XlaBuilderTest, UnboundedOutfeedWithToken) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout, + ParseShape("f32[?, 10]")); + OutfeedWithToken(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*token=*/CreateToken(&b), + /*shape_with_layout=*/shape_with_layout, + /*outfeed_config=*/""); + EXPECT_OK(BuildHloModule(b)); +} + TEST(XlaBuilderTest, UnboundedPad) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); @@ -2216,6 +2405,36 @@ TEST(XlaBuilderTest, UnboundedPad) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedRecv) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::DEVICE_TO_DEVICE); + Recv(/*builder=*/&b, /*shape=*/shape, /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + +TEST(XlaBuilderTest, UnboundedRecvFromHost) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::HOST_TO_DEVICE); + RecvFromHost(/*token=*/CreateToken(&b), /*shape=*/shape, /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + +TEST(XlaBuilderTest, UnboundedRecvWithToken) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::DEVICE_TO_DEVICE); + RecvWithToken(/*token=*/CreateToken(&b), /*shape=*/shape, /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + TEST(XlaBuilderTest, UnboundedReduce) { XlaBuilder b(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {7}, {false}); @@ -2247,6 +2466,17 @@ TEST(XlaBuilderTest, UnboundedReduce) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedReducePrecision) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + ReducePrecision(Parameter(&b, 0, operand, "operand"), /*exponent_bits=*/2, + /*mantissa_bits=*/2); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedReduceWindow) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, 4, 8]")); @@ -2300,6 +2530,55 @@ TEST(XlaBuilderTest, UnboundedReshapeUnsupportedInferredShape) { "Reshaping with unbounded result shape is not supported."))); } +TEST(XlaBuilderTest, UnboundedReverse) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Rev(Parameter(&b, 0, operand, "operand"), /*dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngBitGenerator) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape initial_state, ParseShape("u32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("u32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("(u32[?, 10], u32[?, 10])")); + RngBitGenerator(RandomAlgorithm::RNG_DEFAULT, + Parameter(&b, 0, initial_state, "initial_state"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngNormal) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + RngNormal(Parameter(&b, 0, ShapeUtil::MakeScalarShape(F32), "mu"), + Parameter(&b, 1, ShapeUtil::MakeScalarShape(F32), "sigma"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngUniform) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + RngUniform(Parameter(&b, 0, ShapeUtil::MakeScalarShape(F32), "a"), + Parameter(&b, 1, ShapeUtil::MakeScalarShape(F32), "b"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + TEST(XlaBuilderTest, UnboundedScatter) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, ?, ?]")); @@ -2311,11 +2590,10 @@ TEST(XlaBuilderTest, UnboundedScatter) { XlaComputation update_computation; { const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); - const XlaOp arg0 = Parameter(sub_builder.get(), 0, - ShapeUtil::MakeScalarShape(F32), "arg0"); - const XlaOp arg1 = Parameter(sub_builder.get(), 1, - ShapeUtil::MakeScalarShape(F32), "arg1"); - Add(arg0, arg1); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); TF_ASSERT_OK_AND_ASSIGN(update_computation, sub_builder->Build()); } @@ -2419,6 +2697,82 @@ TEST(XlaBuilderTest, StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); } +TEST(XlaBuilderTest, UnboundedSelectAndScatter) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + XlaComputation select; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1"), + ComparisonDirection::kGe); + TF_ASSERT_OK_AND_ASSIGN(select, sub_builder->Build()); + } + + XlaComputation scatter; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(scatter, sub_builder->Build()); + } + + SelectAndScatter(Parameter(&b, 0, operand, "operand"), select, + /*window_dimensions=*/ + std::array({3, 1}), + /*window_strides=*/std::array({2, 1}), + Padding::kValid, Parameter(&b, 1, source, "source"), + Parameter(&b, 2, init_value, "init_value"), scatter); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSend) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::DEVICE_TO_DEVICE); + Send(/*operand=*/Parameter(&b, 0, operand, "operand"), /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + +TEST(XlaBuilderTest, UnboundedSendToHost) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout, + ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::DEVICE_TO_HOST); + SendToHost(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*token=*/CreateToken(&b), /*shape_with_layout=*/shape_with_layout, + /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + +TEST(XlaBuilderTest, UnboundedSendWithToken) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + ChannelHandle handle; + handle.set_handle(1); + handle.set_type(ChannelHandle::DEVICE_TO_DEVICE); + SendWithToken(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*token=*/CreateToken(&b), /*handle=*/handle); + EXPECT_OK(BuildHloModule(b)); +} + TEST(XlaBuilderTest, UnboundedSlice) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]")); @@ -2427,7 +2781,33 @@ TEST(XlaBuilderTest, UnboundedSlice) { /*start_indices=*/{0, 1, 2}, /*limit_indices=*/{1, 3, 5}, /*strides=*/{1, 1, 1}); - TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSort) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + XlaComputation comparator; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1"), + ComparisonDirection::kLt); + TF_ASSERT_OK_AND_ASSIGN(comparator, sub_builder->Build()); + } + + Sort({Parameter(&b, 0, operand, "operand")}, comparator, + /*dimension=*/0, /*is_stable=*/true); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); EXPECT_THAT(GetRoot(*module), GmockMatch(m::Op().WithShapeEqualTo(&expected))); } @@ -2445,6 +2825,91 @@ TEST(XlaBuilderTest, UnboundedTranspose) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedTriangularSolve) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape a_shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape b_shape, ParseShape("f32[10, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[10, ?]")); + TriangularSolveOptions options; + TriangularSolve(Parameter(&b, 0, a_shape, "a"), + Parameter(&b, 1, b_shape, "b"), + /*left_side=*/true, /*lower*/ true, /*unit_diagonal=*/false, + TriangularSolveOptions::TRANSPOSE); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedTuple) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + const Shape expected = ShapeUtil::MakeTupleShape({operand}); + Tuple(&b, {Parameter(&b, 0, operand, "operand")}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedWhile) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); + + XlaComputation add; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(add, sub_builder->Build()); + } + + XlaComputation condition; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Ge(/*lhs=*/ConstantR0(sub_builder.get(), 10.0f), + /*rhs=*/Reduce(/*operand=*/Parameter(sub_builder.get(), 0, init, "prev"), + ConstantR0(sub_builder.get(), 0.0f), add, + /*dimensions_to_reduce=*/{0})); + TF_ASSERT_OK_AND_ASSIGN(condition, sub_builder->Build()); + } + + XlaComputation body; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(ConstantR1(sub_builder.get(), {1.0f}), + Parameter(sub_builder.get(), 0, init, "prev"), + /*broadcast_dimensions=*/{0}); + TF_ASSERT_OK_AND_ASSIGN(body, sub_builder->Build()); + } + + While(condition, body, Parameter(&b, 0, init, "init")); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedXor) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("s32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("s32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); + Xor(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/empty_array); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, XlaBuilderUnboundedUnaryOpTest, ::testing::ValuesIn( {{"f32[?]", "f32[?]", &Abs}, @@ -2462,6 +2927,7 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, XlaBuilderUnboundedUnaryOpTest, {"f32[?]", "f32[?]", &Log1p}, {"f32[?]", "f32[?]", &Logistic}, {"f32[?]", "f32[?]", &Neg}, + {"s32[?]", "s32[?]", &Not}, {"u32[?]", "u32[?]", &PopulationCount}, {"f32[?]", "f32[?]", &Real}, {"f32[?]", "f32[?]", &Round}, @@ -2483,6 +2949,11 @@ INSTANTIATE_TEST_SUITE_P( {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Atan2}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "c64[?, ?, 2, 2, <=2, <=2, ?]", + &Complex}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "c64[?, 10]", &Complex}, {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Div}, @@ -2493,6 +2964,11 @@ INSTANTIATE_TEST_SUITE_P( &Max}, {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, "f32[?, 10]", &Max}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Min}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Min}, {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Mul}, @@ -2505,6 +2981,26 @@ INSTANTIATE_TEST_SUITE_P( &Pow}, {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, "f32[?, 10]", &Pow}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Rem}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Rem}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftLeft}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftLeft}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftRightArithmetic}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftRightArithmetic}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftRightLogical}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftRightLogical}, {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Sub}, diff --git a/xla/service/BUILD b/xla/service/BUILD index df559d0118f5a0..d509fc6c37c2cb 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -599,12 +599,12 @@ xla_cc_test( "//xla:statusor", "//xla:test", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:padding", "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index 02e0cb595993e4..ae03d2ec14035f 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -1532,8 +1532,11 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } } - return ShapeUtil::MakeShape(output_shape.element_type(), - arg_shape->dimensions()); + return ShapeUtil::MakeShape( + output_shape.element_type(), arg_shape->dimensions(), + /*dynamic_dimensions=*/ + std::vector(arg_shape->dynamic_dimensions().begin(), + arg_shape->dynamic_dimensions().end())); } /* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( @@ -2323,13 +2326,15 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { "Arguments to triangular solve must have equal rank; got %s and %s.", b.ToString(), a.ToString()); } - if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + if (!CompatibleDimensionSizes(a.dimensions(a.rank() - 2), + a.dimensions(a.rank() - 1))) { return InvalidArgument( "The two minor dimensions of 'a' must have equal size, got %s.", a.ToString()); } - if (a.dimensions(a.rank() - 1) != - b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) { + if (!CompatibleDimensionSizes( + a.dimensions(a.rank() - 1), + b.dimensions(b.rank() - (options.left_side() ? 2 : 1)))) { return InvalidArgument( "The shared dimension of 'a' and 'b' does not match, got shapes %s and " "%s", @@ -2367,9 +2372,10 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { "The 'a' argument to Cholesky must have rank >= 2, got shape %s", a.ToString()); } - if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + if (!CompatibleDimensionSizes(a.dimensions(a.rank() - 2), + a.dimensions(a.rank() - 1))) { return InvalidArgument( - "The two minor dimensions of 'a' must have equal size, got %s.", + "The two minor dimensions of 'a' must have compatible size, got %s.", a.ToString()); } return a; @@ -2388,9 +2394,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather")); Shape output_shape = *operand_shape; - output_shape.set_dimensions( - all_gather_dimension, - shard_count * output_shape.dimensions(all_gather_dimension)); + int64_t output_shape_dimension = + output_shape.dimensions(all_gather_dimension); + output_shape.set_dimensions(all_gather_dimension, + IsUnboundedDynamicSize(output_shape_dimension) + ? Shape::kUnboundedSize + : shard_count * output_shape_dimension); output_shapes.push_back(output_shape); } if (output_shapes.size() == 1) { @@ -3033,7 +3042,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } - if (slice_dim_size > input_dim_size) { + if (!IsUnboundedDynamicSize(input_dim_size) && + slice_dim_size > input_dim_size) { return InvalidArgument( "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); @@ -3157,7 +3167,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) { const int64_t input_dim_size = operand_shape.dimensions(dim); const int64_t update_dim_size = update_shape.dimensions(dim); - if (update_dim_size < 0) { + if (!IsUnboundedDynamicSize(update_dim_size) && update_dim_size < 0) { return InvalidArgument( "Size index %d to dynamic update slice must be >= 0.", update_dim_size); diff --git a/xla/service/shape_inference_test.cc b/xla/service/shape_inference_test.cc index 781852cd41d9dd..9bc23cdb217eaa 100644 --- a/xla/service/shape_inference_test.cc +++ b/xla/service/shape_inference_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/shape_inference.h" +#include +#include #include #include #include @@ -22,7 +24,9 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" @@ -32,10 +36,8 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -141,6 +143,10 @@ class UnboundedBinaryOpShapeInferenceTest class UnboundedCompareOpShapeInferenceTest : public ::testing::TestWithParam {}; +// Subclass for testing unbounded dynamic complex op +class UnboundedComplexOpShapeInferenceTest + : public ::testing::TestWithParam {}; + // Subclass for testing unbounded dynamic concatenate op class UnboundedConcatenateOpShapeInferenceTest : public ::testing::TestWithParam> {}; @@ -4028,6 +4034,28 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { } } +TEST_F(ShapeInferenceTest, UnboundedAllGather) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferAllGatherShape( + {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/2)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedAllReduce) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferAllReduceShape({&operand})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); @@ -4188,6 +4216,16 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupported) { HasSubstr("Non-broadcast dimensions must not be dynamic.")); } +TEST_F(ShapeInferenceTest, UnboundedCholesky) { + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferCholeskyShape(a)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedClampOpShapeInferenceTest, UnboundedClamp) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); @@ -4217,6 +4255,17 @@ TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { "Expected array argument for clamp min, but got (f32[2], f32[?]).")); } +TEST_F(ShapeInferenceTest, UnboundedCollectivePermute) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferCollectivePermuteShape( + /*operand_shapes=*/{&operand})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); @@ -4236,6 +4285,25 @@ TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { } } +TEST_P(UnboundedComplexOpShapeInferenceTest, UnboundedComplex) { + TF_ASSERT_OK_AND_ASSIGN(const Shape real, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape imag, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, real, imag, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + TEST_P(UnboundedConcatenateOpShapeInferenceTest, UnboundedConcatenate) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape(GetParam()[1])); @@ -4380,6 +4448,35 @@ TEST_F(ShapeInferenceTest, UnboundedDotGeneral) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST_F(ShapeInferenceTest, UnboundedDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_index, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDynamicSliceShape( + operand, /*start_index_shapes=*/{start_index, start_index}, + /*slice_sizes=*/{2, 2}, /*allow_scalar_indices=*/true)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape update, ParseShape("f32[?, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_index, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand, update, /*start_index_shapes=*/{start_index, start_index}, + /*allow_scalar_indices=*/true)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_F(ShapeInferenceTest, UnboundedGather) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]")); TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, @@ -4403,6 +4500,34 @@ TEST_F(ShapeInferenceTest, UnboundedGather) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST(XlaBuilderTest, UnboundedGetTupleElement) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferGetTupleElementShape( + ShapeUtil::MakeTupleShape({operand}), /*index=*/0)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedMap) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand0, ParseShape("f32[2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, ?, ?]")); + + const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferMapShape(/*arg_shapes=*/{&operand0, &operand1}, + to_apply, /*dimensions=*/{0, 1, 2})); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); @@ -4422,6 +4547,25 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { } } +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMin) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kMinimum, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); @@ -4533,6 +4677,18 @@ TEST_F(ShapeInferenceTest, UnboundedReduceInvalidReduceDimension) { HasSubstr("All reduced tensors must have compatible dimension")); } +TEST_F(ShapeInferenceTest, UnboundedReducePrecision) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred, + ShapeInference::InferReducePrecisionShape(operand, /*exponent_bits=*/2, + /*mantissa_bits=*/2)); + ASSERT_TRUE(ShapeUtil::Equal(inferred, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_F(ShapeInferenceTest, UnboundedReduceWindow) { TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, 4, 8]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 3, 5]")); @@ -4561,6 +4717,25 @@ TEST_F(ShapeInferenceTest, UnboundedReduceWindow) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedRemainder) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kRemainder, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + TEST_F(ShapeInferenceTest, UnboundedReshape) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2,3]")); @@ -4595,6 +4770,42 @@ TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedMixOfDynamism) { "not supported.")); } +TEST_F(ShapeInferenceTest, UnboundedReverse) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferReverseShape(operand, /*dimensions=*/{0, 1})); + ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedScatter) { + TF_ASSERT_OK_AND_ASSIGN(Shape input, ParseShape("f32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_indices, ParseShape("s32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape updates, ParseShape("f32[?, ?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[?, ?, ?]")); + + const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + + ScatterDimensionNumbers dimension_numbers; + dimension_numbers.add_update_window_dims(2); + dimension_numbers.add_update_window_dims(3); + dimension_numbers.add_inserted_window_dims(0); + dimension_numbers.add_scatter_dims_to_operand_dims(1); + dimension_numbers.add_scatter_dims_to_operand_dims(0); + dimension_numbers.set_index_vector_dim(2); + + TF_ASSERT_OK_AND_ASSIGN( + Shape result, + ShapeInference::InferScatterShape({&input, &scatter_indices, &updates}, + to_apply, dimension_numbers)); + EXPECT_TRUE(ShapeUtil::Equal(result, expected)) + << "inferred: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedSelectOpShapeInferenceTest, UnboundedSelect) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); @@ -4623,6 +4834,103 @@ TEST_F(ShapeInferenceTest, UnboundedSelectWithTupleUnsupported) { "(pred[2], pred[?]).")); } +TEST_F(ShapeInferenceTest, UnboundedSelectAndScatter) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + Window window; + WindowDimension dim0; + dim0.set_base_dilation(1); + dim0.set_size(3); + dim0.set_stride(2); + dim0.set_padding_low(0); + dim0.set_padding_high(1); + dim0.set_window_dilation(1); + + WindowDimension dim1; + dim1.set_base_dilation(1); + dim1.set_size(1); + dim1.set_stride(1); + dim1.set_padding_low(0); + dim1.set_padding_high(0); + dim1.set_window_dilation(1); + + *window.add_dimensions() = dim0; + *window.add_dimensions() = dim1; + + TF_ASSERT_OK_AND_ASSIGN( + Shape result, + ShapeInference::InferSelectAndScatterShape( + operand, + /*select_shape=*/ShapeUtil::MakeProgramShape({f32_, f32_}, pred_), + window, source, init_value, + /*scatter_shape=*/ + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_))); + + EXPECT_TRUE(ShapeUtil::Equal(result, expected)) + << "inferred: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftLeft) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftLeft, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftRightArithmetic) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftRightArithmetic, lhs, + rhs, GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftRightLogical) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftRightLogical, lhs, + rhs, GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + TEST_F(ShapeInferenceTest, UnboundedSlice) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, 3]")); @@ -4636,6 +4944,17 @@ TEST_F(ShapeInferenceTest, UnboundedSlice) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST_F(ShapeInferenceTest, UnboundedSort) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&operand})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); @@ -4655,32 +4974,6 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { } } -TEST_F(ShapeInferenceTest, UnboundedScatter) { - TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, ?, ?]")); - TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_indices, - ParseShape("s32[?, ?, ?]")); - TF_ASSERT_OK_AND_ASSIGN(const Shape updates, ParseShape("f32[?, ?, ?, ?]")); - TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, ?]")); - - const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); - - ScatterDimensionNumbers dimension_numbers; - dimension_numbers.add_update_window_dims(2); - dimension_numbers.add_update_window_dims(3); - dimension_numbers.add_inserted_window_dims(0); - dimension_numbers.add_scatter_dims_to_operand_dims(1); - dimension_numbers.add_scatter_dims_to_operand_dims(0); - dimension_numbers.set_index_vector_dim(2); - - TF_ASSERT_OK_AND_ASSIGN( - const Shape result, - ShapeInference::InferScatterShape({&input, &scatter_indices, &updates}, - to_apply, dimension_numbers)); - EXPECT_TRUE(ShapeUtil::Equal(result, expected)) - << "inferred: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected); -} - TEST_F(ShapeInferenceTest, UnboundedTranspose) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}")); @@ -4705,6 +4998,69 @@ TEST_F(ShapeInferenceTest, UnboundedTransposeRank1) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST_F(ShapeInferenceTest, UnboundedTriangularSolve) { + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape b, ParseShape("f32[?, ?, 4]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 4]")); + TriangularSolveOptions options; + options.set_left_side(true); + options.set_lower(true); + options.set_unit_diagonal(false); + options.set_transpose_a(TriangularSolveOptions::TRANSPOSE); + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferTriangularSolveShape(a, b, options)); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedTuple) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + const Shape expected = ShapeUtil::MakeTupleShape({operand}); + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferVariadicOpShape( + HloOpcode::kTuple, std::vector({&operand}))); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedWhile) { + TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferWhileShape( + /*condition=*/ShapeUtil::MakeProgramShape({result_shape}, pred_), + /*body=*/ShapeUtil::MakeProgramShape({result_shape}, result_shape), + /*init=*/init)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedXor) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kXor, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_status.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_status.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedLogicalOpShapeInferenceTest, ::testing::ValuesIn( @@ -4794,6 +5150,36 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, "", kIncompatibleBinaryOpShapeErrorMessage}})); +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, + UnboundedComplexOpShapeInferenceTest, + ::testing::ValuesIn( + {// LHS | RHS | bdims | Res + // 1 | ? | [] | ? + {"f32[1]", "f32[?]", {}, "c64[?]"}, + // ? | 1 | [] | ? + {"f32[?]", "f32[1]", {}, "c64[?]"}, + // 2 | ? | [] | 2 + {"f32[2]", "f32[?]", {}, "c64[2]"}, + // ? | 2 | [] | 2 + {"f32[?]", "f32[2]", {}, "c64[2]"}, + // <=2 | ? | [] | <=2 + {"f32[<=2]", "f32[?]", {}, "c64[<=2]"}, + // ? | <=2 | [] | <=2 + {"f32[?]", "f32[<=2]", {}, "c64[<=2]"}, + // ? | ? | [] | ? + {"f32[?]", "f32[?]", {}, "c64[?]"}, + // 1 | ?,3 | [0] | ?,3 + {"f32[1]", "f32[?,3]", zero_array, "c64[?,3]"}, + // 2 | ?,3 | [0] | err + {"f32[2]", "f32[?,3]", zero_array, "", + kBroadcastDimensionMismatchErrorMessage}, + // ?,2 | ?,3 | [] | err + {"f32[?,2]", + "f32[?,3]", + {}, + "", + kIncompatibleBinaryOpShapeErrorMessage}})); + INSTANTIATE_TEST_SUITE_P( UnboundedDynamism, UnboundedConcatenateOpShapeInferenceTest, ::testing::Values( @@ -4933,6 +5319,7 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedUnaryOpShapeInferenceTest, {"f32[?]", "f32[?]", HloOpcode::kLog1p}, {"f32[?]", "f32[?]", HloOpcode::kLogistic}, {"f32[?]", "f32[?]", HloOpcode::kNegate}, + {"s32[?]", "s32[?]", HloOpcode::kNot}, {"u32[?]", "u32[?]", HloOpcode::kPopulationCount}, {"f32[?]", "f32[?]", HloOpcode::kReal}, {"f32[?]", "f32[?]", HloOpcode::kRoundNearestAfz},