Skip to content

Commit

Permalink
Merge pull request #592 from HigherOrderCO/pat-names-nat
Browse files Browse the repository at this point in the history
Ignore redundant match cases in type checker; Add Nat pattern matching syntax; Fix var names when flattening variables that match constructors
  • Loading branch information
VictorTaelin authored Oct 21, 2024
2 parents 5085285 + a3b098f commit c02879e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 43 deletions.
35 changes: 18 additions & 17 deletions src/Kind/Check.hs
Original file line number Diff line number Diff line change
Expand Up @@ -197,23 +197,12 @@ check src val typ dep = debug ("check: " ++ termShower False val dep ++ "\n :
(All typ_nam typ_inp typ_bod) -> do
case reduce book fill 2 typ_inp of
(Dat adt_scp adt_cts adt_typ) -> do
let adt_cts_map = M.fromList (map (\ (Ctr cnm tele) -> (cnm, tele)) adt_cts)
-- Check if all cases are present
let hasDefaultCase = any (\(cnm, _) -> cnm == "_") cse
unless hasDefaultCase $ do
let presentCases = M.fromList cse
forM_ adt_cts $ \ (Ctr cnm _) -> do
unless (M.member cnm presentCases) $ do
envLog (Error src (Hol ("missing_case:" ++ cnm) []) (Hol "incomplete_match" []) (Mat cse) dep)
envFail
-- If there is a default case, check that it is well-typed
when hasDefaultCase $ do
let defaultCase = snd $ head $ filter (\(cnm, _) -> cnm == "_") cse
check Nothing defaultCase (All "" typ_inp typ_bod) dep
-- Check if all concrete cases are well-typed
forM_ cse $ \ (cnm, cbod) -> do
when (cnm /= "_") $ case M.lookup cnm adt_cts_map of
Just tele -> do
-- Check every expected case of the match
-- Skips redundant cases
let presentCases = M.fromList $ reverse cse
forM_ adt_cts $ \ (Ctr cnm tele) -> do
case M.lookup cnm presentCases of
Just cbod -> do
let a_r = teleToTerm tele dep
let eqs = extractEqualities (reduce book fill 2 typ_inp) (reduce book fill 2 (snd a_r)) dep
let rt0 = teleToType tele (typ_bod (Ann False (Con cnm (fst a_r)) typ_inp)) dep
Expand All @@ -222,6 +211,18 @@ check src val typ dep = debug ("check: " ++ termShower False val dep ++ "\n :
unreachable Nothing cbod dep
else
check Nothing cbod rt1 dep
Nothing -> case M.lookup "_" presentCases of
Just defaultCase -> do
check Nothing defaultCase (All "" typ_inp typ_bod) dep
Nothing -> do
envLog (Error src (Hol ("missing_case:" ++ cnm) []) (Hol "incomplete_match" []) (Mat cse) dep)
envFail

-- Check if all cases refer to an expected constructor
let adt_cts_map = M.fromList (map (\ (Ctr cnm tele) -> (cnm, tele)) adt_cts)
forM_ cse $ \ (cnm, cbod) -> do
when (cnm /= "_") $ case M.lookup cnm adt_cts_map of
Just _ -> return ()
Nothing -> do
envLog (Error src (Hol ("constructor_not_found:"++cnm) []) (Hol "unknown_type" []) (Mat cse) dep)
envFail
Expand Down
80 changes: 54 additions & 26 deletions src/Kind/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ data Name (p0: P0) (p1: P1) ... : (i0: I0) (i1: I1) -> (Name p0 p1 ... i0 i1 ...
this should desugar to:
Name
: ∀(p0: P0) ∀(p1: P1) ... ∀(i0: I0) ∀(i1: I1) : (Name p0 p1 ... i0 i1 ...)
: ∀(p0: P0) ∀(p1: P1) ... ∀(i0: I0) ∀(i1: I1) : (Name p0 p1 ... i0 i1 ...)
= data[i0 i1] {
#Ctr0 { x0: T0 x1: T1 ... } : (Name p0 p1 ... i0 i1 ...)
#Ctr1 { x0: T0 x1: T1 ... } : (Name p0 p1 ... i0 i1 ...)
Expand All @@ -563,7 +563,7 @@ parseDefADT = do
ptype <- parseTerm
char_skp ')'
return (pname, ptype)
indices <- P.choice
indices <- P.choice
[ do
P.try $ char_skp '~'
P.many $ do
Expand Down Expand Up @@ -595,7 +595,7 @@ parseDefADT = do
fillTeleRet ret (TRet (Met _ _)) = TRet ret
fillTeleRet _ (TRet ret) = TRet ret
fillTeleRet ret (TExt nm tm bod) = TExt nm tm (\x -> fillTeleRet ret (bod x))

parseDefFun :: Parser (String, Term)
parseDefFun = do
numb <- P.optionMaybe $ char_skp '#'
Expand Down Expand Up @@ -642,19 +642,35 @@ parseDef = P.choice
parsePattern :: Parser Pattern
parsePattern = do
P.choice [
do
name <- name_skp
return (PVar name),
do
char_skp '#'
name <- name_skp
args <- P.option [] $ P.try $ do
char_skp '{'
args <- P.many parsePattern
char_skp '}'
return args
return (PCtr name args)
]
parsePatternNat,
parsePatternCtr,
parsePatternVar
] <* skip

parsePatternNat :: Parser Pattern
parsePatternNat = do
num <- P.try $ do
char_skp '#'
P.many1 digit
let n = read num
return $ (foldr (\_ acc -> PCtr "Succ" [acc]) (PCtr "Zero" []) [1..n])

parsePatternCtr :: Parser Pattern
parsePatternCtr = do
name <- P.try $ do
char_skp '#'
name_skp
args <- P.option [] $ P.try $ do
char_skp '{'
args <- P.many parsePattern
char_skp '}'
return args
return $ (PCtr name args)

parsePatternVar :: Parser Pattern
parsePatternVar = do
name <- P.try $ name_skp
return $ (PVar name)

parseUses :: Parser Uses
parseUses = P.many $ P.try $ do
Expand Down Expand Up @@ -799,7 +815,7 @@ parseDoFor monad = do
-- If-Then-Else
-- ------------

-- if cond { t } else { f }
-- if cond { t } else { f }
-- --------------------------------- desugars to
-- match cond { #True: t #False: f }

Expand Down Expand Up @@ -926,7 +942,7 @@ flattenVarCol col mat bods fresh depth =

-- Flattens a column with constructors and possibly variables
flattenAdtCol :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Int -> Term
flattenAdtCol col mat bods fresh depth =
flattenAdtCol col mat bods fresh depth =
-- trace (replicate (depth * 2) ' ' ++ "flattenAdtCol: col = " ++ show col ++ ", fresh = " ++ show fresh) $
let nam = maybe ("%f" ++ show fresh) id (getColName col)
ctr = map (makeCtrCase col mat bods (fresh+1) nam depth) (getColCtrs col)
Expand All @@ -935,18 +951,24 @@ flattenAdtCol col mat bods fresh depth =

-- Creates a constructor case: '#Name: body'
makeCtrCase :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> String -> Int -> String -> (String, Term)
makeCtrCase col mat bods fresh var depth ctr =
makeCtrCase col mat bods fresh var depth ctr =
-- trace (replicate (depth * 2) ' ' ++ "makeCtrCase: col = " ++ show col ++ ", mat = " ++ show mat ++ ", bods = " ++ show (map termShow bods) ++ ", fresh = " ++ show fresh ++ ", var = " ++ var ++ ", ctr = " ++ ctr) $
let (mat', bods') = foldr go ([], []) (zip3 col mat bods)
bod = flattenDef mat' bods' fresh (depth + 1)
in (ctr, bod)
where go ((PCtr nam ps), pats, bod) (mat, bods)
| nam == ctr = ((ps ++ pats):mat, bod:bods)
| otherwise = (mat, bods)
go ((PVar "_"), pats, bod) (mat, bods) =
let ari = getCtrArity col ctr
pat = [PVar "_" | _ <- [0..ari-1]]
in ((pat ++ pats):mat, bod:bods)
go ((PVar nam), pats, bod) (mat, bods) =
let ar = getArity $ fromJust $ find (\case PCtr c _ | c == ctr -> True ; _ -> False) col
ps = [PVar (nam++"."++show i) | i <- [0..ar-1]]
in ((ps ++ pats) : mat, bod:bods)
let ari = getCtrArity col ctr
var = [nam++"."++show i | i <- [0..ari-1]]
pat = map PVar var
bo2 = Let nam (foldl (\f a -> App f (Ref a)) (Ref ctr) var) (\x -> bod)
in ((pat ++ pats):mat, bo2:bods)

-- Creates a default case: '#_: body'
makeDflCase :: [Pattern] -> [[Pattern]] -> [Term] -> Int -> Int -> [(String, Term)]
Expand All @@ -963,9 +985,12 @@ isVar :: Pattern -> Bool
isVar (PVar _) = True
isVar _ = False

getArity :: Pattern -> Int
getArity (PCtr _ pats) = length pats
getArity _ = 0
getCtrArity :: [Pattern] -> String -> Int
getCtrArity ((PCtr nam ps):pats) ctr
| nam == ctr = length ps
| otherwise = getCtrArity pats ctr
getCtrArity (_:pats) ctr = getCtrArity pats ctr
getCtrArity [] _ = 0

getCol :: [[Pattern]] -> ([Pattern], [[Pattern]])
getCol (pats:mat) = unzip (catMaybes (map uncons (pats:mat)))
Expand All @@ -974,4 +999,7 @@ getColCtrs :: [Pattern] -> [String]
getColCtrs col = toList . fromList $ foldr (\pat acc -> case pat of (PCtr nam _) -> nam:acc ; _ -> acc) [] col

getColName :: [Pattern] -> Maybe String
getColName col = foldr (A.<|>) Nothing $ map (\case PVar nam -> Just nam; _ -> Nothing) col
getColName col = foldr (A.<|>) Nothing $ map go col
where go (PVar "_") = Nothing
go (PVar nam) = Just nam
go _ = Nothing

0 comments on commit c02879e

Please sign in to comment.