Skip to content

Commit

Permalink
fixing the errors from before
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Oct 30, 2023
1 parent 4c8a333 commit 212360a
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def __init__(
target_supplied=True,
interpolate=False,
sort_column_clone=None,
target_supplied=True,
**kwargs
):
"""
Expand All @@ -287,7 +286,6 @@ def __init__(
sort_col1 = sort_column_clone if sort_column_clone else "datetime"
self.original_df[sort_col1] = self.original_df["datetime"].astype("datetime64[ns]")
self.original_df["original_index"] = self.original_df.index
self.target_supplied1 = target_supplied
if len(self.relevant_cols3) > 0:
self.original_df[self.relevant_cols3] = self.df[self.relevant_cols3]

Expand All @@ -301,7 +299,7 @@ def get_from_start_date(self, forecast_start: datetime, original_df=None):
return self.__getitem__(revised_index - self.forecast_history)

def __getitem__(self, idx):
if self.target_supplied1:
if self.target_supplied:
historical_rows = self.df.iloc[idx: self.forecast_history + idx]
target_idx_start = self.forecast_history + idx
# Why aren't we using these
Expand Down Expand Up @@ -395,7 +393,7 @@ def __init__(
:param forecast_history: [description], defaults to 1
:type forecast_history: int, optional
:param no_scale: [description], defaults to True
:type no_scale: bool, optional
:type no_scale: bool, optionals
:param sort_column: [description], defaults to None
:type sort_column: [type], optional
"""
Expand Down Expand Up @@ -556,7 +554,7 @@ def df_to_numpy(pandas_stuff: pd.DataFrame):
return torch.from_numpy(pandas_stuff.to_numpy()).float()

def __getitem__(self, idx):
if self.target_supplied1:
if self.target_supplied:
historical_rows = self.df.iloc[idx: self.forecast_history + idx]
target_idx_start = self.forecast_history + idx
# Why aren't we using these
Expand Down Expand Up @@ -671,7 +669,7 @@ def __init__(self, series_id_col: str, main_params: dict, return_method: str, re
"""
super().__init__(series_id_col, main_params, return_method, return_all)
self.forecast_total = forecast_total
self.csv_test_loaders = [CSVTestLoader(loader_1, 336, kwargs=main_params) for loader_1 in self.listed_vals]
self.csv_test_loaders = [CSVTestLoader(loader_1, 336, **main_params) for loader_1 in self.listed_vals]

def get_from_start_date_all(self, forecast_start: datetime, series_id: int = None):
res = []
Expand Down

0 comments on commit 212360a

Please sign in to comment.