diff --git a/.gitignore b/.gitignore index a1df0aa..c328553 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ RLA.egg-info** **/code/** **/results/** **/log/** -**/.ipynb_checkpoints/* \ No newline at end of file +**/.ipynb_checkpoints/* +**/.DS_Store +test/target_data_root/* diff --git a/README.md b/README.md index 8a95617..14e491f 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ Querying: ### Other principles The second design principle is easy for integration. It still has a long way to go to achieve it. We give several example projects integrating with RLA in the directory example. -<<<<<<< HEAD + 1. PPO with RLA based on the [stable_baselines (tensorflow)](https://github.com/Stable-Baselines-Team/stable-baselines): example/sb_ppo_example 2. PPO with RL based on the [stable_baselines3 (pytorch)](https://github.com/DLR-RM/stable-baselines3): example/sb3_ppo_example diff --git a/RLA/easy_log/log_tools.py b/RLA/easy_log/log_tools.py index 6cf370a..9db3d7c 100644 --- a/RLA/easy_log/log_tools.py +++ b/RLA/easy_log/log_tools.py @@ -33,23 +33,21 @@ def __init__(self, rla_config_path, proj_root, task, regex, *args, **kwargs): fs = open(rla_config_path, encoding="UTF-8") self.private_config = yaml.load(fs) self.proj_root = proj_root - self.task = task + self.task_table_name = task self.regex = regex def _download_log(self, show=False): - from RLA.auto_ftp import FTPHandler - for log_type in self.log_types: - root_dir_regex = osp.join(self.proj_root, log_type, self.task, self.regex) + root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, self.regex) empty = True for root_dir in glob.glob(root_dir_regex): pass class DeleteLogTool(BasicLogTool): - def __init__(self, proj_root, task, regex, filter, *args, **kwargs): + def __init__(self, proj_root, task_table_name, regex, filter, *args, **kwargs): self.proj_root = proj_root - self.task = task + self.task_table_name = task_table_name self.regex = regex assert isinstance(filter, Filter) self.filter = filter @@ -57,7 +55,7 @@ def __init__(self, proj_root, task, regex, filter, *args, **kwargs): super(DeleteLogTool, self).__init__(*args, **kwargs) def _find_small_timestep_log(self): - root_dir_regex = osp.join(self.proj_root, LOG, self.task, self.regex) + root_dir_regex = osp.join(self.proj_root, LOG, self.task_table_name, self.regex) for root_dir in glob.glob(root_dir_regex): print("searching dirs", root_dir) if os.path.exists(root_dir): @@ -107,14 +105,16 @@ def _find_small_timestep_log(self): print("[delete] find an experiment without any files. ", file_list[0]) def _delete_related_log(self, regex, show=False): + log_found = 0 for log_type in self.log_types: print(f"--- search {log_type} ---") - root_dir_regex = osp.join(self.proj_root, log_type, self.task, regex) + root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, regex) empty = True for root_dir in glob.glob(root_dir_regex): empty = False if os.path.exists(root_dir): print("find a matched experiment", root_dir) + log_found += 1 for file_list in os.walk(root_dir): # walk into the leave of the file-tree. for name in file_list[2]: @@ -126,16 +126,6 @@ def _delete_related_log(self, regex, show=False): print("skip the permission error file") if not show: print("delete sub-dir {}".format(file_list[0])) - # if not show: - # if len(os.listdir(file_list[0])) == 0: - # cur_dir = file_list[0] - # while True: - # shutil.rmtree(cur_dir, ignore_errors=True) - # print(" -- delete the empty dir", cur_dir, "---") - # cur_dir = os.path.abspath(os.path.join(cur_dir, "..")) - # if len(os.listdir(cur_dir)) != 0: - # break - # print("delete file {}".format(name)) if os.path.isdir(root_dir): if not show: try: @@ -150,43 +140,57 @@ def _delete_related_log(self, regex, show=False): else: print("not dir {}".format(root_dir)) if empty: print("empty regex {}".format(root_dir_regex)) + return log_found - def delete_related_log(self): + def delete_related_log(self, skip_ask=False): self._delete_related_log(show=True, regex=self.regex) - s = input("delete these files? (y/n)") + if skip_ask: + s = 'y' + else: + s = input("delete these files? (y/n)") if s == 'y': print("do delete ...") - self._delete_related_log(show=False, regex=self.regex) + return self._delete_related_log(show=False, regex=self.regex) + else: + return 0 - def delete_small_timestep_log(self): + def delete_small_timestep_log(self, skip_ask=False): self._find_small_timestep_log() print("complete searching.") - s = input("show files to be deleted? (y/n)") - if s == 'y': + if skip_ask: + s = 'y' + else: + s = input("show files to be deleted? (y/n)") + log_found = 0 + + if s == 'y' or skip_ask: for reg in self.small_timestep_regs: print("[delete small-timestep log] reg: ", reg) self._delete_related_log(show=True, regex=reg + '*') - s = input("delete these files? (y/n)") - if s == 'y': + if skip_ask: + s = 'y' + else: + s = input("delete these files? (y/n)") + if s == 'y' or skip_ask: for reg in self.small_timestep_regs: print("do delete: ", reg) - self._delete_related_log(show=False, regex=reg + '*') - + log_found += self._delete_related_log(show=False, regex=reg + '*') + return log_found class ArchiveLogTool(BasicLogTool): - def __init__(self, proj_root, task, regex, archive_name_as_task, remove, *args, **kwargs): + def __init__(self, proj_root, task_table_name, regex, archive_table_name, remove, *args, **kwargs): self.proj_root = proj_root - self.task = task + self.task_table_name = task_table_name self.regex = regex self.remove = remove - self.archive_name_as_task = archive_name_as_task + self.archive_table_name = archive_table_name super(ArchiveLogTool, self).__init__(*args, **kwargs) def _archive_log(self, show=False): for log_type in self.log_types: - root_dir_regex = osp.join(self.proj_root, log_type, self.task, self.regex) - archive_root_dir = osp.join(self.proj_root, log_type, self.archive_name_as_task) - prefix_dir = osp.join(self.proj_root, log_type, self.task) + root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, self.regex) + archive_root_dir = osp.join(self.proj_root, log_type, self.archive_table_name) + prefix_dir = osp.join(self.proj_root, log_type, self.task_table_name) prefix_len = len(prefix_dir) empty = True # os.system("chmod +x -R \"{}\"".format(prefix_dir)) @@ -217,12 +221,12 @@ def _archive_log(self, show=False): if empty: print("empty regex {}".format(root_dir_regex)) pass - def archive_log(self): + def archive_log(self, skip_ask=False): self._archive_log(show=True) - warn = '' - if self.remove: - warn = '[WARN] You are in the \'\'remove\'\' setting, the original log files will be removed!!' - s = input("archive these files? (y/n) \n " + warn) + if skip_ask: + s = 'y' + else: + s = input("archive these files? (y/n) \n ") if s == 'y': print("do archive ...") self._archive_log(show=False) diff --git a/rla_scripts/archive_log.py b/rla_scripts/archive_expt.py similarity index 64% rename from rla_scripts/archive_log.py rename to rla_scripts/archive_expt.py index 355a400..a488364 100644 --- a/rla_scripts/archive_log.py +++ b/rla_scripts/archive_expt.py @@ -14,8 +14,8 @@ def argsparser(): parser = argparse.ArgumentParser("Archive Log") # reduce setting - parser.add_argument('--task', type=str) - parser.add_argument('--archive_name_as_task', type=str, default='archived') + parser.add_argument('--task_table_name', type=str) + parser.add_argument('--archive_table_name', type=str, default=ARCHIVED_TABLE) parser.add_argument('--regex', type=str) parser.add_argument('--remove', action='store_true') @@ -25,6 +25,6 @@ def argsparser(): if __name__=='__main__': args = argsparser() - dlt = ArchiveLogTool(proj_root=DATA_ROOT, task=args.task, regex=args.regex, - archive_name_as_task=args.archive_name_as_task, remove=args.remove) + dlt = ArchiveLogTool(proj_root=DATA_ROOT, task_table_name=args.task_table_name, regex=args.regex, + archive_table_name=args.archive_table_name, remove=args.remove) dlt.archive_log() \ No newline at end of file diff --git a/rla_scripts/config.py b/rla_scripts/config.py index 50ba426..b211564 100644 --- a/rla_scripts/config.py +++ b/rla_scripts/config.py @@ -1 +1,2 @@ -DATA_ROOT = '../example/simplest_code/' \ No newline at end of file +DATA_ROOT = '../example/simplest_code/' +ARCHIVED_TABLE = 'archived' \ No newline at end of file diff --git a/rla_scripts/delete_log.py b/rla_scripts/delete_expt.py similarity index 83% rename from rla_scripts/delete_log.py rename to rla_scripts/delete_expt.py index afb6128..0c7120f 100644 --- a/rla_scripts/delete_log.py +++ b/rla_scripts/delete_expt.py @@ -9,7 +9,7 @@ def argsparser(): parser = argparse.ArgumentParser("Delete Log") # reduce setting - parser.add_argument('--task', type=str, default="") + parser.add_argument('--task_table_name', type=str, default="") parser.add_argument('--regex', type=str) parser.add_argument('--timestep_bound', type=int, default=100) parser.add_argument('--delete_type', type=str, default=Filter.ALL) @@ -21,7 +21,7 @@ def argsparser(): args = argsparser() filter = Filter() filter.config(type=args.delete_type, timstep_bound=args.timestep_bound) - dlt = DeleteLogTool(proj_root=DATA_ROOT, task=args.task, regex=args.regex, + dlt = DeleteLogTool(proj_root=DATA_ROOT, task_table_name=args.task_table_name, regex=args.regex, filter=filter) if args.delete_type == Filter.ALL: dlt.delete_related_log() diff --git a/rla_scripts/view_expt.py b/rla_scripts/view_expt.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 8d99be6..eefb385 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='RLA', - version="0.5.1", + version="0.5.2", description=( 'RL assistant' ), diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/_base.py b/test/_base.py new file mode 100644 index 0000000..eb8175e --- /dev/null +++ b/test/_base.py @@ -0,0 +1,18 @@ +import unittest +import shutil +import os + +class BaseTest(unittest.TestCase): + """ + Base test class. + """ + SOURCE_DATA_ROOT = './test_data_root' + TARGET_DATA_ROOT = './target_data_root' + TASK_NAME = 'demo_task' + def remove_and_copy_data(self): + """ + reset the experiment data for test. + """ + if os.path.exists(self.TARGET_DATA_ROOT): + shutil.rmtree(self.TARGET_DATA_ROOT) + shutil.copytree(self.SOURCE_DATA_ROOT, self.TARGET_DATA_ROOT) diff --git a/test/test_data_root/__init__.py b/test/test_data_root/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_scripts.py b/test/test_scripts.py new file mode 100644 index 0000000..ae877ee --- /dev/null +++ b/test/test_scripts.py @@ -0,0 +1,54 @@ +from test._base import BaseTest +from RLA.easy_log.log_tools import DeleteLogTool, Filter +from RLA.easy_log.log_tools import ArchiveLogTool + +class ScriptTest(BaseTest): + + def test_delete_reg(self) -> None: + """ + test delete log filtered by regex. + """ + self.remove_and_copy_data() + filter = Filter() + filter.config(type=Filter.ALL, timstep_bound=1) + dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter) + log_found = dlt.delete_related_log(skip_ask=True) + assert log_found == 10 + log_found = dlt.delete_related_log(skip_ask=True) + assert log_found == 0 + + def test_delete_reg_small_ts(self): + """ + test delete log filtered by regex and threshold of time-step. + """ + self.remove_and_copy_data() + filter = Filter() + # none of the experiment satisfied timestep <=1 + filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=1) + dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter) + log_found = dlt.delete_small_timestep_log(skip_ask=True) + assert log_found == 0 + # all the experiment satisfied timestep <=2000 + filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=2000) + dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name='demo_task', regex='2022/03/01/21-13*', filter=filter) + log_found = dlt.delete_small_timestep_log(skip_ask=True) + assert log_found == 10 + # nothing left + filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=2000) + dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter) + log_found = dlt.delete_small_timestep_log(skip_ask=True) + assert log_found == 0 + + def test_archive(self): + self.remove_and_copy_data() + # archive experiments. + dlt = ArchiveLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', + archive_table_name='archived', remove=False) + dlt.archive_log(skip_ask=True) + # remove the archived experiments. + filter = Filter() + filter.config(type=Filter.ALL, timstep_bound=1) + dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name='archived', regex='2022/03/01/21-13*', filter=filter) + log_found = dlt.delete_related_log(skip_ask=True) + assert log_found == 10 + \ No newline at end of file