diff --git a/README.md b/README.md index f793f32..4740777 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ In the provided ```ObliqueDTPolicy``` class, the method get_oblique_data generat # Usage ```bash -pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.1 +pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.3.0 ``` ```python @@ -51,11 +51,11 @@ print(evaluate_policy(oracle, Monitor(env))[0]) clf = DecisionTreeRegressor( max_leaf_nodes=32 ) # Change to DecisionTreeClassifier for discrete Actions. -tree_policy = ObliqueDTPolicy(clf, env) # +learner = ObliqueDTPolicy(clf, env) # # You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs. # Start the imitation learning -interpret = Interpreter(oracle, tree_policy, env) +interpret = Interpreter(oracle, learner, env) interpret.fit(10) # Eval and save the best tree diff --git a/docs/conf.py b/docs/conf.py index 39a5943..85813c8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,7 @@ project = "interpreter" copyright = "2024, Hector Kohler" author = "Hector Kohler" -release = "0.2.1" +release = "0.3.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/usage.md b/docs/usage.md index 1c27afc..2c4404d 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1,6 +1,6 @@ ## Installation ```bash -pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.2.1 +pip install git+https://github.com/KohlerHECTOR/interpreter-py.git@v0.3.0 ``` @@ -36,11 +36,11 @@ print(evaluate_policy(oracle, Monitor(env))[0]) clf = DecisionTreeRegressor( max_leaf_nodes=32 ) # Change to DecisionTreeClassifier for discrete Actions. -tree_policy = ObliqueDTPolicy(clf, env) # +learner = ObliqueDTPolicy(clf, env) # # You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs. # Start the imitation learning -interpret = Interpreter(oracle, tree_policy, env) +interpret = Interpreter(oracle, learner, env) interpret.fit(10) # Eval and save the best tree diff --git a/examples/half_cheetah.py b/examples/half_cheetah.py index 4a7965b..aed8866 100644 --- a/examples/half_cheetah.py +++ b/examples/half_cheetah.py @@ -28,11 +28,11 @@ clf = DecisionTreeRegressor( max_leaf_nodes=32 ) # Change to DecisionTreeClassifier for discrete Actions. -tree_policy = ObliqueDTPolicy(clf, env) # +learner = ObliqueDTPolicy(clf, env) # # You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs. # Start the imitation learning -interpret = Interpreter(oracle, tree_policy, env) +interpret = Interpreter(oracle, learner, env) interpret.fit(10) # Eval and save the best tree diff --git a/interpreter/__init__.py b/interpreter/__init__.py index b82aaf9..1dd7cae 100644 --- a/interpreter/__init__.py +++ b/interpreter/__init__.py @@ -1,2 +1,2 @@ -from .policies import ObliqueDTPolicy, SB3Policy, DTPolicy +from .policies import ObliqueDTPolicy, SB3Policy, DTPolicy, SymbPolicy from .interpreter import Interpreter diff --git a/interpreter/interpreter.py b/interpreter/interpreter.py index 6c69503..735c930 100644 --- a/interpreter/interpreter.py +++ b/interpreter/interpreter.py @@ -1,8 +1,9 @@ +from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy, SymbPolicy + from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.utils import check_for_correct_spaces from stable_baselines3.common.monitor import Monitor -from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy from rlberry.agents import AgentWithSimplePolicy from gymnasium.spaces import Discrete, Box @@ -24,8 +25,8 @@ class Interpreter(AgentWithSimplePolicy): oracle : object The oracle model that generates the data for training. Usually a stable-baselines3 model from the hugging face hub. - tree_policy : object - The decision tree policy to be trained. + learner : object + The decision tree policy or symbolic equation to be trained. env : object The environment in which the policies are evaluated (gym.Env). data_per_iter : int, optional @@ -36,7 +37,7 @@ class Interpreter(AgentWithSimplePolicy): ---------- oracle : object The oracle model that generates the data for training. - tree_policy : object + learner : object The decision tree policy to be trained. data_per_iter : int The number of data points to generate per iteration. @@ -48,31 +49,31 @@ class Interpreter(AgentWithSimplePolicy): A list to store the rewards of the trained tree policies over iterations. """ - def __init__(self, oracle, tree_policy, env, data_per_iter=5000, **kwargs): + def __init__(self, oracle, learner, env, data_per_iter=5000, **kwargs): assert isinstance(oracle, SB3Policy) and ( - isinstance(tree_policy, DTPolicy) - or isinstance(tree_policy, ObliqueDTPolicy) + isinstance(learner, DTPolicy) + or isinstance(learner, ObliqueDTPolicy) or isinstance(learner, SymbPolicy) ) AgentWithSimplePolicy.__init__(self, env, **kwargs) if not isinstance(self.eval_env, Monitor): self.eval_env = Monitor(self.eval_env) self._oracle = oracle - self._tree_policy = tree_policy - self._policy = deepcopy(tree_policy) + self._learner = learner + self._policy = deepcopy(learner) self._data_per_iter = data_per_iter check_for_correct_spaces( self.env, - self._tree_policy.observation_space, - self._tree_policy.action_space, + self._learner.observation_space, + self._learner.action_space, ) check_for_correct_spaces( self.env, self._oracle.observation_space, self._oracle.action_space ) check_for_correct_spaces( self.eval_env, - self._tree_policy.observation_space, - self._tree_policy.action_space, + self._learner.observation_space, + self._learner.action_space, ) check_for_correct_spaces( self.eval_env, self._oracle.observation_space, self._oracle.action_space @@ -90,17 +91,17 @@ def fit(self, nb_timesteps): print("Fitting tree nb {} ...".format(0)) nb_iter = int(max(1, nb_timesteps // self._data_per_iter)) S, A = self.generate_data(self._oracle, self._data_per_iter) - self._tree_policy.fit_tree(S, A) - self._policy = deepcopy(self._tree_policy) - tree_reward, _ = evaluate_policy(self._tree_policy, self.eval_env) + self._learner.fit(S, A) + self._policy = deepcopy(self._learner) + tree_reward, _ = evaluate_policy(self._learner, self.eval_env) current_max_reward = tree_reward - # self.tree_policies = [deepcopy(self._tree_policy)] + # self.tree_policies = [deepcopy(self._learner)] # self.tree_policies_rewards = [tree_reward] for t in range(1, nb_iter + 1): print("Fitting tree nb {} ...".format(t + 1)) S_tree, _ = self.generate_data( - self._tree_policy, int((t / nb_iter) * self._data_per_iter) + self._learner, int((t / nb_iter) * self._data_per_iter) ) S_oracle, A_oracle = self.generate_data( self._oracle, int((1 - t / nb_iter) * self._data_per_iter) @@ -109,14 +110,14 @@ def fit(self, nb_timesteps): S = np.concatenate((S, S_tree, S_oracle)) A = np.concatenate((A, self._oracle.predict(S_tree)[0], A_oracle)) - self._tree_policy.fit_tree(S, A) - tree_reward, _ = evaluate_policy(self._tree_policy, self.eval_env) + self._learner.fit(S, A) + tree_reward, _ = evaluate_policy(self._learner, self.eval_env) if tree_reward > current_max_reward: current_max_reward = tree_reward - self._policy = deepcopy(self._tree_policy) + self._policy = deepcopy(self._learner) print("New best tree reward: {}".format(tree_reward)) - # self.tree_policies += [deepcopy(self._tree_policy)] + # self.tree_policies += [deepcopy(self._learner)] # self.tree_policies_rewards += [tree_reward] def policy(self, obs): diff --git a/interpreter/policies.py b/interpreter/policies.py index 9baa076..c2c5ccd 100644 --- a/interpreter/policies.py +++ b/interpreter/policies.py @@ -1,10 +1,9 @@ +from pysr import PySRRegressor import gymnasium as gym from abc import ABC, abstractmethod import numpy as np -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.base import RegressorMixin, ClassifierMixin from stable_baselines3.common.utils import is_vectorized_box_observation -from tqdm import tqdm class Policy(ABC): @@ -55,7 +54,64 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0): """ raise NotImplementedError +class SymbPolicy(Policy): + def __init__(self, model, env): + assert isinstance(model, PySRRegressor) + assert isinstance(env.action_space, gym.spaces.Box), "Symbolic regression only works for continuous actions" + self.model = model + self.model.temp_equation_file = True + super().__init__(env.observation_space, env.action_space) + + S = [self.observation_space.sample() for _ in range(10)] + A = [self.action_space.sample() for _ in range(10)] + self.model.fit(S, A, ) + self.model.warm_start = True + self.model.batching = True + + def predict(self, obs, state=None, deterministic=True, episode_start=0): + """ + Predict the action to take given an observation. + + Parameters + ---------- + obs : np.ndarray + The observation input. + state : object, optional + The state of the policy (default is None). + deterministic : bool, optional + Whether to use a deterministic policy (default is True). + episode_start : int, optional + The episode start index (default is 0). + + Returns + ------- + action : np.ndarray + The action to take. + state : object + The updated state of the policy. + """ + if not is_vectorized_box_observation(obs, self.observation_space): + if isinstance(self.action_space, gym.spaces.Discrete): + action = self.model.predict(obs.reshape(1, -1)).squeeze().astype(int) + else: + if self.action_space.shape[0] > 1: + action = self.model.predict(obs.reshape(1, -1)).squeeze() + else: + action = self.model.predict(obs.reshape(1, -1)) + return action, state + else: + if isinstance(self.action_space, gym.spaces.Discrete): + return self.model.predict(obs).astype(int), None + else: + if self.action_space.shape[0] > 1: + return self.model.predict(obs), None + else: + return self.model.predict(obs)[:, np.newaxis], None + + def fit(self, X, y): + return self.model.fit(X, y) + class SB3Policy(Policy): def __init__(self, base_policy): self.base_policy = base_policy @@ -142,7 +198,7 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0): else: return self.clf.predict(obs)[:, np.newaxis], None - def fit_tree(self, S, A): + def fit(self, S, A): """ Fit the decision tree with the provided observations and actions. @@ -269,7 +325,7 @@ def predict(self, obs, state=None, deterministic=True, episode_start=0): None, ) - def fit_tree(self, S, A): + def fit(self, S, A): """ Fit the decision tree with the provided oblique observations and actions. diff --git a/setup.py b/setup.py index baed55a..e1526c2 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = "0.2.1" +__version__ = "0.3.0" packages = find_packages( exclude=[ @@ -32,5 +32,6 @@ "huggingface-sb3", "tqdm", "gym", + "pysr" ], ) diff --git a/tests/long_test_half_cheetah.py b/tests/long_test_half_cheetah.py index a7bd675..e1c10ad 100644 --- a/tests/long_test_half_cheetah.py +++ b/tests/long_test_half_cheetah.py @@ -30,11 +30,11 @@ def long_test(): clf = DecisionTreeRegressor( max_leaf_nodes=32 ) # Change to DecisionTreeClassifier for discrete Actions. - tree_policy = ObliqueDTPolicy(clf, env) # + learner = ObliqueDTPolicy(clf, env) # # You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs. # Start the imitation learning - interpret = Interpreter(oracle, tree_policy, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(3) # Eval and save the best tree diff --git a/tests/test_policies_interpreter.py b/tests/test_policies_interpreter.py index 695d77e..8b27424 100644 --- a/tests/test_policies_interpreter.py +++ b/tests/test_policies_interpreter.py @@ -40,7 +40,6 @@ def test_dt_policy_ctnuous_actions(): s, _ = env.reset() policy.predict(s) - def test_dt_policy_wrong_clf(): env = gym.make("Acrobot-v1") clf = DecisionTreeRegressor(max_leaf_nodes=8) @@ -49,7 +48,6 @@ def test_dt_policy_wrong_clf(): except AssertionError: pass - def test_dt_policy_ctnuous_actions_wrong_clf(): env = gym.make("Pendulum-v1") clf = DecisionTreeClassifier(max_leaf_nodes=8) @@ -80,8 +78,8 @@ def test_interpreter(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeClassifier(max_leaf_nodes=8) - tree_policy = DTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = DTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(5) @@ -90,8 +88,8 @@ def test_interpreter_oblique(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeClassifier(max_leaf_nodes=8) - tree_policy = ObliqueDTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = ObliqueDTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(5) @@ -100,8 +98,8 @@ def test_interpreter_ctnuous_actions(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeRegressor(max_leaf_nodes=8) - tree_policy = DTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = DTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(3) @@ -110,8 +108,8 @@ def test_interpreter_oblique_ctnuous_actions(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeRegressor(max_leaf_nodes=8) - tree_policy = ObliqueDTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = ObliqueDTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(3) interpret.policy(env.reset()[0]) @@ -121,19 +119,21 @@ def test_interpreter_oblique_ctnuous_actions_high_dim(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeRegressor(max_leaf_nodes=8) - tree_policy = ObliqueDTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = ObliqueDTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(3) interpret.policy(env.reset()[0]) + + def test_interpreter_ctnuous_actions_high_dim(): env = gym.make("Ant-v4") model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeRegressor(max_leaf_nodes=8) - tree_policy = DTPolicy(clf, env) - interpret = Interpreter(oracle, tree_policy, env) + learner = DTPolicy(clf, env) + interpret = Interpreter(oracle, learner, env) interpret.fit(3) interpret.policy(env.reset()[0]) @@ -143,13 +143,13 @@ def test_interpreter_rlberry(): model = PPO("MlpPolicy", env) oracle = SB3Policy(model.policy) clf = DecisionTreeRegressor(max_leaf_nodes=8) - tree_policy = DTPolicy(clf, env) + learner = DTPolicy(clf, env) exp = ExperimentManager( agent_class=Interpreter, train_env=(gym_make, {"id": "Ant-v4"}), fit_budget=1e4, - init_kwargs=dict(oracle=oracle, tree_policy=tree_policy), + init_kwargs=dict(oracle=oracle, learner=learner), n_fit=2, seed=42, ) diff --git a/tests/test_symb_policies.py b/tests/test_symb_policies.py new file mode 100644 index 0000000..10c1090 --- /dev/null +++ b/tests/test_symb_policies.py @@ -0,0 +1,60 @@ +from pysr import PySRRegressor + +from interpreter import SB3Policy, Interpreter, SymbPolicy +import gymnasium as gym +from stable_baselines3 import PPO +from rlberry.manager import ( + ExperimentManager, + evaluate_agents, +) +from rlberry.envs import gym_make + + +def test_symb_policy_ctnuous_actions(): + env = gym.make("Pendulum-v1") + model = PySRRegressor(binary_operators=["+", "-"]) + policy = SymbPolicy(model, env) + s, _ = env.reset() + policy.predict(s) + +def test_symb_policy_discrete_actions(): + env = gym.make("Acrobot-v1") + model = PySRRegressor(binary_operators=["+", "-"]) + try: + policy = SymbPolicy(model, env) + except AssertionError: + pass + +def test_interpreter_symb_ctnuous_actions_high_dim(): + env = gym.make("Swimmer-v4") + model = PPO("MlpPolicy", env) + oracle = SB3Policy(model.policy) + model = PySRRegressor(binary_operators=["+", "-"]) + learner = SymbPolicy(model, env) + interpret = Interpreter(oracle, learner, env) + interpret.fit(5e3) + interpret.policy(env.reset()[0]) + + +def test_interpreter_rlberry(): + env = gym.make("Swimmer-v4") + model = PPO("MlpPolicy", env) + oracle = SB3Policy(model.policy) + model = PySRRegressor(binary_operators=["+", "-"], temp_equation_file=True) + learner = SymbPolicy(model, env) + + exp = ExperimentManager( + agent_class=Interpreter, + train_env=(gym_make, {"id": "Swimmer-v4"}), + fit_budget=1e4, + init_kwargs=dict(oracle=oracle, learner=learner), + n_fit=1, + seed=42, + ) + exp.fit() + + _ = evaluate_agents( + [exp], n_simulations=50, show=False + ) # Evaluate the trained agent on + +