Skip to content

Commit

Permalink
update: argparse logic
Browse files Browse the repository at this point in the history
  • Loading branch information
typoverflow committed Mar 8, 2024
1 parent 0a66b50 commit 1f2dadb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
2 changes: 1 addition & 1 deletion UtilsRL/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

__version__ = "0.6.6"
__version__ = "0.6.7"
72 changes: 41 additions & 31 deletions UtilsRL/exp/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,15 @@ def get_key(_key):
else:
return _key

def parse_cmd_args(convert=True):
def parse_cmd_args():
cmd_parser = argparse.ArgumentParser()
cmd_parser.add_argument("--config", type=str, default=None, help="the path of config file")
config_args, cmd_args = cmd_parser.parse_known_args()
num = len(cmd_args)//2
cmd_args = dict(zip([get_key(cmd_args[2*i]) for i in range(num)], [cmd_args[2*i+1] for i in range(num)]))
args = NameSpace("cmd_args", {}, nested=True) if convert else {}
for key, value in cmd_args.items():
keys = key.split(".")
this = args
for subkey in keys[:-1]:
if convert:
this[subkey] = this.get(subkey, NameSpace(subkey, {}, nested=True))
else:
this[subkey] = this.get(subkey, {})
this = this[subkey]
this[keys[-1]] = safe_eval(value)
return config_args.config, args
cmd_args = dict(zip([get_key(cmd_args[2*i]) for i in range(num)], [safe_eval(cmd_args[2*i+1]) for i in range(num)]))
return config_args.config, cmd_args

def parse_file_args(path, convert=True):
def parse_file_args(path):
# parse args from config files or modules
if isinstance(path, str):
if path.endswith(".json"):
Expand All @@ -70,7 +59,6 @@ def parse_file_args(path, convert=True):
elif path is None:
file_args = {}

file_args = NameSpace("args", file_args, nested=True) if convert else file_args
return file_args


Expand Down Expand Up @@ -103,32 +91,54 @@ def parse_args(
convert : Whether or not convert the parse arguments to NameSpace object.
"""

cmd_path, cmd_args = parse_cmd_args(convert=convert)
cmd_path, cmd_args = parse_cmd_args()

file_args = parse_file_args(path or cmd_path, convert=convert)
file_args = parse_file_args(path or cmd_path)


# update with cmd args
def traverse_add(old, new, current_key=""):
for new_key, new_value in new.items():
if new_key not in old:
logger.warning(f"parse_args: key {current_key + new_key} is not in the config file, setting it to {new_value}")
old[new_key] = new_value
elif type(old[new_key]) != type(new_value):
logger.warning(f"parse_args: overwriting key {current_key + new_key} with {new_value}")
old[new_key] = new_value
else:
if convert and isinstance(new_value, NameSpaceMeta) \
or (not convert and isinstance(new_value, dict)):
traverse_add(old[new_key], new_value, current_key=current_key+new_key+".")
this = old
keys = new_key.split(".")
for kidx, k in enumerate(keys[:-1]):
if isinstance(this, list):
# if this is a list, check the key and go to next level
if not (k.isdigit() and int(k)<=len(this)):
logger.error(f"parse_args: {new_key} does not exist.")
raise ValueError(f"parse_args: {new_key} does not exist.")
k = int(k)
elif isinstance(this, dict):
# if this is a dict, check the key, add an empty dict if necessary and go to next level
if not k in this.keys():
this[k] = {}
if not isinstance(this[k], (dict, list)):
logger.warning(f"parse_args: discarding key {'.'.join(keys[:kidx+1])}")
this[k] = {}
this = this[k]

# finally, set the value with the last key
last_k = keys[-1]
if isinstance(this, list):
if not (last_k.isdigit() and int(last_k)<=len(this)):
logger.error(f"parse_args: {new_key} does not exist.")
raise ValueError(f"parse_args: {new_key} does not exist.")
logger.warning(f"parse_args: overwriting key {new_key} with {new_value}.")
this[int(last_k)] = new_value
elif isinstance(this, dict):
if not last_k in this.keys():
logger.warning(f"parse_args: key {new_key} is not in the config file, setting it to {new_value}.")
this[last_k] = new_value
else:
logger.warning(f"parse_args: overwriting key {current_key + new_key} with {new_value}")
old[new_key] = new_value
logger.warning(f"parse_args: overwriting key {new_key} with {new_value}.")
this[last_k] = new_value

traverse_add(file_args, cmd_args)

if post_init is not None:
post_init(file_args)

if convert:
file_args = NameSpace("args", file_args, nested=True)

return file_args

Expand Down

0 comments on commit 1f2dadb

Please sign in to comment.