From 6df3613e55d0570e75fd5927136942712645ddef Mon Sep 17 00:00:00 2001 From: Nanguage Date: Tue, 13 Sep 2022 12:10:12 +0800 Subject: [PATCH] fix bug --- oneface/arg.py | 10 +++++++--- tests/test_types.py | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/oneface/arg.py b/oneface/arg.py index 4766ba5..c8bc902 100644 --- a/oneface/arg.py +++ b/oneface/arg.py @@ -41,7 +41,7 @@ def check(self, val): if not self.type_checker(val, self.type): raise TypeError( f"Input value {val} is not in valid type({self.type})") - if (self.range is not None) and (self.range_checker is not None): + if (self.range_checker is not None): if (not self.range_checker(val, self.range)): raise ValueError(f"Input value {val} is not in a valid range.") @@ -56,11 +56,15 @@ def register_type_check(cls, type, checker=None): # register basic types +def _check_number_in_range(v, range): + return (range is None) or (range[0] <= v <= range[1]) + + Arg.register_type_check(Empty, lambda v, range: True) Arg.register_type_check(str) -Arg.register_range_check(int, lambda v, range: range[0] <= v <= range[1]) +Arg.register_range_check(int, _check_number_in_range) Arg.register_type_check(int) -Arg.register_range_check(float, lambda v, range: range[0] <= v <= range[1]) +Arg.register_range_check(float, _check_number_in_range) Arg.register_type_check(float) Arg.register_type_check(bool) diff --git a/tests/test_types.py b/tests/test_types.py index 91cdec9..f5db24c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,7 +5,7 @@ def test_selection(): - @one + @one(print_args=False) def func1(a: Arg(Selection, ["a", 2, 3])): print(a) return a @@ -18,7 +18,7 @@ def func1(a: Arg(Selection, ["a", 2, 3])): def test_subset(): - @one + @one(print_args=False) def func1(s: Arg(SubSet, [1,2,3])): print(s) return s @@ -31,7 +31,7 @@ def func1(s: Arg(SubSet, [1,2,3])): def test_inputpath(): - @one + @one(print_args=False) def func1(s: Arg(InputPath)): print(s) return s @@ -46,7 +46,7 @@ def func1(s: Arg(InputPath)): def test_outputpath(): - @one + @one(print_args=False) def func1(s: Arg(OutputPath)): print(s) return s