Skip to content

Commit

Permalink
Merge pull request #59 from instadeepai/feat/concurrent-downloads
Browse files Browse the repository at this point in the history
Modifies json_utils.pull_neptune_data to now download files concurrently
  • Loading branch information
RuanJohn authored Apr 12, 2024
2 parents 3188c46 + 5869627 commit 3b7fd46
Showing 1 changed file with 89 additions and 46 deletions.
135 changes: 89 additions & 46 deletions marl_eval/json_tools/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import os
import zipfile
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Tuple

import neptune
Expand Down Expand Up @@ -112,65 +114,106 @@ def concatenate_json_files(

def pull_neptune_data(
project_name: str,
tag: List,
tags: List[str],
store_directory: str = "./downloaded_json_data",
neptune_data_key: str = "metrics",
disable_progress_bar: bool = False,
) -> None:
"""Pulls experiment json data from Neptune to a local directory.
"""Downloads logs from a Neptune project based on provided tags.
Args:
project_name (str): Name of the Neptune project.
tag (List): List of tags for the experiment(s) that contain the
desired JSON files.
store_directory (str, optional): Directory to store the data.
Default: ./downloaded_json_data.
neptune_data_key (str, optional): Key in the neptune run where the
json data is stored. Default: metrics.
tags (List[str]): List of tags associated with the desired experiments.
store_directory (str, optional): Directory to store the downloaded logs.
Default is "./downloaded_json_data".
neptune_data_key (str, optional): Key for the Neptune data to download.
Default is "metrics".
disable_progress_bar (bool, optional): Whether to hide a progress bar.
Default is False.
Raises:
ValueError: If the provided project name or tags are invalid.
"""
# Get the run ids
project = neptune.init_project(project=project_name)
runs_table_df = project.fetch_runs_table(state="inactive", tag=tag).to_pandas()
run_ids = runs_table_df["sys/id"].values.tolist()

# Check if store_directory exists
if not os.path.exists(store_directory):
os.makedirs(store_directory)
# Create the log directory if it doesn't exist
os.makedirs(store_directory, exist_ok=True)

# Suppress neptune logger
# Disable Neptune logging
neptune_logger = logging.getLogger("neptune")
neptune_logger.setLevel(logging.ERROR)

# Download and unzip the data
for run_id in tqdm(run_ids, desc="Downloading Neptune Data"):
run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only")
for j, data_key in enumerate(
run.get_structure()[neptune_data_key].keys(), start=1
# Initialize the Neptune project
try:
project = neptune.init_project(project=project_name)
except Exception as e:
raise ValueError(f"Invalid project name '{project_name}': {e}")

# Fetch runs based on provided tags
try:
runs_table_df = project.fetch_runs_table(
state="inactive", columns=["sys/id"], tag=tags, sort_by="sys/id"
).to_pandas()
except Exception as e:
raise ValueError(f"Invalid tags {tags}: {e}")

run_ids = runs_table_df["sys/id"].values.tolist()

# Download logs concurrently
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
_download_and_extract_data,
project_name,
run_id,
store_directory,
neptune_data_key,
)
for run_id in run_ids
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Downloading JSON logs",
disable=disable_progress_bar,
):
# Create a unique filename
file_path = f"{store_directory}/{data_key}_{run_id}_{j}"
run[f"{neptune_data_key}/{data_key}"].download(destination=file_path)
# Try to unzip the file else continue to the next file
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
# Create a directory to store unzipped data
os.makedirs(f"{file_path}_unzip", exist_ok=True)
# Unzip the data
zip_ref.extractall(f"{file_path}_unzip")
# Remove the zip file
os.remove(file_path)
except zipfile.BadZipFile:
# If the file is not zipped continue to the next file
# as it is already downloaded and doesn't need to be
# unzipped.
continue
except Exception as e:
print(
f"The following error occurred while unzipping or storing JSON \
data for run {run_id} at path {file_path}: {e}"
)
run.stop()
future.result()

# Restore neptune logger level
neptune_logger.setLevel(logging.INFO)

print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}")


def _download_and_extract_data(
project_name: str, run_id: str, store_directory: str, neptune_data_key: str
) -> None:
try:
with neptune.init_run(
project=project_name, with_id=run_id, mode="read-only"
) as run:
for j, data_key in enumerate(
run.get_structure()[neptune_data_key].keys(), start=1
):
file_path = f"{store_directory}/{run_id}"
if j > 1:
file_path += f"_{j}"
run[f"{neptune_data_key}/{data_key}"].download(destination=file_path)
_extract_zip_file(file_path)
except Exception as e:
print(f"Error downloading data for run {run_id}: {e}")


def _extract_zip_file(file_path: str) -> None:
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
for member in zip_ref.infolist():
if not member.is_dir():
target_path = Path(f"{file_path}{Path(member.filename).suffix}")
target_path.parent.mkdir(parents=True, exist_ok=True)
with zip_ref.open(member) as src, target_path.open("wb") as dest:
dest.write(src.read())
# Remove the zip file
os.remove(file_path)
except zipfile.BadZipFile:
# If the file is not zipped, no action is required
pass
except Exception as e:
print(f"Error while unzipping or storing JSON data at path {file_path}: {e}")

0 comments on commit 3b7fd46

Please sign in to comment.