Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Multivariate models give error when predicting when n_series > batch_size #1276

Merged

Conversation

elephaint
Copy link
Contributor

@elephaint elephaint commented Feb 26, 2025

Multivariate models give an error when predicting when n_series > batch_size. I covered this for training, but not for predicting. Subtle bug, easy fix.

The below code fails without the fix, and runs correctly with it:

import pandas as pd
import matplotlib.pyplot as plt

from neuralforecast import NeuralForecast
from neuralforecast.models import TSMixerx, NHITS
from neuralforecast.utils import generate_series

N_SERIES = 1025
FREQ= 'D'
df = generate_series(n_series=N_SERIES, seed=0, freq=FREQ, equal_ends=True)
max_ds = df.ds.max() - pd.Timedelta(14, FREQ)
Y_TRAIN_DF = df[df.ds < max_ds]
Y_TEST_DF = df[df.ds >= max_ds]

models = [TSMixerx(h=12,
                input_size=24,
                n_series=N_SERIES,
                max_steps=10,
                batch_size=10,
                revin=True,
                valid_batch_size=10,
                ),
            NHITS(h=12,
                input_size=24,
                max_steps=10,
                batch_size=10,
                valid_batch_size=100,
                )
            ]
        

fcst = NeuralForecast(models=models, freq=FREQ)
fcst.fit(df=Y_TRAIN_DF)
forecasts = fcst.predict(futr_df=Y_TEST_DF)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@elephaint elephaint requested a review from marcopeix February 26, 2025 15:58
@elephaint elephaint marked this pull request as draft February 26, 2025 16:17
@elephaint elephaint marked this pull request as ready for review February 26, 2025 16:31
Copy link
Contributor

@marcopeix marcopeix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@marcopeix marcopeix merged commit bb315be into main Feb 26, 2025
17 checks passed
@marcopeix marcopeix deleted the fix/multivariate_batch_size_lower_than_n_series_predict_error branch February 26, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants