Skip to content

Commit

Permalink
OA: RL PID Controller (SISO) #1050
Browse files Browse the repository at this point in the history
  • Loading branch information
amesin13 committed Oct 1, 2024
1 parent 04e7995 commit f2c508f
Showing 1 changed file with 104 additions and 25 deletions.
129 changes: 104 additions & 25 deletions src/mlpro/oa/control/controllers/oa_pid_controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from mlpro.bf.control.controllers.pid_controller import PIDController
from mlpro.bf.ml.basics import *
from mlpro.bf.streams import InstDict
from mlpro.rl import Policy, FctReward
from mlpro.oa.control.basics import OAController
from mlpro.bf.math.basics import Log,Set,MSpace
#from mlpro.oa.control.basics import OAController
from mlpro.bf.math.basics import Log,Set,MSpace,Dimension
from mlpro.bf.mt import Log, Task
from mlpro.bf.systems import Action
from mlpro.bf.control.basics import CTRLError, ControlError, Controller, SetPoint
from mlpro.bf.control.basics import ControlError, Controller, SetPoint
from mlpro.bf.systems.basics import ActionElement, State
from mlpro.bf.various import Log
from mlpro.bf.streams import InstDict, Instance
Expand All @@ -18,55 +19,133 @@

class RLPID(Policy):

def __init__(self, p_observation_space: MSpace, p_action_space: MSpace,pid_controller:PIDController ,p_id=None, p_buffer_size: int = 1, p_ada: bool = True, p_visualize: bool = False, p_logging=Log.C_LOG_ALL ):
def __init__(self, p_observation_space: MSpace, p_action_space: MSpace,pid_controller:PIDController ,policy:Policy=None,p_id=None, p_buffer_size: int = 1, p_ada: bool = True, p_visualize: bool = False, p_logging=Log.C_LOG_ALL ):
super().__init__(p_observation_space, p_action_space, p_id, p_buffer_size, p_ada, p_visualize, p_logging)

self._pid_controller = pid_controller
self._policy = policy
self._old_action = None #None
self._action_space = p_action_space
"""
policy_sb3 = WrPolicySB32MLPro(
PPO(policy="MlpPolicy",n_steps=5,env=None,_init_setup_model=False,device="cpu")
,p_cycle_limit=30,p_observation_space= p_observation_space
,p_action_space=p_action_space,p_ada=p_ada)
"""

## -------------------------------------------------------------------------------------------------
def _init_hyperparam(self, **p_par):

# 1 Create a dispatcher hyperparameter tuple for the RLPID policy
self._hyperparam_tuple = HyperParamDispatcher(p_set=self._hyperparam_space)

# 2 Extend RLPID policy's hp space and tuple from policy
try:
self._hyperparam_space.append( self._policy.get_hyperparam().get_related_set(), p_new_dim_ids=False)
self._hyperparam_tuple.add_hp_tuple(self._policy.get_hyperparam())
except:
pass

## -------------------------------------------------------------------------------------------------
def get_hyperparam(self) -> HyperParamTuple:
return self._policy.get_hyperparam()

## -------------------------------------------------------------------------------------------------
def _update_hyperparameters(self) -> bool:
return self._policy._update_hyperparameters()

## -------------------------------------------------------------------------------------------------

def _adapt(self, p_sars_elem: SARSElement) -> bool:

"""
if self._old_action is None:
#create a pid action
self._old_action = Action(p_action_space=self._action_space,p_values=self._pid_controller.get_parameter_values())

#get SARS Elements
p_state,p_action,p_reward,p_state_new=tuple(p_sars_elem.get_data().keys())

# create a new SARS
p_sars_elem_new = SARSElement(p_state=p_state,
p_action=self._old_action,
p_reward=p_reward,
p_state_new=p_state_new)

#adapt own policy
is_adapted = self._policy.adapt(p_kwargs=p_sars_elem_new)

policy_sb3 = PPO(
policy="MlpPolicy",
n_steps=5,
env=None,
_init_setup_model=False,
device="cpu")
# compute new action with new error value (second s of Sars element)
self._old_action=self._policy.compute_action(p_obs=p_state_new)

sb3_policy =WrPolicySB32MLPro()
sb3_policy._adapt_on_policy(p_sars_elem)
sb3_policy._compute_action_on_policy()
#get the pid paramter values
pid_values = self._old_action.get_feature_data().get_values()

#set paramter pid
self._pid_controller.set_parameter(p_param={"Kp":pid_values[0],
"Ti":pid_values[1],
"Tv":pid_values[2]})
return is_adapted

## -------------------------------------------------------------------------------------------------

def compute_action(self, p_obs: State) -> Action:

#get action
action=self._pid_controller.compute_action(p_ctrl_error=p_obs)

#return action
return action

class RLPIDOffPolicy(Policy):

def __init__(self, p_observation_space: MSpace, p_action_space: MSpace,pid_controller:PIDController ,p_id=None, p_buffer_size: int = 1, p_ada: bool = True, p_visualize: bool = False, p_logging=Log.C_LOG_ALL ):
super().__init__(p_observation_space, p_action_space, p_id, p_buffer_size, p_ada, p_visualize, p_logging)

self._pid_controller = pid_controller
self._action_space = p_action_space.get_dim(p_id=0).get


def _init_hyperparam(self, **p_par):

# create hp
# 1- add dim (Kp,Tn,Tv) in hp space
# 2- create hp tuple from hp space
# 3- set hp tuple values


# 1
self._hyperparam_space.add_dim( self._action_space.get_dim(p_id=0))
self._hyperparam_space.add_dim(self._action_space.get_dim(p_id=1))
self._hyperparam_space.add_dim(self._action_space.get_dim(p_id=2))

p_param={}
self._pid_controller.set_parameter(p_param)
# 2
self._hyperparam_tuple = HyperParamTuple( p_set=self._hyperparam_space )

"""
pass

#3
self._hyperparam_tuple.set_values(self._pid_controller.get_parameter_values())



## -------------------------------------------------------------------------------------------------


def _adapt(self, p_sars_elem: SARSElement) -> bool:
return False

## -------------------------------------------------------------------------------------------------

def compute_action(self, p_obs: State) -> Action:

#create control error from p_obs
crtl_error = ControlError(p_obs.get_feature_data(),p_obs.get_label_data(),p_obs.get_tstamp())
def compute_action(self, p_obs: State) -> Action:

#get action
action=self._pid_controller.compute_action(crtl_error)
action=self._pid_controller.compute_action(p_ctrl_error=p_obs)

#return action
return action









Expand Down

0 comments on commit f2c508f

Please sign in to comment.