diff --git a/UtilsRL/__init__.py b/UtilsRL/__init__.py index 6dbe537..a523e0d 100644 --- a/UtilsRL/__init__.py +++ b/UtilsRL/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.6.6" +__version__ = "0.6.7" diff --git a/UtilsRL/exp/argparse.py b/UtilsRL/exp/argparse.py index 23c24ac..2a1469e 100644 --- a/UtilsRL/exp/argparse.py +++ b/UtilsRL/exp/argparse.py @@ -29,26 +29,15 @@ def get_key(_key): else: return _key -def parse_cmd_args(convert=True): +def parse_cmd_args(): cmd_parser = argparse.ArgumentParser() cmd_parser.add_argument("--config", type=str, default=None, help="the path of config file") config_args, cmd_args = cmd_parser.parse_known_args() num = len(cmd_args)//2 - cmd_args = dict(zip([get_key(cmd_args[2*i]) for i in range(num)], [cmd_args[2*i+1] for i in range(num)])) - args = NameSpace("cmd_args", {}, nested=True) if convert else {} - for key, value in cmd_args.items(): - keys = key.split(".") - this = args - for subkey in keys[:-1]: - if convert: - this[subkey] = this.get(subkey, NameSpace(subkey, {}, nested=True)) - else: - this[subkey] = this.get(subkey, {}) - this = this[subkey] - this[keys[-1]] = safe_eval(value) - return config_args.config, args + cmd_args = dict(zip([get_key(cmd_args[2*i]) for i in range(num)], [safe_eval(cmd_args[2*i+1]) for i in range(num)])) + return config_args.config, cmd_args -def parse_file_args(path, convert=True): +def parse_file_args(path): # parse args from config files or modules if isinstance(path, str): if path.endswith(".json"): @@ -70,7 +59,6 @@ def parse_file_args(path, convert=True): elif path is None: file_args = {} - file_args = NameSpace("args", file_args, nested=True) if convert else file_args return file_args @@ -103,32 +91,54 @@ def parse_args( convert : Whether or not convert the parse arguments to NameSpace object. """ - cmd_path, cmd_args = parse_cmd_args(convert=convert) + cmd_path, cmd_args = parse_cmd_args() - file_args = parse_file_args(path or cmd_path, convert=convert) + file_args = parse_file_args(path or cmd_path) - # update with cmd args def traverse_add(old, new, current_key=""): for new_key, new_value in new.items(): - if new_key not in old: - logger.warning(f"parse_args: key {current_key + new_key} is not in the config file, setting it to {new_value}") - old[new_key] = new_value - elif type(old[new_key]) != type(new_value): - logger.warning(f"parse_args: overwriting key {current_key + new_key} with {new_value}") - old[new_key] = new_value - else: - if convert and isinstance(new_value, NameSpaceMeta) \ - or (not convert and isinstance(new_value, dict)): - traverse_add(old[new_key], new_value, current_key=current_key+new_key+".") + this = old + keys = new_key.split(".") + for kidx, k in enumerate(keys[:-1]): + if isinstance(this, list): + # if this is a list, check the key and go to next level + if not (k.isdigit() and int(k)<=len(this)): + logger.error(f"parse_args: {new_key} does not exist.") + raise ValueError(f"parse_args: {new_key} does not exist.") + k = int(k) + elif isinstance(this, dict): + # if this is a dict, check the key, add an empty dict if necessary and go to next level + if not k in this.keys(): + this[k] = {} + if not isinstance(this[k], (dict, list)): + logger.warning(f"parse_args: discarding key {'.'.join(keys[:kidx+1])}") + this[k] = {} + this = this[k] + + # finally, set the value with the last key + last_k = keys[-1] + if isinstance(this, list): + if not (last_k.isdigit() and int(last_k)<=len(this)): + logger.error(f"parse_args: {new_key} does not exist.") + raise ValueError(f"parse_args: {new_key} does not exist.") + logger.warning(f"parse_args: overwriting key {new_key} with {new_value}.") + this[int(last_k)] = new_value + elif isinstance(this, dict): + if not last_k in this.keys(): + logger.warning(f"parse_args: key {new_key} is not in the config file, setting it to {new_value}.") + this[last_k] = new_value else: - logger.warning(f"parse_args: overwriting key {current_key + new_key} with {new_value}") - old[new_key] = new_value + logger.warning(f"parse_args: overwriting key {new_key} with {new_value}.") + this[last_k] = new_value traverse_add(file_args, cmd_args) if post_init is not None: post_init(file_args) + + if convert: + file_args = NameSpace("args", file_args, nested=True) return file_args