Skip to content

Commit

Permalink
Merge branch 'master' into refactor-macro-comp-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Equilibris authored Jun 27, 2024
2 parents bfc07e0 + b3c2270 commit 6e78b77
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 44 deletions.
10 changes: 5 additions & 5 deletions Qpf/Macro/Comp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ open Qq
open TSyntax.Compat

def synthMvFunctor {n : Nat} (F : Q(TypeFun.{u,u} $n)) : MetaM Q(MvFunctor $F) := do
let inst_type : Q(Type (u+1))
:= q(MvFunctor $F)
let inst_type : Q(Type (u+1)) :=
q(MvFunctor $F)
synthInstanceQ inst_type

def synthQPF {n : Nat} (F : Q(TypeFun.{u,u} $n)) (_ : Q(MvFunctor $F)) : MetaM Q(MvQPF $F) := do
let inst_type : Q(Type (u+1))
:= q(MvQPF $F)
let inst_type : Q(Type (u+1)) :=
q(MvQPF $F)
synthInstanceQ inst_type


Expand Down Expand Up @@ -233,9 +233,9 @@ partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermE
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)

return { F := comp, functor, qpf }


partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false): TermElabM (ElabQpfResult u arity) := do
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[binderType, body]
elabQpf vars newTarget targetStx normalized
Expand Down
75 changes: 36 additions & 39 deletions Qpf/Macro/Data/Replace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ structure CtorArgs where
(args : Array Name)
(per_type : Array (Array Name))

/- TODO(@William): make these correspond by combining expr and vars into a product -/
structure Replace where
(expr: Array Term)
(vars: Array Name)
(newParameters : Array (Name × Term))
(ctor: CtorArgs)

def Replace.vars (r : Replace) : Array Name := r.newParameters.map Prod.fst
def Replace.expr (r : Replace) : Array Term := r.newParameters.map Prod.snd

variable (m) [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m]
[AddMessageContext m] [MonadLiftT IO m]
Expand All @@ -31,23 +31,21 @@ private abbrev ReplaceM := StateT Replace m

variable {m}

private def Replace.new : m Replace :=
do pure ⟨#[], #[], ⟨#[], #[]⟩⟩
private def Replace.new : m Replace :=
do pure ⟨#[], ⟨#[], #[]⟩⟩

private def CtorArgs.reset : ReplaceM m Unit := do
let r ← StateT.get
let n := r.vars.size
let ctor: CtorArgs := ⟨#[], (Array.range n).map fun _ => #[]⟩
StateT.set ⟨r.expr, r.vars, ctor
StateT.set { r with ctor }

private def CtorArgs.get : ReplaceM m CtorArgs := do
pure (←StateT.get).ctor

/--
The arity of the shape type created *after* replacing, i.e., the size of `r.expr`
-/
/-- The arity of the shape type created *after* replacing, i.e., the size of `r.newParameters` -/
def Replace.arity (r : Replace) : Nat :=
r.expr.size
r.newParameters.size

def Replace.getBinderIdents (r : Replace) : Array Ident :=
r.vars.map mkIdent
Expand All @@ -68,23 +66,23 @@ private def ReplaceM.identFor (stx : Term) : ReplaceM m Ident := do
let r ← StateT.get
let ctor := r.ctor
let argName ← mkFreshBinderName
let ctor_args := ctor.args.push argName
let args := ctor.args.push argName
-- dbg_trace "\nstx := {stx}\nr.expr := {r.expr}"

let name ← match r.expr.indexOf? stx with
| some id => do
let ctor_per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName
let ctor := ⟨ctor_args, ctor_per_type⟩
StateT.set ⟨r.expr, r.vars, ctor
| some id => do
let per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName
let ctor := { args, per_type }
StateT.set { r with ctor }
pure $ r.vars.get! id
| none => do
let ctor_per_type := ctor.per_type.push #[argName]
let per_type := ctor.per_type.push #[argName]
let name ← mkFreshBinderName
StateT.set ⟨r.expr.push stx, r.vars.push name, ⟨ctor_args, ctor_per_type⟩⟩
StateT.set { newParameters := r.newParameters.push (name, stx), ctor := { args, per_type } }
pure name

return mkIdent name




Expand All @@ -96,13 +94,13 @@ open Lean.Parser in
-/
private partial def shapeOf' : Syntax → ReplaceM m Syntax
| Syntax.node _ ``Term.arrow #[arg, arrow, tail] => do
let ctor_arg ← ReplaceM.identFor ⟨arg⟩
let ctor_arg ← ReplaceM.identFor ⟨arg⟩
let ctor_tail ← shapeOf' tail

-- dbg_trace ">> {arg} ==> {ctor_arg}"
-- dbg_trace ">> {arg} ==> {ctor_arg}"
pure $ mkNode ``Term.arrow #[ctor_arg, arrow, ctor_tail]

| ctor_type =>
| ctor_type =>
pure ctor_type


Expand All @@ -115,7 +113,7 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m
| Syntax.node _ ``Term.arrow #[arg, arrow, tail] => do
let tail ← setResultingType res_type tail
pure $ mkNode ``Term.arrow #[arg, arrow, tail]
| _ =>
| _ =>
pure res_type

-- TODO: this should be deprecated in favour of {v with ...} syntax
Expand All @@ -131,16 +129,16 @@ def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := {
We should instead detect which expressions are dead, and NOT introduce fresh variables for them.
Instead, have the shape functor also take the same dead binders as the original.
This does mean that we should check for collisions between the original dead binders, and the
This does mean that we should check for collisions between the original dead binders, and the
fresh variables.
-/




/-- Runs the given action with a fresh instance of `Replace` -/
def Replace.run : ReplaceM m α → m (α × Replace) :=
fun x => do
def Replace.run : ReplaceM m α → m (α × Replace) :=
fun x => do
let r ← Replace.new
StateT.run x r

Expand Down Expand Up @@ -178,9 +176,9 @@ def preProcessCtors (view : DataView) : m DataView := do
of the same type map to a single variable, where "same" refers to a very simple notion of
syntactic equality. E.g., it does not realize `Nat` and `ℕ` are the same.
-/
def Replace.shapeOfCtors (view : DataView)
(shapeIdent : Syntax)
: m ((Array CtorView × Array CtorArgs) × Replace) :=
def Replace.shapeOfCtors (view : DataView)
(shapeIdent : Syntax)
: m ((Array CtorView × Array CtorArgs) × Replace) :=
Replace.run <| do
for var in view.liveBinders do
let varIdent : Ident := ⟨if var.raw.getKind == ``Parser.Term.binderIdent then
Expand Down Expand Up @@ -225,21 +223,20 @@ Replace.run <| do
let res := Syntax.mkApp (TSyntax.mk shapeIdent) r.getBinderIdents

let ctors ← ctors.mapM fun ctor => do
let type? ← ctor.type?.mapM (setResultingType res)
pure $ CtorView.withType? ctor type?
let type? ← ctor.type?.mapM (setResultingType res)
pure { ctor with type? }

pure (ctors, ctorArgs)




/-- Replace syntax in *all* subexpressions -/
partial def Replace.replaceAllStx (find replace : Syntax) : Syntax → Syntax :=
fun stx =>
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))
partial def Replace.replaceAllStx (find replace stx : Syntax) : Syntax :=
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))



Expand All @@ -255,7 +252,7 @@ partial def Replace.replaceStx (recType newParam : Term) : Term → MetaM Term
pure <| TSyntax.mk <| stx.setArgs #[
replaceAllStx recType newParam arg,
arrow,
←replaceStx recType newParam ⟨tail⟩
←replaceStx recType newParam ⟨tail⟩
]

| stx@(Syntax.node _ ``Term.arrow _) =>
Expand All @@ -264,11 +261,11 @@ partial def Replace.replaceStx (recType newParam : Term) : Term → MetaM Term
| stx@(Syntax.node _ ``Term.depArrow _) =>
throwErrorAt stx "Dependent arrows are not supported, please only use plain arrow types"

| ctor_type =>
| ctor_type =>
if ctor_type != recType then
throwErrorAt ctor_type m!"Unexpected constructor resulting type; expected {recType}, found {ctor_type}"
else
pure ⟨ctor_type⟩
pure ⟨ctor_type⟩



Expand Down

0 comments on commit 6e78b77

Please sign in to comment.