|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | + |
| 6 | +def randorient(k, p, xi): |
| 7 | + # Step length |
| 8 | + Delta = xi / (p - 1) |
| 9 | + |
| 10 | + m = k + 1 |
| 11 | + |
| 12 | + # A truncated p-level grid in one dimension |
| 13 | + xs = np.arange(0, 1, Delta) |
| 14 | + xsl = len(xs) |
| 15 | + |
| 16 | + # Basic sampling matrix |
| 17 | + B = np.vstack((np.zeros((1, k)), np.tril(np.ones((k, k))))) |
| 18 | + |
| 19 | + # Randomization |
| 20 | + |
| 21 | + # Matrix with +1s and -1s on the diagonal with equal probability |
| 22 | + Dstar = np.diag(2 * np.round(np.random.rand(k)) - 1) |
| 23 | + |
| 24 | + # Random base value |
| 25 | + xstar = xs[(np.random.rand(k) * xsl).astype(int)] |
| 26 | + |
| 27 | + # Permutation matrix |
| 28 | + Pstar = np.zeros((k, k)) |
| 29 | + rp = np.random.permutation(k) |
| 30 | + for i in range(k): |
| 31 | + Pstar[i, rp[i]] = 1 |
| 32 | + |
| 33 | + # A random orientation of the sampling matrix |
| 34 | + Bstar = (np.ones((m, 1)) @ xstar.reshape(1, -1) + (Delta / 2) * ((2 * B - np.ones((m, k))) @ Dstar + np.ones((m, k)))) @ Pstar |
| 35 | + |
| 36 | + return Bstar |
| 37 | + |
| 38 | + |
| 39 | +def screeningplan(k, p, xi, r): |
| 40 | + # Empty list to accumulate screening plan rows |
| 41 | + X = [] |
| 42 | + |
| 43 | + for i in range(r): |
| 44 | + X.append(randorient(k, p, xi)) |
| 45 | + |
| 46 | + # Concatenate list of arrays into a single array |
| 47 | + X = np.vstack(X) |
| 48 | + |
| 49 | + return X |
| 50 | + |
| 51 | + |
| 52 | +def screening(X, objhandle, range_, xi, p, labels, print=False) -> pd.DataFrame: |
| 53 | + """ |
| 54 | + Screening method for global sensitivity analysis. |
| 55 | +
|
| 56 | + Args: |
| 57 | + X (np.ndarray): Design matrix with shape (n, k), where n is the number of design points and k is the number of design variables. |
| 58 | + objhandle (function): Objective function to evaluate the design points. |
| 59 | + range_ (np.ndarray): Array with shape (2, k) with the lower and upper bounds for each design variable. |
| 60 | + xi (float): Step length. |
| 61 | + p (int): Number of levels. |
| 62 | + labels (list): List with the names of the design variables. |
| 63 | + print (bool): If True, print the results in a table. If False, plot the results. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + pd.DataFrame: Table with the mean and standard deviation of the elementary effects |
| 67 | +
|
| 68 | +
|
| 69 | + """ |
| 70 | + # Determine the number of design variables (k) |
| 71 | + k = X.shape[1] |
| 72 | + # Determine the number of repetitions (r) |
| 73 | + r = X.shape[0] // (k + 1) |
| 74 | + |
| 75 | + # Scale each design point to the given range and evaluate the objective function |
| 76 | + t = np.zeros(X.shape[0]) |
| 77 | + for i in range(X.shape[0]): |
| 78 | + # X[i, :] = range_[0, :] + X[i, :] * (range_[1, :] - range_[0, :]) |
| 79 | + t[i] = objhandle(X[i, :]) |
| 80 | + |
| 81 | + # Calculate the elementary effects |
| 82 | + F = np.zeros((k, r)) |
| 83 | + for i in range(r): |
| 84 | + for j in range(i * (k + 1), i * (k + 1) + k): |
| 85 | + index = np.where(X[j, :] - X[j + 1, :] != 0)[0][0] |
| 86 | + F[index, i] = (t[j + 1] - t[j]) / (xi / (p - 1)) |
| 87 | + |
| 88 | + # Compute statistical measures |
| 89 | + ssd = np.std(F, axis=1) |
| 90 | + sm = np.abs(np.mean(F, axis=1)) |
| 91 | + |
| 92 | + if print: |
| 93 | + # sort the variables by decreasing mean |
| 94 | + idx = np.argsort(-sm) |
| 95 | + labels = [labels[i] for i in idx] |
| 96 | + sm = sm[idx] |
| 97 | + ssd = ssd[idx] |
| 98 | + df = pd.DataFrame({"varname": labels, "mean": sm, "sd": ssd}) |
| 99 | + |
| 100 | + return df |
| 101 | + else: |
| 102 | + # Generate plot |
| 103 | + plt.figure() |
| 104 | + |
| 105 | + for i in range(k): |
| 106 | + plt.text(sm[i], ssd[i], labels[i], fontsize=10) |
| 107 | + |
| 108 | + plt.axis([min(sm), 1.1 * max(sm), min(ssd), 1.1 * max(ssd)]) |
| 109 | + plt.xlabel("Sample means") |
| 110 | + plt.ylabel("Sample standard deviations") |
| 111 | + plt.gca().set_xlabel("Sample means") |
| 112 | + plt.gca().set_ylabel("Sample standard deviations") |
| 113 | + plt.gca().tick_params(labelsize=10) |
| 114 | + plt.grid(True) |
| 115 | + plt.show() |
0 commit comments