Skip to content

Commit

Permalink
add tree shaking (aka dead function removal)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertmuth committed May 15, 2024
1 parent 79ca7fc commit 8f679f9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
65 changes: 65 additions & 0 deletions FrontEnd/dead_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import argparse
import dataclasses
import collections

from typing import Union, Any, Optional, Set, Dict, List

from FrontEnd import cwast


logger = logging.getLogger(__name__)

_Hell = cwast.DefFun("hell", [], cwast.TypeBase(cwast.BASE_TYPE_KIND.VOID), [])


def ShakeTree(mods: List[cwast.DefMod]):
# callgraph - map fun to its callers
cg: Dict[cwast.DefFun, Set[cwast.DefFun]] = collections.defaultdict(set)
cg[_Hell].add(_Hell) # force hell to be alive

def visitor(call, parents):
nonlocal cg
if isinstance(call, cwast.ExprCall):
if isinstance(call.callee, cwast.Id):
callee = call.callee.x_symbol
assert isinstance(
callee, cwast.DefFun), f"expected fun: {call} {callee}"
caller = parents[0]
if caller is not callee:
logging.info(f"@@@@ {caller.name} -> {callee.name}")
cg[callee].add(caller)

# create call graph
for m in mods:
for fun in m.body_mod:
if not isinstance(fun, cwast.DefFun):
continue
if fun.init or fun.fini or fun.ref or fun.cdecl and fun.name == "main":
cg[fun].add(_Hell)
else:
# make sure the function is recorded
dummy = cg[fun]
cwast.VisitAstRecursivelyWithAllParents(fun, [], visitor)

# compute dead functions
change = True
while change:
dead_funs: List[cwast.DefFun] = []
for fun, callers in cg.items():
has_live_caller = False
for c in callers:
if c in cg:
has_live_caller = True
break
if not has_live_caller:
logging.info(f"@@@ DEAD: {fun.name}")
dead_funs.append(fun)
for d in dead_funs:
del cg[d]
change = len(dead_funs) > 0
for fun in cg:
logging.info(f"@@@ ALIVE: {fun.name}")
for m in mods:
new_body = [f for f in m.body_mod if not isinstance(f, cwast.DefFun) or f in cg]
m.body_mod = new_body
8 changes: 7 additions & 1 deletion FrontEnd/emit_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from FrontEnd import identifier
from FrontEnd import pp_html
from FrontEnd import mod_pool
from FrontEnd import dead_code

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -919,6 +920,8 @@ def SanityCheckMods(phase_name: str, emit_ir: str, mods: list[cwast.DefMod], tc,

def main() -> int:
parser = argparse.ArgumentParser(description='pretty_printer')
parser.add_argument("-shake_tree",
action="store_true", help='remove unreachable functions')
parser.add_argument(
'-arch', help='architecture to generated IR for', default="x64")
parser.add_argument(
Expand Down Expand Up @@ -964,7 +967,6 @@ def main() -> int:
eliminated_nodes.add(cwast.ExprStringify)
eliminated_nodes.add(cwast.EphemeralList)
eliminated_nodes.add(cwast.ModParam)

for mod in mod_topo_order:
cwast.CheckAST(mod, eliminated_nodes)

Expand All @@ -975,6 +977,9 @@ def main() -> int:
for mod in mod_topo_order:
typify.VerifyTypesRecursively(mod, tc, verifier)

if args.shake_tree:
dead_code.ShakeTree(mod_topo_order)

logger.info("partial eval and static assert validation")
eval.DecorateASTWithPartialEvaluation(mod_topo_order)

Expand Down Expand Up @@ -1172,6 +1177,7 @@ def visitor(n, _):
EmitIRDefFun(node, tc, identifier.IdGenIR(node.name))
return 0


if __name__ == "__main__":
# import cProfile
# cProfile.run('main()')
Expand Down

0 comments on commit 8f679f9

Please sign in to comment.