From aea9060a5ef46e82e5c87c1423f4a5cbdc440ea8 Mon Sep 17 00:00:00 2001 From: Eden Rochman Date: Fri, 27 Mar 2026 13:59:53 +0000 Subject: [PATCH] Fix TCN predict crash on single-sample last batch When the final batch has a single sample, model output becomes a 0-d numpy array after .cpu().numpy(). np.concatenate cannot mix 0-d and 1-d arrays, causing a crash. Wrap with np.atleast_1d() before appending to preds list. Fixes #1752 --- qlib/contrib/model/pytorch_tcn_ts.py | 2 +- tests/test_tcn_single_sample_batch.py | 39 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 tests/test_tcn_single_sample_batch.py diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c3..f39b097cfe6 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -279,7 +279,7 @@ def predict(self, dataset): with torch.no_grad(): pred = self.TCN_model(feature.float()).detach().cpu().numpy() - preds.append(pred) + preds.append(np.atleast_1d(pred)) return pd.Series(np.concatenate(preds), index=dl_test.get_index()) diff --git a/tests/test_tcn_single_sample_batch.py b/tests/test_tcn_single_sample_batch.py new file mode 100644 index 00000000000..c34e50db3bd --- /dev/null +++ b/tests/test_tcn_single_sample_batch.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for TCN single-sample batch fix (Fixes #1752). + +When the last batch contains a single sample, the model output after +.cpu().numpy() may be a 0-d array. np.concatenate then fails because +it cannot concatenate 0-d arrays with 1-d arrays. Wrapping with +np.atleast_1d ensures all arrays are at least 1-d. +""" + +import numpy as np +import pytest + + +def test_concatenate_mixed_0d_and_1d(): + """np.concatenate fails with raw 0-d + 1-d arrays, but works with atleast_1d.""" + arr_1d = np.array([1.0, 2.0, 3.0]) + arr_0d = np.float64(4.0) # simulates single-sample .numpy() result + + # Without fix this would raise + with pytest.raises((ValueError, np.exceptions.AxisError)): + np.concatenate([arr_1d, arr_0d]) + + # With atleast_1d it works + result = np.concatenate([np.atleast_1d(arr_1d), np.atleast_1d(arr_0d)]) + np.testing.assert_array_equal(result, [1.0, 2.0, 3.0, 4.0]) + + +def test_atleast_1d_preserves_normal_arrays(): + """atleast_1d should be a no-op for arrays that are already >= 1-d.""" + arr = np.array([5.0, 6.0]) + result = np.atleast_1d(arr) + np.testing.assert_array_equal(result, arr) + assert result.ndim >= 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])