diff --git a/Qpf/Macro/Comp.lean b/Qpf/Macro/Comp.lean index af95e95..5c6aaad 100644 --- a/Qpf/Macro/Comp.lean +++ b/Qpf/Macro/Comp.lean @@ -42,13 +42,13 @@ open Qq open TSyntax.Compat def synthMvFunctor {n : Nat} (F : Q(TypeFun.{u,u} $n)) : MetaM Q(MvFunctor $F) := do - let inst_type : Q(Type (u+1)) - := q(MvFunctor $F) + let inst_type : Q(Type (u+1)) := + q(MvFunctor $F) synthInstanceQ inst_type def synthQPF {n : Nat} (F : Q(TypeFun.{u,u} $n)) (_ : Q(MvFunctor $F)) : MetaM Q(MvQPF $F) := do - let inst_type : Q(Type (u+1)) - := q(MvQPF $F) + let inst_type : Q(Type (u+1)) := + q(MvQPF $F) synthInstanceQ inst_type @@ -154,7 +154,7 @@ def DVec.toExpr {n : Nat} {αs : Q(Fin2 $n → Type u)} (xs : DVec (fun (i : Fin structure ElabQpfResult (u : Level) (arity : Nat) where - F : Q(TypeFun.{u,u} $arity) := by exact q(by infer_instance) + F : Q(TypeFun.{u,u} $arity) functor : Q(MvFunctor $F) := by exact q(by infer_instance) qpf : Q(@MvQPF _ $F $functor) := by exact q(by infer_instance) deriving Inhabited @@ -183,7 +183,7 @@ partial def handleConst (target : Q(Type u)) : TermElabM (ElabQpfResult u arity pure { F := const, functor := q(Const.MvFunctor), qpf := q(Const.mvqpf)} -partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do +partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do let vars' := vars.toList let ⟨numArgs, F, args⟩ ← (Comp.parseApp (isLiveVar vars) target) @@ -226,10 +226,10 @@ partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermE let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2 (fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf) ) - return { F := comp, functor, qpf } -partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false): TermElabM (ElabQpfResult u arity) := do + +partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false) : TermElabM (ElabQpfResult u arity) := do let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[binderType, body] elabQpf vars newTarget targetStx normalized /-- diff --git a/Qpf/Macro/Data/Replace.lean b/Qpf/Macro/Data/Replace.lean index 9120cee..05800ca 100644 --- a/Qpf/Macro/Data/Replace.lean +++ b/Qpf/Macro/Data/Replace.lean @@ -17,12 +17,12 @@ structure CtorArgs where (args : Array Name) (per_type : Array (Array Name)) -/- TODO(@William): make these correspond by combining expr and vars into a product -/ structure Replace where - (expr: Array Term) - (vars: Array Name) + (newParameters : Array (Name × Term)) (ctor: CtorArgs) +def Replace.vars (r : Replace) : Array Name := r.newParameters.map Prod.fst +def Replace.expr (r : Replace) : Array Term := r.newParameters.map Prod.snd variable (m) [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] @@ -31,23 +31,21 @@ private abbrev ReplaceM := StateT Replace m variable {m} -private def Replace.new : m Replace := - do pure ⟨#[], #[], ⟨#[], #[]⟩⟩ +private def Replace.new : m Replace := + do pure ⟨#[], ⟨#[], #[]⟩⟩ private def CtorArgs.reset : ReplaceM m Unit := do let r ← StateT.get let n := r.vars.size let ctor: CtorArgs := ⟨#[], (Array.range n).map fun _ => #[]⟩ - StateT.set ⟨r.expr, r.vars, ctor⟩ + StateT.set { r with ctor } private def CtorArgs.get : ReplaceM m CtorArgs := do pure (←StateT.get).ctor -/-- - The arity of the shape type created *after* replacing, i.e., the size of `r.expr` --/ +/-- The arity of the shape type created *after* replacing, i.e., the size of `r.newParameters` -/ def Replace.arity (r : Replace) : Nat := - r.expr.size + r.newParameters.size def Replace.getBinderIdents (r : Replace) : Array Ident := r.vars.map mkIdent @@ -68,23 +66,23 @@ private def ReplaceM.identFor (stx : Term) : ReplaceM m Ident := do let r ← StateT.get let ctor := r.ctor let argName ← mkFreshBinderName - let ctor_args := ctor.args.push argName + let args := ctor.args.push argName -- dbg_trace "\nstx := {stx}\nr.expr := {r.expr}" let name ← match r.expr.indexOf? stx with - | some id => do - let ctor_per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName - let ctor := ⟨ctor_args, ctor_per_type⟩ - StateT.set ⟨r.expr, r.vars, ctor⟩ + | some id => do + let per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName + let ctor := { args, per_type } + StateT.set { r with ctor } pure $ r.vars.get! id | none => do - let ctor_per_type := ctor.per_type.push #[argName] + let per_type := ctor.per_type.push #[argName] let name ← mkFreshBinderName - StateT.set ⟨r.expr.push stx, r.vars.push name, ⟨ctor_args, ctor_per_type⟩⟩ + StateT.set { newParameters := r.newParameters.push (name, stx), ctor := { args, per_type } } pure name return mkIdent name - + @@ -96,13 +94,13 @@ open Lean.Parser in -/ private partial def shapeOf' : Syntax → ReplaceM m Syntax | Syntax.node _ ``Term.arrow #[arg, arrow, tail] => do - let ctor_arg ← ReplaceM.identFor ⟨arg⟩ + let ctor_arg ← ReplaceM.identFor ⟨arg⟩ let ctor_tail ← shapeOf' tail - -- dbg_trace ">> {arg} ==> {ctor_arg}" + -- dbg_trace ">> {arg} ==> {ctor_arg}" pure $ mkNode ``Term.arrow #[ctor_arg, arrow, ctor_tail] - | ctor_type => + | ctor_type => pure ctor_type @@ -115,7 +113,7 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m | Syntax.node _ ``Term.arrow #[arg, arrow, tail] => do let tail ← setResultingType res_type tail pure $ mkNode ``Term.arrow #[arg, arrow, tail] - | _ => + | _ => pure res_type -- TODO: this should be deprecated in favour of {v with ...} syntax @@ -131,7 +129,7 @@ def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := { We should instead detect which expressions are dead, and NOT introduce fresh variables for them. Instead, have the shape functor also take the same dead binders as the original. - This does mean that we should check for collisions between the original dead binders, and the + This does mean that we should check for collisions between the original dead binders, and the fresh variables. -/ @@ -139,8 +137,8 @@ def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := { /-- Runs the given action with a fresh instance of `Replace` -/ -def Replace.run : ReplaceM m α → m (α × Replace) := - fun x => do +def Replace.run : ReplaceM m α → m (α × Replace) := + fun x => do let r ← Replace.new StateT.run x r @@ -178,9 +176,9 @@ def preProcessCtors (view : DataView) : m DataView := do of the same type map to a single variable, where "same" refers to a very simple notion of syntactic equality. E.g., it does not realize `Nat` and `ℕ` are the same. -/ -def Replace.shapeOfCtors (view : DataView) - (shapeIdent : Syntax) - : m ((Array CtorView × Array CtorArgs) × Replace) := +def Replace.shapeOfCtors (view : DataView) + (shapeIdent : Syntax) + : m ((Array CtorView × Array CtorArgs) × Replace) := Replace.run <| do for var in view.liveBinders do let varIdent : Ident := ⟨if var.raw.getKind == ``Parser.Term.binderIdent then @@ -225,8 +223,8 @@ Replace.run <| do let res := Syntax.mkApp (TSyntax.mk shapeIdent) r.getBinderIdents let ctors ← ctors.mapM fun ctor => do - let type? ← ctor.type?.mapM (setResultingType res) - pure $ CtorView.withType? ctor type? + let type? ← ctor.type?.mapM (setResultingType res) + pure { ctor with type? } pure (ctors, ctorArgs) @@ -234,12 +232,11 @@ Replace.run <| do /-- Replace syntax in *all* subexpressions -/ -partial def Replace.replaceAllStx (find replace : Syntax) : Syntax → Syntax := - fun stx => - if stx == find then - replace - else - stx.setArgs (stx.getArgs.map (replaceAllStx find replace)) +partial def Replace.replaceAllStx (find replace stx : Syntax) : Syntax := + if stx == find then + replace + else + stx.setArgs (stx.getArgs.map (replaceAllStx find replace)) @@ -255,7 +252,7 @@ partial def Replace.replaceStx (recType newParam : Term) : Term → MetaM Term pure <| TSyntax.mk <| stx.setArgs #[ replaceAllStx recType newParam arg, arrow, - ←replaceStx recType newParam ⟨tail⟩ + ←replaceStx recType newParam ⟨tail⟩ ] | stx@(Syntax.node _ ``Term.arrow _) => @@ -264,11 +261,11 @@ partial def Replace.replaceStx (recType newParam : Term) : Term → MetaM Term | stx@(Syntax.node _ ``Term.depArrow _) => throwErrorAt stx "Dependent arrows are not supported, please only use plain arrow types" - | ctor_type => + | ctor_type => if ctor_type != recType then throwErrorAt ctor_type m!"Unexpected constructor resulting type; expected {recType}, found {ctor_type}" else - pure ⟨ctor_type⟩ + pure ⟨ctor_type⟩