Skip to content

Commit

Permalink
temporary fix for stalling the mesh when sram is not write ready
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Nov 8, 2024
1 parent 2916cde commit 6beac35
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
29 changes: 21 additions & 8 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import GemminiISA._
import Util._
import org.chipsalliance.cde.config.Parameters
import midas.targetutils.PerfCounter
import freechips.rocketchip.util.ClockGate

// TODO do we still need to flush when the dataflow is weight stationary? Won't the result just keep travelling through on its own?
class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: Int, config: GemminiArrayConfig[T, U, V])
Expand Down Expand Up @@ -182,9 +183,14 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val cntl_valid = mesh_cntl_signals_q.io.deq.valid
val cntl = mesh_cntl_signals_q.io.deq.bits

// Instantiate the actual mesh
val mesh = Module(new MeshWithDelays(spatialArrayInputType, spatialArrayWeightType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay,
tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks))
val sram_write_ready = Wire(Bool())
val gated_clock = ClockGate(clock, sram_write_ready, "mesh_stall_gate")

val mesh = withClock(gated_clock) {
// Instantiate the actual mesh
Module(new MeshWithDelays(spatialArrayInputType, spatialArrayWeightType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay,
tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks))
}

mesh.io.a.valid := false.B
mesh.io.b.valid := false.B
Expand Down Expand Up @@ -747,6 +753,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val first = Bool()
}

sram_write_ready := VecInit(io.srams.write.map(_.ready)).reduceTree(_ && _)
mesh.io.resp.ready := sram_write_ready

mesh_cntl_signals_q.io.enq.valid := computing

mesh_cntl_signals_q.io.enq.bits.perform_mul_pre := performing_mul_pre
Expand Down Expand Up @@ -811,7 +820,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire || !mesh.io.a.ready) &&
(!cntl.b_fire || mesh.io.b.fire || !mesh.io.b.ready) &&
(!cntl.d_fire || mesh.io.d.fire || !mesh.io.d.ready) &&
(!cntl.first || mesh.io.req.ready)
(!cntl.first || mesh.io.req.ready) // && sram_write_ready

val dataA_valid = cntl.a_garbage || cntl.a_unpadded_cols === 0.U || Mux(cntl.im2colling, im2ColValid, Mux(cntl.a_read_from_acc, accReadValid(cntl.a_bank_acc), readValid(cntl.a_bank)))

Expand All @@ -835,7 +844,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)})

// Pop responses off the scratchpad io ports
when (mesh_cntl_signals_q.io.deq.fire) {
Expand Down Expand Up @@ -883,7 +892,10 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType)))
mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType)))

mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire)
// gate this req valid
// gate mesh control signal fires
// gate control a b d fires
mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire) // && sram_write_ready

mesh.io.req.bits.tag.addr := cntl.c_addr

Expand Down Expand Up @@ -934,12 +946,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
})))

if (ex_write_to_spad) {
io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row
io.srams.write(i).valid := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row
// assert(io.srams.write(i).ready || !io.srams.write(i).valid)
io.srams.write(i).addr := w_row
io.srams.write(i).data := activated_wdata.asUInt
io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b))
} else {
io.srams.write(i).en := false.B
io.srams.write(i).valid := false.B
io.srams.write(i).addr := DontCare
io.srams.write(i).data := DontCare
io.srams.write(i).mask := DontCare
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/gemmini/MeshWithDelays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]

val req = Flipped(Decoupled(new MeshWithDelaysReq(accType, tagType.cloneType, block_size)))

val resp = Valid(new MeshWithDelaysResp(outputType, meshColumns, tileColumns, block_size, tagType.cloneType))
val resp = Decoupled(new MeshWithDelaysResp(outputType, meshColumns, tileColumns, block_size, tagType.cloneType))

val tags_in_progress = Output(Vec(tagqlen, tagType))
})
Expand Down Expand Up @@ -245,7 +245,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]

total_rows_q.io.deq.ready := io.resp.valid && io.resp.bits.last && out_matmul_id === total_rows_q.io.deq.bits.id

io.req.ready := (!req.valid || last_fire) && tagq.io.enq.ready && total_rows_q.io.enq.ready
io.req.ready := (!req.valid || last_fire) && tagq.io.enq.ready && total_rows_q.io.enq.ready && io.resp.ready
io.tags_in_progress := tagq.io.all.map(_.tag)

when (reset.asBool) {
Expand Down
29 changes: 17 additions & 12 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ class ScratchpadReadIO(val n: Int, val w: Int) extends Bundle {
}

class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundle {
val en = Output(Bool())
val valid = Output(Bool())
val ready = Input(Bool())
val addr = Output(UInt(log2Ceil(n).W))
val mask = Output(Vec(mask_len, Bool()))
val data = Output(UInt(w.W))
def fire = valid && ready
}

class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, use_shared_ext_mem: Boolean, is_dummy: Boolean) extends Module {
Expand All @@ -115,7 +117,7 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us
val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true))
val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U
// When the scratchpad is single-ported, the writes take precedence
val singleport_busy_with_write = single_ported.B && io.write.en
val singleport_busy_with_write = single_ported.B && io.write.fire

if (is_dummy) {
q.io.enq.valid := RegNext(ren)
Expand Down Expand Up @@ -147,22 +149,23 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us
val wq = Module(new Queue(ext_mem.write_req.bits.cloneType, 4, pipe=true, flow=true))
ext_mem.write_req <> wq.io.deq

wq.io.enq.valid := io.write.en
wq.io.enq.valid := io.write.valid
io.write.ready := wq.io.enq.ready
wq.io.enq.bits.addr := io.write.addr
wq.io.enq.bits.data := io.write.data
if (aligned_to >= w) {
wq.io.enq.bits.mask := VecInit((~(0.U(mask_len.W))).asBools).asUInt
} else {
wq.io.enq.bits.mask := io.write.mask.asUInt
}
assert(wq.io.enq.ready || (!io.write.en), "TODO (richard): fix this if triggered")
// assert(wq.io.enq.ready || (!io.write.en), "TODO (richard): fix this if triggered")
} else { // use valid only interface
val mem = SyncReadMem(n, Vec(mask_len, mask_elem))

val raddr = io.read.req.bits.addr
val rdata = if (single_ported) {
assert(!(ren && io.write.en))
mem.read(raddr, ren && !io.write.en).asUInt
assert(!(ren && io.write.fire))
mem.read(raddr, ren && !io.write.fire).asUInt
} else {
mem.read(raddr, ren).asUInt
}
Expand All @@ -172,7 +175,8 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us

io.read.req.ready := q_will_be_empty && !singleport_busy_with_write

when(io.write.en) {
io.write.ready := true.B
when(io.write.fire) {
if (aligned_to >= w)
mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), VecInit((~(0.U(mask_len.W))).asBools))
else
Expand Down Expand Up @@ -502,7 +506,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here
val dmawrite = write_dispatch_q.valid && write_norm_q.io.enq.ready &&
!write_dispatch_q.bits.laddr.is_garbage() &&
!(bio.write.en && config.sp_singleported.B) &&
!(bio.write.fire && config.sp_singleported.B) &&
!write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U

bio.read.req.valid := exread || dmawrite
Expand Down Expand Up @@ -554,7 +558,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,

// Writing to the SRAM banks
bank_ios.zipWithIndex.foreach { case (bio, i) =>
val exwrite = io.srams.write(i).en
val exwrite = io.srams.write(i).valid
io.srams.write(i).ready := bio.write.ready

// val laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row
val laddr = mvin_scale_pixel_repeater.io.resp.bits.laddr
Expand All @@ -572,7 +577,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last))
!((mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last))

bio.write.en := exwrite || dmaread || zerowrite
bio.write.valid := exwrite || dmaread || zerowrite

when (exwrite) {
bio.write.addr := io.srams.write(i).addr
Expand All @@ -583,13 +588,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
bio.write.data := mvin_scale_pixel_repeater.io.resp.bits.out.asUInt
bio.write.mask := mvin_scale_pixel_repeater.io.resp.bits.mask take ((spad_w / (aligned_to * 8)) max 1)

mvin_scale_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals
mvin_scale_pixel_repeater.io.resp.ready := bio.write.ready
}.elsewhen (zerowrite) {
bio.write.addr := zero_writer_pixel_repeater.io.resp.bits.laddr.sp_row()
bio.write.data := 0.U
bio.write.mask := zero_writer_pixel_repeater.io.resp.bits.mask

zero_writer_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals
zero_writer_pixel_repeater.io.resp.ready := bio.write.ready
}.otherwise {
bio.write.addr := DontCare
bio.write.data := DontCare
Expand Down

0 comments on commit 6beac35

Please sign in to comment.