66__all__ = ["NBeatsForecaster" ]
77
88import numpy as np
9- import tensorflow as tf
109from sklearn .utils import check_random_state
1110
1211from aeon .forecasting .base import DirectForecastingMixin
1312from aeon .forecasting .deep_learning .base import BaseDeepForecaster
13+ from aeon .utils .validation ._dependencies import _check_soft_dependencies
14+
15+ if _check_soft_dependencies ("tensorflow" , severity = "none" ):
16+ import tensorflow as tf
17+ from tensorflow .keras .layers import Layer
18+ from tensorflow .keras .utils import Sequence
19+ else :
20+
21+ class Layer :
22+ """Dummy class for soft dependency."""
23+
24+ pass
25+
26+ class Sequence :
27+ """Dummy class for soft dependency."""
28+
29+ pass
1430
1531
1632def smape_loss (y_true , y_pred ):
1733 """Symmetric Mean Absolute Percentage Error (SMAPE) loss."""
34+ import tensorflow as tf
35+
1836 epsilon = 0.1
1937 numerator = tf .abs (y_true - y_pred )
2038 denominator = tf .abs (y_true ) + tf .abs (y_pred ) + epsilon
2139 return 200.0 * tf .reduce_mean (numerator / denominator )
2240
2341
24- class _NBeatsDataGenerator (tf . keras . utils . Sequence ):
42+ class _NBeatsDataGenerator (Sequence ):
2543 """Generates training data batches via random sampling."""
2644
2745 def __init__ (self , y , window , horizon , batch_size , steps_per_epoch , random_state ):
@@ -53,7 +71,7 @@ def __getitem__(self, index):
5371 return np .array (X_batch ), np .array (y_batch )
5472
5573
56- class _NBeatsBlock (tf . keras . layers . Layer ):
74+ class _NBeatsBlock (Layer ):
5775 """N-BEATS basic building block."""
5876
5977 def __init__ (
@@ -67,6 +85,8 @@ def __init__(
6785 ** kwargs ,
6886 ):
6987 super ().__init__ (** kwargs )
88+ import tensorflow as tf
89+
7090 self .stack_type = stack_type
7191 self .input_width = input_width
7292 self .forecast_width = forecast_width
@@ -110,6 +130,8 @@ def _basis_generic(self, theta_b, theta_f):
110130 return backcast , forecast
111131
112132 def _basis_trend (self , theta_b , theta_f ):
133+ import tensorflow as tf
134+
113135 t_b = np .linspace (0 , 1 , self .input_width )
114136 t_f = np .linspace (0 , 1 , self .forecast_width )
115137 degree = self .thetas_dim
@@ -122,6 +144,8 @@ def _basis_trend(self, theta_b, theta_f):
122144 return backcast , forecast
123145
124146 def _basis_seasonality (self , theta_b , theta_f ):
147+ import tensorflow as tf
148+
125149 t_b = np .arange (self .input_width ) / self .input_width
126150 t_f = np .arange (self .forecast_width ) / self .forecast_width
127151 harmonics = self .thetas_dim // 2
@@ -283,6 +307,8 @@ def _resolve_thetas_dim(self, stack_type):
283307
284308 def build_model (self , input_shape ):
285309 """Build the N-BEATS model."""
310+ import tensorflow as tf
311+
286312 inputs = tf .keras .Input (shape = input_shape )
287313 if len (input_shape ) > 1 :
288314 x_res = tf .keras .layers .Flatten ()(inputs )
@@ -350,6 +376,8 @@ def build_model(self, input_shape):
350376
351377 def _fit (self , y , exog = None ):
352378 """Fit the model."""
379+ import tensorflow as tf
380+
353381 rng = check_random_state (self .random_state )
354382 seed = rng .randint (0 , np .iinfo (np .int32 ).max )
355383 tf .keras .utils .set_random_seed (seed )
@@ -395,6 +423,8 @@ def _predict(self, y=None, exog=None):
395423
396424 def predict_decomposition (self , y = None ):
397425 """Predict and decompose the forecast into stack components."""
426+ import tensorflow as tf
427+
398428 X_pred = self ._prepare_input (y )
399429 outputs = []
400430 output_names = []
0 commit comments