Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalized target data generation to both flu and covid #35

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Initially the visualization will have these limitations:
- For the `task_ids` entry in predtimechart config option generation, we use `value` for both `value` and `text`, rather than asking the user to provide a mapping from `value` to `text`. A solution is to require that mapping in `predtimechart-config.yml`.
- The `initial_as_of` and `current_date` config fields are the last of `hub_config.fetch_reference_dates`.
- The `initial_task_ids` config field is the first `task_ids` `value`.
- Target data generation: The app `generate_target_json_files.py` is limited to hubs that store their target data as a .csv file in the `target-data` subdirectory. That file is specified via the `target_data_file_name` field in the hub's `predtimechart-config.yml` file. We expect the file has these columns: `date`, `value`, and `location`.

# Required hub configuration

Expand Down Expand Up @@ -93,7 +94,7 @@ We plan to primarily use https://github.com/hubverse-org/example-complex-forecas
- Where is the source data coming from - GitHub vs. S3?
- Which model output formats will we support? The hub docs mention CSV and parquet, but that others (e.g., zipped files) might be supported.
- Regarding naming the .json files, should we be influenced by Arrow's partitioning scheme where it names intermediate directories according to filtering.
- We might need separate apps to update config options vs. visualization data (json files) for the case where the user has changed predtimechart-config.yml independent of a round closing.
- We might need separate apps to update config options vs. visualization data (json files) for the case where the user has changed `predtimechart-config.yml` independent of a round closing.
- Should we filter out `hub_config.horizon_col_name == 0` ?
- Should `forecast_data_for_model_df()`'s `quantile_levels` be stored in a config file somewhere?

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dev = [
[project.entry-points."console_scripts"]
hub_predtimechart = "hub_predtimechart.app.generate_json_files:main"
ptc_generate_json_files = "hub_predtimechart.app.generate_json_files:main"
ptc_generate_flusight_targets = "hub_predtimechart.app.generate_target_json_files_FluSight:main"
ptc_generate_target_json_files = "hub_predtimechart.app.generate_target_json_files:main"
zkamvar marked this conversation as resolved.
Show resolved Hide resolved


[build-system]
Expand Down
3 changes: 2 additions & 1 deletion src/hub_predtimechart/app/generate_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _generate_json_files(hub_config: HubConfig, output_dir: Path, is_regenerate:
if not model_id_to_df: # no model outputs for reference_date
continue

# iterate over each (target X task_ids) combination, outputting to the corresponding json file
# iterate over each (target X task_ids) combination (for now we only support one target), outputting to the
# corresponding json file
available_as_ofs = hub_config.get_available_as_ofs().values()
newest_reference_date = max([max(date) for date in available_as_ofs])
for task_ids_tuple in hub_config.fetch_task_ids_tuples:
Expand Down
128 changes: 128 additions & 0 deletions src/hub_predtimechart/app/generate_target_json_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import sys
from datetime import date, timedelta
from pathlib import Path

import click
import polars as pl
import structlog

from hub_predtimechart.app.generate_json_files import json_file_name
from hub_predtimechart.hub_config import HubConfig
from hub_predtimechart.util.logs import setup_logging


setup_logging()
logger = structlog.get_logger()


@click.command()
@click.argument('hub_dir', type=click.Path(file_okay=False, exists=True))
@click.argument('ptc_config_file', type=click.Path(file_okay=True, exists=False))
@click.argument('target_out_dir', type=click.Path(file_okay=False, exists=True))
def main(hub_dir, ptc_config_file, target_out_dir):
"""
Generates the target data json files used by https://github.com/reichlab/predtimechart to
visualize a hub's forecasts. Handles missing input target data in two ways, depending on the error. 1) If the
`target_data_file_name` entry in the hub config file is missing, then the program will exit with no messages.
2) If the entry is present but the file it points to does not exist, then the program will exit with an error
message, but won't actually raise a Python exception.

HUB_DIR: (input) a directory Path of a https://hubverse.io hub to generate target data json files from

PTC_CONFIG_FILE: (input) a file Path to a `predtimechart-config.yaml` file that specifies how to process `hub_dir`
to get predtimechart output

TARGET_OUT_DIR: (output) a directory Path to output the viz target data json files to
\f
:param hub_dir: (input) a directory Path of a https://hubverse.io hub to generate target data json files from
:param ptc_config_file: (input) a file Path to a `predtimechart-config.yaml` file that specifies how to process
`hub_dir` to get predtimechart output
:param target_out_dir: (output) a directory Path to output the viz target data json files to
"""
logger.info(f'main({hub_dir=}, {target_out_dir=}): entered')

hub_dir = Path(hub_dir)
hub_config = HubConfig(hub_dir, Path(ptc_config_file))
if hub_config.target_data_file_name is None:
logger.info('No `target_data_file_name` entry found in hub config file. exiting')
return

# for each location, generate target data file contents and then save as json
json_files = []
try:
target_data_df = get_target_data_df(hub_dir, hub_config.target_data_file_name)
except FileNotFoundError as error:
logger.error(f"target data file not found. {hub_config.target_data_file_name=}, {error=}")
sys.exit(1)

for loc in target_data_df['location'].unique():
task_ids_tuple = (loc,)
target_out_dir = Path(target_out_dir)
file_name = json_file_name(hub_config.fetch_target_id, task_ids_tuple, reference_date_from_today().isoformat())
location_data_dict = ptc_target_data(target_data_df, task_ids_tuple)
json_files.append(target_out_dir / file_name)
with open(target_out_dir / file_name, 'w') as fp:
json.dump(location_data_dict, fp, indent=4)

logger.info(f'main(): done: {len(json_files)} JSON files generated: {[str(_) for _ in json_files]}. ')


def get_target_data_df(hub_dir, target_data_filename):
"""
Loads the target data csv file from the hub repo for now, file path for target data is hard coded to 'target-data'.
Raises FileNotFoundError if target data file does not exist.
"""
if target_data_filename is None:
raise FileNotFoundError(f"target_data_filename was missing: {target_data_filename}")

target_data_file_path = hub_dir / 'target-data' / target_data_filename
try:
# the override schema handles the 'US' location (the only location that doesn't parse as Int64)
return pl.read_csv(target_data_file_path, schema_overrides={'location': pl.String}, null_values=["NA"])
except FileNotFoundError as error:
raise FileNotFoundError(f"target data file not found. {target_data_file_path=}, {error=}")


def ptc_target_data(target_data_df: pl.DataFrame, task_ids_tuple: tuple[str]):
"""
Returns a dict for a single reference date and location in the target data format documented at https://github.com/reichlab/predtimechart?tab=readme-ov-file#fetchdata-truth-data-format.
Note that this function currently assumes there is only one task id variable other than the reference date, horizon,
and target date, and that task id variable is a location code that matches codes used in the `location` column of
the `target_data_df` argument. That is, looking at that example, this function returns the date and value columns as
in tests/expected/FluSight-forecast-hub/target/wk-inc-flu-hosp_US.json :

{
"date": ["2024-04-27", "2024-04-20", "..."],
"y": [2337, 2860, "..."]
}
"""
loc = task_ids_tuple[0]
target_data_loc = target_data_df.filter(pl.col('location') == loc).sort('date')
target_data_ptc = {
"date": target_data_loc['date'].to_list(),
"y": target_data_loc['value'].to_list()
}

return target_data_ptc


def reference_date_from_today(now: date = None) -> date:
if now is None: # per https://stackoverflow.com/questions/52511405/freeze-time-not-working-for-default-param
now = date.today()

# Calculate the days until the next Saturday
days_to_saturday = 5 - now.weekday()
if days_to_saturday < 0:
days_to_saturday += 7

# Add the calculated days to the given date
return now + timedelta(days=days_to_saturday)


#
# main()
#

if __name__ == '__main__':
main()
78 changes: 0 additions & 78 deletions src/hub_predtimechart/app/generate_target_json_files_FluSight.py

This file was deleted.

26 changes: 0 additions & 26 deletions src/hub_predtimechart/generate_target_data.py

This file was deleted.

6 changes: 4 additions & 2 deletions src/hub_predtimechart/hub_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self, hub_dir: Path, ptc_config_file: Path):
model_tasks_ele = tasks['rounds'][self.rounds_idx]['model_tasks'][self.model_tasks_idx]
self.task_ids = sorted(model_tasks_ele['task_ids'].keys())

# set target_data_file_name
self.target_data_file_name = ptc_config.get('target_data_file_name')
elray1 marked this conversation as resolved.
Show resolved Hide resolved

# set viz_task_ids and fetch_targets. recall: we assume there is only one target_metadata entry, only one
# entry under its `target_keys`
target_metadata = model_tasks_ele['target_metadata'][0]
Expand Down Expand Up @@ -136,13 +139,12 @@ def get_sorted_values_or_first_config_ref_date(reference_dates: set[str]):
return sorted(list(reference_dates))


# loop over every (reference_date X model_id) combination.
# loop over every (reference_date X model_id) combination
as_ofs = {self.fetch_target_id: set()}
for reference_date in self.reference_dates: # ex: ['2022-10-22', '2022-10-29', ...]
for model_id in self.model_id_to_metadata: # ex: ['Flusight-baseline', 'MOBS-GLEAM_FLUH', ...]
model_output_file = self.model_output_file_for_ref_date(model_id, reference_date)
if model_output_file:
# todo xx extract to function, call from here and _generate_json_files()
elray1 marked this conversation as resolved.
Show resolved Hide resolved
if model_output_file.suffix == '.csv':
df = pd.read_csv(model_output_file, usecols=[self.target_col_name])
elif model_output_file.suffix in ['.parquet', '.pqt']:
Expand Down
7 changes: 6 additions & 1 deletion src/hub_predtimechart/ptc_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
"type": "string"
}
}
}
},
"target_data_file_name": {
"description": "optional name of the target data file located in the hub's target-data dir",
"type": "string",
"minLength": 1
},
},
"required": [
"rounds_idx",
Expand Down
12 changes: 12 additions & 0 deletions tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_01.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"date": [
"2024-11-09",
"2024-11-16",
"2024-11-23"
],
"y": [
112,
81,
70
]
}
12 changes: 12 additions & 0 deletions tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_US.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"date": [
"2024-11-09",
"2024-11-16",
"2024-11-23"
],
"y": [
7691,
7595,
7290
]
}
33 changes: 24 additions & 9 deletions tests/hub_predtimechart/test_generate_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@
from pathlib import Path

import polars as pl
import pytest
from freezegun import freeze_time

from hub_predtimechart.app.generate_target_json_files_FluSight import reference_date_from_today, get_target_data_df
from hub_predtimechart.generate_target_data import target_data_for_FluSight
from hub_predtimechart.app.generate_target_json_files import reference_date_from_today, get_target_data_df, \
ptc_target_data


def test_generate_target_data_flusight_forecast_hub():
hub_dir = Path('tests/hubs/FluSight-forecast-hub')
target_data_df = pl.read_csv(hub_dir / "target-data/target-hospital-admissions.csv")

target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions.csv')
for loc in ['US', '01']:
task_ids_tuple = (loc,)
with open(f'tests/expected/FluSight-forecast-hub/target/wk-inc-flu-hosp_{loc}.json') as fp:
exp_data = json.load(fp)
act_data = ptc_target_data(target_data_df, task_ids_tuple)
assert act_data == exp_data

# case: Flusight-baseline
model_output_file = hub_dir / 'model-output/Flusight-baseline/2022-10-22-Flusight-baseline.csv'
act_data = target_data_for_FluSight(target_data_df, task_ids_tuple)
assert act_data == exp_data

def test_generate_target_data_covid19_forecast_hub():
hub_dir = Path('tests/hubs/covid19-forecast-hub')
target_data_df = get_target_data_df(hub_dir, 'covid-hospital-admissions.csv')
for loc in ['US', '01']:
task_ids_tuple = (loc,)
with open(f'tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_{loc}.json') as fp:
exp_data = json.load(fp)
act_data = ptc_target_data(target_data_df, task_ids_tuple)
assert act_data == exp_data


@freeze_time("2024-10-24")
Expand All @@ -40,7 +48,7 @@ def test_reference_date_from_today():
assert act_reference_date == exp_reference_date


def test_get_target_data_df():
def test_get_target_data_df_error_cases():
hub_dir = Path('tests/hubs/FluSight-forecast-hub')
act_target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions-no-na.csv')
assert act_target_data_df["value"].dtype == pl.datatypes.Int64
Expand All @@ -49,3 +57,10 @@ def test_get_target_data_df():
act_target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions-yes-na.csv')
assert act_target_data_df["value"].dtype == pl.datatypes.Int64
assert act_target_data_df["value"].to_list() == [3, 16, None, 106, 151, 23, 64, 8, 2, 266]

# case: file not found
elray1 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(FileNotFoundError, match="target_data_filename was missing"):
get_target_data_df(hub_dir, None)

with pytest.raises(FileNotFoundError, match="target data file not found"):
get_target_data_df(hub_dir, 'non-existent-file.csv')
Loading
Loading