diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index 6c01422d..0f463557 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -7,6 +7,7 @@ import GemminiISA._ import Util._ import org.chipsalliance.cde.config.Parameters import midas.targetutils.PerfCounter +import freechips.rocketchip.util.ClockGate // TODO do we still need to flush when the dataflow is weight stationary? Won't the result just keep travelling through on its own? class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: Int, config: GemminiArrayConfig[T, U, V]) @@ -182,9 +183,14 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val cntl_valid = mesh_cntl_signals_q.io.deq.valid val cntl = mesh_cntl_signals_q.io.deq.bits - // Instantiate the actual mesh - val mesh = Module(new MeshWithDelays(spatialArrayInputType, spatialArrayWeightType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay, - tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks)) + val sram_write_ready = Wire(Bool()) + val gated_clock = ClockGate(clock, sram_write_ready, "mesh_stall_gate") + + val mesh = withClock(gated_clock) { + // Instantiate the actual mesh + Module(new MeshWithDelays(spatialArrayInputType, spatialArrayWeightType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay, + tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks)) + } mesh.io.a.valid := false.B mesh.io.b.valid := false.B @@ -747,6 +753,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val first = Bool() } + sram_write_ready := VecInit(io.srams.write.map(_.ready)).reduceTree(_ && _) + mesh.io.resp.ready := sram_write_ready + mesh_cntl_signals_q.io.enq.valid := computing mesh_cntl_signals_q.io.enq.bits.perform_mul_pre := performing_mul_pre @@ -811,7 +820,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire || !mesh.io.a.ready) && (!cntl.b_fire || mesh.io.b.fire || !mesh.io.b.ready) && (!cntl.d_fire || mesh.io.d.fire || !mesh.io.d.ready) && - (!cntl.first || mesh.io.req.ready) + (!cntl.first || mesh.io.req.ready) // && sram_write_ready val dataA_valid = cntl.a_garbage || cntl.a_unpadded_cols === 0.U || Mux(cntl.im2colling, im2ColValid, Mux(cntl.a_read_from_acc, accReadValid(cntl.a_bank_acc), readValid(cntl.a_bank))) @@ -835,7 +844,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType))) val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType))) - val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType))) + val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}) // Pop responses off the scratchpad io ports when (mesh_cntl_signals_q.io.deq.fire) { @@ -883,7 +892,10 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType))) mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType))) - mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire) + // gate this req valid + // gate mesh control signal fires + // gate control a b d fires + mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire) // && sram_write_ready mesh.io.req.bits.tag.addr := cntl.c_addr @@ -934,12 +946,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In }))) if (ex_write_to_spad) { - io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row + io.srams.write(i).valid := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row + // assert(io.srams.write(i).ready || !io.srams.write(i).valid) io.srams.write(i).addr := w_row io.srams.write(i).data := activated_wdata.asUInt io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) } else { - io.srams.write(i).en := false.B + io.srams.write(i).valid := false.B io.srams.write(i).addr := DontCare io.srams.write(i).data := DontCare io.srams.write(i).mask := DontCare diff --git a/src/main/scala/gemmini/MeshWithDelays.scala b/src/main/scala/gemmini/MeshWithDelays.scala index 7ac4e934..7c91c7fa 100644 --- a/src/main/scala/gemmini/MeshWithDelays.scala +++ b/src/main/scala/gemmini/MeshWithDelays.scala @@ -62,7 +62,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val req = Flipped(Decoupled(new MeshWithDelaysReq(accType, tagType.cloneType, block_size))) - val resp = Valid(new MeshWithDelaysResp(outputType, meshColumns, tileColumns, block_size, tagType.cloneType)) + val resp = Decoupled(new MeshWithDelaysResp(outputType, meshColumns, tileColumns, block_size, tagType.cloneType)) val tags_in_progress = Output(Vec(tagqlen, tagType)) }) @@ -245,7 +245,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] total_rows_q.io.deq.ready := io.resp.valid && io.resp.bits.last && out_matmul_id === total_rows_q.io.deq.bits.id - io.req.ready := (!req.valid || last_fire) && tagq.io.enq.ready && total_rows_q.io.enq.ready + io.req.ready := (!req.valid || last_fire) && tagq.io.enq.ready && total_rows_q.io.enq.ready && io.resp.ready io.tags_in_progress := tagq.io.all.map(_.tag) when (reset.asBool) { diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index fc949c31..15d21d0d 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -89,10 +89,12 @@ class ScratchpadReadIO(val n: Int, val w: Int) extends Bundle { } class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundle { - val en = Output(Bool()) + val valid = Output(Bool()) + val ready = Input(Bool()) val addr = Output(UInt(log2Ceil(n).W)) val mask = Output(Vec(mask_len, Bool())) val data = Output(UInt(w.W)) + def fire = valid && ready } class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, use_shared_ext_mem: Boolean, is_dummy: Boolean) extends Module { @@ -115,7 +117,7 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true)) val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U // When the scratchpad is single-ported, the writes take precedence - val singleport_busy_with_write = single_ported.B && io.write.en + val singleport_busy_with_write = single_ported.B && io.write.fire if (is_dummy) { q.io.enq.valid := RegNext(ren) @@ -147,7 +149,8 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us val wq = Module(new Queue(ext_mem.write_req.bits.cloneType, 4, pipe=true, flow=true)) ext_mem.write_req <> wq.io.deq - wq.io.enq.valid := io.write.en + wq.io.enq.valid := io.write.valid + io.write.ready := wq.io.enq.ready wq.io.enq.bits.addr := io.write.addr wq.io.enq.bits.data := io.write.data if (aligned_to >= w) { @@ -155,14 +158,14 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us } else { wq.io.enq.bits.mask := io.write.mask.asUInt } - assert(wq.io.enq.ready || (!io.write.en), "TODO (richard): fix this if triggered") + // assert(wq.io.enq.ready || (!io.write.en), "TODO (richard): fix this if triggered") } else { // use valid only interface val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) val raddr = io.read.req.bits.addr val rdata = if (single_ported) { - assert(!(ren && io.write.en)) - mem.read(raddr, ren && !io.write.en).asUInt + assert(!(ren && io.write.fire)) + mem.read(raddr, ren && !io.write.fire).asUInt } else { mem.read(raddr, ren).asUInt } @@ -172,7 +175,8 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us io.read.req.ready := q_will_be_empty && !singleport_busy_with_write - when(io.write.en) { + io.write.ready := true.B + when(io.write.fire) { if (aligned_to >= w) mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), VecInit((~(0.U(mask_len.W))).asBools)) else @@ -502,7 +506,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here val dmawrite = write_dispatch_q.valid && write_norm_q.io.enq.ready && !write_dispatch_q.bits.laddr.is_garbage() && - !(bio.write.en && config.sp_singleported.B) && + !(bio.write.fire && config.sp_singleported.B) && !write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U bio.read.req.valid := exread || dmawrite @@ -554,7 +558,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Writing to the SRAM banks bank_ios.zipWithIndex.foreach { case (bio, i) => - val exwrite = io.srams.write(i).en + val exwrite = io.srams.write(i).valid + io.srams.write(i).ready := bio.write.ready // val laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row val laddr = mvin_scale_pixel_repeater.io.resp.bits.laddr @@ -572,7 +577,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) !((mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) - bio.write.en := exwrite || dmaread || zerowrite + bio.write.valid := exwrite || dmaread || zerowrite when (exwrite) { bio.write.addr := io.srams.write(i).addr @@ -583,13 +588,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.write.data := mvin_scale_pixel_repeater.io.resp.bits.out.asUInt bio.write.mask := mvin_scale_pixel_repeater.io.resp.bits.mask take ((spad_w / (aligned_to * 8)) max 1) - mvin_scale_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals + mvin_scale_pixel_repeater.io.resp.ready := bio.write.ready }.elsewhen (zerowrite) { bio.write.addr := zero_writer_pixel_repeater.io.resp.bits.laddr.sp_row() bio.write.data := 0.U bio.write.mask := zero_writer_pixel_repeater.io.resp.bits.mask - zero_writer_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals + zero_writer_pixel_repeater.io.resp.ready := bio.write.ready }.otherwise { bio.write.addr := DontCare bio.write.data := DontCare