NBeatsNetwork¶
- class NBeatsNetwork(horizon=1, stacks=None, num_blocks_per_stack=2, units=30, num_trend_coefficients=3, num_seasonal_coefficients=5, num_generic_coefficients=7, share_weights=True, share_coefficients=True)[source]¶
Bases:
BaseDeepLearningNetworkImplementation of the N-BEATS network architecture.
N-BEATS (Neural Basis Expansion Analysis for Interpretable Time Series Forecasting) is a deep learning model for univariate time series forecasting. It uses a backward and forward residual connection structure and is composed of stacks of fully connected layers.
The architecture consists of multiple stacks, each containing several blocks. Each block generates a forecast and a backcast. The backcast is subtracted from the block’s input, and the residual is passed to the next block. The forecasts from all blocks are summed up to produce the final prediction.
This implementation is based on the paper by _[1].
- Parameters:
- horizonint
The length of the forecast horizon.
- stackslist of str, default=[“trend”, “seasonality”]
A list of stack types. Allowed types are “trend”, “seasonality”, and “generic”.
- num_blocks_per_stackint, default=3
The number of blocks within each stack.
- unitsint, default=256
The number of hidden units in the fully connected layers of each block.
- num_trend_coefficientsint, default=3
The number of polynomial coefficients for the trend block.
- num_seasonal_coefficientsint, default=5
The number of Fourier coefficients for the seasonality block.
- num_generic_coefficientsint, default=7
The number of coefficients for the generic block.
- share_weightsbool, default=True
If True, weights of the fully connected layers are shared across all blocks within a stack.
- share_coefficientsbool, default=True
If True, the backcast and forecast of each block share the same basis expansion coefficients.
References
[1]Oreshkin, B. N., Carpov, D., Chapados, N., & Bengio, Y. (2019). N-BEATS: Neural basis expansion analysis for interpretable time series forecasting. arXiv preprint arXiv:1905.10437.
Methods
build_network(input_shape, **kwargs)Build the N-BEATS network.