From ce9d817a264609da4e17fa708e90b8a7bdcec9f9 Mon Sep 17 00:00:00 2001 From: vovw Date: Tue, 16 Sep 2025 00:59:12 +0530 Subject: [PATCH] Remove pkg_resources usage and replace with modern alternatives - Replace pkg_resources.parse_version() with packaging.version.parse() - Replace pkg_resources.get_distribution() with importlib.metadata.distribution() - Update imports in cebra/integrations/sklearn/cebra.py, cebra/helper.py, tests/test_sklearn.py, tests/test_plot.py - Fixes deprecation warning: pkg_resources is deprecated as an API and slated for removal as early as 2025-11-30 Resolves #271 --- cebra/helper.py | 12 ++++++------ cebra/integrations/sklearn/cebra.py | 4 ++-- tests/test_plot.py | 6 +++--- tests/test_sklearn.py | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cebra/helper.py b/cebra/helper.py index 93bae2b1..b956d7f0 100644 --- a/cebra/helper.py +++ b/cebra/helper.py @@ -32,7 +32,7 @@ import numpy as np import numpy.typing as npt -import pkg_resources +import packaging.version import requests import torch @@ -75,8 +75,8 @@ def download_file_from_zip_url(url, *, file): def _is_mps_availabe(torch): available = False - if pkg_resources.parse_version( - torch.__version__) >= pkg_resources.parse_version("1.12"): + if packaging.version.parse( + torch.__version__) >= packaging.version.parse("1.12"): if torch.backends.mps.is_available(): if torch.backends.mps.is_built(): available = True @@ -159,17 +159,17 @@ def requires_package_version(module, version: str): the required ``version``. """ - required_version = pkg_resources.parse_version(version) + required_version = packaging.version.parse(version) def _requires_package_version(function): @wraps(function) def wrapper(*args, patched_version=None, **kwargs): if patched_version is not None: - installed_version = pkg_resources.parse_version( + installed_version = packaging.version.parse( patched_version) # Use the patched version if provided else: - installed_version = pkg_resources.parse_version( + installed_version = packaging.version.parse( module.__version__) if installed_version < required_version: diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 98e56747..0f056b45 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -28,7 +28,7 @@ import numpy as np import numpy.typing as npt import packaging.version -import pkg_resources +import importlib.metadata import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch @@ -1397,7 +1397,7 @@ def save(self, 'numpy_version': np.__version__, 'sklearn_version': - pkg_resources.get_distribution("scikit-learn" + importlib.metadata.distribution("scikit-learn" ).version } }, filename) diff --git a/tests/test_plot.py b/tests/test_plot.py index 1d94d310..2bafa42b 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -25,7 +25,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np -import pkg_resources +import packaging.version import pytest import torch from sklearn.exceptions import NotFittedError @@ -190,8 +190,8 @@ def test_compare_models_with_different_versions(matplotlib_version): # minimum version of matplotlib minimum_version = "3.6" - if pkg_resources.parse_version( - matplotlib_version) < pkg_resources.parse_version(minimum_version): + if packaging.version.parse( + matplotlib_version) < packaging.version.parse(minimum_version): with pytest.raises(ImportError): cebra_plot.compare_models(models=fitted_models, patched_version=matplotlib_version) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c3d2095c..63bbbab9 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -26,7 +26,7 @@ import _util import _utils_deprecated import numpy as np -import pkg_resources +import packaging.version import pytest import sklearn.utils.estimator_checks import torch @@ -1320,8 +1320,8 @@ def test_check_device(): with pytest.raises(ValueError): cebra_sklearn_utils.check_device(device) - if pkg_resources.parse_version( - torch.__version__) >= pkg_resources.parse_version("1.12"): + if packaging.version.parse( + torch.__version__) >= packaging.version.parse("1.12"): device = "mps" torch.backends.mps.is_available = lambda: True