Skip to content

Commit

Permalink
Add support for negative constant (#2288)
Browse files Browse the repository at this point in the history
### Ticket
#2287

### Problem description
LLVM disabled `implicitTruncate` as default option for APInt constructor
which triggered assertion for negative constant. It also caused failure
for 64 bit constant as 64 bit constants are converted to 32 bit
implicitly.

### What's changed
- Use APInt object to extract constant values instead of basic data types
which covers both positive and negative values.
- Use truncation with sign saturation instead of implicit truncation.
- TTIR->TTNN conversion for `ttir.constant` is also updated to consider
signless integer as signed value as they can also store negative values. 

### Checklist
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
mmanzoorTT authored Feb 27, 2025
1 parent 2175af5 commit 340e05f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 75 deletions.
130 changes: 57 additions & 73 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,13 @@ class StableHLOToTTIRConstantOpConversionPattern
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

mlir::ElementsAttr valueAttr = getValueAttr(srcOp.getValue());
mlir::ElementsAttr valueAttr;
LogicalResult valueAttrLegalityResult =
getValueAttr(srcOp, rewriter, valueAttr);

if (!valueAttrLegalityResult.succeeded()) {
return valueAttrLegalityResult;
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
valueAttr);
Expand All @@ -644,104 +650,82 @@ class StableHLOToTTIRConstantOpConversionPattern
// 2. Boolean tensor: TTNN does not support boolean data. So they are
// converted to bfloat16 tensors.
// 3. Integer tensor: TTNN does not support 64 bit integer. So they are
// converted to 32 bit tensor.
// converted to 32 bit tensor (with signed saturation).
// 4. Float tensor: TTNN does not support 64 bit float. So they are converted
// to 32 bit tensor.
mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const {
LogicalResult getValueAttr(mlir::stablehlo::ConstantOp &srcOp,
ConversionPatternRewriter &rewriter,
mlir::ElementsAttr &newValueAttr) const {
mlir::ElementsAttr valueAttr = srcOp.getValue();
Type elementType = valueAttr.getElementType();
size_t bitWidth = elementType.getIntOrFloatBitWidth();
bool isTensor = !valueAttr.getShapedType().getShape().empty();
bool isIntTensor = isTensor && isa<IntegerType>(elementType) &&
bitWidth != 1 && bitWidth != 64;
bool isFloatTensor = isTensor && isa<FloatType>(elementType) &&
bitWidth != 1 && bitWidth != 64;

if (isTensor && (isIntTensor || isFloatTensor)) {
return valueAttr;
bool isScalar = valueAttr.getShapedType().getShape().empty();
bool isElementTypeSupported =
(isa<IntegerType>(elementType) || isa<FloatType>(elementType)) &&
bitWidth != 1 && bitWidth != 64;
if (!isScalar && isElementTypeSupported) {
newValueAttr = valueAttr;
return success();
}

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));
if (isa<IntegerType>(elementType)) {
switch (bitWidth) {
case 1: {
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint8_t>(valueAttr, 8)
: rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint16_t>(valueAttr, 16)
: rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint32_t>(valueAttr, 32)
: rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint64_t>(valueAttr, 32)
: rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
}
}
newValueAttr = rebuildIntValueAttr(valueAttr, valueType, bitWidth);
return success();
}
if (isa<FloatType>(elementType)) {
// Convert 64 bit floating point numbers to 32 bit floating point numbers.
if (bitWidth == 64) {
std::vector<mlir::APFloat> floatValues;
for (mlir::APFloat value : valueAttr.getValues<mlir::APFloat>()) {
float fl = static_cast<float>(value.convertToDouble());
mlir::APFloat input = mlir::APFloat(fl);
floatValues.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
std::vector<mlir::APFloat> floatValues(
valueAttr.getValues<mlir::APFloat>().begin(),
valueAttr.getValues<mlir::APFloat>().end());
return mlir::DenseElementsAttr::get(valueType, floatValues);
newValueAttr = rebuildFloatValueAttr(valueAttr, valueType, bitWidth);
return success();
}
assert(false && "Unsupported data type.");
return rewriter.notifyMatchFailure(srcOp, "Unsupported data type.");
}

// Extract the values (using the given ElementType) and create new data
// structure. This is used to convert scalars (of type boolean, int8, int16,
// int32, int64, uint8, uint16, uint32, uint64) and tensors (of type boolean
// and int64).
template <typename ElementType>
mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr,
size_t bitWidth) const {
mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));

// Extract the values and create new ElementsAttr data structure. This is used
// to convert scalars boolean to bfloat16 and 64 bit integer to 32 bit integer
// by truncating with signed saturation.
mlir::ElementsAttr rebuildIntValueAttr(mlir::ElementsAttr valueAttr,
mlir::ShapedType valueType,
size_t bitWidth) const {
// Create data structure for boolean type with bfloat16.
if (bitWidth == 1) {
std::vector<mlir::APFloat> booleanValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APFloat input(mlir::APFloat::BFloat(), value);
booleanValue.emplace_back(input);
for (bool value : valueAttr.getValues<bool>()) {
booleanValue.emplace_back(mlir::APFloat::BFloat(), value);
}
return mlir::DenseElementsAttr::get(valueType, booleanValue);
}

// Create data structure for other types.
std::vector<mlir::APInt> IntegerValue = {};
for (ElementType value : valueAttr.getValues<ElementType>()) {
mlir::APInt input(bitWidth, value);
IntegerValue.emplace_back(input);
for (mlir::APInt value : valueAttr.getValues<mlir::APInt>()) {
// Truncate to 32 bits with signed saturation in case of 64 bit integers.
IntegerValue.emplace_back(bitWidth == 64 ? value.truncSSat(32) : value);
}
return mlir::DenseElementsAttr::get(valueType, IntegerValue);
}

mlir::ElementsAttr rebuildFloatValueAttr(mlir::ElementsAttr valueAttr,
mlir::ShapedType valueType,
size_t bitWidth) const {
// Convert 64 bit floating point numbers to 32 bit floating point numbers.
if (bitWidth == 64) {
std::vector<mlir::APFloat> floatValues;
for (mlir::APFloat value : valueAttr.getValues<mlir::APFloat>()) {
float fl = static_cast<float>(value.convertToDouble());
mlir::APFloat input = mlir::APFloat(fl);
floatValues.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
std::vector<mlir::APFloat> floatValues(
valueAttr.getValues<mlir::APFloat>().begin(),
valueAttr.getValues<mlir::APFloat>().end());
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
};
} // namespace

Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,11 @@ class ConstantOpConversionPattern

mlir::APFloat fillValue(mlir::APFloat::IEEEsingle());
if (valueAttr.getElementType().isInteger()) {
// Both signed and signless integer can have negative values.
bool isSigned = valueAttr.getElementType().isSignedInteger() ||
valueAttr.getElementType().isSignlessInteger();
fillValue.convertFromAPInt(valueAttr.getSplatValue<llvm::APInt>(),
valueAttr.getElementType().isSignedInteger(),
llvm::RoundingMode::TowardZero);
isSigned, llvm::RoundingMode::TowardZero);
} else {
fillValue = valueAttr.getSplatValue<mlir::APFloat>();
}
Expand Down
33 changes: 33 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,37 @@ module @jit_constant attributes {} {
// CHECK: return %{{[0-9]+}} : tensor<2x2xui32>
return %0 : tensor<2x2xui64>
}

func.func public @test_int8_negative_scalar() -> tensor<i8> {
// CHECK-LABEL: func.func public @test_int8_negative_scalar
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<-3> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = stablehlo.constant dense<-3> : tensor<i8>
// CHECK: return %[[CONSTANT]] : tensor<1xi8>
return %0 : tensor<i8>
}

func.func public @test_int16_negative_splat() -> tensor<64xi16> {
// CHECK-LABEL: func.func public @test_int16_negative_splat
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<-3> : tensor<64xi16>}> : () -> tensor<64xi16>
%0 = stablehlo.constant dense<-3> : tensor<64xi16>
// CHECK: return %[[CONSTANT]] : tensor<64xi16>
return %0 : tensor<64xi16>
}

func.func public @test_int32_negative_multiple() -> tensor<2x2xi32> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<{{([[])}}[0, -1], [-2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
%0 = stablehlo.constant dense<[[0, -1], [-2, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CONSTANT]] : tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

func.func public @test_int64_negative_min_scalar() -> tensor<i64> {
// CHECK-LABEL: func.func public @test_int64_negative_min_scalar
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<-2147483648> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = stablehlo.constant dense<9223372036854775808> : tensor<i64>
// CHECK: return %[[CONSTANT]] : tensor<1xi32>
return %0 : tensor<i64>
}
}
8 changes: 8 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,12 @@ module attributes {} {
%0 = "ttir.constant"() <{value = dense<[[1], [2], [3]]> : tensor<3x1xui8>}> : () -> tensor<3x1xui8>
return %0 : tensor<3x1xui8>
}

func.func @test_constant_i32_negative() -> tensor<1x1x3xi32> {
// CHECK: "ttnn.constant"
// CHECK-SAME: value = dense
// CHECK-SAME: -1, 2, 3
%0 = "ttir.constant"() <{value = dense<[[[-1, 2, 3]]]> : tensor<1x1x3xi32>}> : () -> tensor<1x1x3xi32>
return %0 : tensor<1x1x3xi32>
}
}

0 comments on commit 340e05f

Please sign in to comment.