Skip to content

Commit

Permalink
Revert "refactor: replace.lean"
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer authored Jun 26, 2024
1 parent a0f240c commit b5f4d49
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions Qpf/Macro/Data/Replace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ structure CtorArgs where

/- TODO(@William): make these correspond by combining expr and vars into a product -/
structure Replace where
(vals: Array (Name × Term))
(expr: Array Term)
(vars: Array Name)
(ctor: CtorArgs)

def Replace.vars (r : Replace): Array Name := r.vals.map Prod.fst
def Replace.expr (r : Replace): Array Term := r.vals.map Prod.snd

variable (m) [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m]
[AddMessageContext m] [MonadLiftT IO m]
Expand All @@ -33,13 +32,13 @@ private abbrev ReplaceM := StateT Replace m
variable {m}

private def Replace.new : m Replace :=
do pure ⟨#[], ⟨#[], #[]⟩⟩
do pure ⟨#[], #[], ⟨#[], #[]⟩⟩

private def CtorArgs.reset : ReplaceM m Unit := do
let r ← StateT.get
let n := r.vars.size
let ctor: CtorArgs := ⟨#[], (Array.range n).map fun _ => #[]⟩
StateT.set { r with ctor }
StateT.set ⟨r.expr, r.vars, ctor

private def CtorArgs.get : ReplaceM m CtorArgs := do
pure (←StateT.get).ctor
Expand All @@ -48,7 +47,7 @@ private def CtorArgs.get : ReplaceM m CtorArgs := do
The arity of the shape type created *after* replacing, i.e., the size of `r.expr`
-/
def Replace.arity (r : Replace) : Nat :=
r.vals.size
r.expr.size

def Replace.getBinderIdents (r : Replace) : Array Ident :=
r.vars.map mkIdent
Expand All @@ -69,19 +68,19 @@ private def ReplaceM.identFor (stx : Term) : ReplaceM m Ident := do
let r ← StateT.get
let ctor := r.ctor
let argName ← mkFreshBinderName
let args := ctor.args.push argName
let ctor_args := ctor.args.push argName
-- dbg_trace "\nstx := {stx}\nr.expr := {r.expr}"

let name ← match r.expr.indexOf? stx with
| some id => do
let per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName
let ctor := { args, per_type }
StateT.set { r with ctor }
| some id => do
let ctor_per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName
let ctor := ⟨ctor_args, ctor_per_type⟩
StateT.set ⟨r.expr, r.vars, ctor
pure $ r.vars.get! id
| none => do
let per_type := ctor.per_type.push #[argName]
let ctor_per_type := ctor.per_type.push #[argName]
let name ← mkFreshBinderName
StateT.set { vals := r.vals.push (name, stx), ctor := { args, per_type } }
StateT.set ⟨r.expr.push stx, r.vars.push name, ⟨ctor_args, ctor_per_type⟩⟩
pure name

return mkIdent name
Expand Down Expand Up @@ -119,6 +118,11 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m
| _ =>
pure res_type

-- TODO: this should be deprecated in favour of {v with ...} syntax
def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := {
ctor with type?
}

/-
TODO: currently these functions ignore dead variables, everything is replaced.
This is OK, we can supply a "dead" value to a live variable, but we lose the ability to have
Expand Down Expand Up @@ -201,7 +205,7 @@ Replace.run <| do

let type? ← ctor.type?.mapM $ shapeOf'

pure ({ ctor with type? }, ←CtorArgs.get)
pure $ (CtorView.withType? ctor type?, ←CtorArgs.get)

let r ← StateT.get
let ctors := pairs.map Prod.fst;
Expand All @@ -212,8 +216,8 @@ Replace.run <| do

-- HACK: It seems that `Array.append` causes a stack overflow, so we go through `List` for now
-- TODO: fix this after updating to newer Lean version
let per_type := per_type.appendList $ List.replicate diff (#[] : Array Name)
{ ctorArgs with per_type }
let per_type := per_type.appendList $ (List.range diff).map (fun _ => (#[] : Array Name));
ctorArgs.args, per_type

-- Now that we know how many free variables were introduced, we can fix up the resulting type
-- of each constructor to be `Shape α_0 α_1 ... α_n`
Expand All @@ -222,19 +226,20 @@ Replace.run <| do

let ctors ← ctors.mapM fun ctor => do
let type? ← ctor.type?.mapM (setResultingType res)
pure { ctor with type? }
pure $ CtorView.withType? ctor type?

pure (ctors, ctorArgs)




/-- Replace syntax in *all* subexpressions -/
partial def Replace.replaceAllStx (find replace stx : Syntax) : Syntax :=
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))
partial def Replace.replaceAllStx (find replace : Syntax) : Syntax → Syntax :=
fun stx =>
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))



Expand Down Expand Up @@ -287,7 +292,7 @@ def makeNonRecursive (view : DataView) : MetaM (DataView × Name) := do

let ctors ← view.ctors.mapM fun ctor => do
let type? ← ctor.type?.mapM (Replace.replaceStx expected recId <| TSyntax.mk ·)
pure { ctor with type? }
return CtorView.withType? ctor type?

let view := view.setCtors ctors
pure (view, rec)

0 comments on commit b5f4d49

Please sign in to comment.