Skip to content

Commit 3f6c1a8

Browse files
0.27.7
1 parent c95f526 commit 3f6c1a8

3 files changed

Lines changed: 50 additions & 12 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.27.6"
10+
version = "0.27.7"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/stats.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -786,24 +786,36 @@ def preprocess_df_for_ols(df, independent_var_columns, target_col) -> tuple:
786786
return X_encoded, y
787787

788788

789-
def get_combinations(z_ind: list) -> list:
789+
def get_combinations(ind_list: list, type="indices") -> list:
790790
"""
791791
Generates all possible combinations of two targets from a list of target indices. Order is not important.
792792
793793
Args:
794-
z_ind (list): A list of target indices.
794+
ind_list (list): A list of target indices.
795795
796796
Returns:
797797
list: A list of tuples, where each tuple contains a combination of two target indices.
798798
The order of the targets within a tuple is not important, and each combination
799799
appears only once.
800+
type (str): The type of output, either 'values' or 'indices'. Default is 'indices'.
800801
801802
Examples:
802803
>>> from spotpython.utils.stats import get_combinations
803-
>>> z_ind = [0, 1, 2, 30]
804-
>>> combinations = get_combinations(z_ind)
805-
>>> print(combinations)
806-
[(0, 1), (0, 2), (0, 30), (1, 2), (1, 30), (2, 30)]
804+
>>> ind_list = [0, 10, 20, 30]
805+
>>> combinations = get_combinations(ind_list)
806+
>>> combinations = get_combinations(ind_list, type='indices')
807+
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
808+
>>> print(combinations, type='values')
809+
[(0, 10), (0, 20), (0, 30), (1, 20), (1, 30), (2, 30)]
807810
"""
808-
combinations = [(z_ind[i], z_ind[j]) for i in range(len(z_ind)) for j in range(i + 1, len(z_ind))]
811+
# check that ind_list is a list
812+
if not isinstance(ind_list, list):
813+
raise ValueError("ind_list must be a list.")
814+
m = len(ind_list)
815+
if type == "values":
816+
combinations = [(ind_list[i], ind_list[j]) for i in range(m) for j in range(i + 1, m)]
817+
elif type == "indices":
818+
combinations = [(i, j) for i in range(m) for j in range(i + 1, m)]
819+
else:
820+
raise ValueError("type must be either 'values' or 'indices'.")
809821
return combinations

test/test_get_combinations.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,49 @@
44
def test_get_combinations_empty_list():
55
"""Test get_combinations with an empty list."""
66
assert get_combinations([]) == []
7+
assert get_combinations([], type="indices") == []
8+
assert get_combinations([], type="values") == []
79

810
def test_get_combinations_single_element():
911
"""Test get_combinations with a single element."""
1012
assert get_combinations([0]) == []
13+
assert get_combinations([0], type="indices") == []
14+
assert get_combinations([0], type="values") == []
1115

1216
def test_get_combinations_two_elements():
1317
"""Test get_combinations with two elements."""
1418
assert get_combinations([0, 1]) == [(0, 1)]
19+
assert get_combinations([0, 1], type="indices") == [(0, 1)]
20+
assert get_combinations([0, 1], type="values") == [(0, 1)]
1521

1622
def test_get_combinations_multiple_elements():
1723
"""Test get_combinations with multiple elements."""
1824
z_ind = [0, 1, 2, 3]
19-
expected = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
20-
assert get_combinations(z_ind) == expected
25+
expected_indices = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
26+
expected_values = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
27+
assert get_combinations(z_ind) == expected_indices
28+
assert get_combinations(z_ind, type="indices") == expected_indices
29+
assert get_combinations(z_ind, type="values") == expected_values
2130

2231
def test_get_combinations_non_sequential_indices():
2332
"""Test get_combinations with non-sequential indices."""
2433
z_ind = [10, 20, 30]
25-
expected = [(10, 20), (10, 30), (20, 30)] # Indices are based on values, not indices
26-
assert get_combinations(z_ind) == expected
34+
expected_indices = [(0, 1), (0, 2), (1, 2)]
35+
expected_values = [(10, 20), (10, 30), (20, 30)]
36+
assert get_combinations(z_ind) == expected_indices
37+
assert get_combinations(z_ind, type="indices") == expected_indices
38+
assert get_combinations(z_ind, type="values") == expected_values
39+
40+
def test_get_combinations_mixed_values():
41+
"""Test get_combinations with mixed values."""
42+
z_ind = [0, 10, 20, 30]
43+
expected_indices = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
44+
expected_values = [(0, 10), (0, 20), (0, 30), (10, 20), (10, 30), (20, 30)]
45+
assert get_combinations(z_ind) == expected_indices
46+
assert get_combinations(z_ind, type="indices") == expected_indices
47+
assert get_combinations(z_ind, type="values") == expected_values
48+
49+
def test_get_combinations_invalid_type():
50+
"""Test get_combinations with an invalid type."""
51+
with pytest.raises(ValueError):
52+
get_combinations([1, 2, 3], type="invalid")

0 commit comments

Comments
 (0)