From b2b558f788efdb171cba163189eaf9979f8c9dbb Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Thu, 4 Apr 2024 07:44:16 -0700 Subject: [PATCH] Insert symmetric quantization parameters for weights PiperOrigin-RevId: 621856550 --- .../mhlo_quant_legalize_to_int.cc | 45 ++++-- .../mhlo/mhlo-quant-legalize-to-int.mlir | 128 ++++++++++++++++++ 2 files changed, 160 insertions(+), 13 deletions(-) diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc index d0da0f12e098f..7cdadd2af8dbe 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc @@ -592,6 +592,18 @@ struct DotLikeDimensionNumbers { SmallVector rhsContractingDims; }; +// Checks if zero points of the given quantized type are zero. +bool isZeroPointZero(QuantType type) { + if (isPerTensorType(type)) { + return getPerTensorType(type).getZeroPoint() == 0; + } + if (isPerChannelType(type)) { + ArrayRef zeroPoints = getPerChannelType(type).getZeroPoints(); + return llvm::all_of(zeroPoints, [](int64_t zp) { return zp == 0; }); + } + return false; +} + // A shared matchAndRewrite implementation for dot-like hybrid quantized // operators. Hybrid ops are currently only interpreted as weight-only // quantization ops, this might change in the future. @@ -611,8 +623,10 @@ LogicalResult matchAndRewriteDotLikeHybridOp( adaptor.getRhs()); Operation::result_range resultRange = barrier.getResults(); Value rhs = resultRange.front(); - auto rhsElementType = getElementTypeOrSelf(op.getRhs().getType()) - .template cast(); + auto rhsElementQuantType = getQuantType(op.getRhs().getType()); + if (failed(rhsElementQuantType)) { + return failure(); + } auto resFloat32TensorType = op.getResult().getType().template cast(); auto rhsFloat32TensorType = @@ -620,21 +634,25 @@ LogicalResult matchAndRewriteDotLikeHybridOp( rewriter.getF32Type()); // Get scales and zero points for rhs. - Value rhsZeroPoint = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr((rhsElementType.getZeroPoint()))); - Value rhsScaleConstant = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(rhsElementType.getScale()))); + Value rhsScale, rhsZeroPoint; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), *rhsElementQuantType, rhsScale, + rhsZeroPoint, + /*outputZeroPointInFp=*/true, broadcastDims); // Dequantize rhs_float32_tensor. Value rhsFloat32Tensor = rewriter.create(op->getLoc(), rhsFloat32TensorType, rhs); - rhsFloat32Tensor = rewriter.create( - op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint, - nullptr); + + // Subtract zero points only when it is not zero. + if (!isZeroPointZero(*rhsElementQuantType)) { + rhsFloat32Tensor = rewriter.create( + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint, + broadcastDims); + } rhsFloat32Tensor = rewriter.create( - op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScaleConstant, - nullptr); + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScale, + broadcastDims); // Execute conversion target op. SmallVector operands{lhsFloat32Tensor, rhsFloat32Tensor}; @@ -1045,7 +1063,8 @@ FailureOr isDotLikeOpHybrid(DotLikeOp op) { // both per-tensor quantized. return false; } - if (!isLhsQuant && !isLhsQuantPerChannel && isRhsQuant && !isResQuant && + if (!isLhsQuant && !isLhsQuantPerChannel && + (isRhsQuant || isRhsQuantPerChannel) && !isResQuant && !isResQuantPerChannel) { return true; } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir index 01ac1b8d13bc4..02e0693a772d8 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir @@ -1779,6 +1779,61 @@ func.func @dot_hybrid( // ----- +// CHECK-LABEL: func @dot_general_hybrid_per_channel +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8> +func.func @dot_general_hybrid_per_channel( + %arg0: tensor<3x2xf32>, + %arg1: tensor<2x2x!quant.uniform:f32:1, {3.000000e+00, 4.000000e+00}>> + ) -> tensor<3x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-NOT: chlo.broadcast_subtract + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]]) + // CHECK: {dot_dimension_numbers = #mhlo.dot} : (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32> + // CHECK: return %[[DOT]] + + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot} : ( + tensor<3x2xf32>, + tensor<2x2x!quant.uniform:f32:1, {3.000000e+00, 4.000000e+00}>> + ) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: func @dot_general_hybrid_per_channel_asymmetric +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8> +func.func @dot_general_hybrid_per_channel_asymmetric( + %arg0: tensor<3x2xf32>, + %arg1: tensor<2x2x!quant.uniform:f32:1, {3.000000e+00:10, 4.000000e+00:20}>> + ) -> tensor<3x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]]) + // CHECK: {dot_dimension_numbers = #mhlo.dot} : (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32> + // CHECK: return %[[DOT]] + + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot} : ( + tensor<3x2xf32>, + tensor<2x2x!quant.uniform:f32:1, {3.000000e+00:10, 4.000000e+00:20}>> + ) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + func.func @dot_hybrid_result_type_not_float( %arg0: tensor, %arg1: tensor>) { @@ -1839,6 +1894,79 @@ func.func @conv2d_static_hybrid( // ----- +// CHECK-LABEL: func @conv2d_hybrid_per_channel +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8> +func.func @conv2d_hybrid_per_channel( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32> + // CHECK-NOT: chlo.broadcast_subtract + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32> + // CHECK: return %[[CONV]] + + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1xf32>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2xf32> + return %0 : tensor<128x26x26x2xf32> +} + +// ----- + +// CHECK-LABEL: func @conv2d_hybrid_per_channel_asymmetric +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8> +func.func @conv2d_hybrid_per_channel_asymmetric( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32> + // CHECK: return %[[CONV]] + + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1xf32>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2xf32> + return %0 : tensor<128x26x26x2xf32> +} + +// ----- + func.func @conv2d_hybrid_result_not_float( %arg0: tensor<128x28x28x1xf32>, %arg1: tensor<3x3x1x128x!quant.uniform>) {