From ee039cfde8a1c21d7a00646f0beca6ace79e53cc Mon Sep 17 00:00:00 2001 From: arnab39 Date: Wed, 5 Jun 2024 03:15:39 -0400 Subject: [PATCH 1/2] added automated model-driven prior discovery --- equiadapt/common/basecanonicalization.py | 35 +++- .../images/canonicalization/discrete_group.py | 171 +++++++++++------- .../configs/experiment/default.yaml | 1 + examples/images/classification/model.py | 14 +- 4 files changed, 149 insertions(+), 72 deletions(-) diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index 6f6dfc8..a82ce1a 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F # Base skeleton for the canonicalization class # DiscreteGroupCanonicalization and ContinuousGroupCanonicalization will inherit from this class @@ -287,18 +288,41 @@ def invert_canonicalization( """ raise NotImplementedError() - def get_prior_regularization_loss(self) -> torch.Tensor: + # def get_prior_regularization_loss(self) -> torch.Tensor: + # """ + # Gets the prior regularization loss. + + # Returns: + # torch.Tensor: The prior regularization loss. + # """ + # group_activations = self.canonicalization_info_dict["group_activations"] + # dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to( + # self.device + # ) + # return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior) + + def get_prior_regularization_loss(self, dataset_prior: Optional[torch.Tensor] = None) -> torch.Tensor: """ Gets the prior regularization loss. + Args: + dataset_prior (torch.Tensor, optional): The prior distribution. Defaults to identity prior. + Returns: torch.Tensor: The prior regularization loss. """ group_activations = self.canonicalization_info_dict["group_activations"] - dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to( - self.device - ) - return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior) + + # If dataset_prior is not provided, create an identity prior + if dataset_prior is None: + dataset_prior = torch.zeros_like(group_activations).to(self.device) + dataset_prior[:, 0] = 1.0 # Set the first column to 1, rest to 0 + + # Ensure group_activations are in log space + log_group_activations = F.log_softmax(group_activations, dim=1) + + # KL Divergence + return F.kl_div(log_group_activations, dataset_prior, reduction='batchmean') def get_identity_metric(self) -> torch.Tensor: """ @@ -406,6 +430,7 @@ def get_prior_regularization_loss(self) -> torch.Tensor: .to(self.device) ) return torch.nn.MSELoss()(group_elements_rep, dataset_prior) + def get_identity_metric(self) -> torch.Tensor: """ diff --git a/equiadapt/images/canonicalization/discrete_group.py b/equiadapt/images/canonicalization/discrete_group.py index 0ec66c5..061c3a5 100644 --- a/equiadapt/images/canonicalization/discrete_group.py +++ b/equiadapt/images/canonicalization/discrete_group.py @@ -90,6 +90,64 @@ def __init__( if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) ) + + # group augment specific cropping and padding (required for group_augment()) + group_augment_in_shape = canonicalization_hyperparams.resize_shape + self.crop_group_augment = ( + torch.nn.Identity() + if in_shape[0] == 1 + else transforms.CenterCrop(group_augment_in_shape) + ) + self.pad_group_augment = ( + torch.nn.Identity() + if in_shape[0] == 1 + else transforms.Pad( + math.ceil(group_augment_in_shape * 0.5), padding_mode="edge" + ) + ) + + def rotate_and_maybe_reflect( + self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False + ) -> List[torch.Tensor]: + """ + Rotate and maybe reflect the input images. + + Args: + x (torch.Tensor): The input image. + degrees (torch.Tensor): The degrees of rotation. + reflect (bool, optional): Whether to reflect the image. Defaults to False. + + Returns: + List[torch.Tensor]: The list of rotated and maybe reflected images. + """ + x_augmented_list = [] + for degree in degrees: + x_rot = self.pad_group_augment(x) + x_rot = K.geometry.rotate(x_rot, -degree) + if reflect: + x_rot = K.geometry.hflip(x_rot) + x_rot = self.crop_group_augment(x_rot) + x_augmented_list.append(x_rot) + return x_augmented_list + + def group_augment(self, x: torch.Tensor) -> torch.Tensor: + """ + Augment the input images by applying group transformations (rotations and reflections). + This function is used both for the energy based optimization method for the discrete rotation + + Args: + x (torch.Tensor): The input image. + + Returns: + torch.Tensor: The augmented image. + """ + degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device) + x_augmented_list = self.rotate_and_maybe_reflect(x, degrees) + + if self.group_type == "roto-reflection": + x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True) + + return torch.cat(x_augmented_list, dim=0) def groupactivations_to_groupelement(self, group_activations: torch.Tensor) -> dict: """ @@ -257,6 +315,54 @@ def invert_canonicalization( group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore induced_rep_type=induced_rep_type, ) + + def get_prior( + self, + x: torch.Tensor, + model: torch.nn.Module, + targets: torch.Tensor, + metric_function: torch.nn.Module, + tau: float = 1.0, + ) -> torch.Tensor: + """ + Get the prior for the input images. + + Args: + x (torch.Tensor): The input images. shape = (batch_size, in_channels, height, width) + model (torch.nn.Module): The prediction model which decides the prior. + targets (torch.Tensor): The targets for the task. shape = eg. (batch_size, num_classes) + metric_function (torch.nn.Module): The function to calculate the unnormalized probability masses for each group element. + tau (float, optional): The temperature parameter. Defaults to 1.0. Decides the sharpness of the prior distribution. + + Returns: + torch.Tensor: output prior of the model and x. shape = (batch_size, group_size) + """ + with torch.no_grad(): + batch_size = x.shape[0] + x_augmented = self.group_augment(x) # size (group_size * batch_size, in_channels, height, width) + # If a self.group_augment_target is defined, apply the same transformation to the targets + # Or else just repeat the targets for each group element in the first dimension + if hasattr(self, "group_augment_target"): + targets_augmented = self.group_augment_target(targets) + else: + targets_augmented = targets.repeat(self.num_group, 1).flatten() # size (group_size * batch_size) + + # Get the output of the model for the augmented images + model_output = model(x_augmented) # size eg (group_size * batch_size, num_classes) + + # Get the unnormalized probability masses for each group element + unnormalized_prob_masses = metric_function( + model_output, targets_augmented + ).reshape(self.num_group, batch_size).transpose(0, 1) # size (batch_size, group_size) + + # Get the prior for the input images + prior = F.softmax(unnormalized_prob_masses / tau, dim=-1) # size (batch_size, group_size) + + return prior + + + + class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): @@ -360,21 +466,6 @@ def __init__( ) self.out_vector_size = canonicalization_network.out_vector_size - # group optimization specific cropping and padding (required for group_augment()) - group_augment_in_shape = canonicalization_hyperparams.resize_shape - self.crop_group_augment = ( - torch.nn.Identity() - if in_shape[0] == 1 - else transforms.CenterCrop(group_augment_in_shape) - ) - self.pad_group_augment = ( - torch.nn.Identity() - if in_shape[0] == 1 - else transforms.Pad( - math.ceil(group_augment_in_shape * 0.5), padding_mode="edge" - ) - ) - self.reference_vector = torch.nn.Parameter( torch.randn(1, self.out_vector_size), requires_grad=canonicalization_hyperparams.learn_ref_vec, @@ -384,48 +475,6 @@ def __init__( "num_group": self.num_group, } - def rotate_and_maybe_reflect( - self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False - ) -> List[torch.Tensor]: - """ - Rotate and maybe reflect the input images. - - Args: - x (torch.Tensor): The input image. - degrees (torch.Tensor): The degrees of rotation. - reflect (bool, optional): Whether to reflect the image. Defaults to False. - - Returns: - List[torch.Tensor]: The list of rotated and maybe reflected images. - """ - x_augmented_list = [] - for degree in degrees: - x_rot = self.pad_group_augment(x) - x_rot = K.geometry.rotate(x_rot, -degree) - if reflect: - x_rot = K.geometry.hflip(x_rot) - x_rot = self.crop_group_augment(x_rot) - x_augmented_list.append(x_rot) - return x_augmented_list - - def group_augment(self, x: torch.Tensor) -> torch.Tensor: - """ - Augment the input images by applying group transformations (rotations and reflections). - - Args: - x (torch.Tensor): The input image. - - Returns: - torch.Tensor: The augmented image. - """ - degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device) - x_augmented_list = self.rotate_and_maybe_reflect(x, degrees) - - if self.group_type == "roto-reflection": - x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True) - - return torch.cat(x_augmented_list, dim=0) - def get_group_activations(self, x: torch.Tensor) -> torch.Tensor: """ Gets the group activations for the input image. @@ -437,12 +486,8 @@ def get_group_activations(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: The group activations. """ x = self.transformations_before_canonicalization_network_forward(x) - x_augmented = self.group_augment( - x - ) # size (batch_size * group_size, in_channels, height, width) - vector_out = self.canonicalization_network( - x_augmented - ) # size (batch_size * group_size, reference_vector_size) + x_augmented = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width) + vector_out = self.canonicalization_network(x_augmented) # size (batch_size * group_size, reference_vector_size) self.canonicalization_info_dict = {"vector_out": vector_out} if self.artifact_err_wt: diff --git a/examples/images/classification/configs/experiment/default.yaml b/examples/images/classification/configs/experiment/default.yaml index 67ef281..097e405 100644 --- a/examples/images/classification/configs/experiment/default.yaml +++ b/examples/images/classification/configs/experiment/default.yaml @@ -13,6 +13,7 @@ training: loss: task_weight: 1.0 # Weight of the task loss in the total loss prior_weight: 100.0 # Weight of the prior in the loss function if zero dont use it + automated_prior: true # Whether to use automated prior weight or not group_contrast_weight: 0 # Weight of the group contrastive loss (set to 0 for group_equivariant, 0.0001 for opt_equivariant) inference: method: group # Type of inference options 1) vanilla 2) group diff --git a/examples/images/classification/model.py b/examples/images/classification/model.py index 6fd9a49..f5391c3 100644 --- a/examples/images/classification/model.py +++ b/examples/images/classification/model.py @@ -1,9 +1,11 @@ +from math import tau import pytorch_lightning as pl import torch from inference_utils import get_inference_method from model_utils import get_dataset_specific_info, get_prediction_network from omegaconf import DictConfig from torch.optim.lr_scheduler import MultiStepLR +from torch.nn import functional as F from examples.images.common.utils import get_canonicalization_network, get_canonicalizer @@ -13,9 +15,6 @@ class ImageClassifierPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() - self.loss, self.image_shape, self.num_classes = get_dataset_specific_info( - hyperparams.dataset.dataset_name - ) self.loss, self.image_shape, self.num_classes = get_dataset_specific_info( hyperparams.dataset.dataset_name ) @@ -104,7 +103,14 @@ def training_step(self, batch: torch.Tensor): # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: - prior_loss = self.canonicalizer.get_prior_regularization_loss() + if self.hyperparams.experiment.training.loss.automated_prior: + def metric_function(model_predictions, targets): + return -F.cross_entropy(model_predictions, targets, reduction='none') + prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.1) + prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore + else: + prior_loss = self.canonicalizer.get_prior_regularization_loss() + loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight metric_identity = self.canonicalizer.get_identity_metric() training_metrics.update( From 13178a61d32d03ae0dbb07df0980a62a18024899 Mon Sep 17 00:00:00 2001 From: arnab39 Date: Wed, 31 Jul 2024 16:35:29 -0400 Subject: [PATCH 2/2] added RL --- examples/images/classification/model.py | 2 +- .../images/reinforcementlearning/__init__.py | 0 .../reinforcementlearning/configs/__init__.py | 0 .../canonicalization/group_equivariant.yaml | 11 + .../configs/canonicalization/identity.yaml | 1 + .../configs/default.yaml | 23 ++ .../configs/experiment/default.yaml | 12 + .../configs/prediction/default.yaml | 0 .../images/reinforcementlearning/network.py | 67 +++++ .../reinforcementlearning/prepare/__init__.py | 0 .../prepare/gym_cartpole.py | 94 +++++++ .../images/reinforcementlearning/train.py | 237 ++++++++++++++++++ .../images/reinforcementlearning/utils.py | 66 +++++ 13 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 examples/images/reinforcementlearning/__init__.py create mode 100644 examples/images/reinforcementlearning/configs/__init__.py create mode 100644 examples/images/reinforcementlearning/configs/canonicalization/group_equivariant.yaml create mode 100644 examples/images/reinforcementlearning/configs/canonicalization/identity.yaml create mode 100644 examples/images/reinforcementlearning/configs/default.yaml create mode 100644 examples/images/reinforcementlearning/configs/experiment/default.yaml create mode 100644 examples/images/reinforcementlearning/configs/prediction/default.yaml create mode 100644 examples/images/reinforcementlearning/network.py create mode 100644 examples/images/reinforcementlearning/prepare/__init__.py create mode 100644 examples/images/reinforcementlearning/prepare/gym_cartpole.py create mode 100644 examples/images/reinforcementlearning/train.py create mode 100644 examples/images/reinforcementlearning/utils.py diff --git a/examples/images/classification/model.py b/examples/images/classification/model.py index f5391c3..d758091 100644 --- a/examples/images/classification/model.py +++ b/examples/images/classification/model.py @@ -106,7 +106,7 @@ def training_step(self, batch: torch.Tensor): if self.hyperparams.experiment.training.loss.automated_prior: def metric_function(model_predictions, targets): return -F.cross_entropy(model_predictions, targets, reduction='none') - prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.1) + prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.01) prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore else: prior_loss = self.canonicalizer.get_prior_regularization_loss() diff --git a/examples/images/reinforcementlearning/__init__.py b/examples/images/reinforcementlearning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/images/reinforcementlearning/configs/__init__.py b/examples/images/reinforcementlearning/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/images/reinforcementlearning/configs/canonicalization/group_equivariant.yaml b/examples/images/reinforcementlearning/configs/canonicalization/group_equivariant.yaml new file mode 100644 index 0000000..eb2b11d --- /dev/null +++ b/examples/images/reinforcementlearning/configs/canonicalization/group_equivariant.yaml @@ -0,0 +1,11 @@ +canonicalization_type: group_equivariant +network_type: e2cnn # Options for canonization method 1) e2cnn 2) custom 3) none +network_hyperparams: + kernel_size: 5 # Kernel size for the canonization network + out_channels: 32 # Number of output channels for the canonization network + num_layers: 3 # Number of layers in the canonization network + group_type: rotation # Type of group for the canonization network + num_rotations: 4 # Number of rotations for the canonization network +beta: 1.0 # Beta parameter for the canonization network +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization +resize_shape: 64 # Resize shape for the input \ No newline at end of file diff --git a/examples/images/reinforcementlearning/configs/canonicalization/identity.yaml b/examples/images/reinforcementlearning/configs/canonicalization/identity.yaml new file mode 100644 index 0000000..1598d17 --- /dev/null +++ b/examples/images/reinforcementlearning/configs/canonicalization/identity.yaml @@ -0,0 +1 @@ +canonicalization_type: identity \ No newline at end of file diff --git a/examples/images/reinforcementlearning/configs/default.yaml b/examples/images/reinforcementlearning/configs/default.yaml new file mode 100644 index 0000000..95b8d43 --- /dev/null +++ b/examples/images/reinforcementlearning/configs/default.yaml @@ -0,0 +1,23 @@ +canonicalization_type: group_equivariant # will be set later in training script +device: cuda + +# metadata specialised for each experiment +core: + version: 0.0.1 + tags: + - ${now:%Y-%m-%d} + +hydra: + run: + dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/ + + sweep: + dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/ + subdir: ${hydra.job.num}_${hydra.job.id} + +defaults: + - _self_ + - env: default + - experiment: default + - canonicalization: identity + - wandb: default \ No newline at end of file diff --git a/examples/images/reinforcementlearning/configs/experiment/default.yaml b/examples/images/reinforcementlearning/configs/experiment/default.yaml new file mode 100644 index 0000000..909e87b --- /dev/null +++ b/examples/images/reinforcementlearning/configs/experiment/default.yaml @@ -0,0 +1,12 @@ +learning_rate: 0.001 +batch_size: 128 +gamma: 0.999 +eps_start: 0.9 +eps_end: 0.01 +eps_decay: 3000 +target_update: 10 +replay_memory_size: 100000 +end_score: 200 +training_stop: 142 +num_episodes: 50000 +last_episodes_num: 20 \ No newline at end of file diff --git a/examples/images/reinforcementlearning/configs/prediction/default.yaml b/examples/images/reinforcementlearning/configs/prediction/default.yaml new file mode 100644 index 0000000..e69de29 diff --git a/examples/images/reinforcementlearning/network.py b/examples/images/reinforcementlearning/network.py new file mode 100644 index 0000000..6c5729c --- /dev/null +++ b/examples/images/reinforcementlearning/network.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import random + +class DQN(nn.Module): + def __init__(self, input_shape, num_actions, dueling_DQN=False): + super(DQN, self).__init__() + + self.input_shape = input_shape + self.num_actions = num_actions + self.dueling_DQN = dueling_DQN + + self.features = nn.Sequential( + nn.Conv2d(input_shape[0], 32, kernel_size=5, stride=2), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=5, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=5, stride=2), + nn.BatchNorm2d(64), + nn.ReLU() + ) + + feature_size = self._get_feature_size() + + if self.dueling_DQN: + self.advantage = nn.Sequential( + nn.Linear(feature_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, self.num_actions) + ) + self.value = nn.Sequential( + nn.Linear(feature_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 1) + ) + else: + self.action_value = nn.Sequential( + nn.Linear(feature_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, self.num_actions) + ) + + def forward(self, x): + x = x.float() / 255 # Normalize the input + x = self.features(x) + x = x.view(x.size(0), -1) + + if self.dueling_DQN: + advantage = self.advantage(x) + value = self.value(x) + q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) + else: + q_values = self.action_value(x) + + return q_values + + def _get_feature_size(self): + self.features.eval() + with torch.no_grad(): + return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1) + diff --git a/examples/images/reinforcementlearning/prepare/__init__.py b/examples/images/reinforcementlearning/prepare/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/images/reinforcementlearning/prepare/gym_cartpole.py b/examples/images/reinforcementlearning/prepare/gym_cartpole.py new file mode 100644 index 0000000..cf68696 --- /dev/null +++ b/examples/images/reinforcementlearning/prepare/gym_cartpole.py @@ -0,0 +1,94 @@ +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +import gym + +class CartpoleWrapper(gym.Wrapper): + def __init__(self, env, env_hyperparams): + """Initialize the wrapper for the CartPole environment to preprocess images. + + Args: + env (gym.Env): The Gym environment to wrap. + env_hyperparams (dict): Dictionary containing settings for image preprocessing, + such as resize dimensions and whether to apply grayscale. + """ + super().__init__(env) + self.env = env + self.num_actions = env.action_space.n + self.state_shape = env.observation_space.shape + + # Base transformations that are always applied + transformations = [ + T.ToPILImage(), + T.Resize(env_hyperparams["resize_pixels"], + interpolation=Image.BICUBIC), + ] + + # Conditional grayscale transformation + if env_hyperparams["grayscale"]: + transformations.append(T.Grayscale()) + + # Final transformation to tensor + transformations.append(T.ToTensor()) + + # Compose all transformations into a single callable object + self.resize = T.Compose(transformations) + + def get_cart_location(self, screen_width): + """Calculate the cart's location on the screen for cropping. + + Args: + screen_width (int): The width of the screen from the environment. + + Returns: + int: The pixel location of the center of the cart. + """ + world_width = self.env.x_threshold * 2 + scale = screen_width / world_width + return int(self.env.state[0] * scale + screen_width / 2.0) # Middle of the cart + + def get_screen(self): + """Capture, process, and crop the environment's screen. + + Transforms the screen into a format suitable for input to a neural network: + crops, downsamples, converts to grayscale, and rescales. + + Returns: + torch.Tensor: The processed screen tensor ready for model input. + """ + # Capture screen from the environment + screen = self.env.render().transpose((2, 0, 1)) # CHW format + _, screen_height, screen_width = screen.shape + + # Crop the vertical dimension to focus on the main area of interest + screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)] + + # Define the width of the cropped area around the cart + view_width = int(screen_width * 0.6) + cart_location = self.get_cart_location(screen_width) + + # Calculate the horizontal slice range to center crop around the cart + if cart_location < view_width // 2: + slice_range = slice(view_width) + elif cart_location > (screen_width - view_width // 2): + slice_range = slice(-view_width, None) + else: + slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2) + + # Apply the calculated slice to crop horizontally + screen = screen[:, :, slice_range] + + # Normalize, convert to tensor, resize, and add a batch dimension + screen = np.ascontiguousarray(screen, dtype=np.float32) / 255. + screen = torch.from_numpy(screen) + return self.resize(screen).unsqueeze(0) + + def step(self, action): + """Apply an action to the environment, returning the processed screen, reward, done, and info.""" + return self.env.step(action) + + def reset(self): + """Reset the environment and return the initial processed screen.""" + self.env.reset() + diff --git a/examples/images/reinforcementlearning/train.py b/examples/images/reinforcementlearning/train.py new file mode 100644 index 0000000..4958c30 --- /dev/null +++ b/examples/images/reinforcementlearning/train.py @@ -0,0 +1,237 @@ +from re import A +from tracemalloc import stop +import gym +import os +import hydra +import omegaconf +import wandb +import random +import math + +from itertools import count +import torch +import torch.optim as optim +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from zmq import device + +from prepare.gym_cartpole import CartpoleWrapper +from collections import deque + +from typing import List, Tuple +from omegaconf import DictConfig +from network import DQN +from utils import ReplayMemory, Transition, load_envs + +from tqdm import tqdm + + +# Setup the environment using a wrapper +def setup_environment(env_hyperparams): + if env_hyperparams["name"] == "cartpole": + env = gym.make('CartPole-v1', render_mode='rgb_array') + env = CartpoleWrapper(env, env_hyperparams) + return env + +# Action selection , if stop training == True, only exploitation +def select_action(dqn, state, steps_done, exp_hyperparams, stop_training): + dqn.eval() + sample = random.random() + eps_threshold = exp_hyperparams["eps_end"] + (exp_hyperparams["eps_start"]- exp_hyperparams["eps_end"]) * \ + math.exp(-1. * steps_done / exp_hyperparams["eps_decay"]) + # print('Epsilon = ', eps_threshold, end='\n') + if sample > eps_threshold or stop_training: + with torch.no_grad(): + # t.max(1) will return largest column value of each row. + # second column on max result is index of where max element was + # found, so we pick action with the larger expected reward. + q_values = dqn(state) + action = q_values.max(1)[1].item() + else: + action = random.randrange(dqn.num_actions) + dqn.train() + return action + +def optimize_model( + memory: ReplayMemory, + dqn: DQN, + optimizer: optim.Optimizer, + target_dqn: DQN, + batch_size: int, + gamma: float) -> None: + """ + Optimize the DQN model using the given memory replay buffer. + + Args: + memory (ReplayMemory): The replay memory buffer. + dqn (DQN): The DQN model. + optimizer (optim.Optimizer): The optimizer for updating the model parameters. + target_dqn (DQN): The target DQN model. + batch_size (int): The batch size for training. + gamma (float): The discount factor for future rewards. + """ + if len(memory) < batch_size: + return + transitions = memory.sample(batch_size) + batch = Transition(*zip(*transitions)) + + state_batch = torch.cat(batch.state) + action_batch = torch.cat(batch.action).unsqueeze(1) + + device = state_batch.device + + non_final_mask = torch.tensor( + tuple(map(lambda s: s is not None, batch.next_state)), + device=device, dtype=torch.bool + ) + non_final_next_states = torch.cat( + [s for s in batch.next_state if s is not None] + ) + + reward_batch = torch.cat(batch.reward).type(torch.FloatTensor).to(device) + + state_action_values = dqn(state_batch).gather(1, action_batch) + + next_state_values = torch.zeros(batch_size, device=action_batch.device) + next_state_values[non_final_mask] = target_dqn(non_final_next_states).max(1)[0].detach() + expected_state_action_values = (next_state_values * gamma) + reward_batch + + loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) + wandb.log({'Loss:': loss.item()}) + + optimizer.zero_grad() + loss.backward() + for param in dqn.parameters(): + param.grad.data.clamp_(-1, 1) + optimizer.step() + + + +def train_rl(hyperparams: DictConfig) -> None: + # set system environment variables for wandb + if hyperparams["wandb"]["use_wandb"]: + print("Using wandb for logging...") + os.environ["WANDB_MODE"] = "online" + else: + print("Wandb disabled for logging...") + os.environ["WANDB_MODE"] = "disabled" + os.environ["WANDB_DIR"] = hyperparams["wandb"]["wandb_dir"] + os.environ["WANDB_CACHE_DIR"] = hyperparams["wandb"]["wandb_cache_dir"] + + # initialize wandb + wandb.init( + config=OmegaConf.to_container(hyperparams, resolve=True), + entity=hyperparams["wandb"]["wandb_entity"], + project=hyperparams["wandb"]["wandb_project"], + dir=hyperparams["wandb"]["wandb_dir"], + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + env = setup_environment(hyperparams["env"]) + env.reset() + init_screen = env.get_screen().to(device) + + _, inchannels, screen_height, screen_width = init_screen.shape + FRAMES = hyperparams["env"]["frames"] + input_shape = (FRAMES * inchannels, screen_height, screen_width) + print("Screen height: ", screen_height," | Width: ", screen_width) + + # Get number of actions from gym action space + num_actions = env.num_actions + + exp_hyperparams = hyperparams["experiment"] + + dqn = DQN(input_shape, num_actions).to(device) + target_dqn = DQN(input_shape, num_actions).to(device) + target_dqn.load_state_dict(dqn.state_dict()) + target_dqn.eval() + + optimizer = optim.RMSprop(dqn.parameters()) + # optimizer = optim.Adam(dqn.parameters(), lr=exp_hyperparams["learning_rate"]) + memory = ReplayMemory(exp_hyperparams["replay_memory_size"]) + mean_last = deque([0] * exp_hyperparams['last_episodes_num'], exp_hyperparams['last_episodes_num']) + + stop_training = False + count_final = 0 + steps_done = 0 + episode_durations = [] + num_episodes = exp_hyperparams["num_episodes"] + + # Wrap your range function with tqdm for a progress bar + for i_episode in tqdm(range(num_episodes), desc="Training Episodes"): + # for i_episode in range(exp_hyperparams["num_episodes"]): + # Initialize the environment and state + env.reset() + init_screen = env.get_screen().to(device) + screens = deque([init_screen] * FRAMES, FRAMES) + state = torch.cat(list(screens), dim=1) + + for t in count(): + + # Select and perform an action + action = select_action(dqn, state, steps_done, exp_hyperparams, stop_training) + + state_variables, _, done, _, _ = env.step(action) + steps_done += 1 + + # Observe new state + screens.append(env.get_screen().to(device)) + next_state = torch.cat(list(screens), dim=1) if not done else None + + # Reward modification for better stability + x, x_dot, theta, theta_dot = state_variables + r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8 + r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5 + reward = r1 + r2 + reward = torch.tensor([reward], device=device) + if t >= exp_hyperparams["end_score"] - 1: + reward = reward + 20 + done = 1 + else: + if done: + reward = reward - 20 + + # Store the transition in memory + action = torch.tensor([action], device=device) + memory.push(state, action, next_state, reward) + + # Move to the next state + state = next_state + + # Perform one step of the optimization (on the target network) + if done: + episode_durations.append(t + 1) + mean_last.append(t + 1) + mean = 0 + wandb.log({'Episode duration': t+1 , 'Episode number': i_episode}) + for i in range(exp_hyperparams['last_episodes_num']): + mean = mean_last[i] + mean + mean = mean/exp_hyperparams['last_episodes_num'] + if mean < exp_hyperparams['training_stop'] and stop_training == False: + optimize_model(memory, dqn, optimizer, target_dqn, + exp_hyperparams["batch_size"], + exp_hyperparams["gamma"]) + else: + stop_training = True + break + + # Update the target network, copying all weights and biases in DQN + if i_episode % exp_hyperparams['target_update'] == 0: + target_dqn.load_state_dict(dqn.state_dict()) + if stop_training == True: + count_final += 1 + if count_final >= 100: + break + + print('Training Complete') + env.close() + +# load the variables from .env file +load_envs() + +@hydra.main(config_path=str("./configs/"), config_name="default") +def main(cfg: omegaconf.DictConfig) -> None: + train_rl(cfg) + +if __name__ == '__main__': + main() diff --git a/examples/images/reinforcementlearning/utils.py b/examples/images/reinforcementlearning/utils.py new file mode 100644 index 0000000..d7c1335 --- /dev/null +++ b/examples/images/reinforcementlearning/utils.py @@ -0,0 +1,66 @@ +import random +from collections import namedtuple +import dotenv +from typing import Optional + +# Define Transition as a namedtuple for better structure and readability +Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) + +class ReplayMemory: + def __init__(self, capacity): + """Initialize the ReplayMemory with a fixed capacity. + + Args: + capacity (int): The maximum size of the memory. + """ + self.capacity = capacity + self.memory = [] + self.position = 0 + + def push(self, state, action, next_state, reward): + """Saves a transition into memory. + + Overwrites the oldest transition if memory is at capacity. + Args: + state: The state of the environment before taking the action. + action: The action taken. + next_state: The state of the environment after taking the action. + reward: The reward received after taking the action. + """ + # Create a Transition from the given arguments + transition = Transition(state, action, next_state, reward) + + # Check if there is still room to append a new transition + if len(self.memory) < self.capacity: + self.memory.append(None) + # Overwrite the oldest data if the memory is full + self.memory[self.position] = transition + # Move the write position; wraps around to the beginning using modulo + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size): + """Samples a batch of transitions from memory. + + Args: + batch_size (int): Number of transitions to sample. + + Returns: + list: A list of randomly sampled transitions. + """ + return random.sample(self.memory, batch_size) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + +def load_envs(env_file: Optional[str] = None) -> None: + """ + Load all the environment variables defined in the `env_file`. + This is equivalent to `. env_file` in bash. + + It is possible to define all the system specific variables in the `env_file`. + + :param env_file: the file that defines the environment variables to use. If None + it searches for a `.env` file in the project. + """ + dotenv.load_dotenv(dotenv_path=env_file, override=True) \ No newline at end of file