diff --git a/Qpf/Macro/Data/Replace.lean b/Qpf/Macro/Data/Replace.lean index 9120cee..89de3bb 100644 --- a/Qpf/Macro/Data/Replace.lean +++ b/Qpf/Macro/Data/Replace.lean @@ -19,10 +19,11 @@ structure CtorArgs where /- TODO(@William): make these correspond by combining expr and vars into a product -/ structure Replace where - (expr: Array Term) - (vars: Array Name) + (vals: Array (Name × Term)) (ctor: CtorArgs) +def Replace.vars (r : Replace): Array Name := r.vals.map Prod.fst +def Replace.expr (r : Replace): Array Term := r.vals.map Prod.snd variable (m) [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] @@ -32,13 +33,13 @@ private abbrev ReplaceM := StateT Replace m variable {m} private def Replace.new : m Replace := - do pure ⟨#[], #[], ⟨#[], #[]⟩⟩ + do pure ⟨#[], ⟨#[], #[]⟩⟩ private def CtorArgs.reset : ReplaceM m Unit := do let r ← StateT.get let n := r.vars.size let ctor: CtorArgs := ⟨#[], (Array.range n).map fun _ => #[]⟩ - StateT.set ⟨r.expr, r.vars, ctor⟩ + StateT.set { r with ctor } private def CtorArgs.get : ReplaceM m CtorArgs := do pure (←StateT.get).ctor @@ -47,7 +48,7 @@ private def CtorArgs.get : ReplaceM m CtorArgs := do The arity of the shape type created *after* replacing, i.e., the size of `r.expr` -/ def Replace.arity (r : Replace) : Nat := - r.expr.size + r.vals.size def Replace.getBinderIdents (r : Replace) : Array Ident := r.vars.map mkIdent @@ -68,19 +69,19 @@ private def ReplaceM.identFor (stx : Term) : ReplaceM m Ident := do let r ← StateT.get let ctor := r.ctor let argName ← mkFreshBinderName - let ctor_args := ctor.args.push argName + let args := ctor.args.push argName -- dbg_trace "\nstx := {stx}\nr.expr := {r.expr}" let name ← match r.expr.indexOf? stx with - | some id => do - let ctor_per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName - let ctor := ⟨ctor_args, ctor_per_type⟩ - StateT.set ⟨r.expr, r.vars, ctor⟩ + | some id => do + let per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName + let ctor := { args, per_type } + StateT.set { r with ctor } pure $ r.vars.get! id | none => do - let ctor_per_type := ctor.per_type.push #[argName] + let per_type := ctor.per_type.push #[argName] let name ← mkFreshBinderName - StateT.set ⟨r.expr.push stx, r.vars.push name, ⟨ctor_args, ctor_per_type⟩⟩ + StateT.set { vals := r.vals.push (name, stx), ctor := { args, per_type } } pure name return mkIdent name @@ -118,11 +119,6 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m | _ => pure res_type --- TODO: this should be deprecated in favour of {v with ...} syntax -def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := { - ctor with type? -} - /- TODO: currently these functions ignore dead variables, everything is replaced. This is OK, we can supply a "dead" value to a live variable, but we lose the ability to have @@ -205,7 +201,7 @@ Replace.run <| do let type? ← ctor.type?.mapM $ shapeOf' - pure $ (CtorView.withType? ctor type?, ←CtorArgs.get) + pure ({ ctor with type? }, ←CtorArgs.get) let r ← StateT.get let ctors := pairs.map Prod.fst; @@ -216,8 +212,8 @@ Replace.run <| do -- HACK: It seems that `Array.append` causes a stack overflow, so we go through `List` for now -- TODO: fix this after updating to newer Lean version - let per_type := per_type.appendList $ (List.range diff).map (fun _ => (#[] : Array Name)); - ⟨ctorArgs.args, per_type⟩ + let per_type := per_type.appendList $ List.replicate diff (#[] : Array Name) + { ctorArgs with per_type } -- Now that we know how many free variables were introduced, we can fix up the resulting type -- of each constructor to be `Shape α_0 α_1 ... α_n` @@ -226,7 +222,7 @@ Replace.run <| do let ctors ← ctors.mapM fun ctor => do let type? ← ctor.type?.mapM (setResultingType res) - pure $ CtorView.withType? ctor type? + pure { ctor with type? } pure (ctors, ctorArgs) @@ -234,12 +230,11 @@ Replace.run <| do /-- Replace syntax in *all* subexpressions -/ -partial def Replace.replaceAllStx (find replace : Syntax) : Syntax → Syntax := - fun stx => - if stx == find then - replace - else - stx.setArgs (stx.getArgs.map (replaceAllStx find replace)) +partial def Replace.replaceAllStx (find replace stx : Syntax) : Syntax := + if stx == find then + replace + else + stx.setArgs (stx.getArgs.map (replaceAllStx find replace)) @@ -292,7 +287,7 @@ def makeNonRecursive (view : DataView) : MetaM (DataView × Name) := do let ctors ← view.ctors.mapM fun ctor => do let type? ← ctor.type?.mapM (Replace.replaceStx expected recId <| TSyntax.mk ·) - return CtorView.withType? ctor type? + pure { ctor with type? } let view := view.setCtors ctors pure (view, rec)