Skip to content

Commit

Permalink
Fold CastOp and ExpandShapeOp into air::DmaMemcpyOp (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Oct 19, 2024
1 parent b1eeeb6 commit b5d30a3
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 47 deletions.
52 changes: 52 additions & 0 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,26 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma(
src_offsets.clear();
for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++)
src_offsets.push_back(constZero);
} else if (auto expandShapeOp = dyn_cast<memref::ExpandShapeOp>(
src_ancestor_memref_ops[0])) {
// Init. memref type
src_memref_ty =
llvm::cast<MemRefType>(expandShapeOp.getViewSource().getType());
src = expandShapeOp.getViewSource();
// Init. offsets
src_offsets.clear();
for (unsigned i = 0; i < expandShapeOp.getReassociationIndices().size();
i++)
src_offsets.push_back(constZero);
} else if (auto castOp =
dyn_cast<memref::CastOp>(src_ancestor_memref_ops[0])) {
// Init. memref type
src_memref_ty = llvm::cast<MemRefType>(castOp.getViewSource().getType());
src = castOp.getViewSource();
// Init. offsets
src_offsets.clear();
for (unsigned i = 0; i < src_memref_ty.getRank(); i++)
src_offsets.push_back(constZero);
}
} else {
src_offsets = dmaOp.getSrcOffsets();
Expand All @@ -1290,6 +1310,26 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma(
dst_offsets.clear();
for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++)
dst_offsets.push_back(constZero);
} else if (auto expandShapeOp = dyn_cast<memref::ExpandShapeOp>(
dst_ancestor_memref_ops[0])) {
// Init. memref type
dst_memref_ty =
llvm::cast<MemRefType>(expandShapeOp.getViewSource().getType());
dst = expandShapeOp.getViewSource();
// Init. offsets
dst_offsets.clear();
for (unsigned i = 0; i < expandShapeOp.getReassociationIndices().size();
i++)
dst_offsets.push_back(constZero);
} else if (auto castOp =
dyn_cast<memref::CastOp>(dst_ancestor_memref_ops[0])) {
// Init. memref type
dst_memref_ty = llvm::cast<MemRefType>(castOp.getViewSource().getType());
dst = castOp.getViewSource();
// Init. offsets
dst_offsets.clear();
for (unsigned i = 0; i < dst_memref_ty.getRank(); i++)
dst_offsets.push_back(constZero);
}
} else {
dst_offsets = dmaOp.getDstOffsets();
Expand Down Expand Up @@ -1341,6 +1381,9 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma(
llvm::cast<MemRefType>(memref::SubViewOp::inferResultType(
src_memref_ty, subviewOp.getStaticOffsets(),
subviewOp.getStaticSizes(), subviewOp.getStaticStrides()));
} else if (auto castOp = dyn_cast<memref::CastOp>(memrefOp)) {
// Init. memref type
src_memref_ty = llvm::cast<MemRefType>(castOp.getResult().getType());
}
}

Expand Down Expand Up @@ -1387,6 +1430,9 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma(
llvm::cast<MemRefType>(memref::SubViewOp::inferResultType(
dst_memref_ty, subviewOp.getStaticOffsets(),
subviewOp.getStaticSizes(), subviewOp.getStaticStrides()));
} else if (auto castOp = dyn_cast<memref::CastOp>(memrefOp)) {
// Init. memref type
dst_memref_ty = llvm::cast<MemRefType>(castOp.getResult().getType());
}
}

Expand Down Expand Up @@ -1537,6 +1583,9 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase<CopyToDmaPass> {
dyn_cast<memref::SubViewOp>(ancestor)) {
std::get<1>(log_entry).push_back(ancestor);
ancestor = subview_anc.getSource().getDefiningOp();
} else if (auto cast_anc = dyn_cast<memref::CastOp>(ancestor)) {
std::get<1>(log_entry).push_back(ancestor);
ancestor = cast_anc.getViewSource().getDefiningOp();
} else
exit = true;
}
Expand All @@ -1556,6 +1605,9 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase<CopyToDmaPass> {
dyn_cast<memref::SubViewOp>(ancestor)) {
std::get<2>(log_entry).push_back(ancestor);
ancestor = subview_anc.getSource().getDefiningOp();
} else if (auto cast_anc = dyn_cast<memref::CastOp>(ancestor)) {
std::get<2>(log_entry).push_back(ancestor);
ancestor = cast_anc.getViewSource().getDefiningOp();
} else
exit = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

// RUN: air-opt %s -air-copy-to-dma -canonicalize -cse | FileCheck %s

// Memref::SubviewOp, memref::ExpandShapeOp and memref::TransposeOp folding.

// CHECK: %[[CST128:.*]] = arith.constant 128 : index
// CHECK: %[[CST32:.*]] = arith.constant 32 : index
// CHECK: %[[CST8:.*]] = arith.constant 8 : index
Expand Down Expand Up @@ -36,54 +38,87 @@
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
module {
func.func @func0(%0 : memref<8x16xi32>, %1 : memref<16x32xi32>, %2 : memref<8x32xi32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
air.launch (%arg0, %arg1) in (%arg2=%c1, %arg3=%c2) args(%arg4=%0, %arg5=%1, %arg6=%2) : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> {
air.segment @segment_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> {
%c1_0 = arith.constant 1 : index
%3 = affine.apply #map()[%arg7]
%4 = affine.apply #map1()[%arg8]
%subview = memref.subview %arg9[%3, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>>
%subview_1 = memref.subview %arg10[0, %4] [16, 16] [1, 1] : memref<16x32xi32> to memref<16x16xi32, strided<[32, 1], offset: ?>>
%subview_2 = memref.subview %arg11[%3, %4] [8, 16] [1, 1] : memref<8x32xi32> to memref<8x16xi32, strided<[32, 1], offset: ?>>
%alloc = memref.alloc() : memref<1x1x8x16xi32, 1>
%transpose = memref.transpose %subview (d0, d1) -> (d0, d1) : memref<8x16xi32, strided<[16, 1], offset: ?>> to memref<8x16xi32, strided<[16, 1], offset: ?>>
air.dma_memcpy_nd (%alloc[] [] [], %transpose[] [] []) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32, strided<[16, 1], offset: ?>>)
%alloc_3 = memref.alloc() : memref<1x1x16x16xi32, 1>
%transpose_4 = memref.transpose %subview_1 (d0, d1) -> (d0, d1) : memref<16x16xi32, strided<[32, 1], offset: ?>> to memref<16x16xi32, strided<[32, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_4[] [] []) : (memref<1x1x16x16xi32, 1>, memref<16x16xi32, strided<[32, 1], offset: ?>>)
%alloc_5 = memref.alloc() : memref<1x1x8x16xi32, 1>
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1_0, %arg15=%c1_0) args(%arg16=%alloc, %arg17=%alloc_3, %arg18=%alloc_5) : memref<1x1x8x16xi32, 1>, memref<1x1x16x16xi32, 1>, memref<1x1x8x16xi32, 1> {
%c0_i32 = arith.constant 0 : i32
%subview_8 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>
%subview_9 = memref.subview %arg17[0, %arg13, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xi32, 1> to memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1>
%subview_10 = memref.subview %arg18[%arg12, %arg13, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>
%alloc_11 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2>
%expand_shape = memref.expand_shape %subview_8 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 4, 2, 8]: memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> into memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1>
%transpose_12 = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>
air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_12[] [] []) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>)
%alloc_13 = memref.alloc() : memref<1x1x2x2x8x8xi32, 2>
%expand_shape_14 = memref.expand_shape %subview_9 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 8, 2, 8] : memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> into memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>
air.dma_memcpy_nd (%alloc_13[] [] [], %transpose_15[] [] []) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>)
%alloc_16 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2>
%transpose_17 = memref.transpose %alloc_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x2x2x4x8xi32, 2> to memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>
air.dma_memcpy_nd (%subview_10[] [] [], %transpose_17[] [] []) : (memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>, memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>)
memref.dealloc %alloc_11 : memref<1x1x2x2x4x8xi32, 2>
memref.dealloc %alloc_13 : memref<1x1x2x2x8x8xi32, 2>
memref.dealloc %alloc_16 : memref<1x1x2x2x4x8xi32, 2>
}
%subview_6 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<8x16xi32, 1>
%transpose_7 = memref.transpose %subview_6 (d0, d1) -> (d0, d1) : memref<8x16xi32, 1> to memref<8x16xi32, strided<[16, 1]>, 1>
air.dma_memcpy_nd (%subview_2[] [] [], %transpose_7[] [] []) : (memref<8x16xi32, strided<[32, 1], offset: ?>>, memref<8x16xi32, strided<[16, 1]>, 1>)
memref.dealloc %alloc_3 : memref<1x1x16x16xi32, 1>
memref.dealloc %alloc : memref<1x1x8x16xi32, 1>
memref.dealloc %alloc_5 : memref<1x1x8x16xi32, 1>
func.func @func0(%0 : memref<8x16xi32>, %1 : memref<16x32xi32>, %2 : memref<8x32xi32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
air.launch (%arg0, %arg1) in (%arg2=%c1, %arg3=%c2) args(%arg4=%0, %arg5=%1, %arg6=%2) : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> {
air.segment @segment_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> {
%c1_0 = arith.constant 1 : index
%3 = affine.apply #map()[%arg7]
%4 = affine.apply #map1()[%arg8]
%subview = memref.subview %arg9[%3, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>>
%subview_1 = memref.subview %arg10[0, %4] [16, 16] [1, 1] : memref<16x32xi32> to memref<16x16xi32, strided<[32, 1], offset: ?>>
%subview_2 = memref.subview %arg11[%3, %4] [8, 16] [1, 1] : memref<8x32xi32> to memref<8x16xi32, strided<[32, 1], offset: ?>>
%alloc = memref.alloc() : memref<1x1x8x16xi32, 1>
%transpose = memref.transpose %subview (d0, d1) -> (d0, d1) : memref<8x16xi32, strided<[16, 1], offset: ?>> to memref<8x16xi32, strided<[16, 1], offset: ?>>
air.dma_memcpy_nd (%alloc[] [] [], %transpose[] [] []) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32, strided<[16, 1], offset: ?>>)
%alloc_3 = memref.alloc() : memref<1x1x16x16xi32, 1>
%transpose_4 = memref.transpose %subview_1 (d0, d1) -> (d0, d1) : memref<16x16xi32, strided<[32, 1], offset: ?>> to memref<16x16xi32, strided<[32, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_4[] [] []) : (memref<1x1x16x16xi32, 1>, memref<16x16xi32, strided<[32, 1], offset: ?>>)
%alloc_5 = memref.alloc() : memref<1x1x8x16xi32, 1>
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1_0, %arg15=%c1_0) args(%arg16=%alloc, %arg17=%alloc_3, %arg18=%alloc_5) : memref<1x1x8x16xi32, 1>, memref<1x1x16x16xi32, 1>, memref<1x1x8x16xi32, 1> {
%c0_i32 = arith.constant 0 : i32
%subview_8 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>
%subview_9 = memref.subview %arg17[0, %arg13, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xi32, 1> to memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1>
%subview_10 = memref.subview %arg18[%arg12, %arg13, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>
%alloc_11 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2>
%expand_shape = memref.expand_shape %subview_8 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 4, 2, 8]: memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> into memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1>
%transpose_12 = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>
air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_12[] [] []) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>)
%alloc_13 = memref.alloc() : memref<1x1x2x2x8x8xi32, 2>
%expand_shape_14 = memref.expand_shape %subview_9 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 8, 2, 8] : memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> into memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>
air.dma_memcpy_nd (%alloc_13[] [] [], %transpose_15[] [] []) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>)
%alloc_16 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2>
%transpose_17 = memref.transpose %alloc_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x2x2x4x8xi32, 2> to memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>
air.dma_memcpy_nd (%subview_10[] [] [], %transpose_17[] [] []) : (memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>, memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>)
memref.dealloc %alloc_11 : memref<1x1x2x2x4x8xi32, 2>
memref.dealloc %alloc_13 : memref<1x1x2x2x8x8xi32, 2>
memref.dealloc %alloc_16 : memref<1x1x2x2x4x8xi32, 2>
}
%subview_6 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<8x16xi32, 1>
%transpose_7 = memref.transpose %subview_6 (d0, d1) -> (d0, d1) : memref<8x16xi32, 1> to memref<8x16xi32, strided<[16, 1]>, 1>
air.dma_memcpy_nd (%subview_2[] [] [], %transpose_7[] [] []) : (memref<8x16xi32, strided<[32, 1], offset: ?>>, memref<8x16xi32, strided<[16, 1]>, 1>)
memref.dealloc %alloc_3 : memref<1x1x16x16xi32, 1>
memref.dealloc %alloc : memref<1x1x8x16xi32, 1>
memref.dealloc %alloc_5 : memref<1x1x8x16xi32, 1>
}
}
return
}

// Memref::CastOp folding.

// CHECK: air.herd @herd_0 {{.*}} args(%[[ARG0:.*]]=%{{.*}}, %[[ARG1:.*]]=%{{.*}})
// CHECK-DAG: %[[CST4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[CST3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[CST8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[CST64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[CST256:.*]] = arith.constant 256 : index
// CHECK-DAG: %[[CST768:.*]] = arith.constant 768 : index
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: air.dma_memcpy_nd (%[[ARG1]][] [] [], %[[ARG0]][%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST3]], %[[CST3]], %[[CST4]], %[[CST1]], %[[CST8]], %[[CST8]]] [%[[CST768]], %[[CST256]], %[[CST64]], %[[CST8]], %[[CST8]], %[[CST1]]]) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x32x8xi8, 1 : i32>)
// CHECK: }

func.func @func1() {
%c8 = arith.constant 8 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
air.launch (%arg3, %arg4, %arg5, %arg6) in (%arg7=%c2, %arg8=%c3, %arg9=%c3, %arg10=%c8) {
air.segment @segment_0 {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%alloc = memref.alloc() : memref<3x3x4x1x8x8xi8, 2 : i32>
%alloc_0 = memref.alloc() : memref<3x3x32x8xi8, 1 : i32>
air.herd @herd_0 tile (%arg11, %arg12) in (%arg13=%c4, %arg14=%c1) args(%arg15=%alloc_0, %arg16=%alloc) : memref<3x3x32x8xi8, 1 : i32>, memref<3x3x4x1x8x8xi8, 2 : i32> {
%cast = memref.cast %arg15 : memref<3x3x32x8xi8, 1 : i32> to memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32>
%expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4, 5]] output_shape [3, 3, 4, 8, 1, 8] : memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> into memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5) : memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> to memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg16[] [] [], %transpose[] [] []) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>)
}
}
return
}
return
}

0 comments on commit b5d30a3

Please sign in to comment.