Skip to content

Commit

Permalink
refactor: clean-up Ind defn
Browse files Browse the repository at this point in the history
  • Loading branch information
William Sørensen committed Jul 17, 2024
1 parent 75ba863 commit 40d350e
Showing 1 changed file with 37 additions and 66 deletions.
103 changes: 37 additions & 66 deletions Qpf/Macro/Data/Ind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ open Lean.Parser.Tactic (inductionAlt)
inductive RecursionForm :=
| nonRec (stx: Term)
| directRec
/- | composed -/
-- | composed -- Not supported yet
deriving Repr, BEq

partial def getArgTypes (v : Term) : List Term := match v.raw with
Expand All @@ -44,6 +44,8 @@ we can safely coerce syntax of these categories -/
instance : Coe (TSyntax ``bb) (TSyntax ``bracketedBinder) where coe x := ⟨x.raw⟩
instance : Coe (TSyntax ``matchAltExprs) (TSyntax ``matchAlts) where coe x := ⟨x.raw⟩

/-- When we want to operate on patterns the names we need must start with shape.
This is done as if theres a constructor called `mk` dot notation breaks. -/
def addShapeToName : Name → Name
| .anonymous => .str .anonymous "Shape"
| .str a b => .str (addShapeToName a) b
Expand All @@ -52,23 +54,24 @@ def addShapeToName : Name → Name
section
variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m]

/-- This function assumes the pre-processor has run
/-- Extract takes a constructor and extracts its recursive forms.
This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , ·) <$> (do
let some type := view.type? | pure []
let type_ls := (getArgTypes ⟨type⟩).dropLast

let transform ← type_ls.mapM fun v =>
type_ls.mapM fun v =>
if v == rec_type then pure .directRec
else if containsStx v rec_type then
throwErrorAt v.raw "Cannot handle composed recursive types"
else pure $ .nonRec v

pure transform)
else pure $ .nonRec v)

def mkConstructorType
/-- Generate the binders for the different recursors -/
def mkRecursorBinder
(rec_type : Term) (name : Name)
(form : List RecursionForm)
(inclMotives : Bool) : m (TSyntax ``bracketedBinder) := do
Expand Down Expand Up @@ -103,15 +106,15 @@ def seq (f : TSyntax kx → TSyntax kx → m (TSyntax kx)) : List (TSyntax kx)
| [] => throwError "Expected at least one value for interspersing"


def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax ``matchAlts := do
def generateIndBody (ctors : Array (Name × List RecursionForm)) (includeMotive : Bool) : m $ TSyntax ``matchAlts := do
let deeper: (TSyntaxArray ``matchAlt) ← ctors.mapM fun ⟨outerCase, form⟩ => do
let callName := mkIdent $ flattenForArg outerCase
let outerCaseId := mkIdent $ addShapeToName outerCase
let rec_count := form.count .directRec

let names ← listToEqLenNames form

if 0 = rec_count then
if 0 = rec_count || !includeMotive then
return ← `(matchAltExpr| | $outerCaseId $names*, ih => ($callName $names*))

let names ← toEqLenNames names
Expand All @@ -120,9 +123,6 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `
|>.filter (·.snd == .directRec)
|>.map Prod.fst

let p := mkIdent `proof
let w := mkIdent `witness

let cases: TSyntaxArray _ ← ctors.mapM fun ⟨innerCase, _⟩ => do
let innerCaseTag := mkIdent innerCase
if innerCase != outerCase then
Expand All @@ -134,7 +134,7 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `

`(inductionAlt| | $innerCaseTag:ident $[$names:ident]* => (
$split:tactic*
injection $p:ident with $injections*
injection proof with $injections*
subst $injections:ident*
exact $(← wrapIfNotSingle recs)
))
Expand All @@ -152,8 +152,8 @@ def generateIndBody (ctors : Array (Name × List RecursionForm)) : m $ TSyntax `
$(mkIdent ``Fin2.instOfNatFin2HAddNatInstHAddInstAddNatOfNat):ident
] at ih

rcases ih with$w:ident, $p:ident
cases $w:ident with
rcases ih withw, proof
cases w with
$cases:inductionAlt*
$callName $names* $witnesses*
)
Expand Down Expand Up @@ -189,17 +189,13 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive

`(matchAltExprs| $deeper:matchAlt*)

def natToTerm : ℕ → m Term
| 0 => pure $ mkIdent ``Nat.zero
| .succ n => do `($(mkIdent ``Nat.succ) $(← natToTerm n))

def genRecursors (view : DataView) : CommandElabM Unit := do
let rec_type := view.getExpectedType

let mapped ← view.ctors.mapM (extract view.declName · rec_type)

let ih_types ← mapped.mapM fun ⟨name, base⟩ =>
mkConstructorType (rec_type) (name) base true
mkRecursorBinder (rec_type) (name) base true

let indDef : Command ← `(
@[elab_as_elim, eliminator]
Expand All @@ -208,75 +204,50 @@ def genRecursors (view : DataView) : CommandElabM Unit := do
$ih_types*
: (val : $rec_type) → motive val
:=
$(mkIdent $ ``_root_.MvQPF.Fix.ind)
$(mkIdent ``_root_.MvQPF.Fix.ind)
($(mkIdent `p) := motive)
(match ·,· with $(← generateIndBody mapped)))

Elab.Command.elabCommand indDef
(match ·,· with $(← generateIndBody mapped true)))

let recDef : Command ← `(
@[elab_as_elim]
def $(.str view.shortDeclName "rec" |> mkIdent):ident
{ motive : $rec_type → Type _}
$ih_types*
: (val : $rec_type) → motive val
:=
$(mkIdent $ ``MvQPF.Fix.drec)
:= $(mkIdent ``MvQPF.Fix.drec)
(match · with $(← generateRecBody mapped true)))


Elab.Command.elabCommand recDef

let casesOnTypes ← mapped.mapM fun ⟨name, base⟩ =>
mkConstructorType (rec_type) (name) base false
mkRecursorBinder (rec_type) (name) base false

let casesDef : Command ← `(
@[elab_as_elim]
def $(.str view.shortDeclName "cases" |> mkIdent):ident
{ motive : $rec_type → Type _}
{ motive : $rec_type → Prop}
$casesOnTypes*
: (val : $rec_type) → motive val
:=
$(mkIdent $ ``_root_.MvQPF.Fix.drec)
(match · with $(← generateRecBody mapped false)))

Elab.Command.elabCommand casesDef
:= $(mkIdent ``_root_.MvQPF.Fix.ind)
($(mkIdent `p) := motive)
(match ·,· with $(← generateIndBody mapped false)))

let ixs ← (List.range mapped.size).mapM natToTerm
let ixs := ixs.toArray

let toCtorDef : Command ← `(
def $(.str view.shortDeclName "toCtorIdx" |> mkIdent) (t : $rec_type)
: $(mkIdent ``Nat) :=
$(.str view.shortDeclName "cases" |> mkIdent) $ixs* t
)
Elab.Command.elabCommand toCtorDef

let noConfusionType : Command ← `(
abbrev $(.str view.shortDeclName "noConfusionType" |> mkIdent)
(v : Sort _) (a b : $rec_type) :=
$(mkIdent ``noConfusionTypeEnum)
$(.str view.shortDeclName "toCtorIdx" |> mkIdent)
v a b
)
Elab.Command.elabCommand noConfusionType

let noConfusion : Command ← `(
abbrev $(.str view.shortDeclName "noConfusion" |> mkIdent)
{P : Sort _} {a b : $rec_type} :
a = b → $(.str view.shortDeclName "noConfusionType" |> mkIdent) P a b :=
$(mkIdent ``noConfusionEnum) $(.str view.shortDeclName "toCtorIdx" |> mkIdent)
)
Elab.Command.elabCommand noConfusion
let casesTypeDef : Command ← `(
@[elab_as_elim]
def $(.str view.shortDeclName "casesType" |> mkIdent):ident
{ motive : $rec_type → Type}
$casesOnTypes*
: (val : $rec_type) → motive val
:= $(mkIdent ``_root_.MvQPF.Fix.drec)
(match · with $(← generateRecBody mapped false)))

trace[QPF] "Rec definitions:"
trace[QPF] indDef
trace[QPF] recDef
Elab.Command.elabCommand indDef
Elab.Command.elabCommand recDef

trace[QPF] casesDef

trace[QPF] toCtorDef
trace[QPF] noConfusionType
trace[QPF] noConfusion
trace[QPF] casesTypeDef
Elab.Command.elabCommand casesDef
Elab.Command.elabCommand casesTypeDef

pure ()

0 comments on commit 40d350e

Please sign in to comment.