Skip to content

Commit

Permalink
test: add test to the scripts of RLA
Browse files Browse the repository at this point in the history
BREAKING CHANGE: change the name of parameter from 'task' to 'task_table_name' in rla_scripts for better readability.
  • Loading branch information
xionghuichen committed Apr 25, 2022
1 parent 3ae1cab commit 5764414
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 49 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ RLA.egg-info**
**/code/**
**/results/**
**/log/**
**/.ipynb_checkpoints/*
**/.ipynb_checkpoints/*
**/.DS_Store
test/target_data_root/*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Querying:
### Other principles

The second design principle is easy for integration. It still has a long way to go to achieve it. We give several example projects integrating with RLA in the directory example.
<<<<<<< HEAD

1. PPO with RLA based on the [stable_baselines (tensorflow)](https://github.com/Stable-Baselines-Team/stable-baselines): example/sb_ppo_example
2. PPO with RL based on the [stable_baselines3 (pytorch)](https://github.com/DLR-RM/stable-baselines3): example/sb3_ppo_example

Expand Down
82 changes: 43 additions & 39 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,29 @@ def __init__(self, rla_config_path, proj_root, task, regex, *args, **kwargs):
fs = open(rla_config_path, encoding="UTF-8")
self.private_config = yaml.load(fs)
self.proj_root = proj_root
self.task = task
self.task_table_name = task
self.regex = regex

def _download_log(self, show=False):
from RLA.auto_ftp import FTPHandler

for log_type in self.log_types:
root_dir_regex = osp.join(self.proj_root, log_type, self.task, self.regex)
root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, self.regex)
empty = True
for root_dir in glob.glob(root_dir_regex):

pass

class DeleteLogTool(BasicLogTool):
def __init__(self, proj_root, task, regex, filter, *args, **kwargs):
def __init__(self, proj_root, task_table_name, regex, filter, *args, **kwargs):
self.proj_root = proj_root
self.task = task
self.task_table_name = task_table_name
self.regex = regex
assert isinstance(filter, Filter)
self.filter = filter
self.small_timestep_regs = []
super(DeleteLogTool, self).__init__(*args, **kwargs)

def _find_small_timestep_log(self):
root_dir_regex = osp.join(self.proj_root, LOG, self.task, self.regex)
root_dir_regex = osp.join(self.proj_root, LOG, self.task_table_name, self.regex)
for root_dir in glob.glob(root_dir_regex):
print("searching dirs", root_dir)
if os.path.exists(root_dir):
Expand Down Expand Up @@ -107,14 +105,16 @@ def _find_small_timestep_log(self):
print("[delete] find an experiment without any files. ", file_list[0])

def _delete_related_log(self, regex, show=False):
log_found = 0
for log_type in self.log_types:
print(f"--- search {log_type} ---")
root_dir_regex = osp.join(self.proj_root, log_type, self.task, regex)
root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, regex)
empty = True
for root_dir in glob.glob(root_dir_regex):
empty = False
if os.path.exists(root_dir):
print("find a matched experiment", root_dir)
log_found += 1
for file_list in os.walk(root_dir):
# walk into the leave of the file-tree.
for name in file_list[2]:
Expand All @@ -126,16 +126,6 @@ def _delete_related_log(self, regex, show=False):
print("skip the permission error file")
if not show:
print("delete sub-dir {}".format(file_list[0]))
# if not show:
# if len(os.listdir(file_list[0])) == 0:
# cur_dir = file_list[0]
# while True:
# shutil.rmtree(cur_dir, ignore_errors=True)
# print(" -- delete the empty dir", cur_dir, "---")
# cur_dir = os.path.abspath(os.path.join(cur_dir, ".."))
# if len(os.listdir(cur_dir)) != 0:
# break
# print("delete file {}".format(name))
if os.path.isdir(root_dir):
if not show:
try:
Expand All @@ -150,43 +140,57 @@ def _delete_related_log(self, regex, show=False):
else:
print("not dir {}".format(root_dir))
if empty: print("empty regex {}".format(root_dir_regex))
return log_found

def delete_related_log(self):
def delete_related_log(self, skip_ask=False):
self._delete_related_log(show=True, regex=self.regex)
s = input("delete these files? (y/n)")
if skip_ask:
s = 'y'
else:
s = input("delete these files? (y/n)")
if s == 'y':
print("do delete ...")
self._delete_related_log(show=False, regex=self.regex)
return self._delete_related_log(show=False, regex=self.regex)
else:
return 0

def delete_small_timestep_log(self):
def delete_small_timestep_log(self, skip_ask=False):
self._find_small_timestep_log()
print("complete searching.")
s = input("show files to be deleted? (y/n)")
if s == 'y':
if skip_ask:
s = 'y'
else:
s = input("show files to be deleted? (y/n)")
log_found = 0

if s == 'y' or skip_ask:
for reg in self.small_timestep_regs:
print("[delete small-timestep log] reg: ", reg)
self._delete_related_log(show=True, regex=reg + '*')
s = input("delete these files? (y/n)")
if s == 'y':
if skip_ask:
s = 'y'
else:
s = input("delete these files? (y/n)")
if s == 'y' or skip_ask:
for reg in self.small_timestep_regs:
print("do delete: ", reg)
self._delete_related_log(show=False, regex=reg + '*')

log_found += self._delete_related_log(show=False, regex=reg + '*')
return log_found

class ArchiveLogTool(BasicLogTool):
def __init__(self, proj_root, task, regex, archive_name_as_task, remove, *args, **kwargs):
def __init__(self, proj_root, task_table_name, regex, archive_table_name, remove, *args, **kwargs):
self.proj_root = proj_root
self.task = task
self.task_table_name = task_table_name
self.regex = regex
self.remove = remove
self.archive_name_as_task = archive_name_as_task
self.archive_table_name = archive_table_name
super(ArchiveLogTool, self).__init__(*args, **kwargs)

def _archive_log(self, show=False):
for log_type in self.log_types:
root_dir_regex = osp.join(self.proj_root, log_type, self.task, self.regex)
archive_root_dir = osp.join(self.proj_root, log_type, self.archive_name_as_task)
prefix_dir = osp.join(self.proj_root, log_type, self.task)
root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, self.regex)
archive_root_dir = osp.join(self.proj_root, log_type, self.archive_table_name)
prefix_dir = osp.join(self.proj_root, log_type, self.task_table_name)
prefix_len = len(prefix_dir)
empty = True
# os.system("chmod +x -R \"{}\"".format(prefix_dir))
Expand Down Expand Up @@ -217,12 +221,12 @@ def _archive_log(self, show=False):
if empty: print("empty regex {}".format(root_dir_regex))
pass

def archive_log(self):
def archive_log(self, skip_ask=False):
self._archive_log(show=True)
warn = ''
if self.remove:
warn = '[WARN] You are in the \'\'remove\'\' setting, the original log files will be removed!!'
s = input("archive these files? (y/n) \n " + warn)
if skip_ask:
s = 'y'
else:
s = input("archive these files? (y/n) \n ")
if s == 'y':
print("do archive ...")
self._archive_log(show=False)
Expand Down
8 changes: 4 additions & 4 deletions rla_scripts/archive_log.py → rla_scripts/archive_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
def argsparser():
parser = argparse.ArgumentParser("Archive Log")
# reduce setting
parser.add_argument('--task', type=str)
parser.add_argument('--archive_name_as_task', type=str, default='archived')
parser.add_argument('--task_table_name', type=str)
parser.add_argument('--archive_table_name', type=str, default=ARCHIVED_TABLE)
parser.add_argument('--regex', type=str)
parser.add_argument('--remove', action='store_true')

Expand All @@ -25,6 +25,6 @@ def argsparser():

if __name__=='__main__':
args = argsparser()
dlt = ArchiveLogTool(proj_root=DATA_ROOT, task=args.task, regex=args.regex,
archive_name_as_task=args.archive_name_as_task, remove=args.remove)
dlt = ArchiveLogTool(proj_root=DATA_ROOT, task_table_name=args.task_table_name, regex=args.regex,
archive_table_name=args.archive_table_name, remove=args.remove)
dlt.archive_log()
3 changes: 2 additions & 1 deletion rla_scripts/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
DATA_ROOT = '../example/simplest_code/'
DATA_ROOT = '../example/simplest_code/'
ARCHIVED_TABLE = 'archived'
4 changes: 2 additions & 2 deletions rla_scripts/delete_log.py → rla_scripts/delete_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def argsparser():
parser = argparse.ArgumentParser("Delete Log")
# reduce setting
parser.add_argument('--task', type=str, default="")
parser.add_argument('--task_table_name', type=str, default="")
parser.add_argument('--regex', type=str)
parser.add_argument('--timestep_bound', type=int, default=100)
parser.add_argument('--delete_type', type=str, default=Filter.ALL)
Expand All @@ -21,7 +21,7 @@ def argsparser():
args = argsparser()
filter = Filter()
filter.config(type=args.delete_type, timstep_bound=args.timestep_bound)
dlt = DeleteLogTool(proj_root=DATA_ROOT, task=args.task, regex=args.regex,
dlt = DeleteLogTool(proj_root=DATA_ROOT, task_table_name=args.task_table_name, regex=args.regex,
filter=filter)
if args.delete_type == Filter.ALL:
dlt.delete_related_log()
Expand Down
Empty file added rla_scripts/view_expt.py
Empty file.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='RLA',
version="0.5.1",
version="0.5.2",
description=(
'RL assistant'
),
Expand Down
Empty file added test/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions test/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
import shutil
import os

class BaseTest(unittest.TestCase):
"""
Base test class.
"""
SOURCE_DATA_ROOT = './test_data_root'
TARGET_DATA_ROOT = './target_data_root'
TASK_NAME = 'demo_task'
def remove_and_copy_data(self):
"""
reset the experiment data for test.
"""
if os.path.exists(self.TARGET_DATA_ROOT):
shutil.rmtree(self.TARGET_DATA_ROOT)
shutil.copytree(self.SOURCE_DATA_ROOT, self.TARGET_DATA_ROOT)
Empty file added test/test_data_root/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions test/test_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from test._base import BaseTest
from RLA.easy_log.log_tools import DeleteLogTool, Filter
from RLA.easy_log.log_tools import ArchiveLogTool

class ScriptTest(BaseTest):

def test_delete_reg(self) -> None:
"""
test delete log filtered by regex.
"""
self.remove_and_copy_data()
filter = Filter()
filter.config(type=Filter.ALL, timstep_bound=1)
dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter)
log_found = dlt.delete_related_log(skip_ask=True)
assert log_found == 10
log_found = dlt.delete_related_log(skip_ask=True)
assert log_found == 0

def test_delete_reg_small_ts(self):
"""
test delete log filtered by regex and threshold of time-step.
"""
self.remove_and_copy_data()
filter = Filter()
# none of the experiment satisfied timestep <=1
filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=1)
dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter)
log_found = dlt.delete_small_timestep_log(skip_ask=True)
assert log_found == 0
# all the experiment satisfied timestep <=2000
filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=2000)
dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name='demo_task', regex='2022/03/01/21-13*', filter=filter)
log_found = dlt.delete_small_timestep_log(skip_ask=True)
assert log_found == 10
# nothing left
filter.config(type=Filter.SMALL_TIMESTEP, timstep_bound=2000)
dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*', filter=filter)
log_found = dlt.delete_small_timestep_log(skip_ask=True)
assert log_found == 0

def test_archive(self):
self.remove_and_copy_data()
# archive experiments.
dlt = ArchiveLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name=self.TASK_NAME, regex='2022/03/01/21-13*',
archive_table_name='archived', remove=False)
dlt.archive_log(skip_ask=True)
# remove the archived experiments.
filter = Filter()
filter.config(type=Filter.ALL, timstep_bound=1)
dlt = DeleteLogTool(proj_root=self.TARGET_DATA_ROOT, task_table_name='archived', regex='2022/03/01/21-13*', filter=filter)
log_found = dlt.delete_related_log(skip_ask=True)
assert log_found == 10

0 comments on commit 5764414

Please sign in to comment.