diff --git a/og_marl/vault_utils/analyse_vault.py b/og_marl/vault_utils/analyse_vault.py index 03be0e9f..7c78eef1 100644 --- a/og_marl/vault_utils/analyse_vault.py +++ b/og_marl/vault_utils/analyse_vault.py @@ -153,6 +153,14 @@ def describe_episode_returns( plot_saving_rel_dir: str = "vaults", n_bins: Optional[int] = 50, ) -> None: + """Describe a vault. + + From the specified directory and for the specified uids, + describes vaults according to their episode returns. + The descriptors include a table of episode return mean, standard deviation, min and max. + Additionally, the distributions of episode returns are visualised in histograms + and violin plots. n_bins is how many bins the histogram should have. + """ # get all uids if not specified if vault_uids is None: vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}") diff --git a/og_marl/vault_utils/combine_vaults.py b/og_marl/vault_utils/combine_vaults.py index 420558cf..81690aa4 100644 --- a/og_marl/vault_utils/combine_vaults.py +++ b/og_marl/vault_utils/combine_vaults.py @@ -30,6 +30,7 @@ def get_all_vaults( vault_uids: Optional[list[str]] = None, rel_dir: str = "vaults", ) -> list[Vault]: + """Gets a list of Vaults from the specified directory. Each uid produces one Vault.""" if vault_uids is None: vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}") @@ -40,6 +41,7 @@ def get_all_vaults( def stitch_vault_from_many(vlts: list[Vault], vault_name: str, vault_uid: str, rel_dir: str) -> int: + """Given a list of vaults, writes all experience in each of them to one new vault.""" all_data = vlts[0].read() offline_data = all_data.experience @@ -84,6 +86,11 @@ def stitch_vault_from_many(vlts: list[Vault], vault_name: str, vault_uid: str, r def combine_vaults(rel_dir: str, vault_name: str, vault_uids: Optional[list[str]] = None) -> str: + """Combines datasets in a vault. + + Takes multiple datasets in a vault and combines them into one new vault + with added "_combined" in the name. + """ # check that the vault to be combined exists if not check_directory_exists_and_not_empty(f"./{rel_dir}/{vault_name}"): print(f"Vault './{rel_dir}/{vault_name}' does not exist and cannot be combined.") diff --git a/og_marl/vault_utils/download_vault.py b/og_marl/vault_utils/download_vault.py index 5d8aa2b6..011e7358 100644 --- a/og_marl/vault_utils/download_vault.py +++ b/og_marl/vault_utils/download_vault.py @@ -140,6 +140,7 @@ def print_download_options() -> Dict[str, Dict]: + """Prints as well as returns all options for downloading vaults from OG-MARL huggingface.""" print("VAULT_INFO:") for source in VAULT_INFO.keys(): print(f"\t {source}") @@ -158,6 +159,11 @@ def download_and_unzip_vault( dataset_base_dir: str = "./vaults", dataset_download_url: str = "", ) -> str: + """Downloads and unzips vaults. + + The location of the vault is dataset_base_dir/dataset_source/env_name/scenario_name/. + If the vault already exists and is not empty, the download does not happen. + """ # to prevent downloading the vault twice into the same folder if check_directory_exists_and_not_empty( f"{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}.vlt" @@ -215,6 +221,7 @@ def download_and_unzip_vault( def check_directory_exists_and_not_empty(path: str) -> bool: + """Checks that the directory at path exists and is not empty.""" # Check if the directory exists if os.path.exists(path) and os.path.isdir(path): # Check if the directory is not empty @@ -227,6 +234,7 @@ def check_directory_exists_and_not_empty(path: str) -> bool: def get_available_uids(rel_vault_path: str) -> List[str]: + """Obtains the uids of datasets in a vault at the relative path.""" vault_uids = sorted( next(os.walk(rel_vault_path))[1], reverse=True, diff --git a/og_marl/vault_utils/subsample_bespoke.py b/og_marl/vault_utils/subsample_bespoke.py index ea07d470..ccaa7558 100644 --- a/og_marl/vault_utils/subsample_bespoke.py +++ b/og_marl/vault_utils/subsample_bespoke.py @@ -21,6 +21,7 @@ # given bin edges and a sorted array of values, get the bin number per value def get_bin_numbers(sorted_values: Array, bin_edges: Array) -> Array: + """Assigns known, sorted values to given bins.""" bin_numbers = np.zeros_like(sorted_values) def get_bin_number(bin_num: int, value: float) -> int: @@ -46,6 +47,12 @@ def get_bin_number(bin_num: int, value: float) -> int: def bin_processed_data( all_sorted_return_start_end: Array, n_bins: int = 500 ) -> Tuple[Array, Array, Array, Array, Array]: + """Bins rows in an array according to the values in a particular column. + + Returns bar_labels connected with bar_heights, + as well as padded_heights whose entries' indices are the bin number. + Bin edges and all numbers are also returned. + """ # get bin edges, including final endpoint bin_edges = jnp.linspace( start=min(min(all_sorted_return_start_end[:, 0]), 0), @@ -68,9 +75,16 @@ def bin_processed_data( return bar_labels, bar_heights, padded_heights.astype(int), bin_edges, bin_numbers -# sample from pdf according to heights -# BIG NOTE: CHECK THE DISPARITY, OTHERWISE YOUR DISTRIBUTION WILL BE TOO MUCH def episode_idxes_sampled_from_pdf(pdf: Array, bar_heights: Array) -> list[int]: + """Gets a list of episode indices according to how many episodes you want from each bin. + + Given an array of desired bar heights and the actual bar heights, produces the + indices of the episodes which should be sampled to produce the desired bar heights. + It is assumed that episodes which will be sampled are sorted ascending, and so, for example, + with bars of heights 3,5,7 if you want a resultant histogram of 0,2,0, you will need to + sample two indices randomly from {3,4,5,6,7}. + The function will then return the list of sampled indices. + """ num_to_sample = np.round(pdf).astype(int) sample_range_edges = np.concatenate([[0], np.cumsum(bar_heights)]) diff --git a/og_marl/vault_utils/subsample_similar.py b/og_marl/vault_utils/subsample_similar.py index 94209544..90508b2a 100644 --- a/og_marl/vault_utils/subsample_similar.py +++ b/og_marl/vault_utils/subsample_similar.py @@ -27,6 +27,11 @@ # cumulative summing per-episode def get_episode_returns_and_term_idxes(offline_data: Dict[str, Array]) -> Tuple[Array, Array]: + """Gets the episode returns and final indices from a batch of experience. + + From a batch of experience extract the indices + of the final transitions as well as the returns of each episode in order. + """ rewards = offline_data["rewards"][0, :, 0] terminal_flag = offline_data["terminals"][0, :, ...].all(axis=-1) @@ -49,13 +54,17 @@ def scan_cumsum( return cumsums[term_idxes - 1], term_idxes -# first store indices of episodes, then sort by episode return. -# outputs return, start, end and vault index in vault list def sort_concat(returns: Array, eps_ends: Array) -> Array: + """An original-order-aware sorting and concatenating of episode information. + + From a list of episodes ends and returns which are in the order of the episodes in the + original experience batch, produces an array of rows of return, + start index and end index for each episode. + """ + # build start indexes from end indexes since they are in order episode_start_idxes = eps_ends[:-1] + 1 episode_start_idxes = jnp.insert(episode_start_idxes, 0, 0).reshape(-1, 1) sorting_idxes = jnp.lexsort(jnp.array([returns[:, 0]]), axis=-1) - # print(sorting_idxes) return_start_end = jnp.concatenate( [returns, episode_start_idxes.reshape(-1, 1), eps_ends], axis=-1 @@ -69,6 +78,15 @@ def sort_concat(returns: Array, eps_ends: Array) -> Array: def get_idxes_of_similar_subsets( base_returns: List, comp_returns: List, tol: float = 0.1 ) -> Tuple[List, List]: + """Gets indices of episodes s.t. the subsets of two datasets have similar return distributions. + + Iteratively selects episodes from either dataset with episode return within "tol" of each other. + Importantly, returns are SORTED BEFORE BEING PASSED TO THIS FUNCTION. + Returns the list of all such almost-matching episodes. + (For each episode Ea in dataset A, if there is an episode Eb in dataset B with a return within + "tol" of the return of E, select Ea and Eb to be sampled. + If not, move on from Ea.) + """ base_selected_idxes = [] comp_selected_idxes = [] @@ -98,6 +116,7 @@ def subsample_similar( new_rel_dir: str, new_vault_name: str, ) -> None: + """Subsamples 2 datasets s.t. the new datasets have similar episode return distributions.""" # check that a subsampled vault by the same name does not already exist if check_directory_exists_and_not_empty(f"./{new_rel_dir}/{new_vault_name}"): print( diff --git a/og_marl/vault_utils/subsample_smaller.py b/og_marl/vault_utils/subsample_smaller.py index 0f10ed8a..4566a917 100644 --- a/og_marl/vault_utils/subsample_smaller.py +++ b/og_marl/vault_utils/subsample_smaller.py @@ -31,6 +31,12 @@ def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "terminals") -> Array: + """Process experience to get the length, start and end of all episodes. + + From a block of experience, extracts the length, start position and end position of each + episode. Length is stored for the convenience of a cumsum in the following function, and + to match other episode information blocks which store return, instead. + """ # extract terminals terminal_flag = experience[terminal_key][0, :, ...].all(axis=-1) @@ -41,7 +47,7 @@ def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "term start_idxes = np.zeros_like(term_idxes) start_idxes[1:] = term_idxes[:-1] - # get the length per-episode (TODO maybe redundant) + # get the length per-episode lengths = term_idxes - start_idxes # concatenate for easier unpacking @@ -51,6 +57,13 @@ def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "term def select_episodes_uniformly_up_to_n_transitions(len_start_end: Array, n: int) -> Array: + """Uniformly selects episodes from an episode info array. + + Selects rows (episodes) randomly uniformly from an array containing + episode lengths, start indices and end indices. + Shuffles the indices of the array, then selects the first x shuffled rows + up til the cumulative sum of their lengths exceeds n. + """ # shuffle idxes of all the episodes from the vault shuffled_idxes = np.arange(len_start_end.shape[0]) np.random.shuffle(shuffled_idxes) @@ -69,7 +82,6 @@ def select_episodes_uniformly_up_to_n_transitions(len_start_end: Array, n: int) return randomly_sampled_len_start_end -# given the indices of the required episodes, stitch a vault and save under a user-specified name def stitch_vault_from_sampled_episodes_( experience: Dict[str, Array], len_start_end_sample: Array, @@ -78,6 +90,16 @@ def stitch_vault_from_sampled_episodes_( rel_dir: str, n: int = 500_000, ) -> int: + """Writes a vault given episode information and a batch of experience. + + Takes in experience + and an array with columns R,S,E (reward, start, end) or L, S, E (length, start, end) + describing episode returns (or redundantly lengths) and their positions in the block of + experience. + For every row in len_start_end_sample (for every episode in the sample), + selects the relevant transitions from "experience" and adds it to a new buffer. + In the end, writes the buffer of all episodes to the destination vault. + """ # to prevent downloading the vault twice into the same folder if check_directory_exists_and_not_empty(f"{rel_dir}/{dest_vault_name}/{vault_uid}/"): print(f"Vault '{rel_dir}/{dest_vault_name}.vlt/{vault_uid}' already exists.") @@ -125,6 +147,15 @@ def subsample_smaller_vault( vault_uids: Optional[list] = None, target_number_of_transitions: int = 500000, ) -> str: + """Subsamples a vault to a smaller number of transitions. + + Subsamples every dataset in the list of uids (or, if unspecified, all uids in the vault) + by uniformly randomly selecting episodes from that dataset and storing it in a new vault. + + Once the number of transitions in the episode selection is within one trajectory length + of the desired number of transitions, no further transitions are included. + This is to avoid creating partial trajectories. + """ # check that the vault to be subsampled exists if not check_directory_exists_and_not_empty(f"./{vaults_dir}/{vault_name}"): print(f"Vault './{vaults_dir}/{vault_name}' does not exist and cannot be subsampled.")