Skip to content

Commit

Permalink
refactor: clean-up Ind defn
Browse files Browse the repository at this point in the history
  • Loading branch information
William Sørensen committed Jul 17, 2024
1 parent a8d228c commit 75ba863
Showing 1 changed file with 23 additions and 29 deletions.
52 changes: 23 additions & 29 deletions Qpf/Macro/Data/Ind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Mathlib.Tactic.ExtractGoal

open Lean.Parser (Parser)
open Lean Meta Elab.Command Elab.Term Parser.Term
open Lean.Parser.Tactic (inductionAlt)

/--
The recursive form encodes how a function argument is recursive.
Expand All @@ -30,11 +31,17 @@ def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true
def containsStx (top : Term) (search : Term) : Bool :=
(top.raw.find? (· == search)).isSome

-- These parsers have to be declared as they have optional args breaking quotation
/-- Both `bracketedBinder` and `matchAlts` have optional arguments,
which cause them to not by recognized as parsers in quotation syntax
(that is, ``` `(bracketedBinder| ...) ``` does not work).
To work around this, we define aliases that force the optional argument to it's default value,
so that we can write ``` `(bb| ...) ```instead. -/
abbrev bb : Parser := bracketedBinder
abbrev matchAltExprs : Parser := matchAlts

instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩
/- Since `bb` and `matchAltExprs` are aliases for `bracketedBinder`, resp. `matchAlts`,
we can safely coerce syntax of these categories -/
instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩
instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩

def addShapeToName : Name → Name
Expand All @@ -45,9 +52,9 @@ def addShapeToName : Name → Name
section
variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m]

-- This function assumes the pre-processor has run
-- It also assumes you don't have polymorphic recursive types such as
-- data Ql α | nil | l : α → Ql Bool → Ql α
/-- This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , ·) <$> (do
let some type := view.type? | pure []
Expand Down Expand Up @@ -88,10 +95,13 @@ def wrapIfNotSingle (arr : TSyntaxArray `term) : m Term :=
if let #[s] := arr then `($s)
else `(⟨$arr,*⟩)

def seq [Coe α (TSyntax kx)] (f : α → TSyntax kx → m (TSyntax kx)) : List α → m (TSyntax kx)
/-- This function behaves like reduce but is specialized for TSyntaxes.
It is used to insert ∧s between entries -/
def seq (f : TSyntax kx → TSyntax kx → m (TSyntax kx)) : List (TSyntax kx) → m (TSyntax kx)
| [hd] => pure hd
| hd :: tl => do f hd (← seq f tl)
| [] => pure ⟨.node .none `null #[]⟩
| [] => throwError "Expected at least one value for interspersing"


def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do
let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do
Expand All @@ -104,22 +114,7 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `
if 0 = rec_count then
return ← `(matchAltExpr| | $outerCaseId $names*, ih => ($callName $names*))

-- TODO: this is not a good way to deal with names
-- but, for some unknowable reason, this fixes a shadowing
-- issue

let names : Array Ident := ⟨
names.toList.enum.map fun ⟨i, _⟩ =>
mkIdent <| Name.str (.anonymous) s!"_PRIVATE_ y {i}"
-- ^^^^^^^^^^^^^^^
-- HACK: notice how the variable name contains spaces.
-- This is deliberate, to reduce the possibility of name collisions,
-- given that these names are unhygienic
-- The following would be better, with fresh names,
-- but for some reason this causes the variable `p` to be renamed to something new, which will then give errors when we use `p` later on
-- let names : Array Ident ← names.mapM fun _ => do
-- mkFreshIdent .missing
let names ← toEqLenNames names

let recs := names.zip (form.toArray)
|>.filter (·.snd == .directRec)
Expand All @@ -128,17 +123,16 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `
let p := mkIdent `proof
let w := mkIdent `witness

let cases ← ctors.mapM fun ⟨innerCase, _⟩ => do
let cases: TSyntaxArray _ ← ctors.mapM fun ⟨innerCase, _⟩ => do
let innerCaseTag := mkIdent innerCase
if innerCase != outerCase then
`(tactic| case $innerCaseTag:ident => contradiction)
`(inductionAlt| | $innerCaseTag:ident => contradiction)
else
let split : Array (TSyntax `tactic) ← recs.mapM fun n =>
`(tactic|rcases $n:ident with ⟨_, $n:ident⟩)
let injections ← toEqLenNames names

`(tactic|
case $innerCaseTag:ident $[$names:ident]* => (
`(inductionAlt| | $innerCaseTag:ident $[$names:ident]* => (
$split:tactic*
injection $p:ident with $injections*
subst $injections:ident*
Expand All @@ -159,8 +153,8 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `
] at ih

rcases ih with ⟨$w:ident, $p:ident⟩
cases $w:ident
$cases:tactic*
cases $w:ident with
$cases:inductionAlt*
$callName $names* $witnesses*
)

Expand Down

0 comments on commit 75ba863

Please sign in to comment.