Skip to content

Commit

Permalink
generalized target data generation to both flu and covid
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcornell committed Dec 3, 2024
1 parent eec4110 commit 6492920
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 48 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Initially the visualization will have these limitations:
- For the `task_ids` entry in predtimechart config option generation, we use `value` for both `value` and `text`, rather than asking the user to provide a mapping from `value` to `text`. A solution is to require that mapping in `predtimechart-config.yml`.
- The `initial_as_of` and `current_date` config fields are the last of `hub_config.fetch_reference_dates`.
- The `initial_task_ids` config field is the first `task_ids` `value`.
- Target data generation: The app `generate_target_json_files.py` is limited to hubs that store their target data as a .csv file in the `target-data` subdirectory. That file is specified via the `target_data_file_name` field in the hub's `predtimechart-config.yml` file. We expect the file has these columns: `date`, `value`, and `location`.

# Required hub configuration

Expand Down Expand Up @@ -93,7 +94,7 @@ We plan to primarily use https://github.com/hubverse-org/example-complex-forecas
- Where is the source data coming from - GitHub vs. S3?
- Which model output formats will we support? The hub docs mention CSV and parquet, but that others (e.g., zipped files) might be supported.
- Regarding naming the .json files, should we be influenced by Arrow's partitioning scheme where it names intermediate directories according to filtering.
- We might need separate apps to update config options vs. visualization data (json files) for the case where the user has changed predtimechart-config.yml independent of a round closing.
- We might need separate apps to update config options vs. visualization data (json files) for the case where the user has changed `predtimechart-config.yml` independent of a round closing.
- Should we filter out `hub_config.horizon_col_name == 0` ?
- Should `forecast_data_for_model_df()`'s `quantile_levels` be stored in a config file somewhere?

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dev = [
[project.entry-points."console_scripts"]
hub_predtimechart = "hub_predtimechart.app.generate_json_files:main"
ptc_generate_json_files = "hub_predtimechart.app.generate_json_files:main"
ptc_generate_flusight_targets = "hub_predtimechart.app.generate_target_json_files_FluSight:main"
ptc_generate_target_json_files = "hub_predtimechart.app.generate_target_json_files:main"


[build-system]
Expand Down
3 changes: 2 additions & 1 deletion src/hub_predtimechart/app/generate_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _generate_json_files(hub_config: HubConfig, output_dir: Path, is_regenerate:
if not model_id_to_df: # no model outputs for reference_date
continue

# iterate over each (target X task_ids) combination, outputting to the corresponding json file
# iterate over each (target X task_ids) combination (for now we only support one target), outputting to the
# corresponding json file
available_as_ofs = hub_config.get_available_as_ofs().values()
newest_reference_date = max([max(date) for date in available_as_ofs])
for task_ids_tuple in hub_config.fetch_task_ids_tuples:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import structlog

from hub_predtimechart.app.generate_json_files import json_file_name
from hub_predtimechart.generate_target_data import target_data_for_FluSight
from hub_predtimechart.hub_config import HubConfig
from hub_predtimechart.util.logs import setup_logging


Expand All @@ -17,8 +17,9 @@

@click.command()
@click.argument('hub_dir', type=click.Path(file_okay=False, exists=True))
@click.argument('ptc_config_file', type=click.Path(file_okay=True, exists=False))
@click.argument('target_out_dir', type=click.Path(file_okay=False, exists=True))
def main(hub_dir, target_out_dir):
def main(hub_dir, ptc_config_file, target_out_dir):
'''
Generates the target data json files used by https://github.com/reichlab/predtimechart to
visualize a hub's forecasts.
Expand All @@ -34,17 +35,18 @@ def main(hub_dir, target_out_dir):
logger.info(f'main({hub_dir=}, {target_out_dir=}): entered')

hub_dir = Path(hub_dir)
hub_config = HubConfig(hub_dir, Path(ptc_config_file))
target_out_dir = Path(target_out_dir)
target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions.csv')
target_data_df = get_target_data_df(hub_dir, hub_config.target_data_file_name)

# for each location,
# - generate target data file contents
# - save as json
json_files = []
for loc in target_data_df['location'].unique():
task_ids_tuple = (loc,)
location_data_dict = target_data_for_FluSight(target_data_df, task_ids_tuple)
file_name = json_file_name('wk inc flu hosp', task_ids_tuple, reference_date_from_today())
location_data_dict = ptc_target_data(target_data_df, task_ids_tuple)
file_name = json_file_name('wk inc flu hosp', task_ids_tuple, reference_date_from_today().isoformat())
json_files.append(target_out_dir / file_name)
with open(target_out_dir / file_name, 'w') as fp:
json.dump(location_data_dict, fp, indent=4)
Expand All @@ -53,8 +55,41 @@ def main(hub_dir, target_out_dir):


def get_target_data_df(hub_dir, target_data_filename):
# load the target data csv file from the hub repo for now, file path for target data is hard coded
return pl.read_csv(hub_dir / 'target-data' / target_data_filename, null_values=["NA"])
"""
Loads the target data csv file from the hub repo for now, file path for target data is hard coded to 'target-data'.
Raises FileNotFoundError if target data file does not exist.
"""
schema = {
"location": pl.String, # to handle the 'US' location (the only one that doesn't parse as Int64)
# "date": pl.Date,
# "value": pl.Int64
}
target_data_file_path = hub_dir / 'target-data' / target_data_filename
try:
return pl.read_csv(target_data_file_path, schema_overrides=schema, null_values=["NA"])
except FileNotFoundError as error:
raise FileNotFoundError(f"target data file not found. {target_data_file_path=}, {error=}")


def ptc_target_data(target_data_df: pl.DataFrame, task_ids_tuple: tuple[str]):
"""
Returns a dict for a single reference date in the target data format documented at https://github.com/reichlab/predtimechart?tab=readme-ov-file#fetchdata-truth-data-format.
That is, looking at that example, this function returns the date and value columns as in
tests/expected/FluSight-forecast-hub/target/wk-inc-flu-hosp_US.json :
{
"date": ["2024-04-27", "2024-04-20", "..."],
"y": [2337, 2860, "..."]
}
"""
loc = task_ids_tuple[0]
target_data_loc = target_data_df.filter(pl.col('location') == loc).sort('date')
target_data_ptc = {
"date": target_data_loc['date'].to_list(),
"y": target_data_loc['value'].to_list()
}

return target_data_ptc


def reference_date_from_today(now: date = None) -> date:
Expand Down
26 changes: 0 additions & 26 deletions src/hub_predtimechart/generate_target_data.py

This file was deleted.

6 changes: 4 additions & 2 deletions src/hub_predtimechart/hub_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self, hub_dir: Path, ptc_config_file: Path):
model_tasks_ele = tasks['rounds'][self.rounds_idx]['model_tasks'][self.model_tasks_idx]
self.task_ids = sorted(model_tasks_ele['task_ids'].keys())

# set target_data_file_name
self.target_data_file_name = ptc_config.get('target_data_file_name')

# set viz_task_ids and fetch_targets. recall: we assume there is only one target_metadata entry, only one
# entry under its `target_keys`
target_metadata = model_tasks_ele['target_metadata'][0]
Expand Down Expand Up @@ -136,13 +139,12 @@ def get_sorted_values_or_first_config_ref_date(reference_dates: set[str]):
return sorted(list(reference_dates))


# loop over every (reference_date X model_id) combination.
# loop over every (reference_date X model_id) combination
as_ofs = {self.fetch_target_id: set()}
for reference_date in self.reference_dates: # ex: ['2022-10-22', '2022-10-29', ...]
for model_id in self.model_id_to_metadata: # ex: ['Flusight-baseline', 'MOBS-GLEAM_FLUH', ...]
model_output_file = self.model_output_file_for_ref_date(model_id, reference_date)
if model_output_file:
# todo xx extract to function, call from here and _generate_json_files()
if model_output_file.suffix == '.csv':
df = pd.read_csv(model_output_file, usecols=[self.target_col_name])
elif model_output_file.suffix in ['.parquet', '.pqt']:
Expand Down
7 changes: 6 additions & 1 deletion src/hub_predtimechart/ptc_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
"type": "string"
}
}
}
},
"target_data_file_name": {
"description": "optional name of the target data file located in the hub's target-data dir",
"type": "string",
"minLength": 1
},
},
"required": [
"rounds_idx",
Expand Down
12 changes: 12 additions & 0 deletions tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_01.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"date": [
"2024-11-09",
"2024-11-16",
"2024-11-23"
],
"y": [
112,
81,
70
]
}
12 changes: 12 additions & 0 deletions tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_US.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"date": [
"2024-11-09",
"2024-11-16",
"2024-11-23"
],
"y": [
7691,
7595,
7290
]
}
30 changes: 21 additions & 9 deletions tests/hub_predtimechart/test_generate_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@
from pathlib import Path

import polars as pl
import pytest
from freezegun import freeze_time

from hub_predtimechart.app.generate_target_json_files_FluSight import reference_date_from_today, get_target_data_df
from hub_predtimechart.generate_target_data import target_data_for_FluSight
from hub_predtimechart.app.generate_target_json_files import reference_date_from_today, get_target_data_df, \
ptc_target_data


def test_generate_target_data_flusight_forecast_hub():
hub_dir = Path('tests/hubs/FluSight-forecast-hub')
target_data_df = pl.read_csv(hub_dir / "target-data/target-hospital-admissions.csv")

target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions.csv')
for loc in ['US', '01']:
task_ids_tuple = (loc,)
with open(f'tests/expected/FluSight-forecast-hub/target/wk-inc-flu-hosp_{loc}.json') as fp:
exp_data = json.load(fp)
act_data = ptc_target_data(target_data_df, task_ids_tuple)
assert act_data == exp_data


# case: Flusight-baseline
model_output_file = hub_dir / 'model-output/Flusight-baseline/2022-10-22-Flusight-baseline.csv'
act_data = target_data_for_FluSight(target_data_df, task_ids_tuple)
assert act_data == exp_data
def test_generate_target_data_covid19_forecast_hub():
hub_dir = Path('tests/hubs/covid19-forecast-hub')
target_data_df = get_target_data_df(hub_dir, 'covid-hospital-admissions.csv')
for loc in ['US', '01']:
task_ids_tuple = (loc,)
with open(f'tests/expected/covid19-forecast-hub/target/wk-inc-flu-hosp_{loc}.json') as fp:
exp_data = json.load(fp)
act_data = ptc_target_data(target_data_df, task_ids_tuple)
assert act_data == exp_data


@freeze_time("2024-10-24")
Expand All @@ -40,7 +48,7 @@ def test_reference_date_from_today():
assert act_reference_date == exp_reference_date


def test_get_target_data_df():
def test_get_target_data_df_error_cases():
hub_dir = Path('tests/hubs/FluSight-forecast-hub')
act_target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions-no-na.csv')
assert act_target_data_df["value"].dtype == pl.datatypes.Int64
Expand All @@ -49,3 +57,7 @@ def test_get_target_data_df():
act_target_data_df = get_target_data_df(hub_dir, 'target-hospital-admissions-yes-na.csv')
assert act_target_data_df["value"].dtype == pl.datatypes.Int64
assert act_target_data_df["value"].to_list() == [3, 16, None, 106, 151, 23, 64, 8, 2, 266]

# case: file not found
with pytest.raises(FileNotFoundError, match="target data file not found"):
get_target_data_df(hub_dir, 'non-existent-file.csv')
1 change: 1 addition & 0 deletions tests/hub_predtimechart/test_hub_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_hub_config_complex_forecast_hub():
assert (sorted(list(hub_config.model_id_to_metadata.keys())) ==
sorted(['Flusight-baseline', 'MOBS-GLEAM_FLUH', 'PSI-DICE']))
assert hub_config.task_ids == sorted(['reference_date', 'target', 'horizon', 'location', 'target_end_date'])
assert hub_config.target_data_file_name == 'covid-hospital-admissions.csv'
assert hub_config.target_col_name == 'target'
assert hub_config.viz_task_ids == sorted(['location'])
assert hub_config.fetch_target_id == 'wk inc flu hosp'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ target_date_col_name: 'target_end_date'
horizon_col_name: 'horizon'
initial_checked_models: ['FluSight-baseline', 'FluSight-ensemble']
disclaimer: Most forecasts have failed to reliably predict rapid changes in the trends of reported cases and hospitalizations. Due to this limitation, they should not be relied upon for decisions about the possibility or timing of rapid changes in trends.
target_data_file_name: 'target-hospital-admissions.csv'
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ task_id_text:
"72": "Puerto Rico"
"74": "U.S. Minor Outlying Islands"
"78": "Virgin Islands"
target_data_file_name: 'covid-hospital-admissions.csv'
Loading

0 comments on commit 6492920

Please sign in to comment.