From 18122a4e1006704f7476e2fd26ae34b93afea102 Mon Sep 17 00:00:00 2001 From: ARATA Mizuki Date: Wed, 11 Sep 2024 18:08:30 +0900 Subject: [PATCH] Implement eta conversion of continuations --- Makefile | 1 + src/cps/eta.sml | 106 +++++++++++++++++++++++++++++++++++++++++ src/lunarml-common.mlb | 1 + src/main.sml | 1 + 4 files changed, 109 insertions(+) create mode 100644 src/cps/eta.sml diff --git a/Makefile b/Makefile index 2b7436e..6c9c413 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,7 @@ sources = \ src/cps/inline.sml \ src/cps/decompose-recursive.sml \ src/cps/unpack-record-parameter.sml \ + src/cps/eta.sml \ src/nested.sml \ src/lua-syntax.sml \ src/lua-transform.sml \ diff --git a/src/cps/eta.sml b/src/cps/eta.sml new file mode 100644 index 0000000..090a010 --- /dev/null +++ b/src/cps/eta.sml @@ -0,0 +1,106 @@ +(* + * Copyright (c) 2024 ARATA Mizuki + * This file is part of LunarML. + *) +structure CpsEtaConvert: +sig + val go: CpsSimplify.Context * CSyntax.CExp -> CSyntax.CExp +end = +struct + structure C = CSyntax + fun goCont (env, cont) = + (case C.CVarMap.find (env, cont) of + SOME cont' => cont' + | NONE => cont) + fun goDec (dec, (env, acc)) = + case dec of + C.ValDec + {exp = C.Abs {contParam, params, body, attr}, results as [SOME _]} => + let + (* Eta conversion of a function is not implemented yet *) + val dec' = C.ValDec + { exp = C.Abs + { contParam = contParam + , params = params + , body = goFunction body + , attr = attr + } + , results = results + } + in + (env, dec' :: acc) + end + | C.ValDec {exp = _, results = _} => (env, dec :: acc) + | C.RecDec defs => + let + val defs = + List.map + (fn {name, contParam, params, body, attr} => + { name = name + , contParam = contParam + , params = params + , body = goFunction body + , attr = attr + }) defs + in + (env, C.RecDec defs :: acc) + end + | C.ContDec {name, params, body as C.AppCont {applied, args}} => + (* Eta conversion *) + if + ListPair.allEq (fn (SOME p, C.Var q) => p = q | _ => false) + (params, args) + then + (C.CVarMap.insert (env, name, goCont (env, applied)), acc) + else + ( env + , C.ContDec {name = name, params = params, body = goCExp (env, body)} + :: acc + ) + | C.ContDec {name, params, body} => + ( env + , C.ContDec {name = name, params = params, body = goCExp (env, body)} + :: acc + ) + | C.RecContDec defs => + let + val defs = + List.map + (fn (name, params, body) => (name, params, goCExp (env, body))) + defs + in + (env, C.RecContDec defs :: acc) + end + | C.ESImportDec _ => (env, dec :: acc) + and goCExp (env, exp) = + case exp of + C.Let {decs, cont} => + let val (env, revDecs) = List.foldl goDec (env, []) decs + in C.Let {decs = List.rev revDecs, cont = goCExp (env, cont)} + end + | C.App {applied, cont, args, attr} => + C.App + { applied = applied + , cont = goCont (env, cont) + , args = args + , attr = attr + } + | C.AppCont {applied, args} => + C.AppCont {applied = goCont (env, applied), args = args} + | C.If {cond, thenCont, elseCont} => + C.If + { cond = cond + , thenCont = goCExp (env, thenCont) + , elseCont = goCExp (env, elseCont) + } + | C.Handle {body, handler = (e, h), successfulExitIn, successfulExitOut} => + C.Handle + { body = goFunction body + , handler = (e, goCExp (env, h)) + , successfulExitIn = successfulExitIn + , successfulExitOut = goCont (env, successfulExitOut) + } + | C.Unreachable => exp + and goFunction exp = goCExp (C.CVarMap.empty, exp) + fun go (_: CpsSimplify.Context, exp) = goFunction exp +end; (* structure CpsEtaConvert *) diff --git a/src/lunarml-common.mlb b/src/lunarml-common.mlb index e9529af..db21a1c 100644 --- a/src/lunarml-common.mlb +++ b/src/lunarml-common.mlb @@ -47,6 +47,7 @@ cps/ref-cell.sml cps/inline.sml cps/decompose-recursive.sml cps/unpack-record-parameter.sml +cps/eta.sml nested.sml lua-transform.sml codegen-lua.sml diff --git a/src/main.sml b/src/main.sml index 6998f8e..4ad6541 100644 --- a/src/main.sml +++ b/src/main.sml @@ -112,6 +112,7 @@ struct val cexp = CpsDecomposeRecursive.goCExp (ctx', cexp) val cexp = CpsConstantRefCell.goCExp (ctx', cexp) val cexp = CpsInline.goCExp (ctx', cexp) + val cexp = CpsEtaConvert.go (ctx', cexp) in if #printTimings ctx then print