diff --git a/dargs/dargs.py b/dargs/dargs.py index 86eb8ee..482b512 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -16,6 +16,7 @@ We also need to pay special attention to flat the keys of its choices. """ +import difflib import fnmatch import json import re @@ -800,7 +801,12 @@ def get_choice(self, argdict: dict, path=None) -> "Argument": return self.choice_dict[self.choice_alias[tag]] else: raise ArgumentValueError( - path, f"get invalid choice `{tag}` for flag key `{self.flag_name}`." + path, + f"get invalid choice `{tag}` for flag key `{self.flag_name}`." + + did_you_mean( + tag, + list(self.choice_dict.keys()) + list(self.choice_alias.keys()), + ), ) elif self.optional: return self.choice_dict[self.default_tag] @@ -1042,3 +1048,22 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]: elif isinstance(obj, type): return obj.__name__ return json.JSONEncoder.default(self, obj) + + +def did_you_mean(choice: str, choices: List[str]) -> str: + """Get did you mean message. + + Parameters + ---------- + choice : str + the user's wrong choice + choices : list[str] + all the choices + + Returns + ------- + str + did you mean error message + """ + matches = difflib.get_close_matches(choice, choices) + return f"Did you mean: {matches[0]}?" if matches else "" diff --git a/tests/test_checker.py b/tests/test_checker.py index f00fc16..398215b 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -249,8 +249,9 @@ def test_sub_variants(self): "vnt2_1": 21, } } - with self.assertRaises(ArgumentValueError): + with self.assertRaises(ArgumentValueError) as cm: ca.check(err_dict2) + self.assertIn("Did you mean: type3?", str(cm.exception)) # test optional choice test_dict1["base"].pop("vnt_flag") with self.assertRaises(ArgumentKeyError):