Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cebra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import numpy as np
import numpy.typing as npt
import pkg_resources
import packaging.version
import requests
import torch

Expand Down Expand Up @@ -75,8 +75,8 @@ def download_file_from_zip_url(url, *, file):

def _is_mps_availabe(torch):
available = False
if pkg_resources.parse_version(
torch.__version__) >= pkg_resources.parse_version("1.12"):
if packaging.version.parse(
torch.__version__) >= packaging.version.parse("1.12"):
if torch.backends.mps.is_available():
if torch.backends.mps.is_built():
available = True
Expand Down Expand Up @@ -159,17 +159,17 @@ def requires_package_version(module, version: str):
the required ``version``.
"""

required_version = pkg_resources.parse_version(version)
required_version = packaging.version.parse(version)

def _requires_package_version(function):

@wraps(function)
def wrapper(*args, patched_version=None, **kwargs):
if patched_version is not None:
installed_version = pkg_resources.parse_version(
installed_version = packaging.version.parse(
patched_version) # Use the patched version if provided
else:
installed_version = pkg_resources.parse_version(
installed_version = packaging.version.parse(
module.__version__)

if installed_version < required_version:
Expand Down
4 changes: 2 additions & 2 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
import numpy.typing as npt
import packaging.version
import pkg_resources
import importlib.metadata
import sklearn
import sklearn.utils.validation as sklearn_utils_validation
import torch
Expand Down Expand Up @@ -1397,7 +1397,7 @@ def save(self,
'numpy_version':
np.__version__,
'sklearn_version':
pkg_resources.get_distribution("scikit-learn"
importlib.metadata.distribution("scikit-learn"
).version
}
}, filename)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pkg_resources
import packaging.version
import pytest
import torch
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -190,8 +190,8 @@ def test_compare_models_with_different_versions(matplotlib_version):
# minimum version of matplotlib
minimum_version = "3.6"

if pkg_resources.parse_version(
matplotlib_version) < pkg_resources.parse_version(minimum_version):
if packaging.version.parse(
matplotlib_version) < packaging.version.parse(minimum_version):
with pytest.raises(ImportError):
cebra_plot.compare_models(models=fitted_models,
patched_version=matplotlib_version)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import _util
import _utils_deprecated
import numpy as np
import pkg_resources
import packaging.version
import pytest
import sklearn.utils.estimator_checks
import torch
Expand Down Expand Up @@ -1320,8 +1320,8 @@ def test_check_device():
with pytest.raises(ValueError):
cebra_sklearn_utils.check_device(device)

if pkg_resources.parse_version(
torch.__version__) >= pkg_resources.parse_version("1.12"):
if packaging.version.parse(
torch.__version__) >= packaging.version.parse("1.12"):

device = "mps"
torch.backends.mps.is_available = lambda: True
Expand Down
Loading