diff --git a/build.sbt b/build.sbt index 6792d109..9fca7f68 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ version := "0.8-SNAPSHOT" name := "essent" -scalaVersion := "2.12.18" +scalaVersion := "2.13.12" scalacOptions ++= Seq("-deprecation", "-unchecked") @@ -16,12 +16,11 @@ libraryDependencies += "org.json4s" %% "json4s-native" % "3.6.12" libraryDependencies += "edu.berkeley.cs" %% "firrtl" % "1.5.6" - // Assembly -assemblyJarName in assembly := "essent.jar" +assembly / assemblyJarName := "essent.jar" -assemblyOutputPath in assembly := file("./utils/bin/essent.jar") +assembly / assemblyOutputPath:= file("./utils/bin/essent.jar") // Ignore disabled .scala files @@ -31,7 +30,7 @@ unmanagedSources / excludeFilter := HiddenFileFilter || "*disabled*.scala" // Publishing setup publishMavenStyle := true -publishArtifact in Test := false +Test / publishArtifact := false pomIncludeRepository := { x => false } // POM info diff --git a/project/build.properties b/project/build.properties index c8fcab54..27430827 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.6.2 +sbt.version=1.9.6 diff --git a/src/main/scala/ActivityTracker.scala b/src/main/scala/ActivityTracker.scala index a3abc22b..cf740057 100644 --- a/src/main/scala/ActivityTracker.scala +++ b/src/main/scala/ActivityTracker.scala @@ -31,7 +31,7 @@ class ActivityTracker(w: Writer, opt: OptFlags) { } def declareSigTracking(sg: StatementGraph, topName: String): Unit = { - val allNamesAndTypes = sg.collectValidStmts(sg.nodeRange) flatMap findStmtNameAndType + val allNamesAndTypes = sg.collectValidStmts(sg.nodeRange()) flatMap findStmtNameAndType sigNameToID = (allNamesAndTypes map { _._1 }).zipWithIndex.toMap diff --git a/src/main/scala/AcyclicPart.scala b/src/main/scala/AcyclicPart.scala index aa635435..a7fb4123 100644 --- a/src/main/scala/AcyclicPart.scala +++ b/src/main/scala/AcyclicPart.scala @@ -35,14 +35,14 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi totalInDegree + totalOutDegree - (mergedInDegree + mergedOutDegree) } - def coarsenWithMFFCs() { + def coarsenWithMFFCs(): Unit = { val mffcResults = MFFC(mg, excludeSet) mg.applyInitialAssignments(mffcResults) logger.info(s" #mffcs found: ${mg.mergeIDToMembers.size - excludeSet.size}") logger.info(s" largest mffc: ${(mg.mergeIDToMembers.values.map{_.size}).max}") } - def mergeSingleInputPartsIntoParents(smallPartCutoff: Int = 20) { + def mergeSingleInputPartsIntoParents(smallPartCutoff: Int = 20): Unit = { val smallPartIDs = findSmallParts(smallPartCutoff) val singleInputIDs = smallPartIDs filter { id => (mg.inNeigh(id).size == 1) } val singleInputParents = (singleInputIDs flatMap mg.inNeigh).distinct @@ -57,7 +57,7 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi mergeSingleInputPartsIntoParents(smallPartCutoff) } - def mergeSmallSiblings(smallPartCutoff: Int = 10) { + def mergeSmallSiblings(smallPartCutoff: Int = 10): Unit = { val smallPartIDs = findSmallParts(smallPartCutoff) val inputsAndIDPairs = smallPartIDs map { id => { val inputsCanonicalized = mg.inNeigh(id).toSeq.sorted @@ -75,11 +75,11 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi } } - def mergeSmallParts(smallPartCutoff: Int = 20, mergeThreshold: Double = 0.5) { + def mergeSmallParts(smallPartCutoff: Int = 20, mergeThreshold: Double = 0.5): Unit = { val smallPartIDs = findSmallParts(smallPartCutoff) val mergesToConsider = smallPartIDs flatMap { id => { val numInputs = mg.inNeigh(id).size.toDouble - val siblings = (mg.inNeigh(id) flatMap mg.outNeigh).distinct - id + val siblings = (mg.inNeigh(id) flatMap mg.outNeigh).distinct.filter(_ != id) val legalSiblings = siblings filter { sibID => !excludeSet.contains(sibID) } val orderConstrSibs = legalSiblings filter { _ < id } val myInputSet = mg.inNeigh(id).toSet @@ -101,7 +101,7 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi } } - def mergeSmallPartsDown(smallPartCutoff: Int = 20) { + def mergeSmallPartsDown(smallPartCutoff: Int = 20): Unit = { val smallPartIDs = findSmallParts(smallPartCutoff) val mergesToConsider = smallPartIDs flatMap { id => { val mergeableChildren = mg.outNeigh(id) filter { @@ -122,7 +122,7 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi } } - def partition(smallPartCutoff: Int = 20) { + def partition(smallPartCutoff: Int = 20): Unit = { val toApply = Seq( ("mffc", {ap: AcyclicPart => ap.coarsenWithMFFCs()}), ("single", {ap: AcyclicPart => ap.mergeSingleInputPartsIntoParents()}), @@ -141,16 +141,16 @@ class AcyclicPart(val mg: MergeGraph, excludeSet: Set[NodeID]) extends LazyLoggi assert(checkPartioning()) } - def iterParts() = mg.iterGroups + def iterParts() = mg.iterGroups() def checkPartioning() = { val includedSoFar = HashSet[NodeID]() - val disjoint = mg.iterGroups forall { case (macroID, memberIDs) => { + val disjoint = mg.iterGroups() forall { case (macroID, memberIDs) => { val overlap = includedSoFar.intersect(memberIDs.toSet).nonEmpty includedSoFar ++= memberIDs !overlap }} - val complete = includedSoFar == mg.nodeRange.toSet + val complete = includedSoFar == mg.nodeRange().toSet disjoint && complete } } diff --git a/src/main/scala/ArgsParser.scala b/src/main/scala/ArgsParser.scala index 3c463cb0..5768e5ec 100644 --- a/src/main/scala/ArgsParser.scala +++ b/src/main/scala/ArgsParser.scala @@ -23,7 +23,7 @@ case class OptFlags( essentLogLevel: String = "warn", firrtlLogLevel: String = "warn") { def inputFileDir() = firInputFile.getParent - def outputDir() = if (inputFileDir == null) "" else inputFileDir() + def outputDir() = if (inputFileDir() == null) "" else inputFileDir() } class ArgsParser { diff --git a/src/main/scala/Compiler.scala b/src/main/scala/Compiler.scala index ea12704f..53a0c6fc 100644 --- a/src/main/scala/Compiler.scala +++ b/src/main/scala/Compiler.scala @@ -17,12 +17,12 @@ import logger._ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends LazyLogging { val flagVarName = "PARTflags" - implicit val rn = new Renamer + implicit val rn: Renamer = new Renamer val actTrac = new ActivityTracker(w, initialOpt) - val vcd = if (initialOpt.withVCD) Some(new Vcd(circuit,initialOpt,w,rn)) else None + val vcd: Option[Vcd] = if (initialOpt.withVCD) Some(new Vcd(circuit,initialOpt,w,rn)) else None // Declaring Modules //---------------------------------------------------------------------------- - def declareModule(m: Module, topName: String) { + def declareModule(m: Module, topName: String): Unit = { val registers = findInstancesOf[DefRegister](m.body) val memories = findInstancesOf[DefMemory](m.body) val registerDecs = registers flatMap {d: DefRegister => { @@ -57,7 +57,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L } } - def declareExtModule(m: ExtModule) { + def declareExtModule(m: ExtModule): Unit = { val modName = m.name w.writeLines(0, "") w.writeLines(0, s"typedef struct $modName {") @@ -70,8 +70,8 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L //---------------------------------------------------------------------------- // TODO: move specialized CondMux emitter elsewhere? def writeBodyInner(indentLevel: Int, sg: StatementGraph, opt: OptFlags, - keepAvail: Set[String] = Set()) { - sg.stmtsOrdered foreach { stmt => stmt match { + keepAvail: Set[String] = Set()): Unit = { + sg.stmtsOrdered() foreach { stmt => stmt match { case cm: CondMux => { if (rn.nameToMeta(cm.name).decType == MuxOut) w.writeLines(indentLevel, s"${genCppType(cm.mux.tpe)} ${rn.emit(cm.name)};") @@ -91,8 +91,8 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L }} } - def checkRegResetSafety(sg: StatementGraph) { - val updatesWithResets = sg.allRegDefs filter { r => emitExpr(r.reset) != "UInt<1>(0x0)" } + def checkRegResetSafety(sg: StatementGraph): Unit = { + val updatesWithResets = sg.allRegDefs() filter { r => emitExpr(r.reset) != "UInt<1>(0x0)" } assert(updatesWithResets.isEmpty) } @@ -117,7 +117,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L condPartWorker: MakeCondPart, topName: String, extIOtypes: Map[String, Type], - opt: OptFlags) { + opt: OptFlags): Unit = { // predeclare part outputs val outputPairs = condPartWorker.getPartOutputsToDeclare() val outputConsumers = condPartWorker.getPartInputMap() @@ -134,7 +134,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L w.writeLines(1, s"bool done_reset;") w.writeLines(1, s"bool verbose;") w.writeLines(0, "") - sg.stmtsOrdered foreach { stmt => stmt match { + sg.stmtsOrdered() foreach { stmt => stmt match { case cp: CondPart => { w.writeLines(1, s"void ${genEvalFuncName(cp.id)}() {") if (!cp.alwaysActive) @@ -171,7 +171,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L w.writeLines(0, "") } - def writeZoningBody(sg: StatementGraph, condPartWorker: MakeCondPart, opt: OptFlags) { + def writeZoningBody(sg: StatementGraph, condPartWorker: MakeCondPart, opt: OptFlags): Unit = { w.writeLines(2, "if (reset || !done_reset) {") w.writeLines(3, "sim_cached = false;") w.writeLines(3, "regs_set = false;") @@ -192,7 +192,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L sigName => s"${rn.emit(sigName + condPartWorker.cacheSuffix)} = ${rn.emit(sigName)};" } w.writeLines(2, extIOCaches.toSeq) - sg.stmtsOrdered foreach { stmt => stmt match { + sg.stmtsOrdered() foreach { stmt => stmt match { case cp: CondPart => { if (!cp.alwaysActive) w.writeLines(2, s"if (UNLIKELY($flagVarName[${cp.id}])) ${genEvalFuncName(cp.id)}();") @@ -210,7 +210,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L // General Structure (and Compiler Boilerplate) //---------------------------------------------------------------------------- - def execute(circuit: Circuit) { + def execute(circuit: Circuit): Unit = { val opt = initialOpt val topName = circuit.main val headerGuardName = topName.toUpperCase + "_H_" @@ -234,7 +234,7 @@ class EssentEmitter(initialOpt: OptFlags, w: Writer, circuit: Circuit) extends L w.writeLines(1,s"""char VCD_BUF[2000];""") } val sg = StatementGraph(circuit, opt.removeFlatConnects) - logger.info(sg.makeStatsString) + logger.info(sg.makeStatsString()) val containsAsserts = sg.containsStmtOfType[Stop]() val extIOMap = findExternalPorts(circuit) val condPartWorker = MakeCondPart(sg, rn, extIOMap) @@ -332,10 +332,10 @@ class EssentCompiler(opt: OptFlags) { Dependency(essent.passes.ReplaceRsvdKeywords) ) - def compileAndEmit(circuit: Circuit) { + def compileAndEmit(circuit: Circuit): Unit = { val topName = circuit.main if (opt.writeHarness) { - val harnessFilename = new File(opt.outputDir, s"$topName-harness.cc") + val harnessFilename = new File(opt.outputDir(), s"$topName-harness.cc") val harnessWriter = new FileWriter(harnessFilename) if (opt.withVCD) { HarnessGenerator.topFile(topName, harnessWriter," | dut.genWaveHeader();") } else { HarnessGenerator.topFile(topName, harnessWriter, "")} @@ -344,11 +344,11 @@ class EssentCompiler(opt: OptFlags) { val firrtlCompiler = new transforms.Compiler(readyForEssent) val resultState = firrtlCompiler.execute(CircuitState(circuit, Seq())) if (opt.dumpLoFirrtl) { - val debugWriter = new FileWriter(new File(opt.outputDir, s"$topName.lo.fir")) + val debugWriter = new FileWriter(new File(opt.outputDir(), s"$topName.lo.fir")) debugWriter.write(resultState.circuit.serialize) debugWriter.close() } - val dutWriter = new FileWriter(new File(opt.outputDir, s"$topName.h")) + val dutWriter = new FileWriter(new File(opt.outputDir(), s"$topName.h")) val emitter = new EssentEmitter(opt, dutWriter,resultState.circuit) emitter.execute(resultState.circuit) dutWriter.close() diff --git a/src/main/scala/Driver.scala b/src/main/scala/Driver.scala index abe62c2e..e9e2aa40 100644 --- a/src/main/scala/Driver.scala +++ b/src/main/scala/Driver.scala @@ -7,18 +7,18 @@ import logger._ object Driver { - def main(args: Array[String]) { - (new ArgsParser).getConfig(args) match { + def main(args: Array[String]): Unit = { + (new ArgsParser).getConfig(args.toSeq) match { case Some(config) => generate(config) case None => } } - def generate(opt: OptFlags) { + def generate(opt: OptFlags): Unit = { Logger.setClassLogLevels(Map("essent" -> logger.LogLevel(opt.essentLogLevel))) Logger.setClassLogLevels(Map("firrtl" -> logger.LogLevel(opt.firrtlLogLevel))) val sourceReader = Source.fromFile(opt.firInputFile) - val circuit = firrtl.Parser.parse(sourceReader.getLines, firrtl.Parser.IgnoreInfo) + val circuit = firrtl.Parser.parse(sourceReader.getLines(), firrtl.Parser.IgnoreInfo) sourceReader.close() val compiler = new EssentCompiler(opt) compiler.compileAndEmit(circuit) diff --git a/src/main/scala/Emitter.scala b/src/main/scala/Emitter.scala index f5bbb24f..27ee45b9 100644 --- a/src/main/scala/Emitter.scala +++ b/src/main/scala/Emitter.scala @@ -224,7 +224,7 @@ object Emitter { val printWidth = math.ceil(width.toDouble/4).toInt (format, s"""%0${printWidth}" PRIx64 """") } else { - val printWidth = math.ceil(math.log10((1l< Seq(HyperedgeDep(emitExpr(ru.regRef)+"$final", findDependencesExpr(ru.expr), s)) case mw: MemWrite => val deps = Seq(mw.wrEn, mw.wrMask, mw.wrAddr, mw.wrData) flatMap findDependencesExpr - Seq(HyperedgeDep(mw.nodeName, deps.distinct, s)) + Seq(HyperedgeDep(mw.nodeName(), deps.distinct, s)) case p: Print => val deps = (Seq(p.en) ++ p.args) flatMap findDependencesExpr val uniqueName = "PRINTF" + emitExpr(p.clk) + deps.mkString("$") + Util.tidyString(p.string.serialize) @@ -216,7 +216,7 @@ object Extract extends LazyLogging { namesToExclude: Set[String]): Seq[Statement] = { def isRef(e: Expression): Boolean = e.isInstanceOf[WRef] || e.isInstanceOf[WSubField] def findChainRenames(sg: StatementGraph): Map[String, String] = { - val sourceIDs = sg.nodeRange filter { sg.inNeigh(_).isEmpty } + val sourceIDs = sg.nodeRange() filter { sg.inNeigh(_).isEmpty } def reachableIDs(id: Int): Seq[Int] = { Seq(id) ++ (sg.outNeigh(id) flatMap reachableIDs) } diff --git a/src/main/scala/Graph.scala b/src/main/scala/Graph.scala index 7cdc062f..f2d51348 100644 --- a/src/main/scala/Graph.scala +++ b/src/main/scala/Graph.scala @@ -20,7 +20,7 @@ class Graph { // Graph building //---------------------------------------------------------------------------- - def growNeighsIfNeeded(id: NodeID) { + def growNeighsIfNeeded(id: NodeID): Unit = { assert(id >= 0) if (id >= outNeigh.size) { val numElemsToGrow = id - outNeigh.size + 1 @@ -29,13 +29,13 @@ class Graph { } } - def addEdge(sourceID: NodeID, destID: NodeID) { + def addEdge(sourceID: NodeID, destID: NodeID): Unit = { growNeighsIfNeeded(math.max(sourceID, destID)) outNeigh(sourceID) += destID inNeigh(destID) += sourceID } - def addEdgeIfNew(sourceID: NodeID, destID: NodeID) { + def addEdgeIfNew(sourceID: NodeID, destID: NodeID): Unit = { if ((sourceID >= outNeigh.size) || !outNeigh(sourceID).contains(destID)) addEdge(sourceID, destID) } @@ -79,16 +79,16 @@ class Graph { // Mutators //---------------------------------------------------------------------------- - def removeDuplicateEdges() { + def removeDuplicateEdges(): Unit = { // will not remove self-loops - def uniquifyNeighs(neighs: AdjacencyList) { - (0 until neighs.size) foreach { id => neighs(id) = neighs(id).distinct } + def uniquifyNeighs(neighs: AdjacencyList): Unit = { + neighs.indices foreach { id => neighs(id) = neighs(id).distinct } } uniquifyNeighs(outNeigh) uniquifyNeighs(inNeigh) } - def mergeNodesMutably(mergeDest: NodeID, mergeSources: Seq[NodeID]) { + def mergeNodesMutably(mergeDest: NodeID, mergeSources: Seq[NodeID]): Unit = { val mergedID = mergeDest val idsToRemove = mergeSources val idsToMerge = mergeSources :+ mergeDest @@ -103,8 +103,8 @@ class Graph { inNeigh(outNeighID) --= idsToRemove if (!inNeigh(outNeighID).contains(mergedID)) inNeigh(outNeighID) += mergedID }} - inNeigh(mergedID) = combinedInNeigh.to[ArrayBuffer] - outNeigh(mergedID) = combinedOutNeigh.to[ArrayBuffer] + inNeigh(mergedID) = combinedInNeigh.to(ArrayBuffer) + outNeigh(mergedID) = combinedOutNeigh.to(ArrayBuffer) idsToRemove foreach { deleteID => { inNeigh(deleteID).clear() outNeigh(deleteID).clear() diff --git a/src/main/scala/Harness.scala b/src/main/scala/Harness.scala index 5675f061..16bc2c41 100644 --- a/src/main/scala/Harness.scala +++ b/src/main/scala/Harness.scala @@ -4,6 +4,7 @@ import firrtl._ import firrtl.ir._ import java.io.Writer +import scala.collection.mutable.ArrayBuffer object HarnessGenerator { def harnessConnections(m: Module) = { @@ -54,9 +55,9 @@ object HarnessGenerator { val mapConnects = (internalNames.zipWithIndex) map { case (label: String, index: Int) => s"""comm->map_signal("$modName.$label", $index);""" } - (origOrderInputNames ++ reorderPorts(inputNames) map connectSignal("in_")) ++ - (reorderPorts(outputNames) map connectSignal("out_")) ++ - (reorderPorts(signalNames) map connectSignal("")) ++ mapConnects + ((origOrderInputNames ++ reorderPorts(inputNames.toSeq) map {connectSignal("in_")(_)}) ++ + (reorderPorts(outputNames.toSeq) map {connectSignal("out_")(_)}) ++ + (reorderPorts(signalNames.toSeq) map {connectSignal("")(_)}) ++ mapConnects).toSeq } def topFile(circuitName: String, writer: Writer , vcdHeader: String) = { diff --git a/src/main/scala/MFFC.scala b/src/main/scala/MFFC.scala index 6d5407cc..f0ec9e8b 100644 --- a/src/main/scala/MFFC.scala +++ b/src/main/scala/MFFC.scala @@ -9,18 +9,18 @@ class MFFC(val g: Graph) { import MFFC.{Unclaimed,Excluded} // numeric vertex ID -> MFFC ID - val mffc = ArrayBuffer.fill(g.numNodes)(Unclaimed) + val mffc = ArrayBuffer.fill(g.numNodes())(Unclaimed) - def overrideMFFCs(newAssignments: ArrayBuffer[NodeID]) { + def overrideMFFCs(newAssignments: ArrayBuffer[NodeID]): Unit = { mffc.clear() - newAssignments.copyToBuffer(mffc) + mffc ++= newAssignments } def findMFFCs(): ArrayBuffer[NodeID] = { - val unvisitedSinks = g.nodeRange filter { + val unvisitedSinks = g.nodeRange() filter { id => mffc(id) == Unclaimed && g.outNeigh(id).isEmpty } - val visited = g.nodeRange filter { id => mffc(id) != Unclaimed } + val visited = g.nodeRange() filter { id => mffc(id) != Unclaimed } val fringe = (visited flatMap(g.inNeigh)).distinct val unvisitedFringe = fringe filter { mffc(_) == Unclaimed } val newMFFCseeds = unvisitedSinks.toSet ++ unvisitedFringe @@ -33,7 +33,7 @@ class MFFC(val g: Graph) { } } - def maximizeFFCs(fringe: Set[NodeID]) { + def maximizeFFCs(fringe: Set[NodeID]): Unit = { val fringeAncestors = fringe flatMap g.inNeigh filter { mffc(_) == Unclaimed } val newMembers = fringeAncestors flatMap { parent => { val childrenMFFCs = (g.outNeigh(parent) map mffc).distinct diff --git a/src/main/scala/MergeGraph.scala b/src/main/scala/MergeGraph.scala index 3e64e4cd..74be9040 100644 --- a/src/main/scala/MergeGraph.scala +++ b/src/main/scala/MergeGraph.scala @@ -19,23 +19,23 @@ class MergeGraph extends Graph { // inherits outNeigh and inNeigh from Graph - def buildFromGraph(g: Graph) { + def buildFromGraph(g: Graph): Unit = { // FUTURE: cleaner way to do this with clone on superclass? - outNeigh.appendAll(ArrayBuffer.fill(g.numNodes)(ArrayBuffer[NodeID]())) - inNeigh.appendAll(ArrayBuffer.fill(g.numNodes)(ArrayBuffer[NodeID]())) - g.nodeRange foreach { id => { - g.outNeigh(id).copyToBuffer(outNeigh(id)) - g.inNeigh(id).copyToBuffer(inNeigh(id)) + outNeigh.appendAll(ArrayBuffer.fill(g.numNodes())(ArrayBuffer[NodeID]())) + inNeigh.appendAll(ArrayBuffer.fill(g.numNodes())(ArrayBuffer[NodeID]())) + g.nodeRange() foreach { id => { + outNeigh(id) ++= g.outNeigh(id) + inNeigh(id) ++= g.inNeigh(id) }} - ArrayBuffer.range(0, numNodes()).copyToBuffer(idToMergeID) + idToMergeID ++= ArrayBuffer.range(0, numNodes()) nodeRange() foreach { id => mergeIDToMembers(id) = Seq(id) } } - def applyInitialAssignments(initialAssignments: ArrayBuffer[NodeID]) { + def applyInitialAssignments(initialAssignments: ArrayBuffer[NodeID]): Unit = { // FUTURE: support negative (unassigned) initial assignments idToMergeID.clear() mergeIDToMembers.clear() - initialAssignments.copyToBuffer(idToMergeID) + idToMergeID ++= initialAssignments val asMap = Util.groupIndicesByValue(initialAssignments) asMap foreach { case (mergeID, members) => { assert(members.contains(mergeID)) @@ -44,8 +44,8 @@ class MergeGraph extends Graph { }} } - def mergeGroups(mergeDest: NodeID, mergeSources: Seq[NodeID]) { - val newMembers = (mergeSources map mergeIDToMembers).flatten + def mergeGroups(mergeDest: NodeID, mergeSources: Seq[NodeID]): Unit = { + val newMembers = mergeSources flatMap mergeIDToMembers newMembers foreach { id => idToMergeID(id) = mergeDest} mergeIDToMembers(mergeDest) ++= newMembers mergeSources foreach { id => mergeIDToMembers.remove(id) } diff --git a/src/main/scala/OptElideRegUpdates.scala b/src/main/scala/OptElideRegUpdates.scala index c0fbee56..a0cfc322 100644 --- a/src/main/scala/OptElideRegUpdates.scala +++ b/src/main/scala/OptElideRegUpdates.scala @@ -8,7 +8,7 @@ import firrtl.ir._ object OptElideRegUpdates extends LazyLogging { - def apply(sg: StatementGraph) { + def apply(sg: StatementGraph): Unit = { def safeToMergeWithParentNextNode(u: NodeID): Boolean = { sg.inNeigh(u).nonEmpty && // node u isn't floating (parentless) sg.idToName(sg.inNeigh(u).head).endsWith("$next") && // first parent assigns $next diff --git a/src/main/scala/OptMakeCondMux.scala b/src/main/scala/OptMakeCondMux.scala index 7bb2f280..da0b4674 100644 --- a/src/main/scala/OptMakeCondMux.scala +++ b/src/main/scala/OptMakeCondMux.scala @@ -34,7 +34,7 @@ class MakeCondMux(val sg: StatementGraph, rn: Renamer, keepAvail: Set[NodeID]) { sg.idToStmt(muxID) mapExpr replaceMux(muxWay) } - def makeCondMuxesTopDown(muxIDsRemaining: Set[NodeID], muxIDToWays: Map[NodeID,(Seq[NodeID],Seq[NodeID])]) { + def makeCondMuxesTopDown(muxIDsRemaining: Set[NodeID], muxIDToWays: Map[NodeID,(Seq[NodeID],Seq[NodeID])]): Unit = { val muxesWithMuxesInside = muxIDToWays collect { case (muxID, (tWay, fWay)) if ((tWay ++ fWay) exists muxIDsRemaining) => muxID } @@ -55,7 +55,7 @@ class MakeCondMux(val sg: StatementGraph, rn: Renamer, keepAvail: Set[NodeID]) { } } - def doOpt() { + def doOpt(): Unit = { val muxIDs = (sg.idToStmt.zipWithIndex collect { case (DefNode(_, _, m: Mux), id) => id case (Connect(_, _, m: Mux), id) => id @@ -79,7 +79,7 @@ class MakeCondMux(val sg: StatementGraph, rn: Renamer, keepAvail: Set[NodeID]) { object MakeCondMux { // FUTURE: pull mux chains into if else chains to reduce indent depth // FUTURE: consider mux size threshold - def apply(sg: StatementGraph, rn: Renamer, keepAvailNames: Set[String] = Set()) { + def apply(sg: StatementGraph, rn: Renamer, keepAvailNames: Set[String] = Set()): Unit = { val keepAvailIDs = keepAvailNames map sg.nameToID val optimizer = new MakeCondMux(sg, rn, keepAvailIDs) optimizer.doOpt() diff --git a/src/main/scala/OptMakeCondPart.scala b/src/main/scala/OptMakeCondPart.scala index c972c2b5..a1062b9f 100644 --- a/src/main/scala/OptMakeCondPart.scala +++ b/src/main/scala/OptMakeCondPart.scala @@ -16,8 +16,8 @@ class MakeCondPart(sg: StatementGraph, rn: Renamer, extIOtypes: Map[String, Type val alreadyDeclared = sg.stateElemNames().toSet - def convertIntoCPStmts(ap: AcyclicPart, excludedIDs: Set[NodeID]) { - val idToMemberIDs = ap.iterParts + def convertIntoCPStmts(ap: AcyclicPart, excludedIDs: Set[NodeID]): Unit = { + val idToMemberIDs = ap.iterParts() val idToMemberStmts = (idToMemberIDs map { case (id, members) => { val memberStmts = sg.idToStmt(id) match { case cp: CondPart => cp.memberStmts @@ -25,7 +25,7 @@ class MakeCondPart(sg: StatementGraph, rn: Renamer, extIOtypes: Map[String, Type } (id -> memberStmts) }}).toMap - val idToProducedOutputs = idToMemberStmts mapValues { _ flatMap findResultName } + val idToProducedOutputs = idToMemberStmts.view.mapValues { _ flatMap findResultName } val idToInputNames = idToMemberStmts map { case (id, memberStmts) => { val partDepNames = memberStmts flatMap findDependencesStmt flatMap { _.deps } val externalDepNames = partDepNames.toSet -- (idToProducedOutputs(id).toSet -- alreadyDeclared) @@ -66,16 +66,16 @@ class MakeCondPart(sg: StatementGraph, rn: Renamer, extIOtypes: Map[String, Type else { val newGroupID = matchingIDs.min val memberStmts = matchingIDs map sg.idToStmt - val tempCPstmt = CondPart(newGroupID, true, Seq(), memberStmts, Map()) - sg.mergeStmtsMutably(newGroupID, matchingIDs diff Seq(newGroupID), tempCPstmt) + val tempCPstmt = CondPart(newGroupID, true, Seq(), memberStmts.toSeq, Map()) + sg.mergeStmtsMutably(newGroupID, (matchingIDs diff Seq(newGroupID)).toSeq, tempCPstmt) Some(newGroupID) } } - def doOpt(smallPartCutoff: Int = 20) { + def doOpt(smallPartCutoff: Int = 20): Unit = { val excludedIDs = ArrayBuffer[Int]() clumpByStmtType[Print]() foreach { excludedIDs += _ } - excludedIDs ++= (sg.nodeRange filterNot sg.validNodes) + excludedIDs ++= (sg.nodeRange() filterNot sg.validNodes) val ap = AcyclicPart(sg, excludedIDs.toSet) ap.partition(smallPartCutoff) convertIntoCPStmts(ap, excludedIDs.toSet) @@ -97,7 +97,7 @@ class MakeCondPart(sg: StatementGraph, rn: Renamer, extIOtypes: Map[String, Type case cp: CondPart => cp.outputsToDeclare.toSeq case _ => Seq() }} - allPartOutputTypes + allPartOutputTypes.toSeq } def getExternalPartInputNames(): Seq[String] = { diff --git a/src/main/scala/Renamer.scala b/src/main/scala/Renamer.scala index 70abe07d..1929e6c9 100644 --- a/src/main/scala/Renamer.scala +++ b/src/main/scala/Renamer.scala @@ -20,9 +20,9 @@ class Renamer { val nameToEmitName = HashMap[String,String]() val nameToMeta = HashMap[String,SigMeta]() - def populateFromSG(sg: StatementGraph, extIOMap: Map[String,Type]) { - val stateNames = sg.stateElemNames.toSet - sg.nodeRange foreach { id => { + def populateFromSG(sg: StatementGraph, extIOMap: Map[String,Type]): Unit = { + val stateNames = sg.stateElemNames().toSet + sg.nodeRange() foreach { id => { val name = sg.idToName(id) val decType = if (stateNames.contains(name)) RegSet else if (extIOMap.contains(name)) ExtIO @@ -40,7 +40,7 @@ class Renamer { fixEmitNames() } - def fixEmitNames() { + def fixEmitNames(): Unit = { def shouldBeLocal(meta: SigMeta) = meta.decType match { case Local | MuxOut | PartOut => true case _ => false @@ -53,13 +53,13 @@ class Renamer { } } - def mutateDecTypeIfLocal(name: String, newDecType: SigDecType) { + def mutateDecTypeIfLocal(name: String, newDecType: SigDecType): Unit = { val currentMeta = nameToMeta(name) if (currentMeta.decType == Local) nameToMeta(name) = currentMeta.copy(decType = newDecType) } - def addPartCache(name: String, sigType: firrtl.ir.Type) { + def addPartCache(name: String, sigType: firrtl.ir.Type): Unit = { nameToEmitName(name) = removeDots(name) nameToMeta(name) = SigMeta(PartCache, sigType) } diff --git a/src/main/scala/StatementGraph.scala b/src/main/scala/StatementGraph.scala index c9638b40..f05699c9 100644 --- a/src/main/scala/StatementGraph.scala +++ b/src/main/scala/StatementGraph.scala @@ -47,16 +47,16 @@ class StatementGraph extends Graph { } } - def addEdge(sourceName: String, destName: String) { + def addEdge(sourceName: String, destName: String): Unit = { super.addEdge(getID(sourceName), getID(destName)) } - def addEdgeIfNew(sourceName: String, destName: String) { + def addEdgeIfNew(sourceName: String, destName: String): Unit = { super.addEdgeIfNew(getID(sourceName), getID(destName)) } def addStatementNode(resultName: String, depNames: Seq[String], - stmt: Statement = EmptyStmt) = { + stmt: Statement = EmptyStmt): Unit = { val potentiallyNewDestID = getID(resultName) depNames foreach {depName : String => addEdge(depName, resultName)} if (potentiallyNewDestID >= idToStmt.size) { @@ -69,7 +69,7 @@ class StatementGraph extends Graph { validNodes += potentiallyNewDestID } - def buildFromBodies(bodies: Seq[Statement]) { + def buildFromBodies(bodies: Seq[Statement]): Unit = { val bodyHE = bodies flatMap { case b: Block => b.stmts flatMap findDependencesStmt case s => findDependencesStmt(s) @@ -82,26 +82,26 @@ class StatementGraph extends Graph { //---------------------------------------------------------------------------- def collectValidStmts(ids: Seq[NodeID]): Seq[Statement] = ids filter validNodes map idToStmt - def stmtsOrdered(): Seq[Statement] = collectValidStmts(TopologicalSort(this)) + def stmtsOrdered(): Seq[Statement] = collectValidStmts(TopologicalSort(this).toSeq) def containsStmtOfType[T <: Statement]()(implicit tag: ClassTag[T]): Boolean = { (idToStmt collectFirst { case s: T => s }).isDefined } def findIDsOfStmtOfType[T <: Statement]()(implicit tag: ClassTag[T]): Seq[NodeID] = { - idToStmt.zipWithIndex collect { case (s: T , id: Int) => id } + (idToStmt.zipWithIndex collect { case (s: T , id: Int) => id }).toSeq } - def allRegDefs(): Seq[DefRegister] = idToStmt collect { + def allRegDefs(): Seq[DefRegister] = (idToStmt collect { case dr: DefRegister => dr - } + }).toSeq - def stateElemNames(): Seq[String] = idToStmt collect { + def stateElemNames(): Seq[String] = (idToStmt collect { case dr: DefRegister => dr.name case dm: DefMemory => dm.name - } + }).toSeq - def stateElemIDs() = findIDsOfStmtOfType[DefRegister] ++ findIDsOfStmtOfType[DefMemory] + def stateElemIDs() = findIDsOfStmtOfType[DefRegister]() ++ findIDsOfStmtOfType[DefMemory]() def mergeIsAcyclic(nameA: String, nameB: String): Boolean = { val idA = nameToID(nameA) @@ -114,8 +114,8 @@ class StatementGraph extends Graph { // Mutation //---------------------------------------------------------------------------- - def addOrderingDepsForStateUpdates() { - def addOrderingEdges(writerID: NodeID, readerTargetID: NodeID) { + def addOrderingDepsForStateUpdates(): Unit = { + def addOrderingEdges(writerID: NodeID, readerTargetID: NodeID): Unit = { outNeigh(readerTargetID) foreach { readerID => if (readerID != writerID) addEdgeIfNew(readerID, writerID) } @@ -132,7 +132,7 @@ class StatementGraph extends Graph { }} } - def mergeStmtsMutably(mergeDest: NodeID, mergeSources: Seq[NodeID], mergeStmt: Statement) { + def mergeStmtsMutably(mergeDest: NodeID, mergeSources: Seq[NodeID], mergeStmt: Statement): Unit = { val mergedID = mergeDest val idsToRemove = mergeSources idsToRemove foreach { id => idToStmt(id) = EmptyStmt } @@ -151,7 +151,7 @@ class StatementGraph extends Graph { def numNodeRefs() = idToName.size def makeStatsString() = - s"Graph has $numNodes nodes ($numValidNodes valid) and $numEdges edges" + s"Graph has ${numNodes()} nodes (${numValidNodes()} valid) and ${numEdges()} edges" } diff --git a/src/main/scala/TopologicalSort.scala b/src/main/scala/TopologicalSort.scala index d26c9831..f1616349 100644 --- a/src/main/scala/TopologicalSort.scala +++ b/src/main/scala/TopologicalSort.scala @@ -9,7 +9,7 @@ object TopologicalSort { val finalOrdering = ArrayBuffer[NodeID]() val inStack = BitSet() val finished = BitSet() - def visit(v: NodeID) { + def visit(v: NodeID): Unit = { if (inStack(v)) { findCyclesByTopoSort(g) match { case None => throw new Exception("Was a cycle but couldn't reproduce") @@ -26,7 +26,7 @@ object TopologicalSort { finalOrdering += v } } - g.nodeRange foreach { startingID => visit(startingID) } + g.nodeRange() foreach { startingID => visit(startingID) } finalOrdering } @@ -35,7 +35,7 @@ object TopologicalSort { var cycleFound: Option[Seq[NodeID]] = None val inStack = BitSet() val finished = BitSet() - val callerIDs = ArrayBuffer.fill(bg.numNodes)(-1) + val callerIDs = ArrayBuffer.fill(bg.numNodes())(-1) def backtrackToFindCycle(v: NodeID, cycleSoFar: Seq[NodeID]): Seq[NodeID] = { if (callerIDs(v) == -1) cycleSoFar @@ -48,7 +48,7 @@ object TopologicalSort { } } - def visit(v: NodeID, callerID: NodeID) { + def visit(v: NodeID, callerID: NodeID): Unit = { if (inStack(v)) { val cycle = backtrackToFindCycle(callerID, Seq(v)) cycleFound = Some(cycle) @@ -61,7 +61,7 @@ object TopologicalSort { inStack.remove(v) } } - bg.nodeRange foreach { startingID => visit(startingID, startingID) } + bg.nodeRange() foreach { startingID => visit(startingID, startingID) } cycleFound } } diff --git a/src/main/scala/Util.scala b/src/main/scala/Util.scala index bb40a08f..d445e795 100644 --- a/src/main/scala/Util.scala +++ b/src/main/scala/Util.scala @@ -7,12 +7,12 @@ import java.io.Writer object Util { // Given an array, returns a map of value to all indices that had that value (CAM-like) def groupIndicesByValue[T](a: ArrayBuffer[T]): Map[T, Seq[Int]] = { - a.zipWithIndex.groupBy{ _._1 }.mapValues{ v => v.toSeq map { _._2 }} + a.zipWithIndex.groupBy{ _._1 }.view.mapValues{ v => v.toSeq map { _._2 }}.toMap } // Given a list of pairs, returns a map of value of first element to all second values (CAM-like) def groupByFirst[T,Y](l: Seq[(T,Y)]): Map[T, Seq[Y]] = { - l.groupBy{ _._1 }.mapValues{ v => v map { _._2 }} + l.groupBy{ _._1 }.view.mapValues{ v => v map { _._2 }}.toMap } def selectFromMap[K,V](selectors: Seq[K], targetMap: Map[K,V]): Map[K,V] = { @@ -26,7 +26,7 @@ object Util { str filter { !charsToRemove.contains(_) } } - def sortHashMapValues[K](hm: HashMap[K,Seq[Int]]) { + def sortHashMapValues[K](hm: HashMap[K,Seq[Int]]): Unit = { hm.keys foreach { k => hm(k) = hm(k).sorted } } diff --git a/src/main/scala/Vcd.scala b/src/main/scala/Vcd.scala index cefdabce..c4b6d141 100644 --- a/src/main/scala/Vcd.scala +++ b/src/main/scala/Vcd.scala @@ -18,11 +18,11 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { val opt = initopt val topName = circuit.main val sg = StatementGraph(circuit, opt.removeFlatConnects) - val allNamesAndTypes = sg.stmtsOrdered flatMap findStmtNameAndType + val allNamesAndTypes = sg.stmtsOrdered() flatMap findStmtNameAndType var hashMap = HashMap[String,String]() var last_used_index = BigInt(1) - def displayNameIdentifierSize(m: Module, topName: String) { + def displayNameIdentifierSize(m: Module, topName: String): Unit = { val registers = findInstancesOf[DefRegister](m.body) val memories = findInstancesOf[DefMemory](m.body) var depth = 0 @@ -71,7 +71,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { } } - def declareOldValues(m: Module) { + def declareOldValues(m: Module): Unit = { val registers = findInstancesOf[DefRegister](m.body) val registerDecs = registers map { r: DefRegister => s"${genCppType(r.tpe)} ${rn.vcdOldValue(r.name)};" @@ -88,7 +88,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { w.writeLines(1, portDecs) } - def compareOldNewSignal(m: Module) { + def compareOldNewSignal(m: Module): Unit = { val registers = findInstancesOf[DefRegister](m.body) val registerComps = registers map { r: DefRegister => compSig(r.name, rn.vcdOldValue(r.name)) @@ -106,7 +106,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { w.writeLines(2, portComps) } - def assignOldValue(m: Module) { + def assignOldValue(m: Module): Unit = { val registers = findInstancesOf[DefRegister](m.body) val registerAssigns = registers map { r: DefRegister => s"${rn.vcdOldValue(r.name)} = ${r.name};" @@ -128,7 +128,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { } //function for vcd multiple hierarchy - def hierScope(allNamesAndTypes: Seq[(String, Type)],splitted: Seq[Seq[String]], indentlevel: Int, iden_code_hier: String) { + def hierScope(allNamesAndTypes: Seq[(String, Type)],splitted: Seq[Seq[String]], indentlevel: Int, iden_code_hier: String): Unit = { // This groups returns a Map( key -> Seq[Seq[String]] val grouped = splitted groupBy {_.head } @@ -170,7 +170,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { case m: Module => if (m.name == topName) declareOldValues(m) case m: ExtModule => Seq() } - allNamesAndTypes map { case(name, tpe) => + allNamesAndTypes foreach { case(name, tpe) => if (localSignalToTrack(name)) w.writeLines(1, s"""${genCppType(tpe)} ${rn.vcdOldValue(rn.removeDots(name))};""") } @@ -224,10 +224,10 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { case m: Module => displayNameIdentifierSize(m, topName) case m: ExtModule => Seq() } - val name = sg.stmtsOrdered flatMap findResultName + val name = sg.stmtsOrdered() flatMap findResultName val debug_name = name map { n => if ( !n.contains(".")) n else ""} var up_index = last_used_index - debug_name.zipWithIndex map { case(sn, index ) => { + debug_name.zipWithIndex foreach { case(sn, index ) => { val iden_code = genIdenCode(index + last_used_index) val sig_name = rn.removeDots(sn) if ( !hashMap.contains(sig_name)) { @@ -237,7 +237,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { last_used_index = up_index val non_und_name = name map { n => if (!n.contains("._") && !n.contains("$next") && n.contains(".")) n else "" } val splitted = non_und_name map { _.split('.').toSeq} - non_und_name.zipWithIndex map { case(sn , index ) => { + non_und_name.zipWithIndex foreach { case(sn , index ) => { val sig_name = rn.removeDots(sn) val iden_code = genIdenCode(index + last_used_index) hashMap(sig_name) = iden_code @@ -250,7 +250,7 @@ class Vcd(circuit: Circuit, initopt: OptFlags, w: Writer, rn: Renamer) { w.writeLines(0, "") } - def writeFprintf(s: String) { + def writeFprintf(s: String): Unit = { w.writeLines(2,s"fprintf(outfile,$s);") } diff --git a/src/test/scala/AcyclicPartTest.scala b/src/test/scala/AcyclicPartTest.scala index b43ba9a2..c6a3814b 100644 --- a/src/test/scala/AcyclicPartTest.scala +++ b/src/test/scala/AcyclicPartTest.scala @@ -44,65 +44,65 @@ class AcyclicPartSpec extends AnyFlatSpec { } "An AcyclicPart" should "be built from a Graph" in { - val ap = AcyclicPart(buildStartingGraph1) + val ap = AcyclicPart(buildStartingGraph1()) assertResult(ArrayBuffer(0,1,2,3,4,5,6,7)){ ap.mg.idToMergeID } } it should "coarsen by MFFCs" in { val expected = Map((2,Seq(0,1,2)), (4,Seq(3,4)), (5,Seq(5)), (7,Seq(6,7))) - val ap = AcyclicPart(buildStartingGraph1) + val ap = AcyclicPart(buildStartingGraph1()) ap.coarsenWithMFFCs() assertResult(ArrayBuffer(2,2,2,4,4,5,7,7)){ ap.mg.idToMergeID } - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } it should "coarsen by MFFCs w/ exclude set" in { val expected = Map((1,Seq(0,1)), (2,Seq(2)), (3,Seq(3)), (4,Seq(4)), (5,Seq(5)), (6,Seq(6)), (7,Seq(7))) - val ap = AcyclicPart(buildStartingGraph1, Set(2,4,6)) + val ap = AcyclicPart(buildStartingGraph1(), Set(2,4,6)) ap.coarsenWithMFFCs() assertResult(ArrayBuffer(1,1,2,3,4,5,6,7)){ ap.mg.idToMergeID } - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } // TODO: should actually test smallZoneCutoff argument it should "merge single-input partitions into their parents" in { val expected = Map((0,Seq(0,1,8)), (2,Seq(2,3,4)), (5,Seq(5)), (6,Seq(6)), (7,Seq(7))) - val ap = AcyclicPart(buildStartingGraph2) + val ap = AcyclicPart(buildStartingGraph2()) ap.mergeSingleInputPartsIntoParents() assertResult(ArrayBuffer(0,0,2,2,2,5,6,7,0)){ ap.mg.idToMergeID } Util.sortHashMapValues(ap.mg.mergeIDToMembers) - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } it should "merge single-input partitions into their parents w/ exclude set" in { val expected = Map((0,Seq(0)), (1,Seq(1,8)), (2,Seq(2,4)), (3,Seq(3)), (5,Seq(5)), (6,Seq(6)), (7,Seq(7))) - val ap = AcyclicPart(buildStartingGraph2, Set(0,3,6)) + val ap = AcyclicPart(buildStartingGraph2(), Set(0,3,6)) ap.mergeSingleInputPartsIntoParents() assertResult(ArrayBuffer(0,1,2,3,2,5,6,7,1)){ ap.mg.idToMergeID } - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } it should "merge single-input MFFCs with their parents" in { val expected = Map((4,Seq(0,1,2,3,4,5)), (7,Seq(6,7))) - val ap = AcyclicPart(buildStartingGraph1) + val ap = AcyclicPart(buildStartingGraph1()) ap.coarsenWithMFFCs() ap.mergeSingleInputPartsIntoParents() assertResult(ArrayBuffer(4,4,4,4,4,4,7,7)){ ap.mg.idToMergeID } Util.sortHashMapValues(ap.mg.mergeIDToMembers) - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } it should "merge single-input MFFCs with their parents w/ exclude set" in { val expected = Map((0,Seq(0)), (1,Seq(1)), (2,Seq(2)), (4,Seq(3,4,5)), (7,Seq(6,7))) - val ap = AcyclicPart(buildStartingGraph1, Set(1)) + val ap = AcyclicPart(buildStartingGraph1(), Set(1)) ap.coarsenWithMFFCs() ap.mergeSingleInputPartsIntoParents() assertResult(ArrayBuffer(0,1,2,4,4,4,7,7)){ ap.mg.idToMergeID } Util.sortHashMapValues(ap.mg.mergeIDToMembers) - assertResult(expected){ ap.iterParts } + assertResult(expected){ ap.iterParts() } } } diff --git a/src/test/scala/GraphTest.scala b/src/test/scala/GraphTest.scala index d36ebcee..45e931a9 100644 --- a/src/test/scala/GraphTest.scala +++ b/src/test/scala/GraphTest.scala @@ -6,34 +6,34 @@ class GraphSpec extends AnyFlatSpec { "A Graph" should "grow as necessary for new edges" in { val g = new Graph g.addEdge(0,1) - assertResult(2) { g.numNodes } - assertResult(1) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(1) { g.numEdges() } g.addEdge(2,4) - assertResult(5) { g.numNodes } - assertResult(2) { g.numEdges } + assertResult(5) { g.numNodes() } + assertResult(2) { g.numEdges() } } it should "not add duplicate edges (if requested)" in { val g = new Graph g.addEdgeIfNew(0,1) - assertResult(2) { g.numNodes } - assertResult(1) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(1) { g.numEdges() } g.addEdgeIfNew(0,1) - assertResult(2) { g.numNodes } - assertResult(1) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(1) { g.numEdges() } } it should "remove duplicate edges from graph" in { val g = new Graph g.addEdge(0,1) - assertResult(2) { g.numNodes } - assertResult(1) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(1) { g.numEdges() } g.addEdge(0,1) - assertResult(2) { g.numNodes } - assertResult(2) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(2) { g.numEdges() } g.removeDuplicateEdges() - assertResult(2) { g.numNodes } - assertResult(1) { g.numEdges } + assertResult(2) { g.numNodes() } + assertResult(1) { g.numEdges() } } it should "be able to merge nodes mutably" in { diff --git a/src/test/scala/MergeGraphTest.scala b/src/test/scala/MergeGraphTest.scala index f0c18bca..b67dd7ac 100644 --- a/src/test/scala/MergeGraphTest.scala +++ b/src/test/scala/MergeGraphTest.scala @@ -23,7 +23,7 @@ class MergeGraphSpec extends AnyFlatSpec { "A MergeGraph" should "be built from a Graph with initialAssignments" in { val mg = MergeGraph(buildStartingGraph(), initialAssignments) assert(mg.idToMergeID == initialAssignments) - assert(mg.iterGroups == Map( + assert(mg.iterGroups() == Map( (1,Seq(0,1,2)), (3,Seq(3)), (4,Seq(4)), (6,Seq(5,6)))) assert(mg.outNeigh(0).isEmpty) assert(mg.outNeigh(1) == Seq(6)) @@ -45,7 +45,7 @@ class MergeGraphSpec extends AnyFlatSpec { val mg = MergeGraph(buildStartingGraph()) mg.applyInitialAssignments(initialAssignments) assert(mg.idToMergeID == initialAssignments) - assert(mg.iterGroups == Map( + assert(mg.iterGroups() == Map( (1,Seq(0,1,2)), (3,Seq(3)), (4,Seq(4)), (6,Seq(5,6)))) assert(mg.outNeigh(0).isEmpty) assert(mg.outNeigh(1) == Seq(6)) @@ -67,7 +67,7 @@ class MergeGraphSpec extends AnyFlatSpec { val mg = buildStartingMG() mg.mergeGroups(6, Seq(1)) assert(mg.idToMergeID == ArrayBuffer(6,6,6,3,4,6,6)) - assert(mg.iterGroups == Map( + assert(mg.iterGroups() == Map( (3,Seq(3)), (4,Seq(4)), (6,Seq(5,6,0,1,2)))) assert(mg.outNeigh(0).isEmpty) assert(mg.outNeigh(1).isEmpty) @@ -90,7 +90,7 @@ class MergeGraphSpec extends AnyFlatSpec { val mg = buildStartingMG() mg.mergeGroups(1, Seq(3)) assert(mg.idToMergeID == ArrayBuffer(1,1,1,1,4,6,6)) - assert(mg.iterGroups == Map( + assert(mg.iterGroups() == Map( (1,Seq(0,1,2,3)), (4,Seq(4)), (6,Seq(5,6)))) assert(mg.outNeigh(0).isEmpty) assert(mg.outNeigh(1) == Seq(6)) @@ -112,7 +112,7 @@ class MergeGraphSpec extends AnyFlatSpec { val mg = buildStartingMG() mg.mergeGroups(1, Seq(4)) assert(mg.idToMergeID == ArrayBuffer(1,1,1,3,1,6,6)) - assert(mg.iterGroups == Map( + assert(mg.iterGroups() == Map( (1,Seq(0,1,2,4)), (3,Seq(3)), (6,Seq(5,6)))) assert(mg.outNeigh(0).isEmpty) assert(mg.outNeigh(1) == Seq(6)) diff --git a/src/test/scala/ReplaceRsvdKeyTest.scala b/src/test/scala/ReplaceRsvdKeyTest.scala index dffab6ed..0e6aba42 100644 --- a/src/test/scala/ReplaceRsvdKeyTest.scala +++ b/src/test/scala/ReplaceRsvdKeyTest.scala @@ -10,13 +10,13 @@ import scala.io.Source class ReplaceRsvdKeyTest extends AnyFlatSpec{ "Mypass" should "Replace all reserve keyword" in { val sourceReader = Source.fromURL(getClass.getResource("/ReplacedRsvdKey.fir")) - val circuit = firrtl.Parser.parse(sourceReader.getLines, firrtl.Parser.IgnoreInfo) + val circuit = firrtl.Parser.parse(sourceReader.getLines(), firrtl.Parser.IgnoreInfo) sourceReader.close() val deps = firrtl.stage.Forms.LowFormOptimized ++ Seq(Dependency(ReplaceRsvdKeywords)) val firrtlCompiler = new firrtl.stage.transforms.Compiler(deps) val resultState = firrtlCompiler.execute(CircuitState(circuit, Seq())) val CorrectReader = Source.fromURL(getClass.getResource("/ReplacedRsvdKey_correct.fir")) - val correctString = CorrectReader.getLines.mkString("\n") + val correctString = CorrectReader.getLines().mkString("\n") assert(correctString == resultState.circuit.serialize) } } diff --git a/src/test/scala/StatementGraphTest.scala b/src/test/scala/StatementGraphTest.scala index ffb2b9b9..b22a95f8 100644 --- a/src/test/scala/StatementGraphTest.scala +++ b/src/test/scala/StatementGraphTest.scala @@ -8,38 +8,38 @@ class StatementGraphSpec extends AnyFlatSpec { "A NamedGraph" should "grow as necessary for new edges" in { val sg = new StatementGraph sg.addEdge("alpha", "beta") - assertResult(0) { sg.numValidNodes } - assertResult(2) { sg.numNodeRefs } - assertResult(1) { sg.numEdges } + assertResult(0) { sg.numValidNodes() } + assertResult(2) { sg.numNodeRefs() } + assertResult(1) { sg.numEdges() } sg.addEdge("gamma", "zeta") - assertResult(0) { sg.numValidNodes } - assertResult(4) { sg.numNodeRefs } - assertResult(2) { sg.numEdges } + assertResult(0) { sg.numValidNodes() } + assertResult(4) { sg.numNodeRefs() } + assertResult(2) { sg.numEdges() } } it should "not add duplicate edges (if requested)" in { val sg = new StatementGraph sg.addEdgeIfNew("alpha", "beta") - assertResult(0) { sg.numValidNodes } - assertResult(2) { sg.numNodeRefs } - assertResult(1) { sg.numEdges } + assertResult(0) { sg.numValidNodes() } + assertResult(2) { sg.numNodeRefs() } + assertResult(1) { sg.numEdges() } sg.addEdgeIfNew("alpha", "beta") - assertResult(0) { sg.numValidNodes } - assertResult(2) { sg.numNodeRefs } - assertResult(1) { sg.numEdges } + assertResult(0) { sg.numValidNodes() } + assertResult(2) { sg.numNodeRefs() } + assertResult(1) { sg.numEdges() } } it should "be buildable from hyperedges" in { val sg = new StatementGraph sg.addStatementNode("child", Seq("parent0","parent1")) - assertResult(1) { sg.numValidNodes } - assertResult(3) { sg.numNodeRefs } - assertResult(2) { sg.numEdges } + assertResult(1) { sg.numValidNodes() } + assertResult(3) { sg.numNodeRefs() } + assertResult(2) { sg.numEdges() } assert(sg.idToStmt(sg.nameToID("child")) == EmptyStmt) sg.addStatementNode("sibling", Seq("parent0","parent1"), Block(Seq())) - assertResult(2) { sg.numValidNodes } - assertResult(4) { sg.numNodeRefs } - assertResult(4) { sg.numEdges } + assertResult(2) { sg.numValidNodes() } + assertResult(4) { sg.numNodeRefs() } + assertResult(4) { sg.numEdges() } assert(sg.idToStmt(sg.nameToID("sibling")) == Block(Seq())) } @@ -104,6 +104,6 @@ class StatementGraphSpec extends AnyFlatSpec { it should "be able to handle a 1 node graph with no edges" in { val stmt = DefNode(NoInfo,"dummy",UIntLiteral(0,IntWidth(1))) val sg = StatementGraph(Seq(stmt)) - assertResult(Seq(stmt)) { sg.stmtsOrdered } + assertResult(Seq(stmt)) { sg.stmtsOrdered() } } }