diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index c0c7f9d14..b90f7273b 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -333,7 +333,27 @@ def infer_on_torch_model( def handle_evaluation_series_loader(csv_series_id_loader: SeriesIDTestLoader, model, device, hours_to_forecast: int, datetime_start): - pass + data = csv_series_id_loader.get_from_start_date_all(datetime_start) + for i in range(0, len(data)): + history, df_train_and_test, forecast_start_idx = data[i] + print(history) + print(df_train_and_test) + print(forecast_start_idx) + """ + end_tensor = generate_predictions( + model, + df_train_and_test, + csv_series_id_loader, + history, + device, + forecast_start_idx, + model.params["dataset_params"]["forecast_length"], + hours_to_forecast, + decoder_params=None, + multi_params=1 + ) + print(end_tensor)""" + return def handle_ci_multi(prediction_samples: torch.Tensor, csv_test_loader: CSVTestLoader, multi_params: int, diff --git a/tests/test_series_id.py b/tests/test_series_id.py index 555efe83b..36380f96d 100644 --- a/tests/test_series_id.py +++ b/tests/test_series_id.py @@ -51,7 +51,7 @@ def __init__(self, model): def test_series_test_loader(self): loader_ds1 = SeriesIDTestLoader("PLANT_ID", self.dataset_params, "shit") - self.assertEqual(loader_ds1.get_from_start_date_all(datetime(2020, 8, 1))[0][0].shape[0], 20) + self.assertEqual(loader_ds1.get_from_start_date_all(datetime(2020, 7, 6))[0][0].shape[0], 20) # historical_rows, all_rows_orig, targ_idx = loader_ds1.get_from_start_date_all(datetime(2020, 8, 1))[0] # self.assertEqual(historical_rows.shape[0], 20) # self.assertEqual(historical_rows.shape[1], 3)