Skip to content

Commit

Permalink
Fix return (#46)
Browse files Browse the repository at this point in the history
* refactor

* add benchmark

* a simpler test

* lame fix (#47)

* use an index to mark the return cont

* refactor and nicify things

* further refactor

* rename file

---------

Co-authored-by: ahuoguo <ahuoguo@gmail.com>
Co-authored-by: ahuoguo <52595524+ahuoguo@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 4, 2024
1 parent cc32b0d commit f0e2046
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 107 deletions.
18 changes: 18 additions & 0 deletions benchmarks/wasm/even_odd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#[no_mangle]
#[inline(never)]
fn is_even(n: u32) -> bool {
if n == 0 { true }
else { is_odd(n - 1) }
}

#[no_mangle]
#[inline(never)]
fn is_odd(n: u32) -> bool {
if n == 0 { false }
else { is_even(n - 1) }
}

#[no_mangle]
fn real_main() -> bool {
is_even(12)
}
31 changes: 31 additions & 0 deletions benchmarks/wasm/even_odd.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
(module
(type (;0;) (func (param i32) (result i32)))
(type (;1;) (func (result i32)))
(func (;0;) (type 0) (param i32) (result i32)
block ;; label = @1
local.get 0
br_if 0 (;@1;)
i32.const 1
return
end
local.get 0
i32.const -1
i32.add
call 1)
(func (;1;) (type 0) (param i32) (result i32)
block ;; label = @1
local.get 0
br_if 0 (;@1;)
i32.const 0
return
end
local.get 0
i32.const -1
i32.add
call 0)
(func (;2;) (type 1) (result i32)
i32.const 13
call 1)
(start 2)
(memory (;0;) 16)
)
13 changes: 13 additions & 0 deletions benchmarks/wasm/return.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(type (;0;) (func))
(func (;0;) (type 0)
block ;; label = @1
return
end
unreachable
)
(func (;1;) (type 0)
call 0
)
(start 1)
)
143 changes: 97 additions & 46 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value])
object Evaluator {
import Primtives._

type RetCont = List[Value] => Unit
type Cont = List[Value] => Unit
type Cont[A] = List[Value] => A

def eval(insts: List[Instr],
stack: List[Value],
frame: Frame,
trail: List[Cont])
(implicit kont: Cont): Unit = {
def eval[Ans](insts: List[Instr],
stack: List[Value],
frame: Frame,
kont: Cont[Ans],
trail: List[Cont[Ans]],
ret: Int): Ans = {
if (insts.isEmpty) return kont(stack)

val inst = insts.head
Expand All @@ -184,23 +184,23 @@ object Evaluator {
//println(f"stack size: ${stack.size}")
//println(s"eval: $inst")
inst match {
case Drop => eval(rest, stack.tail, frame, trail)
case Drop => eval(rest, stack.tail, frame, kont, trail, ret)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
eval(rest, value :: newStack, frame, trail)
eval(rest, value :: newStack, frame, kont, trail, ret)
case LocalGet(i) =>
eval(rest, frame.locals(i) :: stack, frame, trail)
eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, stack, frame, trail)
eval(rest, stack, frame, kont, trail, ret)
case GlobalGet(i) =>
eval(rest, frame.module.globals(i).value :: stack, frame, trail)
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret)
case GlobalSet(i) =>
val value :: newStack = stack
frame.module.globals(i).ty match {
Expand All @@ -209,111 +209,162 @@ object Evaluator {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case MemorySize =>
eval(rest,
I32V(frame.module.memory.head.size) :: stack,
frame,
trail)
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = frame.module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, trail)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, trail)
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(frame, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
frame.module.memory.head.fill(offset, size, value.toByte)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(frame, 0, src, n) || memOutOfBound(frame, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
frame.module.memory.head.copy(dest, src, n)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
}
case Const(n) => eval(rest, n :: stack, frame, trail)
case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, trail)
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Unary(op) =>
val v :: newStack = stack
eval(rest, evalUnaryOp(op, v) :: newStack, frame, trail)
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, trail)
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Test(op) =>
val v :: newStack = stack
eval(rest, evalTestOp(op, v) :: newStack, frame, trail)
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
frame.module.memory(0).storeInt(addr + offset, v)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = frame.module.memory(0).loadInt(addr + offset)
eval(rest, I32V(value) :: newStack, frame, trail)
eval(rest, I32V(value) :: newStack, frame, kont, trail, ret)
case Nop =>
eval(rest, stack, frame, trail)
eval(rest, stack, frame, kont, trail, ret)
case Unreachable => throw new RuntimeException("Unreachable")
case Block(ty, inner) =>
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail)
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
// TODO: block can take inputs too
eval(inner, List(), frame, k :: trail)(k)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail)
def loop(stack: List[Value]): Unit = {
val k: Cont = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, k :: trail)(restK)
val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
def loop(stack: List[Value]): Ans = {
val k: Cont[Ans] = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, restK, k :: trail, ret+1)
}
loop(List())
case If(ty, thn, els) =>
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, trail)
eval(inner, List(), frame, k :: trail)(k)
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Br(label) =>
trail(label)(stack)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
if (cond != 0) trail(label)(newStack)
else eval(rest, newStack, frame, trail)
case Return => kont(stack)
else eval(rest, newStack, frame, kont, trail, ret)
case Return => trail(ret)(stack)
case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] =>
val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f)
val args = stack.take(ty.inps.size).reverse
val newStack = stack.drop(ty.inps.size)
val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types?
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
val newK: RetCont = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, trail)
val newK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
// We push newK on the trail since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, newK :: trail)(newK)
eval(body, List(), newFrame, newK, newK :: trail, ret+1)
case Call(f) if frame.module.funcs(f).isInstanceOf[Import] =>
frame.module.funcs(f) match {
case Import("console", "log", _) =>
//println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case f => throw new Exception(s"Unknown import $f")
}
case _ =>
println(inst)
throw new Exception(s"instruction $inst not implemented")
}
}

// If `main` is given, then we use that function as the entry point of the program;
// otherwise, we look up the top-level `start` instruction to locate the entry point.
def evalTop[Ans](module: Module, halt: Cont[Ans], main: Option[String] = None): Ans = {
val instrs = main match {
case Some(_) => module.definitions.flatMap({
case FuncDef(`main`, FuncBodyDef(_, _, _, body)) =>
println(s"Entering function $main")
body
case _ => List()
})
case None => module.definitions.flatMap({
case Start(id) => module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) =>
println(s"Entering unnamed function $id")
body
case _ => throw new Exception("Start function has no concrete definition")
}
case _ => List()
})
}

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0)
}

def evalTop(m: Module): Unit = evalTop(m, stack => ())
}
File renamed without changes.
Loading

0 comments on commit f0e2046

Please sign in to comment.