From 27aa3b2090fae093bc1dc54235bd7a51c2eb642f Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 20 Jul 2021 14:45:58 +0900 Subject: [PATCH 1/4] feat(core): Allow `for i in range(len(xs)):` in main functions --- src/Jikka/Common/IOFormat.hs | 39 +++++++++---------- .../RestrictedPython/Convert/ParseMain.hs | 23 ++++++----- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/Jikka/Common/IOFormat.hs b/src/Jikka/Common/IOFormat.hs index 3b2ff159..fb027623 100644 --- a/src/Jikka/Common/IOFormat.hs +++ b/src/Jikka/Common/IOFormat.hs @@ -57,21 +57,16 @@ normalizeFormatTree = \case Seq formats -> formats format -> [format] in Seq (concatMap (unSeq . normalizeFormatTree) formats) - Loop i n body -> Loop i n (normalizeFormatTree body) - -substFormatExpr :: String -> String -> FormatExpr -> FormatExpr -substFormatExpr name value = go - where - go = \case - Var x -> Var (if x == name then value else x) - Plus e k -> Plus (go e) k - At e i -> At (go e) (if i == name then value else i) - Len e -> Len (go e) - -substFormatTree :: String -> String -> FormatTree -> FormatTree -substFormatTree name value = mapFormatTree $ \case - Exp e -> Exp (substFormatExpr name value e) - format -> format + Loop i n body -> case normalizeFormatTree body of + Seq [] -> Seq [] + body -> Loop i n body + +normalizeIOFormat :: IOFormat -> IOFormat +normalizeIOFormat format = + format + { inputTree = normalizeFormatTree (inputTree format), + outputTree = normalizeFormatTree (outputTree format) + } hasNewline :: FormatTree -> Bool hasNewline = \case @@ -95,16 +90,17 @@ formatFormatTree = go text | patt `isPrefixOf` text = subst ++ go (drop (length patt) text) go [] = [] go (c : s) = c : go s - unwords' = replace "\n " "\n" . replace " \n" "\n" . unwords + unwords' = replace "\n\n" "\n" . replace "\n " "\n" . replace " \n" "\n" . unwords in \case Exp e -> formatFormatExpr e - Newline -> "\n" + Newline -> "(newline)\n" Seq formats -> unwords' (map formatFormatTree formats) Loop i n body -> - let first = substFormatTree i "0" body - last = substFormatTree i (formatFormatExpr n) body - dots = if hasNewline body then "...\n" else "..." - in unwords' [formatFormatTree first, dots, formatFormatTree last] + unwords' + [ "for " ++ i ++ " < " ++ formatFormatExpr n ++ " {\n", + formatFormatTree body ++ "\n", + "}" + ] formatIOFormat :: IOFormat -> String formatIOFormat format = @@ -172,6 +168,7 @@ makeReadValueIO toInt fromInt toList fromList format = wrapError' "Jikka.Common. n <- case n of Var n -> toInt =<< lookup n Plus (Var n) k -> (+ k) <$> (toInt =<< lookup n) + Len (Var xs) -> toInteger . V.length <$> (toList =<< lookup xs) _ -> throwInternalError $ "invalid loop size in input tree: " ++ formatFormatExpr n liftIO $ modifyIORef' sizes (M.insert i n) forM_ [0 .. n -1] $ \i' -> do diff --git a/src/Jikka/RestrictedPython/Convert/ParseMain.hs b/src/Jikka/RestrictedPython/Convert/ParseMain.hs index 5b0c69a2..cde61db0 100644 --- a/src/Jikka/RestrictedPython/Convert/ParseMain.hs +++ b/src/Jikka/RestrictedPython/Convert/ParseMain.hs @@ -136,18 +136,22 @@ parseFor go x e body = do n <- case e of CallBuiltin BuiltinRange1 [n] -> return n _ -> throwSemanticErrorAt' (loc' e) $ "for loops in main function must use `range' like `for i in range(n): ...': " ++ formatExpr e - (n, k) <- case value' n of - Name n -> return (n, 0) - BinOp (WithLoc' _ (Name n)) Add (WithLoc' _ (Constant (ConstInt k))) -> return (n, k) - BinOp (WithLoc' _ (Name n)) Sub (WithLoc' _ (Constant (ConstInt k))) -> return (n, - k) - _ -> throwSemanticErrorAt' (loc' n) $ "for loops in main function must use `range(x)', `range(x + k)' or `range(x - k)'" ++ formatExpr n + n <- case value' n of + Name n -> return $ Right (n, 0) + BinOp (WithLoc' _ (Name n)) Add (WithLoc' _ (Constant (ConstInt k))) -> return $ Right (n, k) + BinOp (WithLoc' _ (Name n)) Sub (WithLoc' _ (Constant (ConstInt k))) -> return $ Right (n, - k) + Call (WithLoc' _ (Constant (ConstBuiltin (BuiltinLen _)))) [WithLoc' _ (Name xs)] -> return $ Left xs + _ -> throwSemanticErrorAt' (loc' n) $ "for loops in main function must use `range(x)', `range(x + k)', `range(x - k)', `range(len(xs))`: " ++ formatExpr n + n <- return $ case n of + Right (n, k) -> + let n' = Var (unVarName (value' n)) + in if k == 0 then n' else Plus n' k + Left xs -> Len (Var (unVarName (value' xs))) (input, solve, output) <- go body when (isJust solve) $ do throwSemanticError "cannot call `solve(...)' in for loop" let x' = unVarName (value' x) - let n' = Var (unVarName (value' n)) - let n'' = if k == 0 then n' else Plus n' k - return (Loop x' n'' input, Loop x' n'' output) + return (Loop x' n input, Loop x' n output) parseExprStatement :: MonadError Error m => Expr' -> m FormatTree parseExprStatement e = do @@ -203,5 +207,6 @@ run prog = wrapError' "Jikka.RestrictedPython.Convert.ParseMain" $ do (main, prog) <- return $ splitMain prog main <- forM main $ \main -> do checkMainType main - parseMain main + main <- parseMain main + return $ normalizeIOFormat main return (main, prog) From 6e8c0896abc3f486555448728f79839929b49ec4 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 20 Jul 2021 15:05:13 +0900 Subject: [PATCH 2/4] fix(core): Don't use MakeEager --- src/Jikka/Core/Convert/MakeEager.hs | 81 ------------------------ src/Jikka/Core/Evaluate.hs | 8 ++- test/Jikka/Core/Convert/MakeEagerSpec.hs | 37 ----------- 3 files changed, 6 insertions(+), 120 deletions(-) delete mode 100644 src/Jikka/Core/Convert/MakeEager.hs delete mode 100644 test/Jikka/Core/Convert/MakeEagerSpec.hs diff --git a/src/Jikka/Core/Convert/MakeEager.hs b/src/Jikka/Core/Convert/MakeEager.hs deleted file mode 100644 index f46e8a35..00000000 --- a/src/Jikka/Core/Convert/MakeEager.hs +++ /dev/null @@ -1,81 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} - --- | --- Module : Jikka.Core.Convert.MakeEager --- Description : wraps some exprs with empty lambdas to allow eager evaluation. / 正格評価をするためにいくつかの式を空の lambda で包みます。 --- Copyright : (c) Kimiyuki Onaka, 2020 --- License : Apache License 2.0 --- Maintainer : kimiyuki95@gmail.com --- Stability : experimental --- Portability : portable -module Jikka.Core.Convert.MakeEager - ( run, - ) -where - -import Jikka.Common.Alpha -import Jikka.Common.Error -import Jikka.Core.Language.Beta -import Jikka.Core.Language.BuiltinPatterns -import Jikka.Core.Language.Expr -import Jikka.Core.Language.Lint -import Jikka.Core.Language.Util - -runExpr :: MonadAlpha m => Expr -> m Expr -runExpr = \case - If' t p a b -> case t of - FunTy _ _ -> do - return $ If' t p a b - _ -> do - x <- genVarName' - y <- genVarName' - return $ App (If' (FunTy UnitTy t) p (Lam x UnitTy a) (Lam y UnitTy b)) (Tuple' []) - e -> return e - -runToplevelExpr :: (MonadAlpha m, MonadError Error m) => ToplevelExpr -> m ToplevelExpr -runToplevelExpr = \case - ResultExpr e -> ResultExpr <$> runExpr e - ToplevelLet x t e cont -> ToplevelLet x t <$> runExpr e <*> runToplevelExpr cont - ToplevelLetRec f args ret body cont -> case args of - [] -> do - x <- genVarName' - let g = App (Var f) (Tuple' []) - body <- substitute f g body - cont <- substituteToplevelExpr f g cont - ToplevelLetRec f [(x, UnitTy)] ret <$> runExpr body <*> runToplevelExpr cont - args -> ToplevelLetRec f args ret <$> runExpr body <*> runToplevelExpr cont - -runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program -runProgram = runToplevelExpr - --- | `run` wraps some exprs with lambda redundant things from AST. --- Specifically, this converts @if p then a else b@ to something like @(if p then (lambda. a) else (lambda. b))()@. --- --- For example, this converts: --- --- > let rec fact n = --- > if n == 0 then --- > 1 --- > else --- > n * fact (n - 1) --- > in fact 10 --- --- to: --- --- > let rec fact n = --- > (if n == 0 then --- > fun -> 1 --- > else --- > fun -> n * fact (n - 1) --- > )() --- > in fact 10 -run :: (MonadAlpha m, MonadError Error m) => Program -> m Program -run prog = wrapError' "Jikka.Core.Convert.MakeEager" $ do - precondition $ do - ensureWellTyped prog - prog <- runProgram prog - postcondition $ do - ensureWellTyped prog - ensureEagerlyEvaluatable prog - return prog diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index ea873241..bcf91656 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -29,8 +29,8 @@ import qualified Data.Vector as V import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Common.Matrix -import qualified Jikka.Core.Convert.MakeEager as MakeEager import Jikka.Core.Format (formatBuiltinIsolated) +import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.Lint import Jikka.Core.Language.Runtime @@ -243,6 +243,11 @@ evaluateExpr env = \case Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just val -> return val Lit lit -> literalToValue lit + If' _ p e1 e2 -> do + p <- valueToBool =<< evaluateExpr env p + if p + then evaluateExpr env e1 + else evaluateExpr env e2 e@(App _ _) -> do let (f, args) = curryApp e f <- evaluateExpr env f @@ -279,5 +284,4 @@ callProgram prog args = wrapError' "Jikka.Core.Evaluate" $ do run :: (MonadAlpha m, MonadFix m, MonadError Error m) => Program -> [Value] -> m Value run prog args = do - prog <- MakeEager.run prog callProgram prog args diff --git a/test/Jikka/Core/Convert/MakeEagerSpec.hs b/test/Jikka/Core/Convert/MakeEagerSpec.hs deleted file mode 100644 index b7278507..00000000 --- a/test/Jikka/Core/Convert/MakeEagerSpec.hs +++ /dev/null @@ -1,37 +0,0 @@ -{-# LANGUAGE OverloadedStrings #-} - -module Jikka.Core.Convert.MakeEagerSpec (spec) where - -import Jikka.Common.Alpha -import Jikka.Common.Error -import Jikka.Core.Convert.MakeEager (run) -import Jikka.Core.Language.BuiltinPatterns -import Jikka.Core.Language.Expr -import Test.Hspec - -run' :: Program -> Either Error Program -run' = flip evalAlphaT 0 . run - -spec :: Spec -spec = describe "run" $ do - it "works" $ do - let prog = - ResultExpr - ( If' - IntTy - LitTrue - Lit0 - Lit1 - ) - let expected = - ResultExpr - ( App - ( If' - (FunTy (TupleTy []) IntTy) - LitTrue - (Lam "$0" (TupleTy []) Lit0) - (Lam "$1" (TupleTy []) Lit1) - ) - (Tuple' []) - ) - run' prog `shouldBe` Right expected From 22da850fc549d8dda616f572be11f9871e29c4ce Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 20 Jul 2021 15:12:16 +0900 Subject: [PATCH 3/4] feat: Make `xs.append(x)` available --- docs/language.ja.md | 5 ++++ docs/language.md | 5 ++++ examples/README.md | 2 +- examples/data/point_add_range_sum.sample-1.in | 7 ++++++ .../data/point_add_range_sum.sample-1.out | 4 ++++ examples/{wip => }/point_add_range_sum.py | 16 +++++++++++++ src/Jikka/CPlusPlus/Convert/FromCore.hs | 9 +++++++ src/Jikka/Core/Evaluate.hs | 1 + src/Jikka/Core/Format.hs | 1 + src/Jikka/Core/Language/BuiltinPatterns.hs | 2 ++ src/Jikka/Core/Language/Expr.hs | 2 ++ src/Jikka/Core/Language/TypeCheck.hs | 1 + src/Jikka/Core/Language/Util.hs | 2 ++ src/Jikka/RestrictedPython/Convert/ToCore.hs | 13 +++++++++- .../RestrictedPython/Convert/TypeInfer.hs | 4 +++- src/Jikka/RestrictedPython/Evaluate.hs | 24 ++++++++++++++++--- src/Jikka/RestrictedPython/Format.hs | 1 + .../RestrictedPython/Language/Builtin.hs | 7 +++++- src/Jikka/RestrictedPython/Language/Expr.hs | 7 ++++++ src/Jikka/RestrictedPython/Language/Util.hs | 10 ++++++++ .../Language/VariableAnalysis.hs | 5 ++++ 21 files changed, 121 insertions(+), 7 deletions(-) create mode 100644 examples/data/point_add_range_sum.sample-1.in create mode 100644 examples/data/point_add_range_sum.sample-1.out rename examples/{wip => }/point_add_range_sum.py (50%) diff --git a/docs/language.ja.md b/docs/language.ja.md index 0bac6550..1bd5af52 100644 --- a/docs/language.ja.md +++ b/docs/language.ja.md @@ -185,6 +185,11 @@ Python とほとんど同じです。 - 処理系は `assert` 文を最適化のヒントとして利用できます。 +### expr statements + +式文は副作用のある処理を実行します。 +現在は `xs.append(x)` の形の式文のみが利用可能です。 + ## Semantics (Exprs) diff --git a/docs/language.md b/docs/language.md index 8fd7ec0d..8bf1e1e0 100644 --- a/docs/language.md +++ b/docs/language.md @@ -182,6 +182,11 @@ The augmented assignment `x @= a` statement binds the value `x @ a` to the varia - Implementations can use `assert` statements as hints for optimization. +### expr statements + +Expr statements runs processes with side effects. +Now only `xs.append(x)` is available. + ## Semantics (Exprs) diff --git a/examples/README.md b/examples/README.md index 030a2e1e..396c4535 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,7 +35,7 @@ - CODE FESTIVAL 2015 決勝 [D - 足ゲームII](https://atcoder.jp/contests/code-festival-2015-final-open/tasks/codefestival_2015_final_d) - :hourglass: TLE `m_solutions2019_e.py` - M-SOLUTIONS Programming Contest [E - Product of Arithmetic Progression](https://atcoder.jp/contests/m-solutions2019/tasks/m_solutions2019_e?lang=ja) -- :warning: CE `wip/point_add_range_sum.py` +- :hourglass: TLE `wip/point_add_range_sum.py` - Library Checker [Point Add Range Sum](https://judge.yosupo.jp/problem/point_add_range_sum) - A normal segment tree / 通常のセグメント木 - :warning: CE `wip/dp_w.py` diff --git a/examples/data/point_add_range_sum.sample-1.in b/examples/data/point_add_range_sum.sample-1.in new file mode 100644 index 00000000..291b92d1 --- /dev/null +++ b/examples/data/point_add_range_sum.sample-1.in @@ -0,0 +1,7 @@ +5 5 +1 2 3 4 5 +1 0 5 +1 2 4 +0 3 10 +1 0 5 +1 0 3 diff --git a/examples/data/point_add_range_sum.sample-1.out b/examples/data/point_add_range_sum.sample-1.out new file mode 100644 index 00000000..bcc0016c --- /dev/null +++ b/examples/data/point_add_range_sum.sample-1.out @@ -0,0 +1,4 @@ +15 +7 +25 +6 diff --git a/examples/wip/point_add_range_sum.py b/examples/point_add_range_sum.py similarity index 50% rename from examples/wip/point_add_range_sum.py rename to examples/point_add_range_sum.py index 03c070b8..066dc947 100644 --- a/examples/wip/point_add_range_sum.py +++ b/examples/point_add_range_sum.py @@ -14,3 +14,19 @@ def solve(n: int, q: int, a: List[int], t: List[int], args1: List[int], args2: L r = args2[i] ans.append(sum(a[l:r])) return ans + +def main() -> None: + n, q = map(int, input().split()) + a = list(map(int, input().split())) + assert len(a) == n + t = list(range(q)) + args1 = list(range(q)) + args2 = list(range(q)) + for i in range(q): + t[i], args1[i], args2[i] = map(int, input().split()) + ans = solve(n, q, a, t, args1, args2) + for i in range(len(ans)): + print(ans[i]) + +if __name__ == '__main__': + main() diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index 8d4c1463..3edaebe0 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -222,6 +222,15 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) + X.Snoc t -> go2' $ \xs x -> do + t <- runType t + ys <- Y.newFreshName Y.LocalNameKind + return + ( [ Y.Declare (Y.TyVector t) ys (Just xs), + Y.callMethod' (Y.Var ys) "push_back" [x] + ], + Y.Var ys + ) X.Foldl t1 t2 -> go3'' $ \f init xs -> do (stmtsInit, init) <- runExpr env init (stmtsXs, xs) <- runExpr env xs diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index bcf91656..3057ebd9 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -181,6 +181,7 @@ callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltin ModMatPow _ -> go3' pure valueToInt valueToInt valueFromModMatrix $ \f k m -> join (matpow' <$> valueToModMatrix m f <*> pure k) -- list functions Cons _ -> go2 pure valueToList ValList V.cons + Snoc _ -> go2 valueToList pure ValList V.snoc Foldl _ _ -> go3' pure pure valueToList id $ \f x a -> V.foldM (\x y -> callValue f [x, y]) x a Scanl _ _ -> go3' pure pure valueToList ValList $ \f x a -> scanM (\x y -> callValue f [x, y]) x a Len _ -> go1 valueToList ValInt (fromIntegral . V.length) diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index 4009f1e9..ecf4dde6 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -110,6 +110,7 @@ analyzeBuiltin = \case ModMatPow _ -> fun "modmatpow" -- list functions Cons t -> Fun [t] "cons" + Snoc t -> Fun [t] "snoc" Foldl t1 t2 -> Fun [t1, t2] "foldl" Scanl t1 t2 -> Fun [t1, t2] "scanl" Iterate t -> Fun [t] "iterate" diff --git a/src/Jikka/Core/Language/BuiltinPatterns.hs b/src/Jikka/Core/Language/BuiltinPatterns.hs index 47d38c60..5dfaabaa 100644 --- a/src/Jikka/Core/Language/BuiltinPatterns.hs +++ b/src/Jikka/Core/Language/BuiltinPatterns.hs @@ -111,6 +111,8 @@ pattern Nil' t = Lit (LitNil t) pattern Cons' t e1 e2 = AppBuiltin2 (Cons t) e1 e2 +pattern Snoc' t e1 e2 = AppBuiltin2 (Snoc t) e1 e2 + pattern Foldl' t1 t2 e1 e2 e3 = AppBuiltin3 (Foldl t1 t2) e1 e2 e3 pattern Scanl' t1 t2 e1 e2 e3 = AppBuiltin3 (Scanl t1 t2) e1 e2 e3 diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index fe009dfa..983e0634 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -159,6 +159,8 @@ data Builtin -- | \(: \forall \alpha. \alpha \to \list(\alpha) \to \list(\alpha)\) Cons Type + | -- | \(: \forall \alpha. \list(alpha) \to \alpha \to \list(\alpha)\) + Snoc Type | -- | \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \beta\) Foldl Type Type | -- | \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \list(\beta)\) diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index 696dc470..6824f130 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -70,6 +70,7 @@ builtinToType = \case ModMatPow n -> Fun3Ty (matrixTy n n) IntTy IntTy (matrixTy n n) -- list functions Cons t -> Fun2Ty t (ListTy t) (ListTy t) + Snoc t -> Fun2Ty (ListTy t) t (ListTy t) Foldl t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) t2 Scanl t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) (ListTy t2) Len t -> FunTy (ListTy t) IntTy diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index 8a25a1e7..84a64418 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -75,6 +75,7 @@ mapTypeInBuiltin f = \case ModMatPow n -> ModMatPow n -- list functionslist Cons t -> Cons (f t) + Snoc t -> Snoc (f t) Foldl t1 t2 -> Foldl (f t1) (f t2) Scanl t1 t2 -> Scanl (f t1) (f t2) Len t -> Len (f t) @@ -272,6 +273,7 @@ isConstantTimeBuiltin = \case ModMatPow _ -> True -- list functions Cons _ -> False + Snoc _ -> False Foldl _ _ -> False Scanl _ _ -> False Len _ -> True diff --git a/src/Jikka/RestrictedPython/Convert/ToCore.hs b/src/Jikka/RestrictedPython/Convert/ToCore.hs index 5ea7a86a..8fc25c4b 100644 --- a/src/Jikka/RestrictedPython/Convert/ToCore.hs +++ b/src/Jikka/RestrictedPython/Convert/ToCore.hs @@ -58,6 +58,7 @@ runType = \case X.TupleTy ts -> Y.TupleTy <$> mapM runType ts X.CallableTy args ret -> Y.curryFunTy <$> mapM runType args <*> runType ret X.StringTy -> throwSemanticError "cannot use `str' type out of main function" + X.SideEffectTy -> throwSemanticError "side-effect type must be used only as expr-statement" -- TODO: check in Jikka.RestrictedPython.Language.Lint runConstant :: MonadError Error m => X.Constant -> m Y.Expr runConstant = \case @@ -164,6 +165,7 @@ runAttribute a = wrapAt' (loc' a) $ do X.BuiltinCopy t -> do t <- runType t return $ Y.Lam "x" t (Y.Var "x") + X.BuiltinAppend _ -> throwSemanticError "cannot use `append' out of expr-statements" X.BuiltinSplit -> throwSemanticError "cannot use `split' out of main function" runBoolOp :: X.BoolOp -> Y.Builtin @@ -377,7 +379,16 @@ runStatements (stmt : stmts) cont = case stmt of X.For x iter body -> runForStatement x iter body stmts cont X.If e body1 body2 -> runIfStatement e body1 body2 stmts cont X.Assert _ -> runStatements stmts cont - X.Expr' _ -> runStatements stmts cont + X.Append loc t x e -> do + case X.exprToTarget x of + Nothing -> throwSemanticErrorAt' loc "invalid `append` method" + Just x -> do + t <- runType t + y <- runTargetExpr x + e <- runExpr e + runAssign x (Y.Snoc' t y e) $ do + runStatements stmts cont + X.Expr' e -> throwSemanticErrorAt' (loc' e) "invalid expr-statement" runToplevelStatements :: (MonadState Env m, MonadAlpha m, MonadError Error m) => [X.ToplevelStatement] -> m Y.ToplevelExpr runToplevelStatements [] = return $ Y.ResultExpr (Y.Var "solve") diff --git a/src/Jikka/RestrictedPython/Convert/TypeInfer.hs b/src/Jikka/RestrictedPython/Convert/TypeInfer.hs index 40d2fcdb..d66f8b32 100644 --- a/src/Jikka/RestrictedPython/Convert/TypeInfer.hs +++ b/src/Jikka/RestrictedPython/Convert/TypeInfer.hs @@ -174,7 +174,7 @@ formularizeStatement ret = \case Assert e -> do formularizeExpr' e BoolTy Expr' e -> do - formularizeExpr' e NoneTy + formularizeExpr' e SideEffectTy formularizeToplevelStatement :: (MonadWriter Eqns m, MonadAlpha m) => ToplevelStatement -> m () formularizeToplevelStatement = \case @@ -222,6 +222,7 @@ subst sigma = \case TupleTy ts -> TupleTy (map (subst sigma) ts) CallableTy ts ret -> CallableTy (map (subst sigma) ts) (subst sigma ret) StringTy -> StringTy + SideEffectTy -> SideEffectTy unifyTyVar :: (MonadState Subst m, MonadError Error m) => TypeName -> Type -> m () unifyTyVar x t = @@ -275,6 +276,7 @@ substUnit = \case TupleTy ts -> TupleTy (map substUnit ts) CallableTy ts ret -> CallableTy (map substUnit ts) (substUnit ret) StringTy -> StringTy + SideEffectTy -> SideEffectTy -- | `subst'` does `subst` and replaces all undetermined type variables with the unit type. subst' :: Subst -> Type -> Type diff --git a/src/Jikka/RestrictedPython/Evaluate.hs b/src/Jikka/RestrictedPython/Evaluate.hs index 418227bf..070e6a51 100644 --- a/src/Jikka/RestrictedPython/Evaluate.hs +++ b/src/Jikka/RestrictedPython/Evaluate.hs @@ -34,6 +34,7 @@ import Jikka.Common.Error import Jikka.RestrictedPython.Format (formatAttribute, formatBuiltin, formatOperator) import Jikka.RestrictedPython.Language.Expr import Jikka.RestrictedPython.Language.Lint +import Jikka.RestrictedPython.Language.Util import Jikka.RestrictedPython.Language.Value assign :: MonadState Local m => VarName -> Value -> m () @@ -416,6 +417,16 @@ evalCall' f actualArgs = case f of -- \mathbf{assert}~ e \mid \mu \Downarrow \mathbf{err} -- } -- \] +-- +-- === Rules for \(e\) +-- +-- \[ +-- \cfrac{ +-- e \mid \mu \Downarrow v +-- }{ +-- x.\mathrm{append}(e) \mid \mu \Downarrow \mathbf{stop} \mid (x \mapsto \mathrm{snoc}(\mu(x), v); \mu) +-- } +-- \] evalStatement :: (MonadReader Global m, MonadState Local m, MonadError Error m) => Statement -> m (Maybe Value) evalStatement = \case Return e -> do @@ -455,9 +466,15 @@ evalStatement = \case when (v == BoolVal False) $ do throwRuntimeError "assertion failure" return Nothing - Expr' e -> do - _ <- evalExpr e - return Nothing + Append loc _ x e -> case exprToTarget x of + Nothing -> throwSemanticErrorAt' loc "wrong `append` method" + Just x -> do + v1 <- evalTarget x + v2 <- evalExpr e + v <- ListVal <$> (V.snoc <$> toList v1 <*> pure v2) + assignTarget x v + return Nothing + Expr' e -> throwSemanticErrorAt' (loc' e) "wrong expr-statement" -- | `evalStatements` evaluates sequences of statements of our restricted Python-like language. -- @@ -659,4 +676,5 @@ evalAttribute v0 a args = wrapError' ("calling " ++ formatAttribute a) $ do Nothing -> throwRuntimeError $ "not in list: " ++ formatValue x Just i -> return (toInteger i) BuiltinCopy _ -> go0 toList ListVal id + BuiltinAppend _ -> throwSemanticError "cannot use `append' out of expr-statements" BuiltinSplit -> throwSemanticError "cannot use `split' out of main function" diff --git a/src/Jikka/RestrictedPython/Format.hs b/src/Jikka/RestrictedPython/Format.hs index fd782fbd..12efea0e 100644 --- a/src/Jikka/RestrictedPython/Format.hs +++ b/src/Jikka/RestrictedPython/Format.hs @@ -38,6 +38,7 @@ formatType t = case t of TupleTy ts -> "Tuple[" ++ intercalate ", " (map formatType ts) ++ "]" CallableTy ts ret -> "Callable[[" ++ intercalate ", " (map formatType ts) ++ "], " ++ formatType ret ++ "]" StringTy -> "str" + SideEffectTy -> "side-effect" formatConstant :: Constant -> String formatConstant = \case diff --git a/src/Jikka/RestrictedPython/Language/Builtin.hs b/src/Jikka/RestrictedPython/Language/Builtin.hs index 262a363a..5f87ba83 100644 --- a/src/Jikka/RestrictedPython/Language/Builtin.hs +++ b/src/Jikka/RestrictedPython/Language/Builtin.hs @@ -214,7 +214,7 @@ typeBuiltin = \case BuiltinTuple ts -> CallableTy [TupleTy ts] (TupleTy ts) BuiltinZip ts -> CallableTy (map ListTy ts) (TupleTy ts) BuiltinInput -> CallableTy [] StringTy - BuiltinPrint ts -> CallableTy ts NoneTy + BuiltinPrint ts -> CallableTy ts SideEffectTy mapTypeBuiltin :: (Type -> Type) -> Builtin -> Builtin mapTypeBuiltin f = \case @@ -266,6 +266,7 @@ attributeNames = [ "count", "index", "copy", + "append", "split" ] @@ -279,6 +280,7 @@ resolveAttribute' x = wrapAt' (loc' x) $ case value' x of "count" -> BuiltinCount <$> genType "index" -> BuiltinIndex <$> genType "copy" -> BuiltinCopy <$> genType + "append" -> BuiltinAppend <$> genType "split" -> return BuiltinSplit _ -> throwInternalError $ "not exhaustive: " ++ unAttributeName x' _ -> return x @@ -305,6 +307,7 @@ formatAttribute = \case BuiltinCount _ -> "count" BuiltinIndex _ -> "index" BuiltinCopy _ -> "copy" + BuiltinAppend _ -> "append" BuiltinSplit -> "split" typeAttribute :: Attribute -> (Type, Type) @@ -313,6 +316,7 @@ typeAttribute = \case BuiltinCount t -> (ListTy t, CallableTy [t] IntTy) BuiltinIndex t -> (ListTy t, CallableTy [t] IntTy) BuiltinCopy t -> (ListTy t, CallableTy [] (ListTy t)) + BuiltinAppend t -> (ListTy t, CallableTy [t] SideEffectTy) BuiltinSplit -> (StringTy, CallableTy [] (ListTy StringTy)) mapTypeAttribute :: (Type -> Type) -> Attribute -> Attribute @@ -321,4 +325,5 @@ mapTypeAttribute f = \case BuiltinCount t -> BuiltinCount (f t) BuiltinIndex t -> BuiltinIndex (f t) BuiltinCopy t -> BuiltinCopy (f t) + BuiltinAppend t -> BuiltinAppend (f t) BuiltinSplit -> BuiltinSplit diff --git a/src/Jikka/RestrictedPython/Language/Expr.hs b/src/Jikka/RestrictedPython/Language/Expr.hs index c164030f..7e6123fb 100644 --- a/src/Jikka/RestrictedPython/Language/Expr.hs +++ b/src/Jikka/RestrictedPython/Language/Expr.hs @@ -42,6 +42,7 @@ module Jikka.RestrictedPython.Language.Expr Target (..), Target', Statement (..), + pattern Append, ToplevelStatement (..), Program, ) @@ -83,6 +84,7 @@ unAttributeName (AttributeName x) = x -- \vert & \tau \times \tau \times \dots \times \tau \\ -- \vert & \tau \times \tau \times \dots \times \tau \to \tau -- \vert & \string +-- \vert & \mathbf{side-effect} -- \end{array} -- \] -- @@ -95,6 +97,7 @@ data Type | TupleTy [Type] | CallableTy [Type] Type | StringTy + | SideEffectTy deriving (Eq, Ord, Show, Read) pattern NoneTy = TupleTy [] @@ -199,6 +202,8 @@ data Attribute BuiltinIndex Type | -- | "list.copy" \(: \forall \alpha. \list(\alpha) \to \epsilon \to \list(\alpha)\) BuiltinCopy Type + | -- | "list.append" \(: \forall \alpha. \list(\alpha) \to \alpha \to \mathbf{side-effect}\) + BuiltinAppend Type | -- | "str.split" \(: \forall \alpha. \string \to \epsilon \to \list(\string)\) BuiltinSplit deriving (Eq, Ord, Show, Read) @@ -293,6 +298,8 @@ data Statement Expr' Expr' deriving (Eq, Ord, Show, Read) +pattern Append loc t e1 e2 <- Expr' (WithLoc' loc (Call (WithLoc' _ (Attribute e1 (WithLoc' _ (BuiltinAppend t)))) [e2])) + -- | `TopLevelStatement` represents the statements of our restricted Python-like language. -- They appear in the toplevel of programs. -- diff --git a/src/Jikka/RestrictedPython/Language/Util.hs b/src/Jikka/RestrictedPython/Language/Util.hs index 47d5336d..1ea1de33 100644 --- a/src/Jikka/RestrictedPython/Language/Util.hs +++ b/src/Jikka/RestrictedPython/Language/Util.hs @@ -46,6 +46,7 @@ module Jikka.RestrictedPython.Language.Util targetVars', hasSubscriptTrg, hasBareNameTrg, + exprToTarget, -- * programs toplevelMainDef, @@ -84,6 +85,7 @@ freeTyVars = nub . go TupleTy ts -> concat $ mapM go ts CallableTy ts ret -> concat $ mapM go (ret : ts) StringTy -> [] + SideEffectTy -> [] -- | `freeVars'` reports all free variables. freeVars :: Expr' -> [VarName] @@ -322,5 +324,13 @@ hasBareNameTrg (WithLoc' _ x) = case x of NameTrg _ -> True TupleTrg xs -> any hasSubscriptTrg xs +exprToTarget :: Expr' -> Maybe Target' +exprToTarget e = + WithLoc' (loc' e) <$> case value' e of + Name x -> Just $ NameTrg x + Tuple es -> TupleTrg <$> mapM exprToTarget es + Subscript e1 e2 -> SubscriptTrg <$> exprToTarget e1 <*> pure e2 + _ -> Nothing + toplevelMainDef :: [Statement] -> Program toplevelMainDef body = [ToplevelFunctionDef (WithLoc' Nothing (VarName "main")) [] IntTy body] diff --git a/src/Jikka/RestrictedPython/Language/VariableAnalysis.hs b/src/Jikka/RestrictedPython/Language/VariableAnalysis.hs index eb4093f6..9959c6cf 100644 --- a/src/Jikka/RestrictedPython/Language/VariableAnalysis.hs +++ b/src/Jikka/RestrictedPython/Language/VariableAnalysis.hs @@ -52,6 +52,11 @@ analyzeStatementGeneric isMax = \case then (ReadList (nub $ r ++ r1 ++ r2), WriteList (nub $ w1 ++ w2)) else (ReadList (nub $ r ++ intersect r1 r2), WriteList (nub $ w1 `intersect` w2)) Assert e -> (analyzeExpr e, WriteList []) + Append _ _ x e -> + let w = maybe (WriteList []) analyzeTargetWrite (exprToTarget x) + (ReadList r) = maybe (ReadList []) analyzeTargetRead (exprToTarget x) + (ReadList r') = analyzeExpr e + in (ReadList (nub $ r ++ r'), w) Expr' e -> (analyzeExpr e, WriteList []) analyzeStatementsGeneric :: Bool -> [Statement] -> (ReadList, WriteList) From 9fdf5c8254d28ebb7c701baefb479b870c1533fe Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 20 Jul 2021 15:17:35 +0900 Subject: [PATCH 4/4] test: Update examples/wip/dp_w.py --- examples/data/dp_w.sample-1.in | 4 ++++ examples/data/dp_w.sample-1.out | 1 + examples/data/dp_w.sample-2.in | 5 +++++ examples/data/dp_w.sample-2.out | 1 + examples/data/dp_w.sample-3.in | 2 ++ examples/data/dp_w.sample-3.out | 1 + examples/data/dp_w.sample-4.in | 6 ++++++ examples/data/dp_w.sample-4.out | 1 + examples/data/dp_w.sample-5.in | 9 +++++++++ examples/data/dp_w.sample-5.out | 1 + examples/wip/dp_w.py | 18 ++++++++++++++++-- 11 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 examples/data/dp_w.sample-1.in create mode 100644 examples/data/dp_w.sample-1.out create mode 100644 examples/data/dp_w.sample-2.in create mode 100644 examples/data/dp_w.sample-2.out create mode 100644 examples/data/dp_w.sample-3.in create mode 100644 examples/data/dp_w.sample-3.out create mode 100644 examples/data/dp_w.sample-4.in create mode 100644 examples/data/dp_w.sample-4.out create mode 100644 examples/data/dp_w.sample-5.in create mode 100644 examples/data/dp_w.sample-5.out diff --git a/examples/data/dp_w.sample-1.in b/examples/data/dp_w.sample-1.in new file mode 100644 index 00000000..aa7c6279 --- /dev/null +++ b/examples/data/dp_w.sample-1.in @@ -0,0 +1,4 @@ +5 3 +1 3 10 +2 4 -10 +3 5 10 diff --git a/examples/data/dp_w.sample-1.out b/examples/data/dp_w.sample-1.out new file mode 100644 index 00000000..209e3ef4 --- /dev/null +++ b/examples/data/dp_w.sample-1.out @@ -0,0 +1 @@ +20 diff --git a/examples/data/dp_w.sample-2.in b/examples/data/dp_w.sample-2.in new file mode 100644 index 00000000..fa93223a --- /dev/null +++ b/examples/data/dp_w.sample-2.in @@ -0,0 +1,5 @@ +3 4 +1 3 100 +1 1 -10 +2 2 -20 +3 3 -30 diff --git a/examples/data/dp_w.sample-2.out b/examples/data/dp_w.sample-2.out new file mode 100644 index 00000000..d61f00d8 --- /dev/null +++ b/examples/data/dp_w.sample-2.out @@ -0,0 +1 @@ +90 diff --git a/examples/data/dp_w.sample-3.in b/examples/data/dp_w.sample-3.in new file mode 100644 index 00000000..3f1e1928 --- /dev/null +++ b/examples/data/dp_w.sample-3.in @@ -0,0 +1,2 @@ +1 1 +1 1 -10 diff --git a/examples/data/dp_w.sample-3.out b/examples/data/dp_w.sample-3.out new file mode 100644 index 00000000..573541ac --- /dev/null +++ b/examples/data/dp_w.sample-3.out @@ -0,0 +1 @@ +0 diff --git a/examples/data/dp_w.sample-4.in b/examples/data/dp_w.sample-4.in new file mode 100644 index 00000000..b93b8f39 --- /dev/null +++ b/examples/data/dp_w.sample-4.in @@ -0,0 +1,6 @@ +1 5 +1 1 1000000000 +1 1 1000000000 +1 1 1000000000 +1 1 1000000000 +1 1 1000000000 diff --git a/examples/data/dp_w.sample-4.out b/examples/data/dp_w.sample-4.out new file mode 100644 index 00000000..d007f6b2 --- /dev/null +++ b/examples/data/dp_w.sample-4.out @@ -0,0 +1 @@ +5000000000 diff --git a/examples/data/dp_w.sample-5.in b/examples/data/dp_w.sample-5.in new file mode 100644 index 00000000..bf0293c9 --- /dev/null +++ b/examples/data/dp_w.sample-5.in @@ -0,0 +1,9 @@ +6 8 +5 5 3 +1 1 10 +1 6 -8 +3 6 5 +3 4 9 +5 5 -2 +1 3 -6 +4 6 -7 diff --git a/examples/data/dp_w.sample-5.out b/examples/data/dp_w.sample-5.out new file mode 100644 index 00000000..f599e28b --- /dev/null +++ b/examples/data/dp_w.sample-5.out @@ -0,0 +1 @@ +10 diff --git a/examples/wip/dp_w.py b/examples/wip/dp_w.py index 97e1b801..2eafe299 100644 --- a/examples/wip/dp_w.py +++ b/examples/wip/dp_w.py @@ -1,7 +1,8 @@ # https://atcoder.jp/contests/dp/tasks/dp_w -import math from typing import * +INF = 10 ** 18 + def solve(n: int, m: int, l: List[int], r: List[int], a: List[int]) -> int: assert 1 <= n <= 2 * 10 ** 5 assert 1 <= m <= 2 * 10 ** 5 @@ -11,7 +12,7 @@ def solve(n: int, m: int, l: List[int], r: List[int], a: List[int]) -> int: assert all(1 <= l[i] <= r[i] <= n for i in range(m)) assert all(abs(a_i) <= 10 ** 9 for a_i in a) - dp = [-math.inf for _ in range(n + 1)] + dp = [-INF for _ in range(n + 1)] dp[0] = 0 for x in range(1, n + 1): for y in range(x): @@ -21,3 +22,16 @@ def solve(n: int, m: int, l: List[int], r: List[int], a: List[int]) -> int: b += a[i] dp[x] = min(dp[x], b) return dp[n] + +def main() -> None: + n, m = map(int, input().split()) + l = list(range(m)) + r = list(range(m)) + a = list(range(m)) + for i in range(m): + l[i], r[i], a[i] = map(int, input().split()) + ans = solve(n, m, l, r, a) + print(ans) + +if __name__ == '__main__': + main()