|
4 | 4 | "encoding/json" |
5 | 5 | "fmt" |
6 | 6 | "net/http" |
| 7 | + "sort" |
7 | 8 | "strings" |
8 | 9 |
|
9 | 10 | "github.com/utmstack/soc-ai/configurations" |
@@ -67,136 +68,195 @@ func ChangeAlertStatus(id string, status int, observations string) error { |
67 | 68 | return nil |
68 | 69 | } |
69 | 70 |
|
| 71 | +type AlertCounts struct { |
| 72 | + Incidents int |
| 73 | + FalsePositive int |
| 74 | + Standard int |
| 75 | + Unclassified int |
| 76 | +} |
| 77 | + |
| 78 | +type MatchTypeCounts struct { |
| 79 | + SourceIP AlertCounts |
| 80 | + DestinationIP AlertCounts |
| 81 | + SourceUser AlertCounts |
| 82 | + DestinationUser AlertCounts |
| 83 | +} |
| 84 | + |
70 | 85 | type AlertCorrelation struct { |
71 | | - CurrentAlert schema.Alert |
72 | | - RelatedAlerts []schema.Alert |
73 | | - Classifications []string |
| 86 | + CurrentAlert schema.Alert |
| 87 | + RelatedAlerts []schema.Alert |
| 88 | + Counts MatchTypeCounts |
74 | 89 | } |
75 | 90 |
|
76 | | -func GetRelatedAlerts() ([]schema.Alert, error) { |
77 | | - result, err := ElasticSearch(configurations.ALERT_INDEX_PATTERN, "", "") |
| 91 | +func GetRelatedAlerts(alertName string) ([]schema.Alert, error) { |
| 92 | + result, err := ElasticSearch(configurations.ALERT_INDEX_PATTERN, "name", alertName) |
78 | 93 | if err != nil { |
79 | 94 | return nil, fmt.Errorf("error getting historical alerts: %v", err) |
80 | 95 | } |
81 | 96 |
|
82 | 97 | var alerts []schema.Alert |
83 | | - err = json.Unmarshal(result, &alerts) |
84 | | - if err != nil { |
| 98 | + if err := json.Unmarshal(result, &alerts); err != nil { |
85 | 99 | return nil, fmt.Errorf("error unmarshalling alerts: %v", err) |
86 | 100 | } |
87 | 101 |
|
88 | 102 | return alerts, nil |
89 | 103 | } |
90 | 104 |
|
91 | | -func FindRelatedAlerts(currentAlert schema.Alert) (*AlertCorrelation, error) { |
92 | | - correlation := &AlertCorrelation{ |
93 | | - CurrentAlert: currentAlert, |
94 | | - RelatedAlerts: make([]schema.Alert, 0), |
95 | | - Classifications: make([]string, 0), |
96 | | - } |
97 | | - |
98 | | - historicalResponses, err := GetRelatedAlerts() |
| 105 | +func FindRelatedAlerts(current schema.Alert) (*AlertCorrelation, error) { |
| 106 | + alerts, err := GetRelatedAlerts(current.Name) |
99 | 107 | if err != nil { |
100 | 108 | return nil, err |
101 | 109 | } |
102 | 110 |
|
103 | | - for _, hist := range historicalResponses { |
104 | | - if isAlertRelated(currentAlert, hist) { |
105 | | - correlation.RelatedAlerts = append(correlation.RelatedAlerts, hist) |
106 | | - |
107 | | - classification := "This alert has not been classified" |
108 | | - if len(hist.Tags) > 0 { |
109 | | - classification = strings.Join(hist.Tags, ", ") |
| 111 | + corr := &AlertCorrelation{CurrentAlert: current} |
| 112 | + for _, hist := range alerts { |
| 113 | + if hist.ID == current.ID { |
| 114 | + continue |
| 115 | + } |
| 116 | + if related, matches := isAlertRelated(current, hist); related { |
| 117 | + classif := getAlertClassification(hist) |
| 118 | + for _, m := range matches { |
| 119 | + incrementCount(&corr.Counts, m, classif) |
110 | 120 | } |
111 | | - correlation.Classifications = append(correlation.Classifications, classification) |
| 121 | + corr.RelatedAlerts = append(corr.RelatedAlerts, hist) |
112 | 122 | } |
113 | 123 | } |
114 | | - |
115 | | - utils.Logger.Info("Completed related alerts search. Found %d related alerts for ID: %s", |
116 | | - len(correlation.RelatedAlerts), currentAlert.ID) |
117 | | - |
118 | | - return correlation, nil |
| 124 | + return corr, nil |
119 | 125 | } |
120 | 126 |
|
121 | | -func isAlertRelated(current, historical schema.Alert) bool { |
122 | | - if current.ID == historical.ID { |
123 | | - return false |
| 127 | +func isAlertRelated(current, historical schema.Alert) (bool, []string) { |
| 128 | + if current.ID == historical.ID || current.Name != historical.Name { |
| 129 | + return false, nil |
124 | 130 | } |
125 | 131 |
|
126 | | - if current.Destination.IP != "" && current.Destination.IP == historical.Destination.IP { |
127 | | - return true |
| 132 | + var matches []string |
| 133 | + |
| 134 | + if current.Source.IP != "" && current.Source.IP == historical.Source.IP { |
| 135 | + matches = append(matches, "SourceIP") |
128 | 136 | } |
129 | | - if current.Destination.Port != 0 && current.Destination.Port == historical.Destination.Port { |
130 | | - return true |
| 137 | + if current.Destination.IP != "" && current.Destination.IP == historical.Destination.IP { |
| 138 | + matches = append(matches, "DestinationIP") |
131 | 139 | } |
132 | | - if current.Destination.Host != "" && current.Destination.Host == historical.Destination.Host { |
133 | | - return true |
| 140 | + if current.Source.User != "" && current.Source.User == historical.Source.User { |
| 141 | + matches = append(matches, "SourceUser") |
134 | 142 | } |
135 | 143 | if current.Destination.User != "" && current.Destination.User == historical.Destination.User { |
136 | | - return true |
| 144 | + matches = append(matches, "DestinationUser") |
137 | 145 | } |
138 | 146 |
|
139 | | - if current.Source.IP != "" && current.Source.IP == historical.Source.IP { |
140 | | - return true |
| 147 | + sort.Strings(matches) |
| 148 | + return len(matches) > 0, matches |
| 149 | +} |
| 150 | + |
| 151 | +func getAlertClassification(alert schema.Alert) string { |
| 152 | + if len(alert.Tags) == 0 { |
| 153 | + return "Unclassified alert" |
141 | 154 | } |
142 | | - if current.Source.Port != 0 && current.Source.Port == historical.Source.Port { |
143 | | - return true |
| 155 | + switch strings.ToLower(alert.Tags[0]) { |
| 156 | + case "possible incident": |
| 157 | + return "Possible incident" |
| 158 | + case "false positive": |
| 159 | + return "False positive" |
| 160 | + case "standard alert": |
| 161 | + return "Standard alert" |
| 162 | + default: |
| 163 | + return "Unclassified alert" |
144 | 164 | } |
145 | | - if current.Source.Host != "" && current.Source.Host == historical.Source.Host { |
146 | | - return true |
| 165 | +} |
| 166 | + |
| 167 | +func incrementCount(cnts *MatchTypeCounts, matchType, classif string) { |
| 168 | + var ac *AlertCounts |
| 169 | + |
| 170 | + switch matchType { |
| 171 | + case "SourceIP": |
| 172 | + ac = &cnts.SourceIP |
| 173 | + case "DestinationIP": |
| 174 | + ac = &cnts.DestinationIP |
| 175 | + case "SourceUser": |
| 176 | + ac = &cnts.SourceUser |
| 177 | + case "DestinationUser": |
| 178 | + ac = &cnts.DestinationUser |
147 | 179 | } |
148 | | - if current.Source.User != "" && current.Source.User == historical.Source.User { |
149 | | - return true |
| 180 | + switch classif { |
| 181 | + case "Possible incident": |
| 182 | + ac.Incidents++ |
| 183 | + case "False positive": |
| 184 | + ac.FalsePositive++ |
| 185 | + case "Standard Alert": |
| 186 | + ac.Standard++ |
| 187 | + default: |
| 188 | + ac.Unclassified++ |
150 | 189 | } |
151 | | - |
152 | | - return false |
153 | 190 | } |
154 | 191 |
|
155 | | -func BuildCorrelationContext(correlation *AlertCorrelation) string { |
156 | | - var context strings.Builder |
157 | | - |
158 | | - context.WriteString("\nHistorical Context:\n") |
159 | | - context.WriteString(fmt.Sprintf("Found %d related alerts with similar characteristics:\n", len(correlation.RelatedAlerts))) |
160 | | - |
161 | | - for i, alert := range correlation.RelatedAlerts { |
162 | | - context.WriteString(fmt.Sprintf("\nRelated Alert %d:\n", i+1)) |
163 | | - context.WriteString(fmt.Sprintf("- Name: %s\n", alert.Name)) |
164 | | - context.WriteString(fmt.Sprintf("- Severity: %s\n", alert.SeverityLabel)) |
165 | | - context.WriteString(fmt.Sprintf("- Category: %s\n", alert.Category)) |
166 | | - |
167 | | - classification := "This alert has not been classified" |
168 | | - if i < len(correlation.Classifications) { |
169 | | - classification = correlation.Classifications[i] |
170 | | - } |
171 | | - context.WriteString(fmt.Sprintf("- Classification: %s\n", classification)) |
172 | | - |
173 | | - context.WriteString(fmt.Sprintf("- Time: %s\n", alert.Timestamp)) |
174 | | - |
175 | | - if alert.Source.IP != "" { |
176 | | - context.WriteString(fmt.Sprintf("- Source IP: %s\n", alert.Source.IP)) |
177 | | - } |
178 | | - if alert.Destination.IP != "" { |
179 | | - context.WriteString(fmt.Sprintf("- Destination IP: %s\n", alert.Destination.IP)) |
180 | | - } |
181 | | - if alert.Source.Host != "" { |
182 | | - context.WriteString(fmt.Sprintf("- Source Host: %s\n", alert.Source.Host)) |
183 | | - } |
184 | | - if alert.Destination.Host != "" { |
185 | | - context.WriteString(fmt.Sprintf("- Destination Host: %s\n", alert.Destination.Host)) |
186 | | - } |
187 | | - if alert.Source.User != "" { |
188 | | - context.WriteString(fmt.Sprintf("- Source User: %s\n", alert.Source.User)) |
| 192 | +func BuildCorrelationContext(corr *AlertCorrelation) string { |
| 193 | + if corr == nil || len(corr.RelatedAlerts) == 0 { |
| 194 | + return "No related alerts exist" |
| 195 | + } |
| 196 | + // Group alerts by matches and classifications |
| 197 | + // Example: "SourceIP+DestinationIP" -> { "Possible incident": 2, "False positive": 1 } |
| 198 | + groups := make(map[string]map[string]int) |
| 199 | + for _, alert := range corr.RelatedAlerts { |
| 200 | + if rel, mts := isAlertRelated(corr.CurrentAlert, alert); rel { |
| 201 | + key := strings.Join(mts, "+") |
| 202 | + if _, ok := groups[key]; !ok { |
| 203 | + groups[key] = make(map[string]int) |
| 204 | + } |
| 205 | + classif := getAlertClassification(alert) |
| 206 | + groups[key][classif]++ |
189 | 207 | } |
190 | | - if alert.Destination.User != "" { |
191 | | - context.WriteString(fmt.Sprintf("- Destination User: %s\n", alert.Destination.User)) |
| 208 | + } |
| 209 | + // Ordered summary |
| 210 | + var sb strings.Builder |
| 211 | + total := len(corr.RelatedAlerts) |
| 212 | + sb.WriteString("\nHistorical Context: ") |
| 213 | + sb.WriteString(fmt.Sprintf("In the past, there are %d alerts with the same name", total)) |
| 214 | + |
| 215 | + // Ordered keys |
| 216 | + keys := make([]string, 0, len(groups)) |
| 217 | + for k := range groups { |
| 218 | + keys = append(keys, k) |
| 219 | + } |
| 220 | + sort.Strings(keys) |
| 221 | + |
| 222 | + for _, k := range keys { |
| 223 | + sub := groups[k] |
| 224 | + // Count total for this group |
| 225 | + n := 0 |
| 226 | + for _, v := range sub { |
| 227 | + n += v |
192 | 228 | } |
193 | | - if alert.Source.Port != 0 { |
194 | | - context.WriteString(fmt.Sprintf("- Source Port: %d\n", alert.Source.Port)) |
| 229 | + sb.WriteString(fmt.Sprintf("\n- %d match the same %s", n, translateMatchTypes(strings.Split(k, "+")))) |
| 230 | + if n > 0 { |
| 231 | + sb.WriteString(" and of these " + formatClassifications(sub)) |
195 | 232 | } |
196 | | - if alert.Destination.Port != 0 { |
197 | | - context.WriteString(fmt.Sprintf("- Destination Port: %d\n", alert.Destination.Port)) |
| 233 | + } |
| 234 | + return sb.String() |
| 235 | +} |
| 236 | + |
| 237 | +var matchTypeNames = map[string]string{ |
| 238 | + "SourceIP": "Source IP", |
| 239 | + "DestinationIP": "Destination IP", |
| 240 | + "SourceUser": "Source User", |
| 241 | + "DestinationUser": "Destination User", |
| 242 | +} |
| 243 | + |
| 244 | +func translateMatchTypes(types []string) string { |
| 245 | + sort.Strings(types) |
| 246 | + var out []string |
| 247 | + for _, t := range types { |
| 248 | + if name, ok := matchTypeNames[t]; ok { |
| 249 | + out = append(out, name) |
198 | 250 | } |
199 | 251 | } |
| 252 | + return strings.Join(out, " and ") |
| 253 | +} |
200 | 254 |
|
201 | | - return context.String() |
| 255 | +func formatClassifications(m map[string]int) string { |
| 256 | + parts := make([]string, 0, len(m)) |
| 257 | + for classif, cnt := range m { |
| 258 | + parts = append(parts, fmt.Sprintf("%d were classified as %s", cnt, classif)) |
| 259 | + } |
| 260 | + sort.Strings(parts) |
| 261 | + return strings.Join(parts, ", ") |
202 | 262 | } |
0 commit comments