Skip to content

Commit 6841b1a

Browse files
Copilotmanujosephv
andcommitted
Apply ruff linting and formatting
Co-authored-by: manujosephv <10508493+manujosephv@users.noreply.github.com>
1 parent ac4c5fd commit 6841b1a

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import re
88
from dataclasses import MISSING, dataclass, field
9-
from typing import Any, Dict, Iterable, List, Optional, Union
9+
from typing import Any, Dict, Iterable, List, Optional
1010

1111
from omegaconf import OmegaConf
1212

@@ -192,9 +192,9 @@ class DataConfig:
192192
)
193193

194194
def __post_init__(self):
195-
assert (
196-
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
197-
), "There should be at-least one feature defined in categorical, continuous, or date columns"
195+
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
196+
"There should be at-least one feature defined in categorical, continuous, or date columns"
197+
)
198198
_validate_choices(self)
199199
if os.name == "nt" and self.num_workers != 0:
200200
print("Windows does not support num_workers > 0. Setting num_workers to 0")
@@ -255,9 +255,9 @@ class InferredConfig:
255255

256256
def __post_init__(self):
257257
if self.embedding_dims is not None:
258-
assert all(
259-
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
260-
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
258+
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
259+
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
260+
)
261261
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
262262
else:
263263
self.embedded_cat_dim = 0

tests/test_config.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
"""Tests for config classes."""
33

4-
import pytest
54
from omegaconf import OmegaConf
65

76
from pytorch_tabular.config import TrainerConfig
@@ -15,7 +14,7 @@ def test_devices_list_to_devices_conversion(self):
1514
# Test with a list of devices
1615
trainer_config = TrainerConfig(devices_list=[0, 1])
1716
assert trainer_config.devices == [0, 1]
18-
17+
1918
# Wrap with OmegaConf as done in TabularModel
2019
config = OmegaConf.structured(trainer_config)
2120
assert config.devices == [0, 1]
@@ -24,31 +23,31 @@ def test_devices_list_multiple_gpus(self):
2423
"""Test devices_list with multiple GPU IDs as documented."""
2524
trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4])
2625
assert trainer_config.devices == [1, 2, 3, 4]
27-
26+
2827
config = OmegaConf.structured(trainer_config)
2928
assert config.devices == [1, 2, 3, 4]
3029

3130
def test_devices_int_value(self):
3231
"""Test that devices accepts integer values."""
3332
trainer_config = TrainerConfig(devices=2)
3433
assert trainer_config.devices == 2
35-
34+
3635
config = OmegaConf.structured(trainer_config)
3736
assert config.devices == 2
3837

3938
def test_devices_default_value(self):
4039
"""Test that devices has default value of -1."""
4140
trainer_config = TrainerConfig()
4241
assert trainer_config.devices == -1
43-
42+
4443
config = OmegaConf.structured(trainer_config)
4544
assert config.devices == -1
4645

4746
def test_devices_list_single_device(self):
4847
"""Test devices_list with a single device."""
4948
trainer_config = TrainerConfig(devices_list=[0])
5049
assert trainer_config.devices == [0]
51-
50+
5251
config = OmegaConf.structured(trainer_config)
5352
assert config.devices == [0]
5453

@@ -57,21 +56,18 @@ def test_devices_list_precedence(self):
5756
# When both are provided, devices_list should take precedence
5857
trainer_config = TrainerConfig(devices=2, devices_list=[0, 1])
5958
assert trainer_config.devices == [0, 1]
60-
59+
6160
config = OmegaConf.structured(trainer_config)
6261
assert config.devices == [0, 1]
6362

6463
def test_omegaconf_merge_compatibility(self):
6564
"""Test that config works correctly with OmegaConf.merge."""
6665
trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10)
6766
config = OmegaConf.structured(trainer_config)
68-
67+
6968
# Simulate merging as done in TabularModel
70-
merged = OmegaConf.merge(
71-
OmegaConf.to_container(config),
72-
{"accelerator": "gpu"}
73-
)
74-
69+
merged = OmegaConf.merge(OmegaConf.to_container(config), {"accelerator": "gpu"})
70+
7571
assert merged.devices == [0, 1]
7672
assert merged.max_epochs == 10
7773
assert merged.accelerator == "gpu"

0 commit comments

Comments
 (0)