Skip to content

Commit

Permalink
[mlir][sparse] Introduce batch level format. (llvm#83082)
Browse files Browse the repository at this point in the history
commit-id:7cf83239
  • Loading branch information
PeimingLiu authored and vzakhari committed Mar 14, 2024
1 parent aa737b2 commit 4a51fa8
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 18 deletions.
9 changes: 5 additions & 4 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;

enum MlirSparseTensorLevelFormat {
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
};

enum MlirSparseTensorLevelPropertyNondefault {
Expand Down
29 changes: 23 additions & 6 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,26 @@ enum class Action : uint32_t {
enum class LevelFormat : uint64_t {
Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
LooseCompressed = 0x00080000,
NOutOfM = 0x00100000,
Batch = 0x00020000,
Compressed = 0x00040000,
Singleton = 0x00080000,
LooseCompressed = 0x00100000,
NOutOfM = 0x00200000,
};

constexpr bool encPowOfTwo(LevelFormat fmt) {
auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
return (enc & (enc - 1)) == 0;
}

// All LevelFormats must have only one bit set (power of two).
static_assert(encPowOfTwo(LevelFormat::Dense) &&
encPowOfTwo(LevelFormat::Batch) &&
encPowOfTwo(LevelFormat::Compressed) &&
encPowOfTwo(LevelFormat::Singleton) &&
encPowOfTwo(LevelFormat::LooseCompressed) &&
encPowOfTwo(LevelFormat::NOutOfM));

template <LevelFormat... targets>
constexpr bool isAnyOfFmt(LevelFormat fmt) {
return (... || (targets == fmt));
Expand All @@ -172,6 +186,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
return "undef";
case LevelFormat::Dense:
return "dense";
case LevelFormat::Batch:
return "batch";
case LevelFormat::Compressed:
return "compressed";
case LevelFormat::Singleton:
Expand Down Expand Up @@ -225,10 +241,10 @@ struct LevelType {
static constexpr bool isValidLvlBits(uint64_t lvlBits) {
auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
const uint64_t propertyBits = lvlBits & 0xffff;
// If undefined/dense/NOutOfM, then must be unique and ordered.
// If undefined/dense/batch/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
LevelFormat::NOutOfM>(fmt))
LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
? (propertyBits == 0)
: (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed>(fmt));
Expand Down Expand Up @@ -375,6 +391,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
}
inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
inline bool isCompressedLT(LevelType lt) {
return lt.isa<LevelFormat::Compressed>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",

The supported level-formats are the following:

- **dense** : all entries along this level are stored
- **dense** : all entries along this level are stored and linearized.
- **batch** : all entries along this level are stored but not linearized.
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("batch") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Batch);
} else if (base.compare("compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
} else if (base.compare("structured") == 0) {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
}
}

auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
return emitError() << "Batch lvlType can only be leading levels.";

// SoA property can only be applied on singleton level.
auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
return lt.isa<LevelPropNonDefault::SoA>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
case LevelFormat::Batch:
llvm_unreachable("not implemented");
case LevelFormat::Compressed: {
Value pos = genToPositions(b, l, t, lvl);
Value crd = genToCoordinates(b, l, t, lvl);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/CAPI/sparse_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
// CHECK: level_type: 65536
// CHECK: level_type: 131072
// CHECK: level_type: 131072
// CHECK: level_type: 262144
// CHECK: level_type: 262144
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{Batch lvlType can only be leading levels}}
#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()

// -----

// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)

// -----

#BCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
}>

// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
// CHECK-LABEL: func private @sparse_bcsr(
// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)

// -----

#CSR_explicit = #sparse_tensor.encoding<{
map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/python/dialects/sparse_tensor/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [131072]
# CHECK: lvl_types: [262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down Expand Up @@ -71,9 +71,9 @@ def testEncodingAttrStructure():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [65536, 65536, 4406637494272]
# CHECK: lvl_types: [65536, 65536, 4406638542848]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
# CHECK: structured_n: 2
print(f"structured_n: {casted.structured_n}")
Expand Down Expand Up @@ -157,7 +157,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")

# CHECK: lvl_types: [65536, 131072]
# CHECK: lvl_types: [65536, 262144]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
Expand Down

0 comments on commit 4a51fa8

Please sign in to comment.