From fbde3684d9814443f2911284ae8f75a90bc44757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Thu, 4 Jul 2024 15:26:46 +0100 Subject: [PATCH] feat: add basic ind gen --- Qpf/Macro/Data.lean | 5 ++ Qpf/Macro/Data/Ind.lean | 188 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 Qpf/Macro/Data/Ind.lean diff --git a/Qpf/Macro/Data.lean b/Qpf/Macro/Data.lean index 393935c..8561e40 100644 --- a/Qpf/Macro/Data.lean +++ b/Qpf/Macro/Data.lean @@ -4,6 +4,7 @@ import Mathlib.Data.QPF.Multivariate.Constructions.Fix import Qpf.Macro.Data.Replace import Qpf.Macro.Data.Count import Qpf.Macro.Data.View +import Qpf.Macro.Data.Ind import Qpf.Macro.Common import Qpf.Macro.Comp @@ -566,5 +567,9 @@ def elabData : CommandElab := fun stx => do mkType view base mkConstructors view shape + try mkInd view + catch e => + dbg_trace (← e.toMessageData.toString) + end Data.Command diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean new file mode 100644 index 0000000..90a733b --- /dev/null +++ b/Qpf/Macro/Data/Ind.lean @@ -0,0 +1,188 @@ +import Qpf.Macro.Data.View +import Qpf.Macro.Common +import Mathlib.Data.QPF.Multivariate.Constructions.Fix +import Mathlib.Tactic.ExtractGoal + +open Lean Meta Elab.Command Elab.Term Parser.Term + +inductive RecursionForm := + | trivial (stx: Term) + | recursive -- Simple to infer +deriving Repr, BEq + +def motive := mkIdent $ .str .anonymous "motive" +def ih := mkIdent $ .str .anonymous "ih" +def ind := mkIdent $ ``_root_.MvQPF.Fix.ind + +section +open Lean.Parser in +partial def flattenArrow (v : Term) : List Term := match v.raw with + | .node _ ``Term.arrow #[arg, _, deeper] => + ⟨arg⟩ :: flattenArrow ⟨deeper⟩ + | rest => [⟨rest⟩] + +variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m] + [AddMessageContext m] [MonadLiftT IO m] + +def containsStx (top : Term) (search : Term) : Bool := + (top.raw.find? (· == search)).isSome + +def rip : Name → Name + | .str _ s => .str .anonymous s + | _ => .anonymous + +-- This function assumes the pre-processor has run +-- It also assumes you don't have polymorphic recursive types such as +-- data Ql α | nil | l : α → Ql Bool → Ql α +def extract (view : CtorView) (dv : DataView) : m $ Name × List RecursionForm := + (rip view.declName, ·) <$> (do + let rec_type := dv.getExpectedType + let some type := view.type? | pure [] + let type_ls := (flattenArrow ⟨type⟩).dropLast + + let transform ← type_ls.mapM fun v => + if v == rec_type then pure .recursive + else if containsStx v rec_type then + throwErrorAt v.raw "Cannot handle composed recursive types" + else pure $ .trivial v + + pure transform) + +def bb := bracketedBinder + +instance : Coe (TSyntax `bb) (TSyntax `Lean.Parser.Term.bracketedBinder ) where + coe x := ⟨x.raw⟩ + +open Syntax in +def mkIhType (dv : DataView) (name : Name) (form : List RecursionForm): + m (TSyntax `Lean.Parser.Term.bracketedBinder) := do + let form ← form.mapM fun x => do + let name := mkIdent $ ← mkFreshBinderName + pure (x, name) + let form := form.reverse + + -- Construct the motive type + let out := Syntax.mkApp motive #[ + Syntax.mkApp (mkIdent name) (form.map Prod.snd).toArray.reverse] + -- Add each of the motive hypothesis + let out ← (form.filter (·.fst == .recursive)).foldlM (fun acc ⟨_, name⟩ => `($motive $name → $acc)) out + + let rec_type := dv.getExpectedType + + -- Add the binders + let ty ← form.foldlM (fun acc => (match · with + | ⟨.trivial x, name⟩ => `(($name : $x) → $acc) + | ⟨.recursive, name⟩ => `(($name : $rec_type) → $acc) + )) out + + `(bb | ($(mkIdent name) : $ty)) + +open Lean.Parser.Term in +open Lean.Parser in +private abbrev matchAltExprs : Parser := matchAlts + +def toEqLenNames (x : Array α): m $ Array Ident := x.mapM (fun _ => mkIdent <$> mkFreshBinderName) +def listToEqLenNames (x : List α): m $ Array Ident := toEqLenNames x.toArray + +def wrapIfNotSingle (arr : TSyntaxArray `term) : m Term := + if let #[s] := arr then `($s) + else `(⟨$arr,*⟩) + +def seq [Coe α (TSyntax kx)] (f : α → TSyntax kx → m (TSyntax kx)) : List α → m (TSyntax kx) + | [hd] => pure hd + | hd :: tl => do f hd (← seq f tl) + | [] => pure ⟨.node .none `null #[]⟩ + +open Lean.Parser.Term in +def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax `Lean.Parser.Term.matchAlts := do + let deeper: (TSyntaxArray `Lean.Parser.Term.matchAlt) ← values.mapM fun ⟨name, form⟩ => do + let rec_count := form.count .recursive + let names ← listToEqLenNames form + let recs := names.zip (form.toArray) + |>.filter (·.snd == .recursive) + |>.map Prod.fst + + let witnesses ← toEqLenNames recs + + let body : Term ← if 0 = rec_count + then `($(mkIdent name) $names*) + else + let name := mkIdent name + + let p := mkIdent `p + let w := mkIdent `w + + let cases ← values.mapM fun ⟨case, _⟩ => + let case := mkIdent case + if case != name then + `(tactic| case $case:ident => contradiction) + else do + let split: Syntax.TSepArray `tactic "" := .ofElems $ ← names.mapM fun n => + `(tactic|rcases $n:ident with ⟨_, $n:ident⟩) + + let injections ← listToEqLenNames form + + `(tactic|case $name:ident $[$names:ident]* => { + extract_goal; + $split* + + stop + injection $p with $injections* + subst $injections:ident* + + exact $(← wrapIfNotSingle recs) + }) + + trace[QPF] s!"count : {cases.size} {values.size}" + + let proofs ← wrapIfNotSingle witnesses + let ty ← seq (fun a b => `($a ∧ $b)) (← recs.mapM fun x => `($motive $x)).toList + `(have $proofs:term : $ty := by + simp [$(mkIdent ``MvFunctor.LiftP):ident, $(mkIdent ``MvFunctor.map):ident] at $ih:ident + + rcases $ih:ident with ⟨$w, $p⟩ + + /- sorry -/ + cases $w:ident + $cases:tactic;* + $name $names* $witnesses*) + + `(matchAltExpr| + | .$(mkIdent name) $names*, $ih => $body + ) + let x ← `(matchAltExprs| $deeper:matchAlt* ) + pure ⟨x.raw⟩ +end + + +def mkInd (view : DataView) : CommandElabM Unit := do + let mapped ← view.ctors.mapM (extract · view) + + + let ih_types ← mapped.mapM fun ⟨name, base⟩ => + mkIhType view name base + + let rec_type := view.getExpectedType + + let body ← generate_body mapped + + let nm := .str view.shortDeclName "ind" |> mkIdent + let out: Command ← `( + @[elab_as_elim, eliminator] + def $nm + { $(⟨motive⟩) : $rec_type → Prop} + $ih_types* + : (val : $rec_type) → $motive val + := + $ind + ($(mkIdent `p) := $motive) + (match ·,· with $body)) + + trace[QPF] "Recursor definition:" + trace[QPF] out + + Elab.Command.elabCommand out + + pure () + + /- type.ar -/