Skip to content

Commit

Permalink
[NFC] Switch mhlo dialect to MLIR properties.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621258659
  • Loading branch information
chsigg authored and copybara-github committed Apr 5, 2024
1 parent be5c637 commit 0fe0bd3
Show file tree
Hide file tree
Showing 59 changed files with 1,192 additions and 1,199 deletions.
243 changes: 122 additions & 121 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion xla/mlir_hlo/mhlo/IR/hlo_ops_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def MHLO_Dialect : Dialect {

let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
let usePropertiesForAttributes = 0;
}

include "mhlo/IR/hlo_base.td"
Expand Down
10 changes: 6 additions & 4 deletions xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ inline Value mapMhloOpToStdScalarOp<mhlo::PowOp>(Location loc,
// The accum is correct when the rhs is non-negative. When rhs is
// negative, we return 0 for integer, with the exception of lhs values of 1
// and -1 which have integer results for negative exponents. Specifically, the
// calulation is the following:
// calculation is the following:
//
// - Return accum if the rhs is not negative.
// - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
Expand Down Expand Up @@ -1313,9 +1313,11 @@ struct MhloOpToStdScalarOp {
ArrayRef<Type> argTypes, ValueRange args,
OpBuilder* b) {
static_assert(!std::is_same<MhloOpTy, mhlo::ConvertOp>::value);
return mapOpOfType<MhloOpTy>(
op.getLoc(), resultTypes, argTypes,
typename MhloOpTy::Adaptor(args, op->getAttrDictionary()), b);
typename MhloOpTy::Adaptor adaptor(args, op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions());
return mapOpOfType<MhloOpTy>(op.getLoc(), resultTypes, argTypes, adaptor,
b);
}
// Overload for mhlo::ConvertOp.
static Value mapOpWithArgTypes(mhlo::ConvertOp op, ArrayRef<Type> resultTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ struct InferReturnTypesPattern : public RewritePattern {
SmallVector<Type, 4> types;
if (failed(definingOpInt.inferReturnTypes(
op->getContext(), op->getLoc(), definingOp->getOperands(),
definingOp->getAttrDictionary(), op->getPropertiesStorage(),
definingOp->getRegions(), types))) {
definingOpInt->getAttrDictionary(),
definingOpInt->getPropertiesStorage(), definingOpInt->getRegions(),
types))) {
return failure();
}

Expand Down
2 changes: 1 addition & 1 deletion xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func.func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32>
func.func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> {
// CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<f32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x?xi64> -> tensor<2xindex>
// CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: return %[[BROADCASTED_CONSTANT]] : tensor<?x?xf32>
%result = "chlo.constant_like"(%arg) { value = 3.2 : f32 }
: (tensor<?x?xi64>) -> tensor<?x?xf32>
Expand Down
16 changes: 8 additions & 8 deletions xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func.func @single_bcast_ensure_order(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16
func.func @double_bcasts(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16x?xf32>,
%shape0 : tensor<3xindex>, %shape1 : tensor<3xindex>)
-> (tensor<?x16x?xf32>, tensor<?x16x?xf32>) {
// CHECK-DAG: %[[BCASTED_ARG00:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE0]]) [[BCAST_DIMS0:{broadcast_dimensions = dense<\[1, 2\]> : tensor<2xi64>}]]
// CHECK-DAG: %[[BCASTED_ARG01:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE1]]) [[BCAST_DIMS1:{broadcast_dimensions = dense<\[0, 2\]> : tensor<2xi64>}]]
// CHECK-DAG: %[[BCASTED_ARG00:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE0]]) [[BCAST_DIMS0:<{broadcast_dimensions = dense<\[1, 2\]> : tensor<2xi64>}>]]
// CHECK-DAG: %[[BCASTED_ARG01:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE1]]) [[BCAST_DIMS1:<{broadcast_dimensions = dense<\[0, 2\]> : tensor<2xi64>}>]]
// CHECK-DAG: %[[BCASTED_ARG10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE0]]) [[BCAST_DIMS0]]
// CHECK-DAG: %[[BCASTED_ARG11:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE1]]) [[BCAST_DIMS1]]
// CHECK-DAG: %[[ADD0:.*]] = mhlo.add %[[BCASTED_ARG00]], %[[BCASTED_ARG10]] : [[BCAST_TY:tensor<\?x16x\?xf32>]]
Expand Down Expand Up @@ -85,8 +85,8 @@ func.func @double_bcasts(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16x?xf32>,
func.func @late_output_dimensions(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>,
%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x32xf32> {
// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
// CHECK-DAG: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-DAG: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-DAG: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}>
// CHECK-DAG: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}>
// CHECK-DAG: %[[SUB:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : [[BCAST_TY:tensor<\?x\?x32xf32>]]
// CHECK-DAG: %[[ADD:.*]] = mhlo.add %[[SUB]], %[[SUB]] : [[BCAST_TY]]
// CHECK: return %[[ADD]] : [[BCAST_TY]]
Expand Down Expand Up @@ -118,7 +118,7 @@ func.func @very_late_output_dimensions(%arg0 : tensor<?x32xf32>,
%acc2 = mhlo.subtract %acc1, %arg1 : tensor<?x32xf32>
%acc3 = mhlo.divide %acc2, %arg1 : tensor<?x32xf32>
%1 = shape.shape_of %arg2 : tensor<?x?x?xf32> -> tensor<3xindex>
%3 = "mhlo.dynamic_broadcast_in_dim"(%acc3, %1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
%3 = "mhlo.dynamic_broadcast_in_dim"(%acc3, %1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
func.return %3 : tensor<?x?x32xf32>
}

Expand Down Expand Up @@ -176,7 +176,7 @@ func.func @propagate_within_block_2(%arg : tensor<?x?x?xf32>,
// CHECK-SAME: %[[ARG:.*]]: tensor<1xindex>
func.func @propagate_across_bcasts_cst_src(%s : tensor<1xindex>) -> tensor<?xi1> {
// CHECK-DAG: %[[C1:.*]] = mhlo.constant dense<true> : tensor<i1>
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[C1]], %[[ARG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[C1]], %[[ARG]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<i1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK: return %[[RES]]
%0 = mhlo.constant dense<true> : tensor<i1>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %s)
Expand All @@ -193,7 +193,7 @@ func.func @propagate_across_bcasts_cst_src(%s : tensor<1xindex>) -> tensor<?xi1>
// CHECK-LABEL: @compose_bcast_dims
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xi1>, %[[S0:.*]]: tensor<3xindex>, %[[S1:.*]]: tensor<4xindex>
func.func @compose_bcast_dims(%arg : tensor<?x?xi1>, %s0 : tensor<3xindex>, %s1 : tensor<4xindex>) -> tensor<1x?x1x?xi1> {
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S1]]) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<?x?xi1>, tensor<4xindex>) -> tensor<1x?x1x?xi1>
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S1]]) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<?x?xi1>, tensor<4xindex>) -> tensor<1x?x1x?xi1>
// CHECK: return %[[RES]]
%1 = "mhlo.dynamic_broadcast_in_dim"(%arg, %s0)
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
Expand All @@ -209,7 +209,7 @@ func.func @compose_bcast_dims(%arg : tensor<?x?xi1>, %s0 : tensor<3xindex>, %s1
// CHECK-LABEL: @propagate_across_bcasts
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?x?xf32>, %[[S:.*]]: tensor<3xindex>
func.func @propagate_across_bcasts(%arg : tensor<?x?x?xf32>, %shape : tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S]]) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: return %[[RES]]
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape)
{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
Expand Down
Loading

0 comments on commit 0fe0bd3

Please sign in to comment.