Skip to content

Commit 175ba6a

Browse files
committed
Fix N-BEATS soft dependency imports for CI
1 parent 4f31786 commit 175ba6a

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

aeon/forecasting/deep_learning/_nbeats.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,40 @@
66
__all__ = ["NBeatsForecaster"]
77

88
import numpy as np
9-
import tensorflow as tf
109
from sklearn.utils import check_random_state
1110

1211
from aeon.forecasting.base import DirectForecastingMixin
1312
from aeon.forecasting.deep_learning.base import BaseDeepForecaster
13+
from aeon.utils.validation._dependencies import _check_soft_dependencies
14+
15+
if _check_soft_dependencies("tensorflow", severity="none"):
16+
import tensorflow as tf
17+
from tensorflow.keras.layers import Layer
18+
from tensorflow.keras.utils import Sequence
19+
else:
20+
21+
class Layer:
22+
"""Dummy class for soft dependency."""
23+
24+
pass
25+
26+
class Sequence:
27+
"""Dummy class for soft dependency."""
28+
29+
pass
1430

1531

1632
def smape_loss(y_true, y_pred):
1733
"""Symmetric Mean Absolute Percentage Error (SMAPE) loss."""
34+
import tensorflow as tf
35+
1836
epsilon = 0.1
1937
numerator = tf.abs(y_true - y_pred)
2038
denominator = tf.abs(y_true) + tf.abs(y_pred) + epsilon
2139
return 200.0 * tf.reduce_mean(numerator / denominator)
2240

2341

24-
class _NBeatsDataGenerator(tf.keras.utils.Sequence):
42+
class _NBeatsDataGenerator(Sequence):
2543
"""Generates training data batches via random sampling."""
2644

2745
def __init__(self, y, window, horizon, batch_size, steps_per_epoch, random_state):
@@ -53,7 +71,7 @@ def __getitem__(self, index):
5371
return np.array(X_batch), np.array(y_batch)
5472

5573

56-
class _NBeatsBlock(tf.keras.layers.Layer):
74+
class _NBeatsBlock(Layer):
5775
"""N-BEATS basic building block."""
5876

5977
def __init__(
@@ -67,6 +85,8 @@ def __init__(
6785
**kwargs,
6886
):
6987
super().__init__(**kwargs)
88+
import tensorflow as tf
89+
7090
self.stack_type = stack_type
7191
self.input_width = input_width
7292
self.forecast_width = forecast_width
@@ -110,6 +130,8 @@ def _basis_generic(self, theta_b, theta_f):
110130
return backcast, forecast
111131

112132
def _basis_trend(self, theta_b, theta_f):
133+
import tensorflow as tf
134+
113135
t_b = np.linspace(0, 1, self.input_width)
114136
t_f = np.linspace(0, 1, self.forecast_width)
115137
degree = self.thetas_dim
@@ -122,6 +144,8 @@ def _basis_trend(self, theta_b, theta_f):
122144
return backcast, forecast
123145

124146
def _basis_seasonality(self, theta_b, theta_f):
147+
import tensorflow as tf
148+
125149
t_b = np.arange(self.input_width) / self.input_width
126150
t_f = np.arange(self.forecast_width) / self.forecast_width
127151
harmonics = self.thetas_dim // 2
@@ -283,6 +307,8 @@ def _resolve_thetas_dim(self, stack_type):
283307

284308
def build_model(self, input_shape):
285309
"""Build the N-BEATS model."""
310+
import tensorflow as tf
311+
286312
inputs = tf.keras.Input(shape=input_shape)
287313
if len(input_shape) > 1:
288314
x_res = tf.keras.layers.Flatten()(inputs)
@@ -350,6 +376,8 @@ def build_model(self, input_shape):
350376

351377
def _fit(self, y, exog=None):
352378
"""Fit the model."""
379+
import tensorflow as tf
380+
353381
rng = check_random_state(self.random_state)
354382
seed = rng.randint(0, np.iinfo(np.int32).max)
355383
tf.keras.utils.set_random_seed(seed)
@@ -395,6 +423,8 @@ def _predict(self, y=None, exog=None):
395423

396424
def predict_decomposition(self, y=None):
397425
"""Predict and decompose the forecast into stack components."""
426+
import tensorflow as tf
427+
398428
X_pred = self._prepare_input(y)
399429
outputs = []
400430
output_names = []

0 commit comments

Comments
 (0)