From 3d687e3a22fbaf95e9189863836e1f9c050f984a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Thu, 4 Jul 2024 15:26:46 +0100 Subject: [PATCH 1/5] feat: add basic ind gen Tentative, hacky fix for shadowing issue refactor: Ind.lean refactor: clean-up Ind defn --- Qpf/Macro/Data.lean | 5 ++ Qpf/Macro/Data/Ind.lean | 185 ++++++++++++++++++++++++++++++++++++++++ Test/ListInduction.lean | 24 ++++++ 3 files changed, 214 insertions(+) create mode 100644 Qpf/Macro/Data/Ind.lean create mode 100644 Test/ListInduction.lean diff --git a/Qpf/Macro/Data.lean b/Qpf/Macro/Data.lean index 393935c..8561e40 100644 --- a/Qpf/Macro/Data.lean +++ b/Qpf/Macro/Data.lean @@ -4,6 +4,7 @@ import Mathlib.Data.QPF.Multivariate.Constructions.Fix import Qpf.Macro.Data.Replace import Qpf.Macro.Data.Count import Qpf.Macro.Data.View +import Qpf.Macro.Data.Ind import Qpf.Macro.Common import Qpf.Macro.Comp @@ -566,5 +567,9 @@ def elabData : CommandElab := fun stx => do mkType view base mkConstructors view shape + try mkInd view + catch e => + dbg_trace (← e.toMessageData.toString) + end Data.Command diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean new file mode 100644 index 0000000..8b5143b --- /dev/null +++ b/Qpf/Macro/Data/Ind.lean @@ -0,0 +1,185 @@ +import Qpf.Macro.Data.View +import Qpf.Macro.Common +import Mathlib.Data.QPF.Multivariate.Constructions.Fix +import Mathlib.Tactic.ExtractGoal + +open Lean.Parser (Parser) +open Lean Meta Elab.Command Elab.Term Parser.Term + +/-- + The recursive form encodes how a function argument is recursive. + + Examples ty R α: + + α → R α → List (R α) → R α + [nonRec, directRec, composed ] +-/ +inductive RecursionForm := + | nonRec (stx: Term) + | directRec + /- | composed -/ +deriving Repr, BEq + +partial def getArgTypes (v : Term) : List Term := match v.raw with + | .node _ ``arrow #[arg, _, deeper] => + ⟨arg⟩ :: getArgTypes ⟨deeper⟩ + | rest => [⟨rest⟩] + + +def containsStx (top : Term) (search : Term) : Bool := + (top.raw.find? (· == search)).isSome + +def ripSuffix : Name → Name + | .str _ s => .str .anonymous s + | _ => .anonymous -- Unhandled case + +-- These parsers have to be declared as they have optional args breaking quotation +abbrev bb : Parser := bracketedBinder +abbrev matchAltExprs : Parser := matchAlts + +instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩ +instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩ + +section +variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m] + +-- This function assumes the pre-processor has run +-- It also assumes you don't have polymorphic recursive types such as +-- data Ql α | nil | l : α → Ql Bool → Ql α +def extract (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := + (ripSuffix view.declName, ·) <$> (do + let some type := view.type? | pure [] + let type_ls := (getArgTypes ⟨type⟩).dropLast + + let transform ← type_ls.mapM fun v => + if v == rec_type then pure .directRec + else if containsStx v rec_type then + throwErrorAt v.raw "Cannot handle composed recursive types" + else pure $ .nonRec v + + pure transform) + +def mkConstructorType (rec_type : Term) (name : Ident) (form : List RecursionForm): m (TSyntax ``bracketedBinder) := do + let form ← form.mapM fun x => (x, mkIdent ·) <$> mkFreshBinderName + let form := form.reverse + + let out := Syntax.mkApp (← `(motive)) #[Syntax.mkApp name (form.map Prod.snd).toArray.reverse] + let out ← (form.filter (·.fst == .directRec)).foldlM (fun acc ⟨_, name⟩ => `(motive $name → $acc)) out + + let ty ← form.foldlM (fun acc => (match · with + | ⟨.nonRec x, name⟩ => `(($name : $x) → $acc) + | ⟨.directRec, name⟩ => `(($name : $rec_type) → $acc) + )) out + + `(bb | ($name : $ty)) + +def toEqLenNames (x : Array α) : m $ Array Ident := x.mapM (fun _ => mkIdent <$> mkFreshBinderName) +def listToEqLenNames (x : List α) : m $ Array Ident := toEqLenNames x.toArray + +def wrapIfNotSingle (arr : TSyntaxArray `term) : m Term := + if let #[s] := arr then `($s) + else `(⟨$arr,*⟩) + +def seq [Coe α (TSyntax kx)] (f : α → TSyntax kx → m (TSyntax kx)) : List α → m (TSyntax kx) + | [hd] => pure hd + | hd :: tl => do f hd (← seq f tl) + | [] => pure ⟨.node .none `null #[]⟩ + +def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do + let deeper: (TSyntaxArray ``matchAlt) ← values.mapM fun ⟨outerCase, form⟩ => do + let outerCaseId := mkIdent outerCase + let rec_count := form.count .directRec + + let names ← listToEqLenNames form + + if 0 = rec_count then + return ← `(matchAltExpr| | .$outerCaseId $names*, ih => ($outerCaseId $names*)) + + -- TODO: this is not a good way to deal with names + -- but, for some unknowable reason, this fixes a shadowing + -- issue + + let names : Array Ident := ⟨ + names.toList.enum.map fun ⟨i, _⟩ => + mkIdent <| Name.str (.anonymous) s!"_PRIVATE_ y {i}" + -- ^^^^^^^^^^^^^^^ + -- HACK: notice how the variable name contains spaces. + -- This is deliberate, to reduce the possibility of name collisions, + -- given that these names are unhygienic + ⟩ + -- The following would be better, with fresh names, + -- but for some reason this causes the variable `p` to be renamed to something new, which will then give errors when we use `p` later on + -- let names : Array Ident ← names.mapM fun _ => do + -- mkFreshIdent .missing + + let recs := names.zip (form.toArray) + |>.filter (·.snd == .directRec) + |>.map Prod.fst + + let p := mkIdent `proof + let w := mkIdent `witness + + let cases ← values.mapM fun ⟨innerCase, _⟩ => do + let innerCaseTag := mkIdent innerCase + if innerCase != outerCase then + `(tactic| case $innerCaseTag:ident => contradiction) + else + let split : Array (TSyntax `tactic) ← names.mapM fun n => + `(tactic|rcases $n:ident with ⟨_, $n:ident⟩) + let injections ← toEqLenNames names + + `(tactic| + case $innerCaseTag:ident $[$names:ident]* => ( + $split:tactic* + injection $p:ident with $injections* + subst $injections:ident* + exact $(← wrapIfNotSingle recs) + )) + + let witnesses ← toEqLenNames recs + let proofs ← wrapIfNotSingle witnesses + let type ← seq (fun a b => `($a ∧ $b)) (← recs.mapM fun x => `(motive $x)).toList + + `(matchAltExpr| + | .$outerCaseId $names*, ih => + have $proofs:term : $type := by + simp only [ + $(mkIdent ``MvFunctor.LiftP):ident, + $(mkIdent ``TypeVec.PredLast):ident, + $(mkIdent ``Fin2.instOfNatFin2HAddNatInstHAddInstAddNatOfNat):ident + ] at ih + + rcases ih with ⟨$w:ident, $p:ident⟩ + cases $w:ident + $cases:tactic* + $outerCaseId:ident $names* $witnesses* + ) + + `(matchAltExprs| $deeper:matchAlt* ) +end + + +def mkInd (view : DataView) : CommandElabM Unit := do + let rec_type := view.getExpectedType + + let mapped ← view.ctors.mapM (extract · rec_type) + let ih_types ← mapped.mapM fun ⟨name, base⟩ => + mkConstructorType rec_type (mkIdent name) base + + let out: Command ← `( + @[elab_as_elim, eliminator] + def $(.str view.shortDeclName "ind" |> mkIdent):ident + { motive : $rec_type → Prop} + $ih_types* + : (val : $rec_type) → motive val + := + $(mkIdent $ ``_root_.MvQPF.Fix.ind) + ($(mkIdent `p) := motive) + (match ·,· with $(← generate_body mapped))) + + trace[QPF] "Recursor definition:" + trace[QPF] out + + Elab.Command.elabCommand out + + pure () diff --git a/Test/ListInduction.lean b/Test/ListInduction.lean new file mode 100644 index 0000000..97f7553 --- /dev/null +++ b/Test/ListInduction.lean @@ -0,0 +1,24 @@ +import Qpf.Macro.Data + +/-! +## Test for induction principle generation +-/ + +namespace Test + +data QpfList α + | nil + | cons : α → QpfList α → QpfList α + +/-- info: 'Test.QpfList.ind' depends on axioms: [Quot.sound, propext] -/ +#guard_msgs in #print axioms QpfList.ind + +-- The following test might be a bit brittle. +-- Feel free to remove if it gives too many false positives +/-- +info: Test.QpfList.ind {α : Type} {motive : QpfList α → Prop} (nil : motive QpfList.nil) + (cons : ∀ (x : α) (x_1 : QpfList α), motive x_1 → motive (QpfList.cons x x_1)) (val✝ : QpfList α) : motive val✝ +-/ +#guard_msgs in #check QpfList.ind + +end Test From a8d228c4bc898a049e71e938a5aa71dc739e3085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Mon, 15 Jul 2024 13:09:44 +0100 Subject: [PATCH 2/5] feat: add noConfusion principle add more stuff --- Qpf/Macro/Data.lean | 6 +- Qpf/Macro/Data/Ind.lean | 157 +++++++++++++++++++++++++++++++++------- Test/List.lean | 2 +- Test/ListInduction.lean | 4 +- 4 files changed, 136 insertions(+), 33 deletions(-) diff --git a/Qpf/Macro/Data.lean b/Qpf/Macro/Data.lean index 8561e40..23c4eed 100644 --- a/Qpf/Macro/Data.lean +++ b/Qpf/Macro/Data.lean @@ -567,9 +567,9 @@ def elabData : CommandElab := fun stx => do mkType view base mkConstructors view shape - try mkInd view - catch e => - dbg_trace (← e.toMessageData.toString) + if let .Data := view.command then + try genRecursors view + catch e => trace[QPF] (← e.toMessageData.toString) end Data.Command diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean index 8b5143b..834e095 100644 --- a/Qpf/Macro/Data/Ind.lean +++ b/Qpf/Macro/Data/Ind.lean @@ -25,14 +25,11 @@ partial def getArgTypes (v : Term) : List Term := match v.raw with ⟨arg⟩ :: getArgTypes ⟨deeper⟩ | rest => [⟨rest⟩] +def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true def containsStx (top : Term) (search : Term) : Bool := (top.raw.find? (· == search)).isSome -def ripSuffix : Name → Name - | .str _ s => .str .anonymous s - | _ => .anonymous -- Unhandled case - -- These parsers have to be declared as they have optional args breaking quotation abbrev bb : Parser := bracketedBinder abbrev matchAltExprs : Parser := matchAlts @@ -40,14 +37,19 @@ abbrev matchAltExprs : Parser := matchAlts instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩ instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩ +def addShapeToName : Name → Name + | .anonymous => .str .anonymous "Shape" + | .str a b => .str (addShapeToName a) b + | .num a b => .num (addShapeToName a) b + section variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m] -- This function assumes the pre-processor has run -- It also assumes you don't have polymorphic recursive types such as -- data Ql α | nil | l : α → Ql Bool → Ql α -def extract (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := - (ripSuffix view.declName, ·) <$> (do +def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := + (view.declName.replacePrefix topName .anonymous , ·) <$> (do let some type := view.type? | pure [] let type_ls := (getArgTypes ⟨type⟩).dropLast @@ -59,19 +61,25 @@ def extract (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm pure transform) -def mkConstructorType (rec_type : Term) (name : Ident) (form : List RecursionForm): m (TSyntax ``bracketedBinder) := do +def mkConstructorType + (rec_type : Term) (name : Name) + (form : List RecursionForm) + (inclMotives : Bool) : m (TSyntax ``bracketedBinder) := do let form ← form.mapM fun x => (x, mkIdent ·) <$> mkFreshBinderName let form := form.reverse - let out := Syntax.mkApp (← `(motive)) #[Syntax.mkApp name (form.map Prod.snd).toArray.reverse] - let out ← (form.filter (·.fst == .directRec)).foldlM (fun acc ⟨_, name⟩ => `(motive $name → $acc)) out + let out := Syntax.mkApp (← `(motive)) #[Syntax.mkApp (mkIdent name) (form.map Prod.snd).toArray.reverse] + let out ← + if inclMotives then + (form.filter (·.fst == .directRec)).foldlM (fun acc ⟨_, name⟩ => `(motive $name → $acc)) out + else pure out let ty ← form.foldlM (fun acc => (match · with | ⟨.nonRec x, name⟩ => `(($name : $x) → $acc) | ⟨.directRec, name⟩ => `(($name : $rec_type) → $acc) )) out - `(bb | ($name : $ty)) + `(bb | ($(mkIdent $ flattenForArg name) : $ty)) def toEqLenNames (x : Array α) : m $ Array Ident := x.mapM (fun _ => mkIdent <$> mkFreshBinderName) def listToEqLenNames (x : List α) : m $ Array Ident := toEqLenNames x.toArray @@ -85,15 +93,16 @@ def seq [Coe α (TSyntax kx)] (f : α → TSyntax kx → m (TSyntax kx)) : List | hd :: tl => do f hd (← seq f tl) | [] => pure ⟨.node .none `null #[]⟩ -def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do - let deeper: (TSyntaxArray ``matchAlt) ← values.mapM fun ⟨outerCase, form⟩ => do - let outerCaseId := mkIdent outerCase +def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do + let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do + let callName := mkIdent $ flattenForArg outerCase + let outerCaseId := mkIdent $ addShapeToName outerCase let rec_count := form.count .directRec let names ← listToEqLenNames form if 0 = rec_count then - return ← `(matchAltExpr| | .$outerCaseId $names*, ih => ($outerCaseId $names*)) + return ← `(matchAltExpr| | $outerCaseId $names*, ih => ($callName $names*)) -- TODO: this is not a good way to deal with names -- but, for some unknowable reason, this fixes a shadowing @@ -119,12 +128,12 @@ def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax `` let p := mkIdent `proof let w := mkIdent `witness - let cases ← values.mapM fun ⟨innerCase, _⟩ => do + let cases ← ctors.mapM fun ⟨innerCase, _⟩ => do let innerCaseTag := mkIdent innerCase if innerCase != outerCase then `(tactic| case $innerCaseTag:ident => contradiction) else - let split : Array (TSyntax `tactic) ← names.mapM fun n => + let split : Array (TSyntax `tactic) ← recs.mapM fun n => `(tactic|rcases $n:ident with ⟨_, $n:ident⟩) let injections ← toEqLenNames names @@ -141,7 +150,7 @@ def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax `` let type ← seq (fun a b => `($a ∧ $b)) (← recs.mapM fun x => `(motive $x)).toList `(matchAltExpr| - | .$outerCaseId $names*, ih => + | $outerCaseId $names*, ih => have $proofs:term : $type := by simp only [ $(mkIdent ``MvFunctor.LiftP):ident, @@ -152,21 +161,53 @@ def generate_body (values : Array (Name × List RecursionForm)) : m $ TSyntax `` rcases ih with ⟨$w:ident, $p:ident⟩ cases $w:ident $cases:tactic* - $outerCaseId:ident $names* $witnesses* + $callName $names* $witnesses* ) `(matchAltExprs| $deeper:matchAlt* ) -end +def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive : Bool) : m $ TSyntax ``matchAlts := do + let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do + let callName := mkIdent $ flattenForArg outerCase + let outerCaseId := mkIdent $ addShapeToName outerCase + + let names ← listToEqLenNames form + let names := names.zip form.toArray + + let desArgs ← names.mapM fun ⟨nm, f⟩ => + match f with + | .directRec => `(⟨_, $nm⟩) + | .nonRec _ => `(_) + + let nonMotiveArgs ← names.mapM fun _ => `(_) + let motiveArgs ← if includeMotive then + names.filterMapM fun ⟨nm, f⟩ => + match f with + | .directRec => some <$> `($nm) + | .nonRec _ => pure none + else pure #[] + + + `(matchAltExpr| + | $outerCaseId $desArgs* => + $callName $nonMotiveArgs* $motiveArgs* + ) + + `(matchAltExprs| $deeper:matchAlt*) -def mkInd (view : DataView) : CommandElabM Unit := do +def natToTerm : ℕ → m Term + | 0 => pure $ mkIdent ``Nat.zero + | .succ n => do `($(mkIdent ``Nat.succ) $(← natToTerm n)) + +def genRecursors (view : DataView) : CommandElabM Unit := do let rec_type := view.getExpectedType - let mapped ← view.ctors.mapM (extract · rec_type) + let mapped ← view.ctors.mapM (extract view.declName · rec_type) + let ih_types ← mapped.mapM fun ⟨name, base⟩ => - mkConstructorType rec_type (mkIdent name) base + mkConstructorType (rec_type) (name) base true - let out: Command ← `( + let indDef : Command ← `( @[elab_as_elim, eliminator] def $(.str view.shortDeclName "ind" |> mkIdent):ident { motive : $rec_type → Prop} @@ -175,11 +216,73 @@ def mkInd (view : DataView) : CommandElabM Unit := do := $(mkIdent $ ``_root_.MvQPF.Fix.ind) ($(mkIdent `p) := motive) - (match ·,· with $(← generate_body mapped))) + (match ·,· with $(← generateIndBody mapped))) + + Elab.Command.elabCommand indDef - trace[QPF] "Recursor definition:" - trace[QPF] out + let recDef : Command ← `( + @[elab_as_elim] + def $(.str view.shortDeclName "rec" |> mkIdent):ident + { motive : $rec_type → Type _} + $ih_types* + : (val : $rec_type) → motive val + := + $(mkIdent $ ``MvQPF.Fix.drec) + (match · with $(← generateRecBody mapped true))) - Elab.Command.elabCommand out + + Elab.Command.elabCommand recDef + + let casesOnTypes ← mapped.mapM fun ⟨name, base⟩ => + mkConstructorType (rec_type) (name) base false + + let casesDef : Command ← `( + @[elab_as_elim] + def $(.str view.shortDeclName "cases" |> mkIdent):ident + { motive : $rec_type → Type _} + $casesOnTypes* + : (val : $rec_type) → motive val + := + $(mkIdent $ ``_root_.MvQPF.Fix.drec) + (match · with $(← generateRecBody mapped false))) + + Elab.Command.elabCommand casesDef + + let ixs ← (List.range mapped.size).mapM natToTerm + let ixs := ixs.toArray + + let toCtorDef : Command ← `( + def $(.str view.shortDeclName "toCtorIdx" |> mkIdent) (t : $rec_type) + : $(mkIdent ``Nat) := + $(.str view.shortDeclName "cases" |> mkIdent) $ixs* t + ) + Elab.Command.elabCommand toCtorDef + + let noConfusionType : Command ← `( + abbrev $(.str view.shortDeclName "noConfusionType" |> mkIdent) + (v : Sort _) (a b : $rec_type) := + $(mkIdent ``noConfusionTypeEnum) + $(.str view.shortDeclName "toCtorIdx" |> mkIdent) + v a b + ) + Elab.Command.elabCommand noConfusionType + + let noConfusion : Command ← `( + abbrev $(.str view.shortDeclName "noConfusion" |> mkIdent) + {P : Sort _} {a b : $rec_type} : + a = b → $(.str view.shortDeclName "noConfusionType" |> mkIdent) P a b := + $(mkIdent ``noConfusionEnum) $(.str view.shortDeclName "toCtorIdx" |> mkIdent) + ) + Elab.Command.elabCommand noConfusion + + trace[QPF] "Rec definitions:" + trace[QPF] indDef + trace[QPF] recDef + + trace[QPF] casesDef + + trace[QPF] toCtorDef + trace[QPF] noConfusionType + trace[QPF] noConfusion pure () diff --git a/Test/List.lean b/Test/List.lean index 027b487..add4502 100644 --- a/Test/List.lean +++ b/Test/List.lean @@ -6,4 +6,4 @@ data QpfList α | nil | cons : α → QpfList α → QpfList α -end Test \ No newline at end of file +end Test diff --git a/Test/ListInduction.lean b/Test/ListInduction.lean index 97f7553..94fb4be 100644 --- a/Test/ListInduction.lean +++ b/Test/ListInduction.lean @@ -16,8 +16,8 @@ data QpfList α -- The following test might be a bit brittle. -- Feel free to remove if it gives too many false positives /-- -info: Test.QpfList.ind {α : Type} {motive : QpfList α → Prop} (nil : motive QpfList.nil) - (cons : ∀ (x : α) (x_1 : QpfList α), motive x_1 → motive (QpfList.cons x x_1)) (val✝ : QpfList α) : motive val✝ +info: Test.QpfList.ind {α : Type} {motive✝ : QpfList α → Prop} (nil : motive✝ QpfList.nil) + (cons : ∀ (x : α) (x_1 : QpfList α), motive✝ x_1 → motive✝ (QpfList.cons x x_1)) (val✝ : QpfList α) : motive✝ val✝ -/ #guard_msgs in #check QpfList.ind From 75ba863f28f3cefa58501a00046a160a1d3d86c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Wed, 17 Jul 2024 11:02:30 +0100 Subject: [PATCH 3/5] refactor: clean-up Ind defn --- Qpf/Macro/Data/Ind.lean | 52 ++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean index 834e095..9fea4c6 100644 --- a/Qpf/Macro/Data/Ind.lean +++ b/Qpf/Macro/Data/Ind.lean @@ -5,6 +5,7 @@ import Mathlib.Tactic.ExtractGoal open Lean.Parser (Parser) open Lean Meta Elab.Command Elab.Term Parser.Term +open Lean.Parser.Tactic (inductionAlt) /-- The recursive form encodes how a function argument is recursive. @@ -30,11 +31,17 @@ def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true def containsStx (top : Term) (search : Term) : Bool := (top.raw.find? (· == search)).isSome --- These parsers have to be declared as they have optional args breaking quotation +/-- Both `bracketedBinder` and `matchAlts` have optional arguments, +which cause them to not by recognized as parsers in quotation syntax +(that is, ``` `(bracketedBinder| ...) ``` does not work). +To work around this, we define aliases that force the optional argument to it's default value, +so that we can write ``` `(bb| ...) ```instead. -/ abbrev bb : Parser := bracketedBinder abbrev matchAltExprs : Parser := matchAlts -instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩ +/- Since `bb` and `matchAltExprs` are aliases for `bracketedBinder`, resp. `matchAlts`, +we can safely coerce syntax of these categories -/ +instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩ instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩ def addShapeToName : Name → Name @@ -45,9 +52,9 @@ def addShapeToName : Name → Name section variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m] --- This function assumes the pre-processor has run --- It also assumes you don't have polymorphic recursive types such as --- data Ql α | nil | l : α → Ql Bool → Ql α +/-- This function assumes the pre-processor has run +It also assumes you don't have polymorphic recursive types such as +data Ql α | nil | l : α → Ql Bool → Ql α -/ def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := (view.declName.replacePrefix topName .anonymous , ·) <$> (do let some type := view.type? | pure [] @@ -88,10 +95,13 @@ def wrapIfNotSingle (arr : TSyntaxArray `term) : m Term := if let #[s] := arr then `($s) else `(⟨$arr,*⟩) -def seq [Coe α (TSyntax kx)] (f : α → TSyntax kx → m (TSyntax kx)) : List α → m (TSyntax kx) +/-- This function behaves like reduce but is specialized for TSyntaxes. +It is used to insert ∧s between entries -/ +def seq (f : TSyntax kx → TSyntax kx → m (TSyntax kx)) : List (TSyntax kx) → m (TSyntax kx) | [hd] => pure hd | hd :: tl => do f hd (← seq f tl) - | [] => pure ⟨.node .none `null #[]⟩ + | [] => throwError "Expected at least one value for interspersing" + def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do @@ -104,22 +114,7 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` if 0 = rec_count then return ← `(matchAltExpr| | $outerCaseId $names*, ih => ($callName $names*)) - -- TODO: this is not a good way to deal with names - -- but, for some unknowable reason, this fixes a shadowing - -- issue - - let names : Array Ident := ⟨ - names.toList.enum.map fun ⟨i, _⟩ => - mkIdent <| Name.str (.anonymous) s!"_PRIVATE_ y {i}" - -- ^^^^^^^^^^^^^^^ - -- HACK: notice how the variable name contains spaces. - -- This is deliberate, to reduce the possibility of name collisions, - -- given that these names are unhygienic - ⟩ - -- The following would be better, with fresh names, - -- but for some reason this causes the variable `p` to be renamed to something new, which will then give errors when we use `p` later on - -- let names : Array Ident ← names.mapM fun _ => do - -- mkFreshIdent .missing + let names ← toEqLenNames names let recs := names.zip (form.toArray) |>.filter (·.snd == .directRec) @@ -128,17 +123,16 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` let p := mkIdent `proof let w := mkIdent `witness - let cases ← ctors.mapM fun ⟨innerCase, _⟩ => do + let cases: TSyntaxArray _ ← ctors.mapM fun ⟨innerCase, _⟩ => do let innerCaseTag := mkIdent innerCase if innerCase != outerCase then - `(tactic| case $innerCaseTag:ident => contradiction) + `(inductionAlt| | $innerCaseTag:ident => contradiction) else let split : Array (TSyntax `tactic) ← recs.mapM fun n => `(tactic|rcases $n:ident with ⟨_, $n:ident⟩) let injections ← toEqLenNames names - `(tactic| - case $innerCaseTag:ident $[$names:ident]* => ( + `(inductionAlt| | $innerCaseTag:ident $[$names:ident]* => ( $split:tactic* injection $p:ident with $injections* subst $injections:ident* @@ -159,8 +153,8 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` ] at ih rcases ih with ⟨$w:ident, $p:ident⟩ - cases $w:ident - $cases:tactic* + cases $w:ident with + $cases:inductionAlt* $callName $names* $witnesses* ) From 40d350e0b2367450df822cf39aa67b9d052f59ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Wed, 17 Jul 2024 16:35:51 +0100 Subject: [PATCH 4/5] refactor: clean-up Ind defn --- Qpf/Macro/Data/Ind.lean | 103 +++++++++++++++------------------------- 1 file changed, 37 insertions(+), 66 deletions(-) diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean index 9fea4c6..90e101d 100644 --- a/Qpf/Macro/Data/Ind.lean +++ b/Qpf/Macro/Data/Ind.lean @@ -18,7 +18,7 @@ open Lean.Parser.Tactic (inductionAlt) inductive RecursionForm := | nonRec (stx: Term) | directRec - /- | composed -/ + -- | composed -- Not supported yet deriving Repr, BEq partial def getArgTypes (v : Term) : List Term := match v.raw with @@ -44,6 +44,8 @@ we can safely coerce syntax of these categories -/ instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩ instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩ +/-- When we want to operate on patterns the names we need must start with shape. +This is done as if theres a constructor called `mk` dot notation breaks. -/ def addShapeToName : Name → Name | .anonymous => .str .anonymous "Shape" | .str a b => .str (addShapeToName a) b @@ -52,7 +54,9 @@ def addShapeToName : Name → Name section variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m] -/-- This function assumes the pre-processor has run +/-- Extract takes a constructor and extracts its recursive forms. + +This function assumes the pre-processor has run It also assumes you don't have polymorphic recursive types such as data Ql α | nil | l : α → Ql Bool → Ql α -/ def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := @@ -60,15 +64,14 @@ def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × L let some type := view.type? | pure [] let type_ls := (getArgTypes ⟨type⟩).dropLast - let transform ← type_ls.mapM fun v => + type_ls.mapM fun v => if v == rec_type then pure .directRec else if containsStx v rec_type then throwErrorAt v.raw "Cannot handle composed recursive types" - else pure $ .nonRec v - - pure transform) + else pure $ .nonRec v) -def mkConstructorType +/-- Generate the binders for the different recursors -/ +def mkRecursorBinder (rec_type : Term) (name : Name) (form : List RecursionForm) (inclMotives : Bool) : m (TSyntax ``bracketedBinder) := do @@ -103,7 +106,7 @@ def seq (f : TSyntax kx → TSyntax kx → m (TSyntax kx)) : List (TSyntax kx) | [] => throwError "Expected at least one value for interspersing" -def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do +def generateIndBody (ctors : Array (Name × List RecursionForm)) (includeMotive : Bool) : m $ TSyntax ``matchAlts := do let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do let callName := mkIdent $ flattenForArg outerCase let outerCaseId := mkIdent $ addShapeToName outerCase @@ -111,7 +114,7 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` let names ← listToEqLenNames form - if 0 = rec_count then + if 0 = rec_count || !includeMotive then return ← `(matchAltExpr| | $outerCaseId $names*, ih => ($callName $names*)) let names ← toEqLenNames names @@ -120,9 +123,6 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` |>.filter (·.snd == .directRec) |>.map Prod.fst - let p := mkIdent `proof - let w := mkIdent `witness - let cases: TSyntaxArray _ ← ctors.mapM fun ⟨innerCase, _⟩ => do let innerCaseTag := mkIdent innerCase if innerCase != outerCase then @@ -134,7 +134,7 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` `(inductionAlt| | $innerCaseTag:ident $[$names:ident]* => ( $split:tactic* - injection $p:ident with $injections* + injection proof with $injections* subst $injections:ident* exact $(← wrapIfNotSingle recs) )) @@ -152,8 +152,8 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ` $(mkIdent ``Fin2.instOfNatFin2HAddNatInstHAddInstAddNatOfNat):ident ] at ih - rcases ih with ⟨$w:ident, $p:ident⟩ - cases $w:ident with + rcases ih with ⟨w, proof⟩ + cases w with $cases:inductionAlt* $callName $names* $witnesses* ) @@ -189,17 +189,13 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive `(matchAltExprs| $deeper:matchAlt*) -def natToTerm : ℕ → m Term - | 0 => pure $ mkIdent ``Nat.zero - | .succ n => do `($(mkIdent ``Nat.succ) $(← natToTerm n)) - def genRecursors (view : DataView) : CommandElabM Unit := do let rec_type := view.getExpectedType let mapped ← view.ctors.mapM (extract view.declName · rec_type) let ih_types ← mapped.mapM fun ⟨name, base⟩ => - mkConstructorType (rec_type) (name) base true + mkRecursorBinder (rec_type) (name) base true let indDef : Command ← `( @[elab_as_elim, eliminator] @@ -208,11 +204,9 @@ def genRecursors (view : DataView) : CommandElabM Unit := do $ih_types* : (val : $rec_type) → motive val := - $(mkIdent $ ``_root_.MvQPF.Fix.ind) + $(mkIdent ``_root_.MvQPF.Fix.ind) ($(mkIdent `p) := motive) - (match ·,· with $(← generateIndBody mapped))) - - Elab.Command.elabCommand indDef + (match ·,· with $(← generateIndBody mapped true))) let recDef : Command ← `( @[elab_as_elim] @@ -220,63 +214,40 @@ def genRecursors (view : DataView) : CommandElabM Unit := do { motive : $rec_type → Type _} $ih_types* : (val : $rec_type) → motive val - := - $(mkIdent $ ``MvQPF.Fix.drec) + := $(mkIdent ``MvQPF.Fix.drec) (match · with $(← generateRecBody mapped true))) - - Elab.Command.elabCommand recDef - let casesOnTypes ← mapped.mapM fun ⟨name, base⟩ => - mkConstructorType (rec_type) (name) base false + mkRecursorBinder (rec_type) (name) base false let casesDef : Command ← `( @[elab_as_elim] def $(.str view.shortDeclName "cases" |> mkIdent):ident - { motive : $rec_type → Type _} + { motive : $rec_type → Prop} $casesOnTypes* : (val : $rec_type) → motive val - := - $(mkIdent $ ``_root_.MvQPF.Fix.drec) - (match · with $(← generateRecBody mapped false))) - - Elab.Command.elabCommand casesDef + := $(mkIdent ``_root_.MvQPF.Fix.ind) + ($(mkIdent `p) := motive) + (match ·,· with $(← generateIndBody mapped false))) - let ixs ← (List.range mapped.size).mapM natToTerm - let ixs := ixs.toArray - - let toCtorDef : Command ← `( - def $(.str view.shortDeclName "toCtorIdx" |> mkIdent) (t : $rec_type) - : $(mkIdent ``Nat) := - $(.str view.shortDeclName "cases" |> mkIdent) $ixs* t - ) - Elab.Command.elabCommand toCtorDef - - let noConfusionType : Command ← `( - abbrev $(.str view.shortDeclName "noConfusionType" |> mkIdent) - (v : Sort _) (a b : $rec_type) := - $(mkIdent ``noConfusionTypeEnum) - $(.str view.shortDeclName "toCtorIdx" |> mkIdent) - v a b - ) - Elab.Command.elabCommand noConfusionType - - let noConfusion : Command ← `( - abbrev $(.str view.shortDeclName "noConfusion" |> mkIdent) - {P : Sort _} {a b : $rec_type} : - a = b → $(.str view.shortDeclName "noConfusionType" |> mkIdent) P a b := - $(mkIdent ``noConfusionEnum) $(.str view.shortDeclName "toCtorIdx" |> mkIdent) - ) - Elab.Command.elabCommand noConfusion + let casesTypeDef : Command ← `( + @[elab_as_elim] + def $(.str view.shortDeclName "casesType" |> mkIdent):ident + { motive : $rec_type → Type} + $casesOnTypes* + : (val : $rec_type) → motive val + := $(mkIdent ``_root_.MvQPF.Fix.drec) + (match · with $(← generateRecBody mapped false))) trace[QPF] "Rec definitions:" trace[QPF] indDef trace[QPF] recDef + Elab.Command.elabCommand indDef + Elab.Command.elabCommand recDef trace[QPF] casesDef - - trace[QPF] toCtorDef - trace[QPF] noConfusionType - trace[QPF] noConfusion + trace[QPF] casesTypeDef + Elab.Command.elabCommand casesDef + Elab.Command.elabCommand casesTypeDef pure () From c51fbdafe75a016ade407ba8e86ed8fb31fe4233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Fri, 19 Jul 2024 11:45:23 +0100 Subject: [PATCH 5/5] docs: comment some code --- Qpf/Macro/Data/Ind.lean | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean index 90e101d..a052154 100644 --- a/Qpf/Macro/Data/Ind.lean +++ b/Qpf/Macro/Data/Ind.lean @@ -94,6 +94,8 @@ def mkRecursorBinder def toEqLenNames (x : Array α) : m $ Array Ident := x.mapM (fun _ => mkIdent <$> mkFreshBinderName) def listToEqLenNames (x : List α) : m $ Array Ident := toEqLenNames x.toArray +/-- If the array is a singleton then this can be yielded by the proof, +otherwise it will be a n-ary product -/ def wrapIfNotSingle (arr : TSyntaxArray `term) : m Term := if let #[s] := arr then `($s) else `(⟨$arr,*⟩)