11#!/usr/bin/env python
22"""Tests for config classes."""
33
4- import pytest
54from omegaconf import OmegaConf
65
76from pytorch_tabular .config import TrainerConfig
@@ -15,7 +14,7 @@ def test_devices_list_to_devices_conversion(self):
1514 # Test with a list of devices
1615 trainer_config = TrainerConfig (devices_list = [0 , 1 ])
1716 assert trainer_config .devices == [0 , 1 ]
18-
17+
1918 # Wrap with OmegaConf as done in TabularModel
2019 config = OmegaConf .structured (trainer_config )
2120 assert config .devices == [0 , 1 ]
@@ -24,31 +23,31 @@ def test_devices_list_multiple_gpus(self):
2423 """Test devices_list with multiple GPU IDs as documented."""
2524 trainer_config = TrainerConfig (devices_list = [1 , 2 , 3 , 4 ])
2625 assert trainer_config .devices == [1 , 2 , 3 , 4 ]
27-
26+
2827 config = OmegaConf .structured (trainer_config )
2928 assert config .devices == [1 , 2 , 3 , 4 ]
3029
3130 def test_devices_int_value (self ):
3231 """Test that devices accepts integer values."""
3332 trainer_config = TrainerConfig (devices = 2 )
3433 assert trainer_config .devices == 2
35-
34+
3635 config = OmegaConf .structured (trainer_config )
3736 assert config .devices == 2
3837
3938 def test_devices_default_value (self ):
4039 """Test that devices has default value of -1."""
4140 trainer_config = TrainerConfig ()
4241 assert trainer_config .devices == - 1
43-
42+
4443 config = OmegaConf .structured (trainer_config )
4544 assert config .devices == - 1
4645
4746 def test_devices_list_single_device (self ):
4847 """Test devices_list with a single device."""
4948 trainer_config = TrainerConfig (devices_list = [0 ])
5049 assert trainer_config .devices == [0 ]
51-
50+
5251 config = OmegaConf .structured (trainer_config )
5352 assert config .devices == [0 ]
5453
@@ -57,21 +56,18 @@ def test_devices_list_precedence(self):
5756 # When both are provided, devices_list should take precedence
5857 trainer_config = TrainerConfig (devices = 2 , devices_list = [0 , 1 ])
5958 assert trainer_config .devices == [0 , 1 ]
60-
59+
6160 config = OmegaConf .structured (trainer_config )
6261 assert config .devices == [0 , 1 ]
6362
6463 def test_omegaconf_merge_compatibility (self ):
6564 """Test that config works correctly with OmegaConf.merge."""
6665 trainer_config = TrainerConfig (devices_list = [0 , 1 ], max_epochs = 10 )
6766 config = OmegaConf .structured (trainer_config )
68-
67+
6968 # Simulate merging as done in TabularModel
70- merged = OmegaConf .merge (
71- OmegaConf .to_container (config ),
72- {"accelerator" : "gpu" }
73- )
74-
69+ merged = OmegaConf .merge (OmegaConf .to_container (config ), {"accelerator" : "gpu" })
70+
7571 assert merged .devices == [0 , 1 ]
7672 assert merged .max_epochs == 10
7773 assert merged .accelerator == "gpu"
0 commit comments