From 719b0f3bf67a9e1209a885dc5344de8b5a62efff Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Thu, 19 Dec 2024 15:35:49 -0800 Subject: [PATCH] Issue #66: Rewarder composition --- docs/source/release_notes.rst | 1 + src/bsk_rl/data/__init__.py | 13 ++ src/bsk_rl/data/base.py | 2 +- src/bsk_rl/data/composition.py | 227 ++++++++++++++++++++++++ src/bsk_rl/gym.py | 19 +- tests/integration/data/test_int_data.py | 52 ++++++ tests/unittest/data/test_composition.py | 156 ++++++++++++++++ tests/unittest/test_gym_env.py | 13 ++ 8 files changed, 477 insertions(+), 6 deletions(-) create mode 100644 src/bsk_rl/data/composition.py create mode 100644 tests/unittest/data/test_composition.py diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 0b11ce97..9eca2b6c 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -26,6 +26,7 @@ Development Version * Improve performance of :class:`~bsk_rl.obs.Eclipse` observations by about 95%. * Logs a warning if the initial battery charge or buffer level is incompatible with its capacity. * Optimize communication when all satellites are communicating with each other. +* Allow for the specification of multiple rewarders in the environment. diff --git a/src/bsk_rl/data/__init__.py b/src/bsk_rl/data/__init__.py index b86d2ab5..5cdc32a7 100644 --- a/src/bsk_rl/data/__init__.py +++ b/src/bsk_rl/data/__init__.py @@ -74,6 +74,19 @@ ... ) +Multiple reward systems can be added to the environment by instead passing an iterable of +reward systems to the ``data`` field of the environment constructor: + +.. code-block:: python + + env = ConstellationTasking( + ..., + data=(ScanningTimeReward(), SomeOtherReward()), + ... + ) + +On the backend, this creates a :class:`~bsk_rl.data.composition.ComposedDataStore` that +handles the combination of multiple reward systems. """ from bsk_rl.data.base import GlobalReward diff --git a/src/bsk_rl/data/base.py b/src/bsk_rl/data/base.py index d9e9d609..d4986b0f 100644 --- a/src/bsk_rl/data/base.py +++ b/src/bsk_rl/data/base.py @@ -193,7 +193,7 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]: self.data += new_data nonzero_reward = {k: v for k, v in reward.items() if v != 0} - logger.info(f"Data reward: {nonzero_reward}") + logger.info(f"Total reward: {nonzero_reward}") return reward diff --git a/src/bsk_rl/data/composition.py b/src/bsk_rl/data/composition.py new file mode 100644 index 00000000..28fa78b1 --- /dev/null +++ b/src/bsk_rl/data/composition.py @@ -0,0 +1,227 @@ +"""Data composition classes.""" + +import logging +from typing import TYPE_CHECKING, Optional + +from bsk_rl.data.base import Data, DataStore, GlobalReward +from bsk_rl.sats import Satellite +from bsk_rl.scene.scenario import Scenario + +if TYPE_CHECKING: + from bsk_rl.sats import Satellite + +logger = logging.getLogger(__name__) + + +class ComposedData(Data): + """Data for composed data types.""" + + def __init__(self, *data: Data) -> None: + """Data for composed data types. + + Args: + data: Data types to compose. + """ + self.data = data + + def __add__(self, other: "ComposedData") -> "ComposedData": + """Combine two units of composed data. + + Args: + other: Another unit of composed data to combine with this one. + + Returns: + Combined unit of composed data. + """ + if len(self.data) == 0 and len(other.data) == 0: + data = [] + elif len(self.data) == 0: + data = [type(d)() + d for d in other.data] + elif len(other.data) == 0: + data = [d + type(d)() for d in self.data] + elif len(self.data) == len(other.data): + data = [d1 + d2 for d1, d2 in zip(self.data, other.data)] + else: + raise ValueError( + "ComposedData units must have the same number of data types." + ) + return ComposedData(*data) + + def __getattr__(self, name: str): + """Search for an attribute in the datas.""" + for data in self.data: + if hasattr(data, name): + return getattr(data, name) + raise AttributeError(f"No Data in ComposedData has attribute '{name}'") + + +class ComposedDataStore(DataStore): + data_type = ComposedData + + def pass_data(self) -> None: + """Pass data to the sub-datastores. + + :meta private: + """ + for ds, data in zip(self.datastores, self.data.data): + ds.data = data + + def __init__( + self, + satellite: "Satellite", + *datastore_types: type[DataStore], + initial_data: Optional[ComposedData] = None, + ): + """DataStore for composed data types. + + Args: + satellite: Satellite which data is being stored for. + datastore_types: DataStore types to compose. + initial_data: Initial data to start the store with. Usually comes from + :class:`~bsk_rl.data.GlobalReward.initial_data`. + """ + self.data: ComposedData + super().__init__(satellite, initial_data) + self.datastores = tuple([ds(satellite) for ds in datastore_types]) + self.pass_data() + + def __getattr__(self, name: str): + """Search for an attribute in the datastores.""" + for datastore in self.datastores: + if hasattr(datastore, name): + return getattr(datastore, name) + raise AttributeError( + f"No DataStore in ComposedDataStore has attribute '{name}'" + ) + + def get_log_state(self) -> list: + """Pull information used in determining current data contribution.""" + log_states = [ds.get_log_state() for ds in self.datastores] + return log_states + + def compare_log_states(self, prev_state: list, new_state: list) -> Data: + """Generate a unit of composed data based on previous step and current step logs.""" + data = [ + ds.compare_log_states(prev, new) + for ds, prev, new in zip(self.datastores, prev_state, new_state) + ] + return ComposedData(*data) + + def update_from_logs(self) -> Data: + """Update the data store based on collected information.""" + new_data = super().update_from_logs() + self.pass_data() + return new_data + + def update_with_communicated_data(self) -> None: + """Update the data store based on collected information from other satellites.""" + super().update_with_communicated_data() + self.pass_data() + + +class ComposedReward(GlobalReward): + datastore_type = ComposedDataStore + + def pass_data(self) -> Data: + """Pass data to the sub-rewarders. + + :meta private: + """ + for rewarder, data in zip(self.rewarders, self.data.data): + rewarder.data = data + + def __init__(self, *rewarders: GlobalReward) -> None: + """Rewarder for composed data types. + + This type can be automatically constructed by passing a tuple of rewarders to + the environment constructor's `reward` argument. + + Args: + rewarders: Global rewarders to compose. + """ + super().__init__() + self.rewarders = rewarders + + def __getattr__(self, name: str): + """Search for an attribute in the rewarders.""" + for rewarder in self.rewarders: + if hasattr(rewarder, name): + return getattr(rewarder, name) + raise AttributeError( + f"No GlobalReward in ComposedReward has attribute '{name}'" + ) + + def reset_pre_sim_init(self) -> None: + """Handle resetting for all rewarders.""" + super().reset_pre_sim_init() + for rewarder in self.rewarders: + rewarder.reset_pre_sim_init() + + def reset_post_sim_init(self) -> None: + """Handle resetting for all rewarders.""" + super().reset_post_sim_init() + for rewarder in self.rewarders: + rewarder.reset_post_sim_init() + + def reset_overwrite_previous(self) -> None: + """Handle resetting for all rewarders.""" + super().reset_overwrite_previous() + for rewarder in self.rewarders: + rewarder.reset_overwrite_previous() + + def link_scenario(self, scenario: Scenario) -> None: + """Link the rewarder to the scenario.""" + super().link_scenario(scenario) + for rewarder in self.rewarders: + rewarder.link_scenario(scenario) + + def initial_data(self, satellite: Satellite) -> ComposedData: + """Furnsish the datastore with :class:`ComposedData`.""" + return ComposedData( + *[rewarder.initial_data(satellite) for rewarder in self.rewarders] + ) + + def create_data_store(self, satellite: Satellite) -> None: + """Create a :class:`CompositeDataStore` for a satellite.""" + # TODO support passing kwargs + satellite.data_store = ComposedDataStore( + satellite, + *[r.datastore_type for r in self.rewarders], + initial_data=self.initial_data(satellite), + ) + self.cum_reward[satellite.name] = 0.0 + + def calculate_reward( + self, new_data_dict: dict[str, ComposedData] + ) -> dict[str, float]: + """Calculate reward for each data type and combine them.""" + data_len = len(list(new_data_dict.values())[0].data) + + for data in new_data_dict.values(): + assert len(data.data) == data_len + + reward = {} + if data_len != 0: + for i, rewarder in enumerate(self.rewarders): + reward_i = rewarder.calculate_reward( + {sat_id: data.data[i] for sat_id, data in new_data_dict.items()} + ) + + # Logging + nonzero_reward = {k: v for k, v in reward_i.items() if v != 0} + if len(nonzero_reward) > 0: + logger.info(f"{type(rewarder).__name__} reward: {nonzero_reward}") + + for sat_id, sat_reward in reward_i.items(): + reward[sat_id] = reward.get(sat_id, 0.0) + sat_reward + return reward + + def reward(self, new_data_dict: dict[str, ComposedData]) -> dict[str, float]: + """Return combined reward calculation and update data.""" + reward = super().reward(new_data_dict) + self.pass_data() + return reward + + +__doc_title__ = "Data Composition" +__all__ = ["ComposedReward", "ComposedDataStore", "ComposedData"] diff --git a/src/bsk_rl/gym.py b/src/bsk_rl/gym.py index a2e88814..146a294d 100644 --- a/src/bsk_rl/gym.py +++ b/src/bsk_rl/gym.py @@ -13,6 +13,7 @@ from bsk_rl.comm import CommunicationMethod, NoCommunication from bsk_rl.data import GlobalReward, NoReward +from bsk_rl.data.composition import ComposedReward from bsk_rl.sats import Satellite from bsk_rl.scene import Scenario from bsk_rl.sim import Simulator @@ -36,7 +37,7 @@ def __init__( self, satellites: Union[Satellite, list[Satellite]], scenario: Optional[Scenario] = None, - rewarder: Optional[GlobalReward] = None, + rewarder: Optional[Union[GlobalReward, list[GlobalReward]]] = None, world_type: Optional[type[WorldModel]] = None, world_args: Optional[dict[str, Any]] = None, communicator: Optional[CommunicationMethod] = None, @@ -68,7 +69,8 @@ def __init__( scenario: Environment the satellite is acting in; contains information about targets, etc. See :ref:`bsk_rl.scene`. rewarder: Handles recording and rewarding for data collection towards - objectives. See :ref:`bsk_rl.data`. + objectives. Can be a single rewarder or a tuple of multiple rewarders. + See :ref:`bsk_rl.data`. communicator: Manages communication between satellites. See :ref:`bsk_rl.comm`. sat_arg_randomizer: For correlated randomization of satellites arguments. Should be a function that takes a list of satellites and returns a dictionary that @@ -125,8 +127,6 @@ def __init__( if scenario is None: scenario = Scenario() - if rewarder is None: - rewarder = NoReward() if world_type is None: world_type = self._minimum_world_model() @@ -137,7 +137,16 @@ def __init__( self.scenario = deepcopy(scenario) self.scenario.link_satellites(self.satellites) - self.rewarder = deepcopy(rewarder) + + rewarder = deepcopy(rewarder) + if rewarder is None: + rewarder = NoReward() + if ( + isinstance(rewarder, Iterable) + and not type(rewarder).__name__ == "MagicMock" + ): + rewarder = ComposedReward(*rewarder) + self.rewarder = rewarder self.rewarder.link_scenario(self.scenario) if communicator is None: diff --git a/tests/integration/data/test_int_data.py b/tests/integration/data/test_int_data.py index f652899e..96f1d4d2 100644 --- a/tests/integration/data/test_int_data.py +++ b/tests/integration/data/test_int_data.py @@ -1,5 +1,57 @@ +import gymnasium as gym + +from bsk_rl import act, data, obs, sats, scene +from bsk_rl.data.composition import ComposedReward +from bsk_rl.utils.orbital import random_orbit + # For data models not tested in other tests # NoData sufficiently checked in many cases # UniqueImageData sufficiently checked in test_int_communication + +# from ..test_int_full_environments + + +class FullFeaturedSatellite(sats.ImagingSatellite): + observation_spec = [ + obs.SatProperties(dict(prop="r_BN_P", module="dynamics", norm=6e6)), + obs.Time(), + ] + action_spec = [act.Image(n_ahead_image=10)] + + +def test_multi_rewarder(): + env = gym.make( + "GeneralSatelliteTasking-v1", + satellites=[ + FullFeaturedSatellite( + "Sentinel-2A", + sat_args=FullFeaturedSatellite.default_sat_args( + oe=random_orbit, + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ), + FullFeaturedSatellite( + "Sentinel-2B", + sat_args=FullFeaturedSatellite.default_sat_args( + oe=random_orbit, + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ), + ], + scenario=scene.UniformTargets(n_targets=1000), + rewarder=(data.UniqueImageReward(), data.UniqueImageReward()), + sim_rate=0.5, + max_step_duration=1e9, + time_limit=5700.0, + disable_env_checker=True, + ) + + assert isinstance(env.unwrapped.rewarder, ComposedReward) + + env.reset() + for _ in range(10): + env.step(env.action_space.sample()) diff --git a/tests/unittest/data/test_composition.py b/tests/unittest/data/test_composition.py new file mode 100644 index 00000000..6413feee --- /dev/null +++ b/tests/unittest/data/test_composition.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from bsk_rl.data.composition import ComposedData, ComposedDataStore, ComposedReward + + +class TestComposedData: + @pytest.mark.parametrize( + "data1, data2, expected", + [ + ((1, 2), (3, 4), (4, 6)), + ((), (3, 4), (3, 4)), + ((1, 2), (), (1, 2)), + ((), (), ()), + ], + ) + def test_add(self, data1, data2, expected): + composed_data = ComposedData(*data1) + other_composed_data = ComposedData(*data2) + result = composed_data + other_composed_data + assert result.data == expected + + def test_add_different_lengths(self): + composed_data = ComposedData(1, 2) + other_composed_data = ComposedData(3) + with pytest.raises(ValueError): + composed_data + other_composed_data + + def test_getattr(self): + data1 = MagicMock() + data1.a = 1 + del data1.b + del data1.c + data2 = MagicMock() + data2.b = 2 + del data2.a + del data2.c + composed_data = ComposedData(data1, data2) + assert composed_data.a == 1 + assert composed_data.b == 2 + with pytest.raises(AttributeError): + _ = composed_data.c + + +class TestComposedDataStore: + def test_pass_data(self): + sat = MagicMock() + ds1 = MagicMock() + ds1_type = MagicMock(return_value=ds1) + ds2 = MagicMock() + ds2_type = MagicMock(return_value=ds2) + composed_data_store = ComposedDataStore(sat, ds1_type, ds2_type) + composed_data_store.data = ComposedData(1, 2) + composed_data_store.pass_data() + assert ds1.data == 1 + assert ds2.data == 2 + + def test_getattr(self): + sat = MagicMock() + ds1 = MagicMock() + ds1_type = MagicMock(return_value=ds1) + ds2 = MagicMock() + ds2_type = MagicMock(return_value=ds2) + composed_data_store = ComposedDataStore(sat, ds1_type, ds2_type) + ds1.a = 1 + assert composed_data_store.a == 1 + + def test_get_log_state(self): + sat = MagicMock() + ds1 = MagicMock(get_log_state=MagicMock(return_value=1)) + ds1_type = MagicMock(return_value=ds1) + ds2 = MagicMock(get_log_state=MagicMock(return_value=2)) + ds2_type = MagicMock(return_value=ds2) + composed_data_store = ComposedDataStore(sat, ds1_type, ds2_type) + + log_states = composed_data_store.get_log_state() + for ds in [ds1, ds2]: + ds.get_log_state.assert_called_once() + assert log_states == [1, 2] + + def test_compare_log_states(self): + sat = MagicMock() + ds1 = MagicMock(get_log_state=MagicMock(return_value=1)) + ds1_type = MagicMock(return_value=ds1) + ds2 = MagicMock(get_log_state=MagicMock(return_value=2)) + ds2_type = MagicMock(return_value=ds2) + composed_data_store = ComposedDataStore(sat, ds1_type, ds2_type) + + composed_data_store.compare_log_states([1, 2], [3, 4]) + ds1.compare_log_states.assert_called_once_with(1, 3) + ds2.compare_log_states.assert_called_once_with(2, 4) + + +class TestComposedReward: + def test_pass_data(self): + rewarder1 = MagicMock() + rewarder2 = MagicMock() + composed_rewarder = ComposedReward(rewarder1, rewarder2) + composed_rewarder.data = ComposedData(1, 2) + composed_rewarder.pass_data() + assert rewarder1.data == 1 + assert rewarder2.data == 2 + + @pytest.mark.parametrize( + "function", + [ + "reset_pre_sim_init", + "reset_post_sim_init", + "reset_overwrite_previous", + ], + ) + def test_resetable(self, function): + rewarder1 = MagicMock() + rewarder2 = MagicMock() + composed_rewarder = ComposedReward(rewarder1, rewarder2) + getattr(composed_rewarder, function)() + for rewarder in [rewarder1, rewarder2]: + getattr(rewarder, function).assert_called_once() + + def test_initial_data(self): + rewarder1 = MagicMock(initial_data=MagicMock(return_value=1)) + rewarder2 = MagicMock(initial_data=MagicMock(return_value=2)) + composed_rewarder = ComposedReward(rewarder1, rewarder2) + data = composed_rewarder.initial_data("sat") + assert data.data == (1, 2) + + def test_data_store(self): + sat = MagicMock() + ds1 = MagicMock(get_log_state=MagicMock(return_value=1)) + ds1_type = MagicMock(return_value=ds1) + rewarder1 = MagicMock(datastore_type=ds1_type) + ds2 = MagicMock(get_log_state=MagicMock(return_value=2)) + ds2_type = MagicMock(return_value=ds2) + rewarder2 = MagicMock(datastore_type=ds2_type) + composed_rewarder = ComposedReward(rewarder1, rewarder2) + composed_rewarder.create_data_store(sat) + assert sat.data_store.datastores == (ds1, ds2) + + def test_calculate_reward(self): + rewarder1 = MagicMock( + calculate_reward=MagicMock(return_value={"sat1": 1, "sat2": 2}) + ) + rewarder2 = MagicMock( + calculate_reward=MagicMock(return_value={"sat1": 3, "sat2": 4}) + ) + composed_rewarder = ComposedReward(rewarder1, rewarder2) + + reward = composed_rewarder.calculate_reward( + { + "sat1": MagicMock(data=["d11", "d21"]), + "sat2": MagicMock(data=["d12", "d22"]), + } + ) + + assert reward == {"sat1": 4, "sat2": 6} diff --git a/tests/unittest/test_gym_env.py b/tests/unittest/test_gym_env.py index 91744701..04584c02 100644 --- a/tests/unittest/test_gym_env.py +++ b/tests/unittest/test_gym_env.py @@ -4,6 +4,7 @@ from gymnasium import spaces from bsk_rl import ConstellationTasking, GeneralSatelliteTasking, SatelliteTasking +from bsk_rl.data.composition import ComposedReward from bsk_rl.sats import Satellite @@ -68,6 +69,18 @@ def test_minimum_world_model_mixed(self): assert issubclass(model, TypeA) assert issubclass(model, TypeB) + def test_multiple_rewarders(self): + mock_sat = MagicMock() + mock_sat.sat_args_generator = {} + mock_rewarder = [MagicMock(scenario=None), MagicMock(scenario=None)] + env = GeneralSatelliteTasking( + satellites=[mock_sat], + world_type=MagicMock(), + scenario=MagicMock(), + rewarder=mock_rewarder, + ) + assert isinstance(env.rewarder, ComposedReward) + @patch("bsk_rl.gym.Simulator") def test_reset(self, mock_sim): mock_sat = MagicMock()