NBeatsNetwork

class NBeatsNetwork(horizon=1, stacks=None, num_blocks_per_stack=2, units=30, num_trend_coefficients=3, num_seasonal_coefficients=5, num_generic_coefficients=7, share_weights=True, share_coefficients=True)[source]

Bases: BaseDeepLearningNetwork

Implementation of the N-BEATS network architecture.

N-BEATS (Neural Basis Expansion Analysis for Interpretable Time Series Forecasting) is a deep learning model for univariate time series forecasting. It uses a backward and forward residual connection structure and is composed of stacks of fully connected layers.

The architecture consists of multiple stacks, each containing several blocks. Each block generates a forecast and a backcast. The backcast is subtracted from the block’s input, and the residual is passed to the next block. The forecasts from all blocks are summed up to produce the final prediction.

This implementation is based on the paper by _[1].

Parameters:
horizonint

The length of the forecast horizon.

stackslist of str, default=[“trend”, “seasonality”]

A list of stack types. Allowed types are “trend”, “seasonality”, and “generic”.

num_blocks_per_stackint, default=3

The number of blocks within each stack.

unitsint, default=256

The number of hidden units in the fully connected layers of each block.

num_trend_coefficientsint, default=3

The number of polynomial coefficients for the trend block.

num_seasonal_coefficientsint, default=5

The number of Fourier coefficients for the seasonality block.

num_generic_coefficientsint, default=7

The number of coefficients for the generic block.

share_weightsbool, default=True

If True, weights of the fully connected layers are shared across all blocks within a stack.

share_coefficientsbool, default=True

If True, the backcast and forecast of each block share the same basis expansion coefficients.

References

[1]

Oreshkin, B. N., Carpov, D., Chapados, N., & Bengio, Y. (2019). N-BEATS: Neural basis expansion analysis for interpretable time series forecasting. arXiv preprint arXiv:1905.10437.

Methods

build_network(input_shape, **kwargs)

Build the N-BEATS network.

build_network(input_shape, **kwargs)[source]

Build the N-BEATS network.

Parameters:
input_shapetuple

Shape of the input data (n_timepoints, n_channels).

**kwargsdict

Additional keyword arguments (unused).

Returns:
tuple

(input_layer, output_layer) representing the network.