Skip to content

Commit

Permalink
reformed methods, ?: how to get proper test batches
Browse files Browse the repository at this point in the history
  • Loading branch information
leostre committed May 13, 2024
1 parent 07b2f10 commit a8da9e6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 32 deletions.
111 changes: 80 additions & 31 deletions fedot_ind/core/models/nn/network_impl/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,28 @@
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.operations.evaluation.operation_implementations.data_operations.ts_transformations import \
transform_features_and_target_into_lagged
from fedot_ind.core.operation.transformation.data.hankel import HankelMatrix
from fedot_ind.core.architecture.preprocessing.data_convertor import DataConverter
import torch.utils.data as data
from fedot_ind.core.architecture.settings.computational import default_device
import torch.optim.lr_scheduler as lr_scheduler
from fedot.core.data.data_split import train_test_data_setup


class _TSScaler(Module):
def __init__(self):
super().__init__()
self.factors = None
self.eps = 1e-10

def forward(self, x, normalize=True):
if normalize:
self.means = x.mean(dim=-1, keepdim=True)
self.factors = torch.sqrt(x.std(dim=-1, keepdim=True,
unbiased=False)) + self.eps
return (x - self.means) / self.factors
else:
return x * self.factors + self.means

class DeepARModule(Module):
_loss_fns = {
Expand All @@ -39,11 +55,12 @@ def __init__(self, cell_type, input_size, hidden_size, rnn_layers, dropout, dist
dropout = dropout if rnn_layers > 1 else 0.
)
self.hidden_size = hidden_size
self.scaler = RevIN(
affine=False,
input_dim=input_size,
dim=-1, # -1 in case series-wise normalization, 0 for batch-wise, RNN needs series_wise
)
self.scaler = _TSScaler()
# self.scaler = RevIN(
# affine=False,
# input_dim=input_size,
# dim=-1, # -1 in case series-wise normalization, 0 for batch-wise, RNN needs series_wise
# )
self.distribution = self._loss_fns[distribution]
if distribution is not None:
self.projector = Linear(self.hidden_size, len(self.distribution.distribution_arguments))
Expand All @@ -62,19 +79,20 @@ def encode(self, ts: torch.Tensor):
return hidden_state

def _decode_whole_seq(self, ts: torch.Tensor, hidden_state: torch.Tensor):
""" used for next value predition"""
output, hidden_state = self.rnn(
ts, hidden_state
)
output = self.projector(output)
return output, hidden_state


def forward(self, x: torch.Tensor, n_samples: int = None, mode='raw'):
"""
Forward pass
x.size == (nseries, length)
"""
x = self.scaler(x, mode=True)
x = self.scaler(x, normalize=True)
hidden_state = self.encode(x)
# decode

Expand All @@ -101,6 +119,7 @@ def to_predictions(self, params: torch.Tensor):
return distr.sample((1,)).T.squeeze() # distr_n x 1

def _transform_params(self, distr_params, mode='raw'):
# factors =
if mode == 'raw':
transformed = distr_params
elif mode == 'quantiles':
Expand All @@ -122,13 +141,11 @@ def predict(self, test_x: torch.Tensor, mode=None):
def decode(self, x, hidden_state=None, n_samples=0, mode='raw'):
if hidden_state is None:
hidden_state = torch.zeros((self.hidden_size,)).float()

if not n_samples:
output, _ = self._decode_whole_seq(x, hidden_state)
output = self._transform_params(output, mode=mode)
else:
x = x.repeat_interleave(n_samples, 0)
hidden_state = self.rnn.repeat_interleave(hidden_state, n_samples)

# make predictions which are fed into next step
output = self.decode_autoregressive(
first_target=x[:, 0],
Expand All @@ -137,17 +154,16 @@ def decode(self, x, hidden_state=None, n_samples=0, mode='raw'):
n_decoder_steps=x.size(1),
n_samples=n_samples,
)

return output


def _decode_one(self, x,
idx,
hidden_state,
):
x = x[:, [idx], ...]
x = x[..., [idx]]
prediction, hidden_state = self._decode_whole_seq(x, hidden_state)
prediction = prediction[:, 0] # select first time step fo this index
prediction = prediction[:, [0], ...] # select first time step fo this index
return prediction, hidden_state

def decode_autoregressive(
Expand Down Expand Up @@ -217,6 +233,9 @@ def __init__(self, params: Optional[OperationParameters] = {}):
self.forecast_mode = params.get('forecast_mode', 'raw')
self.quantiles = params.get('quantiles', None)

self.test_patch_len = None



def _init_model(self, ts) -> tuple:
self.loss_fn = DeepARModule._loss_fns[self.expected_distribution]()
Expand All @@ -238,10 +257,18 @@ def _init_model(self, ts) -> tuple:

return self.loss_fn, self.optimizer

def fit(self, input_data: InputData):
self._fit_model(input_data)

def _fit_model(self, input_data: InputData, split_data: bool = False):
def fit(self, input_data: InputData, split_data: bool = False):
train_loader, val_loader = self._prepare_data(input_data, split_data=split_data)
loss_fn, optimizer = self._init_model(input_data)
self._train_loop(model=self.model,
train_loader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
val_loader=val_loader,
)
return self

def _prepare_data(self, input_data: InputData, split_data):
val_loader = None
if self.preprocess_to_lagged:
self.patch_len = input_data.features.shape[-1]
Expand All @@ -251,19 +278,14 @@ def _fit_model(self, input_data: InputData, split_data: bool = False):
dominant_window_size = WindowSizeSelector(
method='dff').get_window_size(input_data.features)
self.patch_len = 2 * dominant_window_size
train_loader, val_loader = self._prepare_data(
train_loader, val_loader = self._get_train_val_loaders(
input_data.features, self.patch_len, split_data)

self.test_patch_len = self.patch_len
loss_fn, optimizer = self._init_model(input_data)
return self._train_loop(model=self.model,
train_loader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
val_loader=val_loader,
)
return train_loader, val_loader

def _predict(self, test_loader, output_mode):

def _predict_loop(self, test_loader, output_mode):
model = self.model # or model for inference?
output = model.predict(test_loader, output_mode)

Expand Down Expand Up @@ -329,7 +351,7 @@ def _train_loop(self, model,
iter_count += 1
optimizer.zero_grad()
batch_x = batch_x.float().to(default_device())
batch_y = batch_y.float().to(default_device())
batch_y = batch_y[:, ..., [0]].float().to(default_device()) # only first entrance
outputs, *hidden_state = model(batch_x)
# return batch_x, outputs, batch_y

Expand Down Expand Up @@ -368,7 +390,7 @@ def _train_loop(self, model,
scheduler.get_last_lr()[0]))
return best_model

def _predict_loop(self, test_loader):
def __predict_loop(self, test_loader):
outputs = []
with torch.no_grad():
for x_test in test_loader:
Expand All @@ -385,7 +407,7 @@ def _create_dataset(self,
freq: int = 1):
return ts

def _prepare_data(self,
def _get_train_val_loaders(self,
ts,
patch_len=None,
split_data: bool = True,
Expand Down Expand Up @@ -419,10 +441,12 @@ def __ts_to_input_data(self, input_data: Union[InputData, pd.DataFrame]):
if isinstance(input_data, InputData):
return input_data

if not isinstance(input_data, pd.DataFrame):
time_series = pd.DataFrame(input_data)
task = Task(TaskTypesEnum.ts_forecasting,
TsForecastingParams(forecast_length=self.horizon))

if not isinstance(input_data, pd.DataFrame):
time_series = pd.DataFrame(input_data)

if 'datetime' in time_series.columns:
idx = pd.to_datetime(time_series['datetime'].values)
else:
Expand All @@ -445,4 +469,29 @@ def __create_torch_loader(self, train_data):
batch_size=self.batch_size, shuffle=False)
return train_loader


def _get_test_loader(self,
test_data: Union[InputData, torch.Tensor]):
test_data = self.__ts_to_input_data(test_data)
if len(test_data.features.shape) == 1:
test_data.features = test_data.features[None, ...]

if not self.preprocess_to_lagged:
features = HankelMatrix(time_series=test_data.features,
window_size=self.test_patch_len or self.patch_len).trajectory_matrix
features = torch.from_numpy(DataConverter(data=features).
convert_to_torch_format()).float().permute(2, 1, 0)
target = torch.from_numpy(DataConverter(
data=features).convert_to_torch_format()).float()
else:
features = test_data.features
features = torch.from_numpy(DataConverter(data=features).
convert_to_torch_format()).float()
target = torch.from_numpy(DataConverter(
data=features).convert_to_torch_format()).float()

test_loader = torch.utils.data.DataLoader(data.TensorDataset(features, target),
batch_size=self.batch_size, shuffle=False)
return test_loader


5 changes: 4 additions & 1 deletion fedot_ind/core/models/nn/network_modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ class DistributionLoss(nn.Module):
distribution_arguments: List[str]
quantiles: List[float] = [.05, .25, .5, .75, .95]
need_affine=True
scale_dependent_idx = tuple()
loc_dependent_idx = tuple()

def __init__(
self, reduction="mean",
Expand Down Expand Up @@ -312,11 +314,12 @@ class NormalDistributionLoss(DistributionLoss):

distribution_class = distributions.Normal
distribution_arguments = ["loc", "scale"]
scale_dependent_idx = (1,)
loc_dependent_idx = (0,)
need_affine=False

@classmethod
def _map_x_to_distribution(self, x: torch.Tensor) -> distributions.Normal:
assert isinstance(x, torch.Tensor), 'x must be tensor!'
loc = x[..., -2]
scale = F.softplus(x[..., -1])
distr = self.distribution_class(loc=loc, scale=scale)
Expand Down

0 comments on commit a8da9e6

Please sign in to comment.