Skip to content

Commit

Permalink
try test again
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Nov 2, 2023
1 parent 249ea04 commit 8174a2f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
22 changes: 21 additions & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,27 @@ def infer_on_torch_model(

def handle_evaluation_series_loader(csv_series_id_loader: SeriesIDTestLoader, model, device,
hours_to_forecast: int, datetime_start):
pass
data = csv_series_id_loader.get_from_start_date_all(datetime_start)
for i in range(0, len(data)):
history, df_train_and_test, forecast_start_idx = data[i]
print(history)
print(df_train_and_test)
print(forecast_start_idx)
"""
end_tensor = generate_predictions(
model,
df_train_and_test,
csv_series_id_loader,
history,
device,
forecast_start_idx,
model.params["dataset_params"]["forecast_length"],
hours_to_forecast,
decoder_params=None,
multi_params=1
)
print(end_tensor)"""
return


def handle_ci_multi(prediction_samples: torch.Tensor, csv_test_loader: CSVTestLoader, multi_params: int,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_series_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, model):

def test_series_test_loader(self):
loader_ds1 = SeriesIDTestLoader("PLANT_ID", self.dataset_params, "shit")
self.assertEqual(loader_ds1.get_from_start_date_all(datetime(2020, 8, 1))[0][0].shape[0], 20)
self.assertEqual(loader_ds1.get_from_start_date_all(datetime(2020, 7, 6))[0][0].shape[0], 20)
# historical_rows, all_rows_orig, targ_idx = loader_ds1.get_from_start_date_all(datetime(2020, 8, 1))[0]
# self.assertEqual(historical_rows.shape[0], 20)
# self.assertEqual(historical_rows.shape[1], 3)
Expand Down

0 comments on commit 8174a2f

Please sign in to comment.