Skip to content

Commit

Permalink
Chore:pre-commit wrap-ups.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Aug 22, 2024
1 parent 4a86475 commit ccf603f
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 7 deletions.
8 changes: 8 additions & 0 deletions og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def describe_episode_returns(
plot_saving_rel_dir: str = "vaults",
n_bins: Optional[int] = 50,
) -> None:
"""Describe a vault.
From the specified directory and for the specified uids,
describes vaults according to their episode returns.
The descriptors include a table of episode return mean, standard deviation, min and max.
Additionally, the distributions of episode returns are visualised in histograms
and violin plots. n_bins is how many bins the histogram should have.
"""
# get all uids if not specified
if vault_uids is None:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")
Expand Down
7 changes: 7 additions & 0 deletions og_marl/vault_utils/combine_vaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_all_vaults(
vault_uids: Optional[list[str]] = None,
rel_dir: str = "vaults",
) -> list[Vault]:
"""Gets a list of Vaults from the specified directory. Each uid produces one Vault."""
if vault_uids is None:
vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")

Expand All @@ -40,6 +41,7 @@ def get_all_vaults(


def stitch_vault_from_many(vlts: list[Vault], vault_name: str, vault_uid: str, rel_dir: str) -> int:
"""Given a list of vaults, writes all experience in each of them to one new vault."""
all_data = vlts[0].read()
offline_data = all_data.experience

Expand Down Expand Up @@ -84,6 +86,11 @@ def stitch_vault_from_many(vlts: list[Vault], vault_name: str, vault_uid: str, r


def combine_vaults(rel_dir: str, vault_name: str, vault_uids: Optional[list[str]] = None) -> str:
"""Combines datasets in a vault.
Takes multiple datasets in a vault and combines them into one new vault
with added "_combined" in the name.
"""
# check that the vault to be combined exists
if not check_directory_exists_and_not_empty(f"./{rel_dir}/{vault_name}"):
print(f"Vault './{rel_dir}/{vault_name}' does not exist and cannot be combined.")
Expand Down
8 changes: 8 additions & 0 deletions og_marl/vault_utils/download_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@


def print_download_options() -> Dict[str, Dict]:
"""Prints as well as returns all options for downloading vaults from OG-MARL huggingface."""
print("VAULT_INFO:")
for source in VAULT_INFO.keys():
print(f"\t {source}")
Expand All @@ -158,6 +159,11 @@ def download_and_unzip_vault(
dataset_base_dir: str = "./vaults",
dataset_download_url: str = "",
) -> str:
"""Downloads and unzips vaults.
The location of the vault is dataset_base_dir/dataset_source/env_name/scenario_name/.
If the vault already exists and is not empty, the download does not happen.
"""
# to prevent downloading the vault twice into the same folder
if check_directory_exists_and_not_empty(
f"{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}.vlt"
Expand Down Expand Up @@ -215,6 +221,7 @@ def download_and_unzip_vault(


def check_directory_exists_and_not_empty(path: str) -> bool:
"""Checks that the directory at path exists and is not empty."""
# Check if the directory exists
if os.path.exists(path) and os.path.isdir(path):
# Check if the directory is not empty
Expand All @@ -227,6 +234,7 @@ def check_directory_exists_and_not_empty(path: str) -> bool:


def get_available_uids(rel_vault_path: str) -> List[str]:
"""Obtains the uids of datasets in a vault at the relative path."""
vault_uids = sorted(
next(os.walk(rel_vault_path))[1],
reverse=True,
Expand Down
18 changes: 16 additions & 2 deletions og_marl/vault_utils/subsample_bespoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# given bin edges and a sorted array of values, get the bin number per value
def get_bin_numbers(sorted_values: Array, bin_edges: Array) -> Array:
"""Assigns known, sorted values to given bins."""
bin_numbers = np.zeros_like(sorted_values)

def get_bin_number(bin_num: int, value: float) -> int:
Expand All @@ -46,6 +47,12 @@ def get_bin_number(bin_num: int, value: float) -> int:
def bin_processed_data(
all_sorted_return_start_end: Array, n_bins: int = 500
) -> Tuple[Array, Array, Array, Array, Array]:
"""Bins rows in an array according to the values in a particular column.
Returns bar_labels connected with bar_heights,
as well as padded_heights whose entries' indices are the bin number.
Bin edges and all numbers are also returned.
"""
# get bin edges, including final endpoint
bin_edges = jnp.linspace(
start=min(min(all_sorted_return_start_end[:, 0]), 0),
Expand All @@ -68,9 +75,16 @@ def bin_processed_data(
return bar_labels, bar_heights, padded_heights.astype(int), bin_edges, bin_numbers


# sample from pdf according to heights
# BIG NOTE: CHECK THE DISPARITY, OTHERWISE YOUR DISTRIBUTION WILL BE TOO MUCH
def episode_idxes_sampled_from_pdf(pdf: Array, bar_heights: Array) -> list[int]:
"""Gets a list of episode indices according to how many episodes you want from each bin.
Given an array of desired bar heights and the actual bar heights, produces the
indices of the episodes which should be sampled to produce the desired bar heights.
It is assumed that episodes which will be sampled are sorted ascending, and so, for example,
with bars of heights 3,5,7 if you want a resultant histogram of 0,2,0, you will need to
sample two indices randomly from {3,4,5,6,7}.
The function will then return the list of sampled indices.
"""
num_to_sample = np.round(pdf).astype(int)
sample_range_edges = np.concatenate([[0], np.cumsum(bar_heights)])

Expand Down
25 changes: 22 additions & 3 deletions og_marl/vault_utils/subsample_similar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@

# cumulative summing per-episode
def get_episode_returns_and_term_idxes(offline_data: Dict[str, Array]) -> Tuple[Array, Array]:
"""Gets the episode returns and final indices from a batch of experience.
From a batch of experience extract the indices
of the final transitions as well as the returns of each episode in order.
"""
rewards = offline_data["rewards"][0, :, 0]
terminal_flag = offline_data["terminals"][0, :, ...].all(axis=-1)

Expand All @@ -49,13 +54,17 @@ def scan_cumsum(
return cumsums[term_idxes - 1], term_idxes


# first store indices of episodes, then sort by episode return.
# outputs return, start, end and vault index in vault list
def sort_concat(returns: Array, eps_ends: Array) -> Array:
"""An original-order-aware sorting and concatenating of episode information.
From a list of episodes ends and returns which are in the order of the episodes in the
original experience batch, produces an array of rows of return,
start index and end index for each episode.
"""
# build start indexes from end indexes since they are in order
episode_start_idxes = eps_ends[:-1] + 1
episode_start_idxes = jnp.insert(episode_start_idxes, 0, 0).reshape(-1, 1)
sorting_idxes = jnp.lexsort(jnp.array([returns[:, 0]]), axis=-1)
# print(sorting_idxes)

return_start_end = jnp.concatenate(
[returns, episode_start_idxes.reshape(-1, 1), eps_ends], axis=-1
Expand All @@ -69,6 +78,15 @@ def sort_concat(returns: Array, eps_ends: Array) -> Array:
def get_idxes_of_similar_subsets(
base_returns: List, comp_returns: List, tol: float = 0.1
) -> Tuple[List, List]:
"""Gets indices of episodes s.t. the subsets of two datasets have similar return distributions.
Iteratively selects episodes from either dataset with episode return within "tol" of each other.
Importantly, returns are SORTED BEFORE BEING PASSED TO THIS FUNCTION.
Returns the list of all such almost-matching episodes.
(For each episode Ea in dataset A, if there is an episode Eb in dataset B with a return within
"tol" of the return of E, select Ea and Eb to be sampled.
If not, move on from Ea.)
"""
base_selected_idxes = []
comp_selected_idxes = []

Expand Down Expand Up @@ -98,6 +116,7 @@ def subsample_similar(
new_rel_dir: str,
new_vault_name: str,
) -> None:
"""Subsamples 2 datasets s.t. the new datasets have similar episode return distributions."""
# check that a subsampled vault by the same name does not already exist
if check_directory_exists_and_not_empty(f"./{new_rel_dir}/{new_vault_name}"):
print(
Expand Down
35 changes: 33 additions & 2 deletions og_marl/vault_utils/subsample_smaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@


def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "terminals") -> Array:
"""Process experience to get the length, start and end of all episodes.
From a block of experience, extracts the length, start position and end position of each
episode. Length is stored for the convenience of a cumsum in the following function, and
to match other episode information blocks which store return, instead.
"""
# extract terminals
terminal_flag = experience[terminal_key][0, :, ...].all(axis=-1)

Expand All @@ -41,7 +47,7 @@ def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "term
start_idxes = np.zeros_like(term_idxes)
start_idxes[1:] = term_idxes[:-1]

# get the length per-episode (TODO maybe redundant)
# get the length per-episode
lengths = term_idxes - start_idxes

# concatenate for easier unpacking
Expand All @@ -51,6 +57,13 @@ def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "term


def select_episodes_uniformly_up_to_n_transitions(len_start_end: Array, n: int) -> Array:
"""Uniformly selects episodes from an episode info array.
Selects rows (episodes) randomly uniformly from an array containing
episode lengths, start indices and end indices.
Shuffles the indices of the array, then selects the first x shuffled rows
up til the cumulative sum of their lengths exceeds n.
"""
# shuffle idxes of all the episodes from the vault
shuffled_idxes = np.arange(len_start_end.shape[0])
np.random.shuffle(shuffled_idxes)
Expand All @@ -69,7 +82,6 @@ def select_episodes_uniformly_up_to_n_transitions(len_start_end: Array, n: int)
return randomly_sampled_len_start_end


# given the indices of the required episodes, stitch a vault and save under a user-specified name
def stitch_vault_from_sampled_episodes_(
experience: Dict[str, Array],
len_start_end_sample: Array,
Expand All @@ -78,6 +90,16 @@ def stitch_vault_from_sampled_episodes_(
rel_dir: str,
n: int = 500_000,
) -> int:
"""Writes a vault given episode information and a batch of experience.
Takes in experience
and an array with columns R,S,E (reward, start, end) or L, S, E (length, start, end)
describing episode returns (or redundantly lengths) and their positions in the block of
experience.
For every row in len_start_end_sample (for every episode in the sample),
selects the relevant transitions from "experience" and adds it to a new buffer.
In the end, writes the buffer of all episodes to the destination vault.
"""
# to prevent downloading the vault twice into the same folder
if check_directory_exists_and_not_empty(f"{rel_dir}/{dest_vault_name}/{vault_uid}/"):
print(f"Vault '{rel_dir}/{dest_vault_name}.vlt/{vault_uid}' already exists.")
Expand Down Expand Up @@ -125,6 +147,15 @@ def subsample_smaller_vault(
vault_uids: Optional[list] = None,
target_number_of_transitions: int = 500000,
) -> str:
"""Subsamples a vault to a smaller number of transitions.
Subsamples every dataset in the list of uids (or, if unspecified, all uids in the vault)
by uniformly randomly selecting episodes from that dataset and storing it in a new vault.
Once the number of transitions in the episode selection is within one trajectory length
of the desired number of transitions, no further transitions are included.
This is to avoid creating partial trajectories.
"""
# check that the vault to be subsampled exists
if not check_directory_exists_and_not_empty(f"./{vaults_dir}/{vault_name}"):
print(f"Vault './{vaults_dir}/{vault_name}' does not exist and cannot be subsampled.")
Expand Down

0 comments on commit ccf603f

Please sign in to comment.