diff --git a/.github/.testcoverage.yml b/.github/.testcoverage.yml index f101a5e..e318cc7 100644 --- a/.github/.testcoverage.yml +++ b/.github/.testcoverage.yml @@ -7,3 +7,4 @@ threshold: exclude: paths: - cmd/inboundparse/main.go + - internal/auth/test_utils.go diff --git a/README.md b/README.md index a8e118f..034a3c7 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,9 @@ A SMTP server that receives emails from any domain without authentication and fo ### Email Authentication (RFC Compliant) - 🧐 **SPF Validation (RFC 7208)**: Validates Sender Policy Framework with HELO and envelope sender - 🖊️ **DKIM Validation (RFC 6376)**: Verifies DomainKeys Identified Mail signatures with multi-signature support -- 🕵️‍♂️ **DMARC Validation (RFC 7489)**: Evaluates DMARC policy with identifier alignment checking +- 🕵️‍♂️ **DMARC Validation (RFC 7489)**: Evaluates DMARC policy with hierarchical domain lookup and subdomain policy inheritance - 📝 **Comprehensive Results**: Detailed authentication results with domain, mechanism, and alignment data +- 🔍 **Domain Hierarchy Tracking**: Tracks all attempted DMARC lookups with detailed per-domain results ### Observability & Monitoring - 📋 **Structured Logging**: JSON-formatted logs with configurable levels @@ -227,16 +228,34 @@ The service sends a JSON payload to your webhook with comprehensive email data: "dmarc": { "result": "pass", "policy": "reject", - "subdomain_policy": "reject", + "subdomain_policy": "quarantine", "percentage": 100, - "spf_alignment": "pass (relaxed)", - "dkim_alignment": "pass (relaxed)" + "spf_aligned": true, + "dkim_aligned": true, + "spf_domain": "example.com", + "dkim_domain": "example.com", + "details": [ + { + "domain": "sub.example.com", + "record_found": false, + "error": "dmarc: no policy found for domain" + }, + { + "domain": "example.com", + "record_found": true, + "policy_used": "sp" + } + ], + "spf_alignment": "relaxed", + "dkim_alignment": "relaxed", + "failure_options": "0", + "report_uris": ["mailto:dmarc@example.com"], + "failure_uris": ["mailto:dmarc-fail@example.com"] } } } ``` ---- ## 🤝 Contributing diff --git a/_data/grafana/dashboards/authentication-dashboard.json b/_data/grafana/dashboards/authentication-dashboard.json index 8191037..9949daa 100644 --- a/_data/grafana/dashboards/authentication-dashboard.json +++ b/_data/grafana/dashboards/authentication-dashboard.json @@ -658,6 +658,413 @@ ], "title": "DMARC Results Distribution", "type": "piechart" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 10, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(spf_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "SPF P95", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.50, rate(spf_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "SPF P50", + "refId": "B" + }, + { + "expr": "histogram_quantile(0.95, rate(dkim_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "DKIM P95", + "refId": "C" + }, + { + "expr": "histogram_quantile(0.50, rate(dkim_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "DKIM P50", + "refId": "D" + }, + { + "expr": "histogram_quantile(0.95, rate(dmarc_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "DMARC P95", + "refId": "E" + }, + { + "expr": "histogram_quantile(0.50, rate(dmarc_check_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "DMARC P50", + "refId": "F" + } + ], + "title": "Authentication Check Duration (95th/50th Percentile)", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 32 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "topk(10, rate(spf_results_by_domain_total[5m]))", + "interval": "", + "legendFormat": "{{domain}} - {{result}}", + "refId": "A" + } + ], + "title": "Top SPF Results by Domain", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 40 + }, + "id": 12, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "topk(10, rate(dkim_results_by_domain_total[5m]))", + "interval": "", + "legendFormat": "{{domain}} - {{result}}", + "refId": "A" + } + ], + "title": "Top DKIM Results by Domain", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 40 + }, + "id": 13, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "topk(10, rate(dmarc_results_by_domain_total[5m]))", + "interval": "", + "legendFormat": "{{domain}} - {{result}}", + "refId": "A" + } + ], + "title": "Top DMARC Results by Domain", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + } + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 48 + }, + "id": 14, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + }, + "pieType": "pie", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "rate(auth_errors_total[5m])", + "interval": "", + "legendFormat": "{{check_type}} - {{error_type}}", + "refId": "A" + } + ], + "title": "Authentication Error Breakdown", + "type": "piechart" } ], "refresh": "5s", diff --git a/_data/grafana/dashboards/inboundparse-dashboard.json b/_data/grafana/dashboards/inboundparse-dashboard.json index 9ce970c..94a520a 100644 --- a/_data/grafana/dashboards/inboundparse-dashboard.json +++ b/_data/grafana/dashboards/inboundparse-dashboard.json @@ -825,6 +825,542 @@ ], "title": "Webhook Status Distribution", "type": "piechart" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(message_read_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Read P95", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.95, rate(message_parse_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Parse P95", + "refId": "B" + }, + { + "expr": "histogram_quantile(0.95, rate(message_auth_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Auth P95", + "refId": "C" + }, + { + "expr": "histogram_quantile(0.95, rate(message_webhook_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Webhook P95", + "refId": "D" + }, + { + "expr": "histogram_quantile(0.95, rate(message_total_processing_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Total P95", + "refId": "E" + } + ], + "title": "Message Processing Stage Durations (95th Percentile)", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 12, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "rate(email_body_type_total[5m])", + "interval": "", + "legendFormat": "{{body_type}}", + "refId": "A" + } + ], + "title": "Email Body Types", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 13, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(email_attachment_count_bucket[5m]))", + "interval": "", + "legendFormat": "Attachment Count P95", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.95, rate(email_attachment_size_bytes_bucket[5m]))", + "interval": "", + "legendFormat": "Attachment Size P95 (bytes)", + "refId": "B" + }, + { + "expr": "histogram_quantile(0.95, rate(email_header_count_bucket[5m]))", + "interval": "", + "legendFormat": "Header Count P95", + "refId": "C" + }, + { + "expr": "histogram_quantile(0.95, rate(email_recipient_count_bucket[5m]))", + "interval": "", + "legendFormat": "Recipient Count P95", + "refId": "D" + } + ], + "title": "Email Characteristics (95th Percentile)", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 32 + }, + "id": 14, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(session_duration_seconds_bucket[5m]))", + "interval": "", + "legendFormat": "Session Duration P95", + "refId": "A" + }, + { + "expr": "rate(session_commands_total[5m])", + "interval": "", + "legendFormat": "{{command}}/sec", + "refId": "B" + }, + { + "expr": "rate(session_errors_total[5m])", + "interval": "", + "legendFormat": "{{error_type}} errors/sec", + "refId": "C" + }, + { + "expr": "rate(session_messages_total[5m])", + "interval": "", + "legendFormat": "Messages/sec", + "refId": "D" + } + ], + "title": "Session Lifecycle Metrics", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 40 + }, + "id": 15, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(webhook_request_size_bytes_bucket[5m]))", + "interval": "", + "legendFormat": "Request Size P95", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.95, rate(webhook_response_size_bytes_bucket[5m]))", + "interval": "", + "legendFormat": "Response Size P95", + "refId": "B" + }, + { + "expr": "rate(webhook_retry_attempts_total[5m])", + "interval": "", + "legendFormat": "{{success}} retries/sec", + "refId": "C" + }, + { + "expr": "rate(webhook_errors_total[5m])", + "interval": "", + "legendFormat": "{{error_type}} errors/sec", + "refId": "D" + } + ], + "title": "Enhanced Webhook Metrics", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "vis": false + } + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 40 + }, + "id": 16, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + }, + "pieType": "pie", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "expr": "rate(webhook_errors_total[5m])", + "interval": "", + "legendFormat": "{{error_type}}", + "refId": "A" + } + ], + "title": "Webhook Error Types", + "type": "piechart" } ], "schemaVersion": 27, diff --git a/go.mod b/go.mod index 965b7a1..08596cc 100644 --- a/go.mod +++ b/go.mod @@ -18,8 +18,10 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.24.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 // indirect + github.com/failsafe-go/failsafe-go v0.9.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect diff --git a/go.sum b/go.sum index 09978b7..b1c1a9a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.24.0 h1:H4x4TuulnokZKvHLfzVRTHJfFfnHEeSYJizujEZvmAM= +github.com/bits-and-blooms/bitset v1.24.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -12,6 +14,8 @@ github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 h1:oP4q0fw+fOSWn3 github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-smtp v0.24.0 h1:g6AfoF140mvW0vLNPD/LuCBLEAdlxOjIXqbIkJIS6Wk= github.com/emersion/go-smtp v0.24.0/go.mod h1:ZtRRkbTyp2XTHCA+BmyTFTrj8xY4I+b4McvHxCU2gsQ= +github.com/failsafe-go/failsafe-go v0.9.1 h1:PkKSKLSOPRyJMjx35SfuwQeDuPLB6lBhD+zpQcSe7NU= +github.com/failsafe-go/failsafe-go v0.9.1/go.mod h1:sX5TZ4HrMLYSzErWeckIHRZWgZj9PbKMAEKOVLFWtfM= github.com/getsentry/sentry-go v0.36.0 h1:UkCk0zV28PiGf+2YIONSSYiYhxwlERE5Li3JPpZqEns= github.com/getsentry/sentry-go v0.36.0/go.mod h1:p5Im24mJBeruET8Q4bbcMfCQ+F+Iadc4L48tB1apo2c= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= @@ -60,6 +64,7 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= diff --git a/inboundparse b/inboundparse index 2b93343..b9e6c5c 100644 Binary files a/inboundparse and b/inboundparse differ diff --git a/internal/app/app.go b/internal/app/app.go index d792f03..a587ed6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -74,7 +74,17 @@ func NewApp(cfg *config.Config) (*App, error) { Username: cfg.WebhookUser, Password: cfg.WebhookPass, Timeout: 30 * time.Second, - }) + + // Retry configuration + MaxRetries: cfg.WebhookMaxRetries, + RetryDelay: cfg.WebhookRetryDelay, + MaxRetryDelay: cfg.WebhookMaxRetryDelay, + RetryMultiplier: cfg.WebhookRetryMultiplier, + + // Rate limiting configuration + RateLimitPerSecond: cfg.WebhookRateLimitPerSecond, + RateLimitBurst: cfg.WebhookRateLimitBurst, + }, logger, metrics) // Initialize message processor messageProcessor := processor.NewMessageProcessor( diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 48c8db9..5bec937 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -4,6 +4,7 @@ import ( "context" "inboundparse/internal/config" "inboundparse/internal/domain" + "inboundparse/internal/observability" "io" "net" "strings" @@ -43,15 +44,15 @@ func (m *mockSentryClient) RecoverWithSentry() {} type mockAuthChecker struct{} -func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger) (*domain.SPFResult, error) { return nil, nil } -func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger) (*domain.DKIMResult, error) { return nil, nil } -func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers interface{}, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { +func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers interface{}, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger) (*domain.DMARCResult, error) { return nil, nil } diff --git a/internal/auth/dkim.go b/internal/auth/dkim.go index 4861273..40452ef 100644 --- a/internal/auth/dkim.go +++ b/internal/auth/dkim.go @@ -4,10 +4,11 @@ import ( "context" "fmt" "inboundparse/internal/domain" + "inboundparse/internal/observability" "strings" + "time" "github.com/emersion/go-msgauth/dkim" - "github.com/rs/zerolog/log" ) // dkimChecker implements DKIMChecker interface @@ -18,11 +19,28 @@ func NewDKIMChecker() DKIMChecker { return &dkimChecker{} } -// CheckDKIM performs DKIM authentication check -func (d *dkimChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +// CheckDKIM performs DKIM authentication check following RFC 6376 +func (d *dkimChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DKIMResult, error) { + start := time.Now() + defer func() { + duration := time.Since(start) + metrics.TrackDKIMDuration(duration) + }() verifications, err := dkim.Verify(strings.NewReader(rawMessage)) if err != nil { - log.Warn(). + // Determine error type for metrics + errorType := "verification_error" + switch { + case strings.Contains(err.Error(), "timeout"): + errorType = "timeout_error" + case strings.Contains(err.Error(), "parse"): + errorType = "parse_error" + case strings.Contains(err.Error(), "DNS"): + errorType = "dns_error" + } + metrics.TrackAuthError("dkim", errorType) + + logger.Warn(). Err(err). Msg("DKIM verification error") return &domain.DKIMResult{ @@ -33,26 +51,57 @@ func (d *dkimChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain } var signatures []string + details := make([]domain.DKIMSignatureDetail, 0) // Initialize as empty slice, not nil hasValidSignature := false + // RFC 6376 Section 3.5: Parse each DKIM signature header for _, verification := range verifications { + detail := domain.DKIMSignatureDetail{ + Valid: verification.Err == nil, + } + if verification.Err != nil { + // Invalid signature signatures = append(signatures, fmt.Sprintf("Invalid: %v", verification.Err)) + detail.Error = verification.Err.Error() } else { + // Valid signature signatures = append(signatures, fmt.Sprintf("Valid: %s", verification.Domain)) + detail.Domain = verification.Domain hasValidSignature = true } + + // Extract available fields from go-msgauth Verification struct + detail.HeadersSigned = verification.HeaderKeys + if !verification.Time.IsZero() { + detail.Timestamp = verification.Time.Unix() + } + if !verification.Expiration.IsZero() { + detail.Expiration = verification.Expiration.Unix() + } + + details = append(details, detail) } - log.Debug(). + logger.Debug(). Bool("valid", hasValidSignature). Int("signature_count", len(verifications)). Strs("signatures", signatures). Msg("DKIM check completed") + // Track results by domain for each signature + for _, verification := range verifications { + if verification.Err == nil { + metrics.TrackDKIMResultByDomain(verification.Domain, "valid") + } else { + metrics.TrackDKIMResultByDomain(verification.Domain, "invalid") + } + } + return &domain.DKIMResult{ Valid: hasValidSignature, Signatures: signatures, + Details: details, Raw: fmt.Sprintf("DKIM verification found %d signatures", len(verifications)), }, nil } diff --git a/internal/auth/dkim_test.go b/internal/auth/dkim_test.go index 72296a4..1fd6cb1 100644 --- a/internal/auth/dkim_test.go +++ b/internal/auth/dkim_test.go @@ -3,8 +3,43 @@ package auth import ( "context" "testing" + "time" ) +type mockMetricsCollector struct{} + +func (m *mockMetricsCollector) TrackConnection() func() { return func() {} } +func (m *mockMetricsCollector) TrackSession() func() { return func() {} } +func (m *mockMetricsCollector) TrackMessage(size int64, success bool) {} +func (m *mockMetricsCollector) TrackSPFResult(result string) {} +func (m *mockMetricsCollector) TrackDKIMResult(valid bool) {} +func (m *mockMetricsCollector) TrackDMARCResult(result, policy string) {} +func (m *mockMetricsCollector) TrackWebhookRequest(statusCode int, duration time.Duration) {} +func (m *mockMetricsCollector) TrackWebhookRequestSize(requestSize, responseSize int64) {} +func (m *mockMetricsCollector) TrackWebhookRetry(success bool) {} +func (m *mockMetricsCollector) TrackWebhookError(errorType string) {} +func (m *mockMetricsCollector) TrackSPFDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDKIMDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDMARCDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSPFResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDKIMResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDMARCResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackAuthError(checkType, errorType string) {} +func (m *mockMetricsCollector) TrackMessageReadDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageParseDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageAuthDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageWebhookDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageTotalProcessingDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackEmailAttachmentCount(count int) {} +func (m *mockMetricsCollector) TrackEmailAttachmentSize(size int64) {} +func (m *mockMetricsCollector) TrackEmailBodyType(bodyType string) {} +func (m *mockMetricsCollector) TrackEmailHeaderCount(count int) {} +func (m *mockMetricsCollector) TrackEmailRecipientCount(count int) {} +func (m *mockMetricsCollector) TrackSessionDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSessionCommandCount(command string) {} +func (m *mockMetricsCollector) TrackSessionErrorCount(errorType string) {} +func (m *mockMetricsCollector) TrackSessionMessageCount() {} + func TestDKIMChecker_CheckDKIM_ValidMessage(t *testing.T) { checker := NewDKIMChecker() @@ -16,7 +51,7 @@ Date: Mon, 01 Jan 2024 12:00:00 +0000 This is a test message.` - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -26,12 +61,22 @@ This is a test message.` if result == nil { t.Error("Expected non-nil result") } + + // Test new RFC 6376 compliant fields + if result.Details == nil { + t.Error("Expected Details field to be present (should be empty slice, not nil)") + } + + // Details should be empty array when no signatures are found + if len(result.Details) != 0 { + t.Errorf("Expected empty Details array for message without signatures, got %d items", len(result.Details)) + } } func TestDKIMChecker_CheckDKIM_EmptyMessage(t *testing.T) { checker := NewDKIMChecker() - result, err := checker.CheckDKIM(context.Background(), "") + result, err := checker.CheckDKIM(context.Background(), "", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -47,7 +92,7 @@ func TestDKIMChecker_CheckDKIM_InvalidMessage(t *testing.T) { // Create an invalid message that might cause parsing issues rawMessage := "invalid message content" - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -70,7 +115,7 @@ DKIM-Signature: v=1; a=rsa-sha256; d=example.com; s=selector2; c=relaxed/relaxed This is a test message with multiple DKIM signatures.` - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -84,6 +129,11 @@ This is a test message with multiple DKIM signatures.` t.Error("Expected signatures in result") } + // Test new RFC 6376 compliant fields + if len(result.Details) == 0 { + t.Error("Expected signature details") + } + // Verify that we have both valid and invalid signatures hasValid := false hasInvalid := false @@ -101,6 +151,17 @@ This is a test message with multiple DKIM signatures.` if !hasValid && !hasInvalid { t.Error("Expected at least one signature to be processed") } + + // Test signature details structure - only check fields available from go-msgauth + for _, detail := range result.Details { + // Domain may be empty for invalid signatures, which is expected + // We just verify the structure is present and has the expected fields + if detail.Valid && detail.Domain == "" { + t.Error("Expected domain in valid signature detail") + } + // HeadersSigned, Timestamp, and Expiration may be empty depending on the signature + // We just verify the structure is present + } } func TestDKIMChecker_CheckDKIM_WithValidSignature(t *testing.T) { @@ -115,7 +176,7 @@ DKIM-Signature: v=1; a=rsa-sha256; d=example.com; s=selector; c=relaxed/relaxed; This is a test message with a valid DKIM signature.` - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -143,7 +204,7 @@ DKIM-Signature: v=1; a=rsa-sha256; d=example.com; s=selector; c=relaxed/relaxed; This is a test message with an invalid DKIM signature.` - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -170,7 +231,7 @@ DKIM-Signature: v=1; a=rsa-sha256; d=example.com; s=selector; c=relaxed/relaxed; This is a test message with a malformed DKIM signature.` - result, err := checker.CheckDKIM(context.Background(), rawMessage) + result, err := checker.CheckDKIM(context.Background(), rawMessage, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } diff --git a/internal/auth/dmarc.go b/internal/auth/dmarc.go index da99bb1..78b4deb 100644 --- a/internal/auth/dmarc.go +++ b/internal/auth/dmarc.go @@ -4,11 +4,12 @@ import ( "context" "fmt" "inboundparse/internal/domain" + "inboundparse/internal/observability" "strings" + "time" "github.com/emersion/go-msgauth/dmarc" "github.com/mnako/letters" - "github.com/rs/zerolog/log" ) // dmarcChecker implements DMARCChecker interface @@ -20,54 +21,113 @@ func NewDMARCChecker() DMARCChecker { } // CheckDMARC performs DMARC authentication check with hierarchical domain lookup following RFC 7489 -func (d *dmarcChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { - // Extract domain from From header (RFC5322.From) - this is what DMARC protects +func (d *dmarcChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DMARCResult, error) { + start := time.Now() + defer func() { + duration := time.Since(start) + metrics.TrackDMARCDuration(duration) + }() + + // Extract and validate From domain + fromDomain, result := d.extractFromDomain(headers, logger) + if result != nil { + return result, nil + } + + // Look up DMARC policy with domain hierarchy traversal + record, actualDomain, lookupDetails, err := d.lookupDMARCRecord(fromDomain) + if err != nil { + return d.handleLookupError(err, fromDomain, lookupDetails, logger, metrics), nil + } + + if record == nil { + return d.handleNoRecordFound(fromDomain, lookupDetails, logger), nil + } + + // Ensure we have valid SPF and DKIM results + spfResult, dkimResult = d.ensureValidResults(spfResult, dkimResult) + + // Check SPF and DKIM alignment + spfPass, spfAligned, spfDomain := d.checkSPFAlignment(spfResult, fromDomain, record) + dkimPass, dkimAligned, dkimDomain := d.checkDKIMAlignmentAndExtractDomain(dkimResult, fromDomain, record) + + // Determine which policy to apply and update lookup details + policyToApply, _ := d.determinePolicyToApply(record, actualDomain, fromDomain, lookupDetails) + + // Evaluate DMARC policy + dmarcResult := d.evaluateDMARCPolicy(policyToApply, spfPass, spfAligned, dkimPass, dkimAligned) + + // Create result message and log + rawMsg := d.createResultMessage(fromDomain, actualDomain, record, spfPass, spfAligned, dkimPass, dkimAligned, dmarcResult) + d.logDMARCCheck(logger, dmarcResult, fromDomain, actualDomain, record, spfPass, spfAligned, dkimPass, dkimAligned, spfDomain) + + // Build and return final result + return d.buildDMARCResult(record, actualDomain, policyToApply, dmarcResult, rawMsg, spfAligned, dkimAligned, spfDomain, dkimDomain, lookupDetails, metrics), nil +} + +// extractFromDomain extracts and validates the From domain from headers +func (d *dmarcChecker) extractFromDomain(headers letters.Headers, logger observability.Logger) (string, *domain.DMARCResult) { if len(headers.From) == 0 || headers.From[0] == nil || len(headers.From[0].Address) == 0 { - log.Debug().Msg("DMARC check: No From header found") - return &domain.DMARCResult{ + logger.Debug().Msg("DMARC check: No From header found") + return "", &domain.DMARCResult{ Result: "none", Raw: "No From header found", - }, nil + } } fromHeader := headers.From[0].Address // RFC 7489: The domain in the RFC5322.From field is the primary identifier fromDomain := ExtractDomainFromEmail(fromHeader) if fromDomain == "" { - log.Debug(). + logger.Debug(). Str("from_header", fromHeader). Msg("DMARC check: No domain found in From header") - return &domain.DMARCResult{ + return "", &domain.DMARCResult{ Result: "none", Raw: "No domain found in From header", - }, nil + } } - // Look up DMARC policy with domain hierarchy traversal - record, actualDomain, err := d.lookupDMARCRecord(fromDomain) - if err != nil { - log.Warn(). - Err(err). - Str("domain", fromDomain). - Msg("DMARC lookup error") - return &domain.DMARCResult{ - Result: "none", - Raw: fmt.Sprintf("DMARC lookup error: %v", err), - Error: err.Error(), - }, nil + return fromDomain, nil +} + +// handleLookupError handles DMARC lookup errors +func (d *dmarcChecker) handleLookupError(err error, fromDomain string, lookupDetails []domain.DMARCLookupDetail, logger observability.Logger, metrics observability.MetricsCollector) *domain.DMARCResult { + // Determine error type for metrics + errorType := "dns_error" + if strings.Contains(err.Error(), "timeout") { + errorType = "timeout_error" + } else if strings.Contains(err.Error(), "parse") { + errorType = "parse_error" } + metrics.TrackAuthError("dmarc", errorType) - if record == nil { - log.Debug(). - Str("domain", fromDomain). - Msg("DMARC check: No record found") - return &domain.DMARCResult{ - Result: "none", - Raw: fmt.Sprintf("No DMARC record found for domain %s or its parent domains", fromDomain), - }, nil + logger.Warn(). + Err(err). + Str("domain", fromDomain). + Msg("DMARC lookup error") + return &domain.DMARCResult{ + Result: "none", + Raw: fmt.Sprintf("DMARC lookup error: %v", err), + Error: err.Error(), + Details: lookupDetails, + } +} + +// handleNoRecordFound handles the case when no DMARC record is found +func (d *dmarcChecker) handleNoRecordFound(fromDomain string, lookupDetails []domain.DMARCLookupDetail, logger observability.Logger) *domain.DMARCResult { + logger.Debug(). + Str("domain", fromDomain). + Msg("DMARC check: No record found") + return &domain.DMARCResult{ + Result: "none", + Raw: fmt.Sprintf("No DMARC record found for domain %s or its parent domains", fromDomain), + Details: lookupDetails, } +} - // Use passed results or perform checks if not available +// ensureValidResults ensures SPF and DKIM results are valid +func (d *dmarcChecker) ensureValidResults(spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.SPFResult, *domain.DKIMResult) { if spfResult == nil { // This shouldn't happen in normal flow, but handle gracefully spfResult = &domain.SPFResult{Result: "none"} @@ -76,47 +136,93 @@ func (d *dmarcChecker) CheckDMARC(ctx context.Context, rawMessage string, header // This shouldn't happen in normal flow, but handle gracefully dkimResult = &domain.DKIMResult{Valid: false} } + return spfResult, dkimResult +} +// checkSPFAlignment checks SPF alignment and returns pass status, alignment, and domain +func (d *dmarcChecker) checkSPFAlignment(spfResult *domain.SPFResult, fromDomain string, record *dmarc.Record) (bool, bool, string) { // RFC 7489: SPF identifier alignment // Check if SPF passes and the MAIL FROM domain aligns with From header domain spfPass := spfResult.Result == "pass" - spfAligned := d.checkDomainAlignment(spfResult.Result, fromDomain, string(record.SPFAlignment)) + // Fix bug: Use SPF envelope sender domain for alignment check, not result string + spfDomain := ExtractDomainFromEmail(spfResult.Sender) + spfAligned := d.checkDomainAlignment(spfDomain, fromDomain, string(record.SPFAlignment)) + return spfPass, spfAligned, spfDomain +} +// checkDKIMAlignmentAndExtractDomain checks DKIM alignment and returns pass status, alignment, and domain +func (d *dmarcChecker) checkDKIMAlignmentAndExtractDomain(dkimResult *domain.DKIMResult, fromDomain string, record *dmarc.Record) (bool, bool, string) { // RFC 7489: DKIM identifier alignment // Check if any valid DKIM signature domain aligns with From header domain dkimPass := dkimResult.Valid dkimAligned := d.checkDKIMAlignment(dkimResult.Signatures, fromDomain, string(record.DKIMAlignment)) + dkimDomain := d.extractDKIMDomain(dkimResult.Signatures) + return dkimPass, dkimAligned, dkimDomain +} +// determinePolicyToApply determines which policy to apply based on domain inheritance +func (d *dmarcChecker) determinePolicyToApply(record *dmarc.Record, actualDomain, fromDomain string, lookupDetails []domain.DMARCLookupDetail) (dmarc.Policy, string) { + // Determine which policy to apply based on domain inheritance + // If record was found at a parent domain and original domain is a subdomain, use sp= if present + policyToApply := record.Policy + policyUsed := "p" + + // Check if we found the record at a parent domain (inherited) + if actualDomain != fromDomain { + // Record was inherited from parent domain + // Use subdomain policy (sp=) if present, otherwise fall back to main policy (p=) + if record.SubdomainPolicy != "" { + policyToApply = record.SubdomainPolicy + policyUsed = "sp" + } + } + + // Update lookup details to show which policy was used + for i := range lookupDetails { + if lookupDetails[i].Domain == actualDomain { + lookupDetails[i].PolicyUsed = policyUsed + break + } + } + + return policyToApply, policyUsed +} + +// evaluateDMARCPolicy evaluates the DMARC policy and returns the result +func (d *dmarcChecker) evaluateDMARCPolicy(policyToApply dmarc.Policy, spfPass, spfAligned, dkimPass, dkimAligned bool) string { // RFC 7489: DMARC policy evaluation // DMARC passes if either: // 1. SPF passes AND SPF identifier is aligned, OR // 2. DKIM passes AND at least one DKIM signature identifier is aligned - var result string dmarcPass := (spfPass && spfAligned) || (dkimPass && dkimAligned) - switch record.Policy { + switch policyToApply { case dmarc.PolicyNone: // RFC 7489: Policy "none" means monitor only - always report result but don't reject if dmarcPass { - result = "pass" - } else { - result = "fail" + return "pass" } + return "fail" case dmarc.PolicyQuarantine, dmarc.PolicyReject: // RFC 7489: For quarantine/reject policies, evaluate based on authentication results if dmarcPass { - result = "pass" - } else { - result = "fail" + return "pass" } + return "fail" default: - result = "none" + return "none" } +} - rawMsg := fmt.Sprintf("DMARC evaluation for domain %s (policy found at %s): policy=%s, SPF pass=%v aligned=%v, DKIM pass=%v aligned=%v, result=%s", +// createResultMessage creates the result message for logging +func (d *dmarcChecker) createResultMessage(fromDomain, actualDomain string, record *dmarc.Record, spfPass, spfAligned, dkimPass, dkimAligned bool, result string) string { + return fmt.Sprintf("DMARC evaluation for domain %s (policy found at %s): policy=%s, SPF pass=%v aligned=%v, DKIM pass=%v aligned=%v, result=%s", fromDomain, actualDomain, string(record.Policy), spfPass, spfAligned, dkimPass, dkimAligned, result) +} - log.Debug(). +// logDMARCCheck logs the DMARC check results +func (d *dmarcChecker) logDMARCCheck(logger observability.Logger, result, fromDomain, actualDomain string, record *dmarc.Record, spfPass, spfAligned, dkimPass, dkimAligned bool, spfDomain string) { + logger.Debug(). Str("result", result). Str("domain", fromDomain). Str("actual_domain", actualDomain). @@ -125,28 +231,78 @@ func (d *dmarcChecker) CheckDMARC(ctx context.Context, rawMessage string, header Bool("spf_aligned", spfAligned). Bool("dkim_pass", dkimPass). Bool("dkim_aligned", dkimAligned). + Str("spf_domain", spfDomain). Msg("DMARC check completed") +} + +// buildDMARCResult builds the final DMARC result +func (d *dmarcChecker) buildDMARCResult(record *dmarc.Record, actualDomain string, policyToApply dmarc.Policy, result, rawMsg string, spfAligned, dkimAligned bool, spfDomain, dkimDomain string, lookupDetails []domain.DMARCLookupDetail, metrics observability.MetricsCollector) *domain.DMARCResult { + // Handle percentage - use actual value from record or default to 100 + percentage := 100 + if record.Percent != nil { + percentage = *record.Percent + } + + // Convert failure options to string representation + failureOptions := d.convertFailureOptionsToString(record.FailureOptions) + + // Track result by domain + metrics.TrackDMARCResultByDomain(actualDomain, result) return &domain.DMARCResult{ - Result: result, - Raw: rawMsg, - Policy: string(record.Policy), - }, nil + Result: result, + Raw: rawMsg, + Policy: string(policyToApply), + Domain: actualDomain, + Percentage: percentage, + SubdomainPolicy: string(record.SubdomainPolicy), + SPFAligned: spfAligned, + DKIMAligned: dkimAligned, + SPFDomain: spfDomain, + DKIMDomain: dkimDomain, + Details: lookupDetails, + // RFC 7489 compliant additional fields + SPFAlignment: string(record.SPFAlignment), + DKIMAlignment: string(record.DKIMAlignment), + FailureOptions: failureOptions, + ReportURIs: record.ReportURIAggregate, + FailureURIs: record.ReportURIFailure, + } } // lookupDMARCRecord performs hierarchical DMARC record lookup per RFC 7489 -func (d *dmarcChecker) lookupDMARCRecord(domain string) (*dmarc.Record, string, error) { - currentDomain := domain +func (d *dmarcChecker) lookupDMARCRecord(domainName string) (*dmarc.Record, string, []domain.DMARCLookupDetail, error) { + currentDomain := domainName + var details []domain.DMARCLookupDetail for { record, err := dmarc.Lookup(currentDomain) + + // Create lookup detail for this attempt + detail := domain.DMARCLookupDetail{ + Domain: currentDomain, + } + if err == nil && record != nil { - return record, currentDomain, nil + // Found a record + detail.RecordFound = true + details = append(details, detail) + return record, currentDomain, details, nil } - // If there was an error other than "not found", return it - if err != nil && !d.isDMARCNotFoundError(err) { - return nil, "", err + // Record not found or error occurred + detail.RecordFound = false + if err != nil { + detail.Error = err.Error() + } + details = append(details, detail) + + // If the error is "dmarc: failed to lookup TXT record", continue as if not found + if err != nil && strings.Contains(err.Error(), "dmarc: failed to lookup TXT record") { + // continue down to parent domain + } else if err != nil && !d.isDMARCNotFoundError(err) { + // Early return for other errors (let outer logic handle) + return nil, "", details, err } // RFC 7489: Move up the domain hierarchy @@ -158,7 +314,7 @@ func (d *dmarcChecker) lookupDMARCRecord(domain string) (*dmarc.Record, string, currentDomain = nextDomain } - return nil, "", nil + return nil, "", details, nil } // getParentDomain returns the parent domain or empty string if at top level @@ -177,6 +333,7 @@ func (d *dmarcChecker) getParentDomain(domain string) string { func (d *dmarcChecker) isDMARCNotFoundError(err error) bool { errorMsg := strings.ToLower(err.Error()) return strings.Contains(errorMsg, "not found") || + strings.Contains(errorMsg, "no policy found") || strings.Contains(errorMsg, "no such host") || strings.Contains(errorMsg, "nxdomain") || strings.Contains(errorMsg, "no data") @@ -243,3 +400,40 @@ func (d *dmarcChecker) getOrganizationalDomain(domain string) string { // In production, use Public Suffix List to determine the actual organizational domain return strings.Join(parts[len(parts)-2:], ".") } + +// extractDKIMDomain extracts domain(s) from DKIM signatures for alignment reporting +func (d *dmarcChecker) extractDKIMDomain(signatures []string) string { + var domains []string + for _, sig := range signatures { + if strings.HasPrefix(sig, "Valid: ") { + domain := strings.TrimPrefix(sig, "Valid: ") + domains = append(domains, domain) + } + } + // Return comma-separated list of domains + return strings.Join(domains, ",") +} + +// convertFailureOptionsToString converts DMARC failure options to string representation per RFC 7489 +func (d *dmarcChecker) convertFailureOptionsToString(options dmarc.FailureOptions) string { + var parts []string + + if options&dmarc.FailureAll != 0 { + parts = append(parts, "0") + } + if options&dmarc.FailureAny != 0 { + parts = append(parts, "1") + } + if options&dmarc.FailureDKIM != 0 { + parts = append(parts, "d") + } + if options&dmarc.FailureSPF != 0 { + parts = append(parts, "s") + } + + if len(parts) == 0 { + return "0" // Default to "0" (all) if no options specified + } + + return strings.Join(parts, ":") +} diff --git a/internal/auth/dmarc_test.go b/internal/auth/dmarc_test.go index 93d3bbe..b19c1ea 100644 --- a/internal/auth/dmarc_test.go +++ b/internal/auth/dmarc_test.go @@ -41,7 +41,7 @@ func TestDMARCChecker_CheckDMARC_NoFromHeader(t *testing.T) { From: []*mail.Address{}, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -58,7 +58,7 @@ func TestDMARCChecker_CheckDMARC_EmptyFromHeader(t *testing.T) { From: []*mail.Address{nil}, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -77,7 +77,7 @@ func TestDMARCChecker_CheckDMARC_NoDomainInFrom(t *testing.T) { }, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -95,6 +95,8 @@ func TestDMARCChecker_IsDMARCNotFoundError(t *testing.T) { expected bool }{ {&netError{msg: "not found"}, true}, + {&netError{msg: "no policy found"}, true}, + {&netError{msg: "dmarc: no policy found for domain"}, true}, {&netError{msg: "no such host"}, true}, {&netError{msg: "nxdomain"}, true}, {&netError{msg: "no data"}, true}, @@ -127,7 +129,7 @@ func TestDMARCChecker_CheckDMARC_WithValidFromHeader(t *testing.T) { }, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -147,17 +149,43 @@ func TestDMARCChecker_CheckDMARC_WithSPFAndDKIMResults(t *testing.T) { }, } - spfResult := &domain.SPFResult{Result: "pass"} + spfResult := &domain.SPFResult{Result: "pass", Sender: "test@example.com"} dkimResult := &domain.DKIMResult{Valid: true} - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, spfResult, dkimResult) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, spfResult, dkimResult, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } - // Result should be "fail" since DMARC record will be found for example.com - if result.Result != "fail" { - t.Errorf("Expected result 'fail', got %s", result.Result) + // Result depends on actual DMARC policy for example.com + // The test should verify the result is not empty and check alignment + if result.Result == "" { + t.Error("Expected non-empty result") + } + + // Test new RFC 7489 compliant fields + if result.Domain == "" { + t.Error("Expected domain field to be populated") + } + + if result.SPFDomain != "example.com" { + t.Errorf("Expected SPF domain 'example.com', got %s", result.SPFDomain) + } + + // Test alignment flags + if !result.SPFAligned { + t.Error("Expected SPF to be aligned (same domain)") + } + + // Test additional RFC 7489 fields + if result.SPFAlignment == "" { + t.Error("Expected SPF alignment mode to be populated") + } + if result.DKIMAlignment == "" { + t.Error("Expected DKIM alignment mode to be populated") + } + if result.FailureOptions == "" { + t.Error("Expected failure options to be populated") } } @@ -171,7 +199,7 @@ func TestDMARCChecker_CheckDMARC_WithNilSPFResult(t *testing.T) { } // Test with nil SPF result (should be handled gracefully) - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -179,6 +207,15 @@ func TestDMARCChecker_CheckDMARC_WithNilSPFResult(t *testing.T) { if result.Result != "fail" { t.Errorf("Expected result 'fail', got %s", result.Result) } + + // Test new RFC 7489 compliant fields + if result.SPFDomain != "" { + t.Error("Expected empty SPF domain for nil SPF result") + } + + if result.SPFAligned { + t.Error("Expected SPF not to be aligned for nil SPF result") + } } func TestDMARCChecker_CheckDMARC_WithNilDKIMResult(t *testing.T) { @@ -193,7 +230,7 @@ func TestDMARCChecker_CheckDMARC_WithNilDKIMResult(t *testing.T) { spfResult := &domain.SPFResult{Result: "pass"} // Test with nil DKIM result (should be handled gracefully) - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, spfResult, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, spfResult, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -372,13 +409,49 @@ func TestDMARCChecker_CheckDMARC_WithComplexFromHeader(t *testing.T) { }, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } - if result.Result != "none" { - t.Errorf("Expected result 'none', got %s", result.Result) + // The result depends on whether a DMARC record is found for sub.example.com + // If found, it will be evaluated (could be "fail", "pass", etc.) + // If not found, it will be "none" + // Both are valid outcomes + if result.Result == "" { + t.Error("Expected non-empty result") + } +} + +func TestDMARCChecker_CheckDMARC_AlignmentBugFix(t *testing.T) { + checker := NewDMARCChecker() + + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@example.com"}, + }, + } + + // Test the alignment bug fix: SPF result with different sender domain + spfResult := &domain.SPFResult{ + Result: "pass", + Sender: "test@different.com", // Different domain from From header + } + dkimResult := &domain.DKIMResult{Valid: true} + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, spfResult, dkimResult, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Test that SPF alignment is correctly calculated + if result.SPFDomain != "different.com" { + t.Errorf("Expected SPF domain 'different.com', got %s", result.SPFDomain) + } + + // SPF should not be aligned since domains are different + if result.SPFAligned { + t.Error("Expected SPF not to be aligned (different domains)") } } @@ -393,7 +466,7 @@ func TestDMARCChecker_CheckDMARC_WithMultipleFromAddresses(t *testing.T) { }, } - result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil) + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -402,3 +475,184 @@ func TestDMARCChecker_CheckDMARC_WithMultipleFromAddresses(t *testing.T) { t.Errorf("Expected result 'fail', got %s", result.Result) } } + +func TestDMARCChecker_CheckDMARC_WithLookupDetails(t *testing.T) { + checker := NewDMARCChecker() + + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@example.com"}, + }, + } + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Test that Details field is populated + if result.Details == nil { + t.Error("Expected Details field to be populated") + } + + // Should have at least one lookup attempt + if len(result.Details) == 0 { + t.Error("Expected at least one lookup attempt in Details") + } + + // Check that the first detail has the correct domain + if len(result.Details) > 0 { + if result.Details[0].Domain != "example.com" { + t.Errorf("Expected first lookup domain to be 'example.com', got %s", result.Details[0].Domain) + } + } +} + +func TestDMARCChecker_CheckDMARC_SubdomainPolicyInheritance(t *testing.T) { + checker := NewDMARCChecker() + + // Test with subdomain that should inherit from parent + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@sub.example.com"}, + }, + } + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Test that Details field is populated + if result.Details == nil { + t.Error("Expected Details field to be populated") + } + + // Should have at least one lookup attempt + if len(result.Details) == 0 { + t.Error("Expected at least one lookup attempt") + } + + // First attempt should be at subdomain + if result.Details[0].Domain != "sub.example.com" { + t.Errorf("Expected first lookup at 'sub.example.com', got %s", result.Details[0].Domain) + } + + // If record was found at subdomain, that's fine - it means sub.example.com has its own DMARC record + // If not found, it should continue to parent domain + if !result.Details[0].RecordFound && len(result.Details) > 1 { + // Should show that record was not found at subdomain + if result.Details[0].RecordFound { + t.Error("Expected no record found at subdomain") + } + + // Last attempt should be at parent domain + lastDetail := result.Details[len(result.Details)-1] + if lastDetail.Domain != "example.com" { + t.Errorf("Expected last lookup at 'example.com', got %s", lastDetail.Domain) + } + } +} + +func TestDMARCChecker_CheckDMARC_PolicyUsedTracking(t *testing.T) { + checker := NewDMARCChecker() + + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@example.com"}, + }, + } + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Find the detail where record was found + var foundDetail *domain.DMARCLookupDetail + for i := range result.Details { + if result.Details[i].RecordFound { + foundDetail = &result.Details[i] + break + } + } + + if foundDetail == nil { + t.Error("Expected to find a detail with RecordFound=true") + } else { + // Should show which policy was used + if foundDetail.PolicyUsed == "" { + t.Error("Expected PolicyUsed to be set for found record") + } + } +} + +func TestDMARCChecker_CheckDMARC_ErrorTracking(t *testing.T) { + checker := NewDMARCChecker() + + // Test with a domain that might cause DNS errors + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@nonexistent-domain-that-should-fail.com"}, + }, + } + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Should have Details even if no record found + if result.Details == nil { + t.Error("Expected Details field to be populated even when no record found") + } + + // Should have at least one lookup attempt + if len(result.Details) == 0 { + t.Error("Expected at least one lookup attempt") + } +} + +func TestDMARCChecker_CheckDMARC_NoPolicyFoundHierarchy(t *testing.T) { + checker := NewDMARCChecker() + + // Test with a subdomain that should traverse hierarchy + // This simulates the real-world scenario where t.francispbaker.com has no DMARC record + // but francispbaker.com might have one + headers := letters.Headers{ + From: []*mail.Address{ + {Address: "test@t.francispbaker.com"}, + }, + } + + result, err := checker.CheckDMARC(context.Background(), "raw message", headers, nil, nil, &MockLogger{}, &mockMetricsCollector{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Should have Details showing hierarchy traversal + if result.Details == nil { + t.Error("Expected Details field to be populated") + } + + // Should have at least 2 lookup attempts (t.francispbaker.com -> francispbaker.com) + if len(result.Details) < 2 { + t.Errorf("Expected at least 2 lookup attempts for hierarchy traversal, got %d", len(result.Details)) + } + + // First attempt should be at the original subdomain + if result.Details[0].Domain != "t.francispbaker.com" { + t.Errorf("Expected first lookup at 't.francispbaker.com', got %s", result.Details[0].Domain) + } + + // Should show that no record was found at subdomain + if result.Details[0].RecordFound { + t.Error("Expected no record found at subdomain") + } + + // Last attempt should be at parent domain + lastDetail := result.Details[len(result.Details)-1] + if lastDetail.Domain != "francispbaker.com" { + t.Errorf("Expected last lookup at 'francispbaker.com', got %s", lastDetail.Domain) + } +} diff --git a/internal/auth/interfaces.go b/internal/auth/interfaces.go index aedf47b..afa35f7 100644 --- a/internal/auth/interfaces.go +++ b/internal/auth/interfaces.go @@ -3,6 +3,7 @@ package auth import ( "context" "inboundparse/internal/domain" + "inboundparse/internal/observability" "net" "github.com/mnako/letters" @@ -10,17 +11,17 @@ import ( // SPFChecker interface for SPF validation type SPFChecker interface { - CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) + CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.SPFResult, error) } // DKIMChecker interface for DKIM validation type DKIMChecker interface { - CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) + CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DKIMResult, error) } // DMARCChecker interface for DMARC validation type DMARCChecker interface { - CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) + CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DMARCResult, error) } // AuthChecker combines all authentication checkers @@ -56,25 +57,25 @@ type authChecker struct { } // CheckSPF performs SPF authentication check -func (a *authChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +func (a *authChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.SPFResult, error) { if !a.config.EnableSPF { return nil, nil } - return a.spf.CheckSPF(ctx, remoteIP, domainName, from) + return a.spf.CheckSPF(ctx, remoteIP, domainName, from, logger, metrics) } // CheckDKIM performs DKIM authentication check -func (a *authChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +func (a *authChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DKIMResult, error) { if !a.config.EnableDKIM { return nil, nil } - return a.dkim.CheckDKIM(ctx, rawMessage) + return a.dkim.CheckDKIM(ctx, rawMessage, logger, metrics) } // CheckDMARC performs DMARC authentication check -func (a *authChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { +func (a *authChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DMARCResult, error) { if !a.config.EnableDMARC { return nil, nil } - return a.dmarc.CheckDMARC(ctx, rawMessage, headers, spfResult, dkimResult) + return a.dmarc.CheckDMARC(ctx, rawMessage, headers, spfResult, dkimResult, logger, metrics) } diff --git a/internal/auth/interfaces_test.go b/internal/auth/interfaces_test.go index bdc25ca..45585f9 100644 --- a/internal/auth/interfaces_test.go +++ b/internal/auth/interfaces_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "inboundparse/internal/domain" + "inboundparse/internal/observability" "net" "testing" @@ -15,7 +16,7 @@ type mockSPFChecker struct { err error } -func (m *mockSPFChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +func (m *mockSPFChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger) (*domain.SPFResult, error) { return m.result, m.err } @@ -24,7 +25,7 @@ type mockDKIMChecker struct { err error } -func (m *mockDKIMChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +func (m *mockDKIMChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger) (*domain.DKIMResult, error) { return m.result, m.err } @@ -33,7 +34,7 @@ type mockDMARCChecker struct { err error } -func (m *mockDMARCChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { +func (m *mockDMARCChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger) (*domain.DMARCResult, error) { return m.result, m.err } @@ -48,7 +49,7 @@ func TestAuthChecker_CheckSPF_Enabled(t *testing.T) { remoteIP := net.ParseIP("192.168.1.1") // Test with a domain that should have SPF records - result, err := authChecker.CheckSPF(ctx, remoteIP, "google.com", "test@google.com") + result, err := authChecker.CheckSPF(ctx, remoteIP, "google.com", "test@google.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -72,7 +73,7 @@ func TestAuthChecker_CheckSPF_Disabled(t *testing.T) { ctx := context.Background() remoteIP := net.ParseIP("192.168.1.1") - result, err := authChecker.CheckSPF(ctx, remoteIP, "example.com", "test@example.com") + result, err := authChecker.CheckSPF(ctx, remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -91,7 +92,7 @@ func TestAuthChecker_CheckDKIM_Enabled(t *testing.T) { ctx := context.Background() - result, err := authChecker.CheckDKIM(ctx, "test message") + result, err := authChecker.CheckDKIM(ctx, "test message", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -112,7 +113,7 @@ func TestAuthChecker_CheckDKIM_Disabled(t *testing.T) { ctx := context.Background() - result, err := authChecker.CheckDKIM(ctx, "test message") + result, err := authChecker.CheckDKIM(ctx, "test message", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -134,7 +135,7 @@ func TestAuthChecker_CheckDMARC_Enabled(t *testing.T) { spfResult := &domain.SPFResult{Result: "pass"} dkimResult := &domain.DKIMResult{Valid: true} - result, err := authChecker.CheckDMARC(ctx, "test message", headers, spfResult, dkimResult) + result, err := authChecker.CheckDMARC(ctx, "test message", headers, spfResult, dkimResult, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -160,7 +161,7 @@ func TestAuthChecker_CheckDMARC_Disabled(t *testing.T) { spfResult := &domain.SPFResult{Result: "pass"} dkimResult := &domain.DKIMResult{Valid: true} - result, err := authChecker.CheckDMARC(ctx, "test message", headers, spfResult, dkimResult) + result, err := authChecker.CheckDMARC(ctx, "test message", headers, spfResult, dkimResult, &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } diff --git a/internal/auth/spf.go b/internal/auth/spf.go index 13f6bd7..d2ccba7 100644 --- a/internal/auth/spf.go +++ b/internal/auth/spf.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "inboundparse/internal/domain" + "inboundparse/internal/observability" "net" "strings" + "time" "github.com/rs/zerolog/log" "github.com/zaccone/spf" @@ -19,50 +21,86 @@ func NewSPFChecker() SPFChecker { return &spfChecker{} } -// CheckSPF performs SPF authentication check -func (s *spfChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +// CheckSPF performs SPF authentication check following RFC 7208 +func (s *spfChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.SPFResult, error) { + start := time.Now() + defer func() { + duration := time.Since(start) + metrics.TrackSPFDuration(duration) + }() if remoteIP == nil { - log.Warn().Msg("SPF check: Failed to get remote IP") + logger.Warn().Msg("SPF check: Failed to get remote IP") return &domain.SPFResult{ - Result: "temperror", - Raw: "Failed to get remote IP", + Result: "temperror", + Raw: "Failed to get remote IP", + Problem: "No remote IP address available for SPF evaluation", + Sender: from, + IpAddress: remoteIP.String(), }, nil } if domainName == "" { - log.Debug(). + logger.Debug(). Str("from", from). Msg("SPF check: No domain found in sender address") return &domain.SPFResult{ - Result: "none", - Raw: "No domain found in sender address", + Result: "none", + Raw: "No domain found in sender address", + Problem: "Unable to extract domain from MAIL FROM identity", + Sender: from, + IpAddress: remoteIP.String(), }, nil } + // RFC 7208 Section 2.4: Check MAIL FROM identity result, mechanism, err := spf.CheckHost(remoteIP, domainName, from) if err != nil { - log.Warn(). + // Determine error type for metrics + errorType := "dns_error" + if strings.Contains(err.Error(), "timeout") { + errorType = "timeout_error" + } else if strings.Contains(err.Error(), "parse") { + errorType = "parse_error" + } + metrics.TrackAuthError("spf", errorType) + + logger.Warn(). Err(err). Str("domain", domainName). Str("remote_ip", remoteIP.String()). + Str("identity", from). Msg("SPF check error") return &domain.SPFResult{ - Result: "temperror", - Raw: fmt.Sprintf("SPF check error: %v", err), + Result: "temperror", + Raw: fmt.Sprintf("SPF check error: %v", err), + Problem: fmt.Sprintf("DNS lookup or processing error: %v", err), + Sender: from, + IpAddress: remoteIP.String(), }, nil } - log.Debug(). + // Extract explanation if available (RFC 7208 Section 6.2) + explanation := s.extractSPFExplanation(domainName, result) + + logger.Debug(). Str("result", result.String()). Str("domain", domainName). Str("remote_ip", remoteIP.String()). Str("mechanism", mechanism). + Str("identity", from). + Str("explanation", explanation). Msg("SPF check completed") + // Track result by domain + metrics.TrackSPFResultByDomain(domainName, result.String()) + return &domain.SPFResult{ - Result: result.String(), - Raw: fmt.Sprintf("SPF result: %s for domain %s from IP %s", result.String(), domainName, remoteIP.String()), - Mechanism: mechanism, + Result: result.String(), + Raw: fmt.Sprintf("SPF result: %s for domain %s from IP %s", result.String(), domainName, remoteIP.String()), + Mechanism: mechanism, + Explanation: explanation, + Sender: from, + IpAddress: remoteIP.String(), }, nil } @@ -91,3 +129,18 @@ func ExtractDomainFromEmail(email string) string { return domain } + +// extractSPFExplanation attempts to extract explanation from SPF record (RFC 7208 Section 6.2) +func (s *spfChecker) extractSPFExplanation(domain string, result spf.Result) string { + // For now, return empty string as the spf library doesn't expose explanation + // In a full implementation, this would query the SPF record and extract the exp= tag + // when the result is fail, softfail, or neutral + resultStr := result.String() + if resultStr == "fail" || resultStr == "softfail" || resultStr == "neutral" { + log.Debug(). + Str("domain", domain). + Str("result", resultStr). + Msg("SPF explanation could be extracted from exp= tag in SPF record") + } + return "" +} diff --git a/internal/auth/spf_test.go b/internal/auth/spf_test.go index d492158..1e0e00c 100644 --- a/internal/auth/spf_test.go +++ b/internal/auth/spf_test.go @@ -23,7 +23,7 @@ func TestExtractDomainFromEmail(t *testing.T) { func TestSPFChecker_CheckSPF_NoRemoteIP(t *testing.T) { checker := NewSPFChecker() - result, err := checker.CheckSPF(context.Background(), nil, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), nil, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -31,13 +31,22 @@ func TestSPFChecker_CheckSPF_NoRemoteIP(t *testing.T) { if result.Result != "temperror" { t.Errorf("Expected result 'temperror', got %s", result.Result) } + + // Test new RFC 7208 compliant fields + if result.Sender != "test@example.com" { + t.Errorf("Expected sender 'test@example.com', got %s", result.Sender) + } + + if result.Problem == "" { + t.Error("Expected problem description for temperror result") + } } func TestSPFChecker_CheckSPF_NoDomain(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "", "test@") + result, err := checker.CheckSPF(context.Background(), remoteIP, "", "test@", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -45,13 +54,22 @@ func TestSPFChecker_CheckSPF_NoDomain(t *testing.T) { if result.Result != "none" { t.Errorf("Expected result 'none', got %s", result.Result) } + + // Test new RFC 7208 compliant fields + if result.Sender != "test@" { + t.Errorf("Expected sender 'test@', got %s", result.Sender) + } + + if result.Problem == "" { + t.Error("Expected problem description for none result") + } } func TestSPFChecker_CheckSPF_ValidDomain(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -60,6 +78,11 @@ func TestSPFChecker_CheckSPF_ValidDomain(t *testing.T) { if result.Result == "" { t.Error("Expected non-empty result") } + + // Test new RFC 7208 compliant fields + if result.Sender != "test@example.com" { + t.Errorf("Expected sender 'test@example.com', got %s", result.Sender) + } } func TestSPFChecker_CheckSPF_WithError(t *testing.T) { @@ -67,7 +90,7 @@ func TestSPFChecker_CheckSPF_WithError(t *testing.T) { // Test with a domain that might cause SPF lookup errors remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "nonexistent-domain-that-should-fail.com", "test@nonexistent-domain-that-should-fail.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "nonexistent-domain-that-should-fail.com", "test@nonexistent-domain-that-should-fail.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -83,7 +106,7 @@ func TestSPFChecker_CheckSPF_WithInvalidIP(t *testing.T) { // Test with an invalid IP address remoteIP := net.ParseIP("invalid-ip") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -99,7 +122,7 @@ func TestSPFChecker_CheckSPF_WithIPv6(t *testing.T) { // Test with IPv6 address remoteIP := net.ParseIP("2001:db8::1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -115,7 +138,7 @@ func TestSPFChecker_CheckSPF_WithLocalhost(t *testing.T) { // Test with localhost IP remoteIP := net.ParseIP("127.0.0.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -131,7 +154,7 @@ func TestSPFChecker_CheckSPF_WithPrivateIP(t *testing.T) { // Test with private IP address remoteIP := net.ParseIP("10.0.0.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "test@example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -146,7 +169,7 @@ func TestSPFChecker_CheckSPF_WithEmptyFrom(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -161,7 +184,7 @@ func TestSPFChecker_CheckSPF_WithMalformedFrom(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "malformed-email-address") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example.com", "malformed-email-address", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -176,7 +199,7 @@ func TestSPFChecker_CheckSPF_WithSpecialCharactersInDomain(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "example-with-dashes.com", "test@example-with-dashes.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "example-with-dashes.com", "test@example-with-dashes.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -191,7 +214,7 @@ func TestSPFChecker_CheckSPF_WithSubdomain(t *testing.T) { checker := NewSPFChecker() remoteIP := net.ParseIP("192.168.1.1") - result, err := checker.CheckSPF(context.Background(), remoteIP, "sub.example.com", "test@sub.example.com") + result, err := checker.CheckSPF(context.Background(), remoteIP, "sub.example.com", "test@sub.example.com", &MockLogger{}, &mockMetricsCollector{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } diff --git a/internal/auth/test_utils.go b/internal/auth/test_utils.go new file mode 100644 index 0000000..bbecb5c --- /dev/null +++ b/internal/auth/test_utils.go @@ -0,0 +1,42 @@ +package auth + +import ( + "io" + + "github.com/rs/zerolog" +) + +// MockLogger implements observability.Logger for testing +type MockLogger struct{} + +func (m *MockLogger) Debug() *zerolog.Event { + // Return a properly initialized zerolog event that discards output + logger := zerolog.New(io.Discard) + return logger.Debug() +} + +func (m *MockLogger) Info() *zerolog.Event { + logger := zerolog.New(io.Discard) + return logger.Info() +} + +func (m *MockLogger) Warn() *zerolog.Event { + logger := zerolog.New(io.Discard) + return logger.Warn() +} + +func (m *MockLogger) Error() *zerolog.Event { + logger := zerolog.New(io.Discard) + return logger.Error() +} + +func (m *MockLogger) Fatal() *zerolog.Event { + logger := zerolog.New(io.Discard) + return logger.Fatal() +} + +func (m *MockLogger) With() zerolog.Context { + // Return a no-op context for testing + logger := zerolog.New(io.Discard) + return logger.With() +} diff --git a/internal/config/config.go b/internal/config/config.go index 7a2abdb..9ebcf56 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,6 +28,16 @@ type Config struct { WebhookUser string WebhookPass string + // Webhook Retry Configuration + WebhookMaxRetries int + WebhookRetryDelay time.Duration + WebhookMaxRetryDelay time.Duration + WebhookRetryMultiplier float64 + + // Webhook Rate Limiting Configuration + WebhookRateLimitPerSecond int + WebhookRateLimitBurst int + // Authentication Configuration EnableSPF bool EnableDKIM bool @@ -166,21 +176,29 @@ func New(options ...Option) *Config { func LoadFromFlags() (*Config, error) { // Define flags var ( - listenAddr = flag.String("listen", DefaultListenAddr, "SMTP server listen address") - listenAddrTLS = flag.String("listen-tls", DefaultListenAddrTLS, "SMTP server listen address for TLS") - webhookURL = flag.String("webhook", "", "Webhook URL to send email data to") - webhookUser = flag.String("webhook-user", "", "Basic auth username for webhook") - webhookPass = flag.String("webhook-pass", "", "Basic auth password for webhook") - serverName = flag.String("name", DefaultServerName, "SMTP server name") - maxMessageSize = flag.Int64("max-size", DefaultMaxMessageSize, "Maximum message size in bytes") - readTimeout = flag.Int("read-timeout", int(DefaultReadTimeout.Seconds()), "Read timeout in seconds") - writeTimeout = flag.Int("write-timeout", int(DefaultWriteTimeout.Seconds()), "Write timeout in seconds") - enableSPF = flag.Bool("enable-spf", true, "Enable SPF validation") - enableDKIM = flag.Bool("enable-dkim", true, "Enable DKIM validation") - enableDMARC = flag.Bool("enable-dmarc", true, "Enable DMARC validation") - certFile = flag.String("cert-file", "", "TLS certificate file") - keyFile = flag.String("key-file", "", "TLS key file") - verbose = flag.Bool("verbose", false, "Enable verbose logging") + listenAddr = flag.String("listen", DefaultListenAddr, "SMTP server listen address") + listenAddrTLS = flag.String("listen-tls", DefaultListenAddrTLS, "SMTP server listen address for TLS") + webhookURL = flag.String("webhook", "", "Webhook URL to send email data to") + webhookUser = flag.String("webhook-user", "", "Basic auth username for webhook") + webhookPass = flag.String("webhook-pass", "", "Basic auth password for webhook") + // Webhook retry flags + webhookMaxRetries = flag.Int("webhook-max-retries", 3, "Maximum number of webhook retry attempts") + webhookRetryDelay = flag.Int("webhook-retry-delay", 1, "Initial webhook retry delay in seconds") + webhookMaxRetryDelay = flag.Int("webhook-max-retry-delay", 30, "Maximum webhook retry delay in seconds") + webhookRetryMultiplier = flag.Float64("webhook-retry-multiplier", 2.0, "Webhook retry delay multiplier") + // Webhook rate limiting flags + webhookRateLimitPerSecond = flag.Int("webhook-rate-limit", 10, "Webhook requests per second rate limit") + webhookRateLimitBurst = flag.Int("webhook-rate-burst", 20, "Webhook rate limit burst capacity") + serverName = flag.String("name", DefaultServerName, "SMTP server name") + maxMessageSize = flag.Int64("max-size", DefaultMaxMessageSize, "Maximum message size in bytes") + readTimeout = flag.Int("read-timeout", int(DefaultReadTimeout.Seconds()), "Read timeout in seconds") + writeTimeout = flag.Int("write-timeout", int(DefaultWriteTimeout.Seconds()), "Write timeout in seconds") + enableSPF = flag.Bool("enable-spf", true, "Enable SPF validation") + enableDKIM = flag.Bool("enable-dkim", true, "Enable DKIM validation") + enableDMARC = flag.Bool("enable-dmarc", true, "Enable DMARC validation") + certFile = flag.String("cert-file", "", "TLS certificate file") + keyFile = flag.String("key-file", "", "TLS key file") + verbose = flag.Bool("verbose", false, "Enable verbose logging") // Sentry flags enableSentry = flag.Bool("enable-sentry", false, "Enable Sentry error tracking") sentryDSN = flag.String("sentry-dsn", "", "Sentry DSN for error tracking") @@ -209,6 +227,12 @@ func LoadFromFlags() (*Config, error) { overrideFromEnv("WEBHOOK_URL", webhookURL) overrideFromEnv("WEBHOOK_USER", webhookUser) overrideFromEnv("WEBHOOK_PASS", webhookPass) + overrideIntFromEnv("WEBHOOK_MAX_RETRIES", webhookMaxRetries) + overrideIntFromEnv("WEBHOOK_RETRY_DELAY", webhookRetryDelay) + overrideIntFromEnv("WEBHOOK_MAX_RETRY_DELAY", webhookMaxRetryDelay) + overrideFloat64FromEnv("WEBHOOK_RETRY_MULTIPLIER", webhookRetryMultiplier) + overrideIntFromEnv("WEBHOOK_RATE_LIMIT", webhookRateLimitPerSecond) + overrideIntFromEnv("WEBHOOK_RATE_BURST", webhookRateLimitBurst) overrideFromEnv("SERVER_NAME", serverName) overrideFromEnv("CERT_FILE", certFile) overrideFromEnv("KEY_FILE", keyFile) @@ -243,6 +267,14 @@ func LoadFromFlags() (*Config, error) { WithVerbose(*verbose), ) + // Set webhook retry and rate limiting configuration + config.WebhookMaxRetries = *webhookMaxRetries + config.WebhookRetryDelay = time.Duration(*webhookRetryDelay) * time.Second + config.WebhookMaxRetryDelay = time.Duration(*webhookMaxRetryDelay) * time.Second + config.WebhookRetryMultiplier = *webhookRetryMultiplier + config.WebhookRateLimitPerSecond = *webhookRateLimitPerSecond + config.WebhookRateLimitBurst = *webhookRateLimitBurst + // Set additional fields config.ServerName = *serverName config.MaxMessageSize = *maxMessageSize @@ -284,6 +316,30 @@ func overrideBoolFromEnv(envKey string, target *bool) { // If env var is empty or not set, keep the original value } +// overrideIntFromEnv overrides an int flag from environment variable if present +func overrideIntFromEnv(envKey string, target *int) { + if env := os.Getenv(envKey); env != "" { + if val, err := fmt.Sscanf(env, "%d", target); err == nil && val == 1 { + // Successfully parsed - no action needed as target is already updated + return + } + // If parsing failed, keep the original value + } + // If env var is empty or not set, keep the original value +} + +// overrideFloat64FromEnv overrides a float64 flag from environment variable if present +func overrideFloat64FromEnv(envKey string, target *float64) { + if env := os.Getenv(envKey); env != "" { + if val, err := fmt.Sscanf(env, "%f", target); err == nil && val == 1 { + // Successfully parsed - no action needed as target is already updated + return + } + // If parsing failed, keep the original value + } + // If env var is empty or not set, keep the original value +} + // Validate checks if the configuration is valid func (c *Config) Validate() error { if c.WebhookURL == "" { diff --git a/internal/domain/email.go b/internal/domain/email.go index 492c880..a8f27f6 100644 --- a/internal/domain/email.go +++ b/internal/domain/email.go @@ -22,26 +22,64 @@ type AuthenticationResults struct { // SPFResult contains SPF validation results type SPFResult struct { - Result string `json:"result"` - Raw string `json:"raw"` - Mechanism string `json:"mechanism,omitempty"` - Qualifier string `json:"qualifier,omitempty"` + Result string `json:"result"` + Raw string `json:"raw"` + Mechanism string `json:"mechanism,omitempty"` + Qualifier string `json:"qualifier,omitempty"` + Explanation string `json:"explanation,omitempty"` // SPF explanation from DNS (RFC 7208 Section 6.2) + Problem string `json:"problem,omitempty"` // Description of any processing issues + Sender string `json:"sender,omitempty"` // The identity checked (MAIL FROM) + IpAddress string `json:"ip_address,omitempty"` // The IP address checked (from remote IP) +} + +// DKIMSignatureDetail contains detailed information about a single DKIM signature +// Only includes fields available from go-msgauth library +type DKIMSignatureDetail struct { + Domain string `json:"domain"` // d= tag (from verification.Domain) + HeadersSigned []string `json:"headers_signed"` // h= tag (from verification.HeaderKeys) + Timestamp int64 `json:"timestamp,omitempty"` // t= tag (from verification.Time) + Expiration int64 `json:"expiration,omitempty"` // x= tag (from verification.Expiration) + Valid bool `json:"valid"` + Error string `json:"error,omitempty"` +} + +// DMARCLookupDetail contains detailed information about a single DMARC lookup attempt +type DMARCLookupDetail struct { + Domain string `json:"domain"` // Domain where lookup was attempted + RecordFound bool `json:"record_found"` // Whether a record was found + PolicyUsed string `json:"policy_used,omitempty"` // Which policy was applied (p or sp) + Error string `json:"error,omitempty"` // Any error during lookup } // DKIMResult contains DKIM validation results type DKIMResult struct { - Valid bool `json:"valid"` - Signatures []string `json:"signatures,omitempty"` - Raw string `json:"raw"` - Error string `json:"error,omitempty"` + Valid bool `json:"valid"` + Signatures []string `json:"signatures,omitempty"` + Details []DKIMSignatureDetail `json:"details,omitempty"` // Per-signature structured data + Raw string `json:"raw"` + Error string `json:"error,omitempty"` } // DMARCResult contains DMARC validation results type DMARCResult struct { - Result string `json:"result"` - Raw string `json:"raw"` - Policy string `json:"policy,omitempty"` - Error string `json:"error,omitempty"` + Result string `json:"result"` + Raw string `json:"raw"` + Policy string `json:"policy,omitempty"` + Error string `json:"error,omitempty"` + Domain string `json:"domain,omitempty"` // Domain where policy was found + Percentage int `json:"percentage,omitempty"` // pct= tag value (RFC 7489 Section 6.3) + SubdomainPolicy string `json:"subdomain_policy,omitempty"` // sp= tag value + SPFAligned bool `json:"spf_aligned"` // Whether SPF identifier aligned + DKIMAligned bool `json:"dkim_aligned"` // Whether DKIM identifier aligned + SPFDomain string `json:"spf_domain,omitempty"` // Domain used for SPF alignment check + DKIMDomain string `json:"dkim_domain,omitempty"` // Domain(s) used for DKIM alignment + Details []DMARCLookupDetail `json:"details,omitempty"` // Per-domain lookup details + // Additional RFC 7489 compliant fields + SPFAlignment string `json:"spf_alignment,omitempty"` // aspf= tag value (RFC 7489 Section 6.3) + DKIMAlignment string `json:"dkim_alignment,omitempty"` // adkim= tag value (RFC 7489 Section 6.3) + FailureOptions string `json:"failure_options,omitempty"` // fo= tag value (RFC 7489 Section 6.3) + ReportURIs []string `json:"report_uris,omitempty"` // rua= tag values (RFC 7489 Section 6.3) + FailureURIs []string `json:"failure_uris,omitempty"` // ruf= tag values (RFC 7489 Section 6.3) } // Attachment represents an email attachment diff --git a/internal/observability/logger.go b/internal/observability/logger.go index e95c7fc..8fc144e 100644 --- a/internal/observability/logger.go +++ b/internal/observability/logger.go @@ -52,3 +52,48 @@ func NewLogger(config LoggerConfig) Logger { func DefaultLogger() Logger { return NewLogger(LoggerConfig{Verbose: false}) } + +// sessionLoggerWrapper wraps zerolog.Logger to implement observability.Logger interface +type sessionLoggerWrapper struct { + logger zerolog.Logger +} + +func (w *sessionLoggerWrapper) Debug() *zerolog.Event { + return w.logger.Debug() +} + +func (w *sessionLoggerWrapper) Info() *zerolog.Event { + return w.logger.Info() +} + +func (w *sessionLoggerWrapper) Warn() *zerolog.Event { + return w.logger.Warn() +} + +func (w *sessionLoggerWrapper) Error() *zerolog.Event { + return w.logger.Error() +} + +func (w *sessionLoggerWrapper) Fatal() *zerolog.Event { + return w.logger.Fatal() +} + +func (w *sessionLoggerWrapper) With() zerolog.Context { + return w.logger.With() +} + +// NewSessionLogger creates a session-aware logger with session_id and remote_addr context +func NewSessionLogger(baseLogger Logger, sessionID string, remoteAddr string) Logger { + // Extract the underlying zerolog.Logger from the base logger + // We need to create a new zerolog.Logger with session context + // Since we can't access the underlying zerolog.Logger directly from the interface, + // we'll create a new one with the session context + sessionLogger := zerolog.New(os.Stdout).With(). + Timestamp(). + Str("service", "inboundparse"). + Str("session_id", sessionID). + Str("remote_addr", remoteAddr). + Logger() + + return &sessionLoggerWrapper{logger: sessionLogger} +} diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go index 1897f52..9221a67 100644 --- a/internal/observability/metrics.go +++ b/internal/observability/metrics.go @@ -29,6 +29,34 @@ type MetricsCollector interface { TrackDKIMResult(valid bool) TrackDMARCResult(result, policy string) TrackWebhookRequest(statusCode int, duration time.Duration) + TrackWebhookRequestSize(requestSize, responseSize int64) + TrackWebhookRetry(success bool) + TrackWebhookError(errorType string) + TrackSPFDuration(duration time.Duration) + TrackDKIMDuration(duration time.Duration) + TrackDMARCDuration(duration time.Duration) + TrackSPFResultByDomain(domain, result string) + TrackDKIMResultByDomain(domain, result string) + TrackDMARCResultByDomain(domain, result string) + TrackAuthError(checkType, errorType string) + TrackMessageReadDuration(duration time.Duration) + TrackMessageParseDuration(duration time.Duration) + TrackMessageAuthDuration(duration time.Duration) + TrackMessageWebhookDuration(duration time.Duration) + TrackMessageTotalProcessingDuration(duration time.Duration) + + // Email characteristics metrics + TrackEmailAttachmentCount(count int) + TrackEmailAttachmentSize(size int64) + TrackEmailBodyType(bodyType string) + TrackEmailHeaderCount(count int) + TrackEmailRecipientCount(count int) + + // Session lifecycle metrics + TrackSessionDuration(duration time.Duration) + TrackSessionCommandCount(command string) + TrackSessionErrorCount(errorType string) + TrackSessionMessageCount() } // MetricsServer manages the Prometheus metrics HTTP server @@ -61,8 +89,45 @@ type metricsCollector struct { authChecksDMARC prometheus.Counter // Webhook metrics - webhookRequests *prometheus.CounterVec - webhookDuration prometheus.Histogram + webhookRequests *prometheus.CounterVec + webhookDuration prometheus.Histogram + webhookRequestSize prometheus.Histogram + webhookResponseSize prometheus.Histogram + webhookRetries *prometheus.CounterVec + webhookErrors *prometheus.CounterVec + + // Authentication duration metrics + spfDuration prometheus.Histogram + dkimDuration prometheus.Histogram + dmarcDuration prometheus.Histogram + + // Domain-based authentication metrics + spfResultsByDomain *prometheus.CounterVec + dkimResultsByDomain *prometheus.CounterVec + dmarcResultsByDomain *prometheus.CounterVec + + // Authentication error metrics + authErrors *prometheus.CounterVec + + // Message processing stage metrics + messageReadDuration prometheus.Histogram + messageParseDuration prometheus.Histogram + messageAuthDuration prometheus.Histogram + messageWebhookDuration prometheus.Histogram + messageTotalDuration prometheus.Histogram + + // Email characteristics metrics + emailAttachmentCount prometheus.Histogram + emailAttachmentSize prometheus.Histogram + emailBodyType *prometheus.CounterVec + emailHeaderCount prometheus.Histogram + emailRecipientCount prometheus.Histogram + + // Session lifecycle metrics + sessionDuration prometheus.Histogram + sessionCommands *prometheus.CounterVec + sessionErrors *prometheus.CounterVec + sessionMessageCount prometheus.Counter } // NewMetricsCollector creates a new metrics collector @@ -146,6 +211,133 @@ func NewMetricsCollector() MetricsCollector { Help: "Duration of webhook requests in seconds", Buckets: prometheus.DefBuckets, }), + webhookRequestSize: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "webhook_request_size_bytes", + Help: "Size of webhook request payloads in bytes", + Buckets: []float64{1024, 10240, 102400, 1048576, 10485760, 104857600}, + }), + webhookResponseSize: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "webhook_response_size_bytes", + Help: "Size of webhook response payloads in bytes", + Buckets: []float64{1024, 10240, 102400, 1048576, 10485760, 104857600}, + }), + webhookRetries: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "webhook_retry_attempts_total", + Help: "Total webhook retry attempts", + }, []string{"success"}), + webhookErrors: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "webhook_errors_total", + Help: "Total webhook errors by type", + }, []string{"error_type"}), + + // Authentication duration metrics + spfDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "spf_check_duration_seconds", + Help: "Duration of SPF authentication checks in seconds", + Buckets: prometheus.DefBuckets, + }), + dkimDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "dkim_check_duration_seconds", + Help: "Duration of DKIM authentication checks in seconds", + Buckets: prometheus.DefBuckets, + }), + dmarcDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "dmarc_check_duration_seconds", + Help: "Duration of DMARC authentication checks in seconds", + Buckets: prometheus.DefBuckets, + }), + + // Domain-based authentication metrics + spfResultsByDomain: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "spf_results_by_domain_total", + Help: "SPF authentication results by domain", + }, []string{"domain", "result"}), + dkimResultsByDomain: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "dkim_results_by_domain_total", + Help: "DKIM authentication results by domain", + }, []string{"domain", "result"}), + dmarcResultsByDomain: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "dmarc_results_by_domain_total", + Help: "DMARC authentication results by domain", + }, []string{"domain", "result"}), + + // Authentication error metrics + authErrors: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "auth_errors_total", + Help: "Authentication errors by check type and error type", + }, []string{"check_type", "error_type"}), + + // Message processing stage metrics + messageReadDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "message_read_duration_seconds", + Help: "Duration of message reading stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageParseDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "message_parse_duration_seconds", + Help: "Duration of message parsing stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageAuthDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "message_auth_duration_seconds", + Help: "Duration of message authentication stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageWebhookDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "message_webhook_duration_seconds", + Help: "Duration of message webhook stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageTotalDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "message_total_processing_duration_seconds", + Help: "Total duration of message processing in seconds", + Buckets: prometheus.DefBuckets, + }), + + // Email characteristics metrics + emailAttachmentCount: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "email_attachment_count", + Help: "Number of attachments in emails", + Buckets: []float64{0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 50, 100}, + }), + emailAttachmentSize: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "email_attachment_size_bytes", + Help: "Size of email attachments in bytes", + Buckets: prometheus.ExponentialBuckets(1024, 2, 15), // 1KB to 32MB + }), + emailBodyType: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "email_body_type_total", + Help: "Number of emails by body type (text, html, multipart)", + }, []string{"body_type"}), + emailHeaderCount: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "email_header_count", + Help: "Number of headers in emails", + Buckets: []float64{0, 5, 10, 15, 20, 25, 30, 40, 50, 75, 100, 150, 200}, + }), + emailRecipientCount: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "email_recipient_count", + Help: "Number of recipients in emails", + Buckets: []float64{0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 50, 100}, + }), + + // Session lifecycle metrics + sessionDuration: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "session_duration_seconds", + Help: "Duration of SMTP sessions in seconds", + Buckets: prometheus.DefBuckets, + }), + sessionCommands: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "session_commands_total", + Help: "Number of SMTP commands executed per session", + }, []string{"command"}), + sessionErrors: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "session_errors_total", + Help: "Number of SMTP session errors by type", + }, []string{"error_type"}), + sessionMessageCount: promauto.NewCounter(prometheus.CounterOpts{ + Name: "session_messages_total", + Help: "Number of messages processed per session", + }), } } @@ -213,6 +405,131 @@ func (m *metricsCollector) TrackWebhookRequest(statusCode int, duration time.Dur m.webhookDuration.Observe(duration.Seconds()) } +// TrackWebhookRequestSize records webhook request and response size metrics +func (m *metricsCollector) TrackWebhookRequestSize(requestSize, responseSize int64) { + m.webhookRequestSize.Observe(float64(requestSize)) + m.webhookResponseSize.Observe(float64(responseSize)) +} + +// TrackWebhookRetry records webhook retry attempts +func (m *metricsCollector) TrackWebhookRetry(success bool) { + status := "failure" + if success { + status = "success" + } + m.webhookRetries.WithLabelValues(status).Inc() +} + +// TrackWebhookError records webhook errors by type +func (m *metricsCollector) TrackWebhookError(errorType string) { + m.webhookErrors.WithLabelValues(errorType).Inc() +} + +// TrackSPFDuration records SPF check duration +func (m *metricsCollector) TrackSPFDuration(duration time.Duration) { + m.spfDuration.Observe(duration.Seconds()) +} + +// TrackDKIMDuration records DKIM check duration +func (m *metricsCollector) TrackDKIMDuration(duration time.Duration) { + m.dkimDuration.Observe(duration.Seconds()) +} + +// TrackDMARCDuration records DMARC check duration +func (m *metricsCollector) TrackDMARCDuration(duration time.Duration) { + m.dmarcDuration.Observe(duration.Seconds()) +} + +// TrackSPFResultByDomain records SPF results by domain +func (m *metricsCollector) TrackSPFResultByDomain(domain, result string) { + m.spfResultsByDomain.WithLabelValues(domain, result).Inc() +} + +// TrackDKIMResultByDomain records DKIM results by domain +func (m *metricsCollector) TrackDKIMResultByDomain(domain, result string) { + m.dkimResultsByDomain.WithLabelValues(domain, result).Inc() +} + +// TrackDMARCResultByDomain records DMARC results by domain +func (m *metricsCollector) TrackDMARCResultByDomain(domain, result string) { + m.dmarcResultsByDomain.WithLabelValues(domain, result).Inc() +} + +// TrackAuthError records authentication errors by check type and error type +func (m *metricsCollector) TrackAuthError(checkType, errorType string) { + m.authErrors.WithLabelValues(checkType, errorType).Inc() +} + +// TrackMessageReadDuration records message reading stage duration +func (m *metricsCollector) TrackMessageReadDuration(duration time.Duration) { + m.messageReadDuration.Observe(duration.Seconds()) +} + +// TrackMessageParseDuration records message parsing stage duration +func (m *metricsCollector) TrackMessageParseDuration(duration time.Duration) { + m.messageParseDuration.Observe(duration.Seconds()) +} + +// TrackMessageAuthDuration records message authentication stage duration +func (m *metricsCollector) TrackMessageAuthDuration(duration time.Duration) { + m.messageAuthDuration.Observe(duration.Seconds()) +} + +// TrackMessageWebhookDuration records message webhook stage duration +func (m *metricsCollector) TrackMessageWebhookDuration(duration time.Duration) { + m.messageWebhookDuration.Observe(duration.Seconds()) +} + +// TrackMessageTotalProcessingDuration records total message processing duration +func (m *metricsCollector) TrackMessageTotalProcessingDuration(duration time.Duration) { + m.messageTotalDuration.Observe(duration.Seconds()) +} + +// TrackEmailAttachmentCount records the number of attachments in an email +func (m *metricsCollector) TrackEmailAttachmentCount(count int) { + m.emailAttachmentCount.Observe(float64(count)) +} + +// TrackEmailAttachmentSize records the size of email attachments +func (m *metricsCollector) TrackEmailAttachmentSize(size int64) { + m.emailAttachmentSize.Observe(float64(size)) +} + +// TrackEmailBodyType records the body type of emails +func (m *metricsCollector) TrackEmailBodyType(bodyType string) { + m.emailBodyType.WithLabelValues(bodyType).Inc() +} + +// TrackEmailHeaderCount records the number of headers in an email +func (m *metricsCollector) TrackEmailHeaderCount(count int) { + m.emailHeaderCount.Observe(float64(count)) +} + +// TrackEmailRecipientCount records the number of recipients in an email +func (m *metricsCollector) TrackEmailRecipientCount(count int) { + m.emailRecipientCount.Observe(float64(count)) +} + +// TrackSessionDuration records the duration of SMTP sessions +func (m *metricsCollector) TrackSessionDuration(duration time.Duration) { + m.sessionDuration.Observe(duration.Seconds()) +} + +// TrackSessionCommandCount records the number of SMTP commands executed +func (m *metricsCollector) TrackSessionCommandCount(command string) { + m.sessionCommands.WithLabelValues(command).Inc() +} + +// TrackSessionErrorCount records the number of SMTP session errors by type +func (m *metricsCollector) TrackSessionErrorCount(errorType string) { + m.sessionErrors.WithLabelValues(errorType).Inc() +} + +// TrackSessionMessageCount records the number of messages processed per session +func (m *metricsCollector) TrackSessionMessageCount() { + m.sessionMessageCount.Inc() +} + // metricsServer implements MetricsServer interface type metricsServer struct { server *http.Server @@ -343,3 +660,75 @@ func (n *NoOpMetricsCollector) TrackDMARCResult(result, policy string) {} // TrackWebhookRequest does nothing func (n *NoOpMetricsCollector) TrackWebhookRequest(status int, duration time.Duration) {} + +// TrackWebhookRequestSize does nothing +func (n *NoOpMetricsCollector) TrackWebhookRequestSize(requestSize, responseSize int64) {} + +// TrackWebhookRetry does nothing +func (n *NoOpMetricsCollector) TrackWebhookRetry(success bool) {} + +// TrackWebhookError does nothing +func (n *NoOpMetricsCollector) TrackWebhookError(errorType string) {} + +// TrackSPFDuration does nothing +func (n *NoOpMetricsCollector) TrackSPFDuration(duration time.Duration) {} + +// TrackDKIMDuration does nothing +func (n *NoOpMetricsCollector) TrackDKIMDuration(duration time.Duration) {} + +// TrackDMARCDuration does nothing +func (n *NoOpMetricsCollector) TrackDMARCDuration(duration time.Duration) {} + +// TrackSPFResultByDomain does nothing +func (n *NoOpMetricsCollector) TrackSPFResultByDomain(domain, result string) {} + +// TrackDKIMResultByDomain does nothing +func (n *NoOpMetricsCollector) TrackDKIMResultByDomain(domain, result string) {} + +// TrackDMARCResultByDomain does nothing +func (n *NoOpMetricsCollector) TrackDMARCResultByDomain(domain, result string) {} + +// TrackAuthError does nothing +func (n *NoOpMetricsCollector) TrackAuthError(checkType, errorType string) {} + +// TrackMessageReadDuration does nothing +func (n *NoOpMetricsCollector) TrackMessageReadDuration(duration time.Duration) {} + +// TrackMessageParseDuration does nothing +func (n *NoOpMetricsCollector) TrackMessageParseDuration(duration time.Duration) {} + +// TrackMessageAuthDuration does nothing +func (n *NoOpMetricsCollector) TrackMessageAuthDuration(duration time.Duration) {} + +// TrackMessageWebhookDuration does nothing +func (n *NoOpMetricsCollector) TrackMessageWebhookDuration(duration time.Duration) {} + +// TrackMessageTotalProcessingDuration does nothing +func (n *NoOpMetricsCollector) TrackMessageTotalProcessingDuration(duration time.Duration) {} + +// TrackEmailAttachmentCount does nothing +func (n *NoOpMetricsCollector) TrackEmailAttachmentCount(count int) {} + +// TrackEmailAttachmentSize does nothing +func (n *NoOpMetricsCollector) TrackEmailAttachmentSize(size int64) {} + +// TrackEmailBodyType does nothing +func (n *NoOpMetricsCollector) TrackEmailBodyType(bodyType string) {} + +// TrackEmailHeaderCount does nothing +func (n *NoOpMetricsCollector) TrackEmailHeaderCount(count int) {} + +// TrackEmailRecipientCount does nothing +func (n *NoOpMetricsCollector) TrackEmailRecipientCount(count int) {} + +// TrackSessionDuration does nothing +func (n *NoOpMetricsCollector) TrackSessionDuration(duration time.Duration) {} + +// TrackSessionCommandCount does nothing +func (n *NoOpMetricsCollector) TrackSessionCommandCount(command string) {} + +// TrackSessionErrorCount does nothing +func (n *NoOpMetricsCollector) TrackSessionErrorCount(errorType string) {} + +// TrackSessionMessageCount does nothing +func (n *NoOpMetricsCollector) TrackSessionMessageCount() {} diff --git a/internal/observability/metrics_test.go b/internal/observability/metrics_test.go index 16f3f83..3b5d523 100644 --- a/internal/observability/metrics_test.go +++ b/internal/observability/metrics_test.go @@ -89,6 +89,128 @@ func createTestMetricsCollector(registry *prometheus.Registry) MetricsCollector Help: "Duration of webhook requests in seconds", Buckets: prometheus.DefBuckets, }), + // Webhook metrics + webhookRequestSize: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "webhook_request_size_bytes", + Help: "Size of webhook requests in bytes", + Buckets: prometheus.ExponentialBuckets(1024, 2, 15), // 1KB to 32MB + }), + webhookResponseSize: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "webhook_response_size_bytes", + Help: "Size of webhook responses in bytes", + Buckets: prometheus.ExponentialBuckets(1024, 2, 15), // 1KB to 32MB + }), + webhookRetries: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "webhook_retries_total", + Help: "Number of webhook retry attempts", + }, []string{"success"}), + webhookErrors: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "webhook_errors_total", + Help: "Number of webhook errors by type", + }, []string{"error_type"}), + // Authentication duration metrics + spfDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "spf_duration_seconds", + Help: "Duration of SPF checks in seconds", + Buckets: prometheus.DefBuckets, + }), + dkimDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "dkim_duration_seconds", + Help: "Duration of DKIM checks in seconds", + Buckets: prometheus.DefBuckets, + }), + dmarcDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "dmarc_duration_seconds", + Help: "Duration of DMARC checks in seconds", + Buckets: prometheus.DefBuckets, + }), + // Domain-based authentication metrics + spfResultsByDomain: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "spf_results_by_domain_total", + Help: "SPF results by domain", + }, []string{"domain", "result"}), + dkimResultsByDomain: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "dkim_results_by_domain_total", + Help: "DKIM results by domain", + }, []string{"domain", "result"}), + dmarcResultsByDomain: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "dmarc_results_by_domain_total", + Help: "DMARC results by domain", + }, []string{"domain", "result"}), + // Authentication error metrics + authErrors: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "auth_errors_total", + Help: "Number of authentication errors by type", + }, []string{"check_type", "error_type"}), + // Message processing stage metrics + messageReadDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "message_read_duration_seconds", + Help: "Duration of message read stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageParseDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "message_parse_duration_seconds", + Help: "Duration of message parse stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageAuthDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "message_auth_duration_seconds", + Help: "Duration of message auth stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageWebhookDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "message_webhook_duration_seconds", + Help: "Duration of message webhook stage in seconds", + Buckets: prometheus.DefBuckets, + }), + messageTotalDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "message_total_processing_duration_seconds", + Help: "Total duration of message processing in seconds", + Buckets: prometheus.DefBuckets, + }), + // Email characteristics metrics + emailAttachmentCount: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "email_attachment_count", + Help: "Number of attachments in emails", + Buckets: []float64{0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 50, 100}, + }), + emailAttachmentSize: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "email_attachment_size_bytes", + Help: "Size of email attachments in bytes", + Buckets: prometheus.ExponentialBuckets(1024, 2, 15), // 1KB to 32MB + }), + emailBodyType: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "email_body_type_total", + Help: "Number of emails by body type (text, html, multipart)", + }, []string{"body_type"}), + emailHeaderCount: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "email_header_count", + Help: "Number of headers in emails", + Buckets: []float64{0, 5, 10, 15, 20, 25, 30, 40, 50, 75, 100, 150, 200}, + }), + emailRecipientCount: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "email_recipient_count", + Help: "Number of recipients in emails", + Buckets: []float64{0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 50, 100}, + }), + // Session lifecycle metrics + sessionDuration: factory.NewHistogram(prometheus.HistogramOpts{ + Name: "session_duration_seconds", + Help: "Duration of SMTP sessions in seconds", + Buckets: prometheus.DefBuckets, + }), + sessionCommands: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "session_commands_total", + Help: "Number of SMTP commands executed per session", + }, []string{"command"}), + sessionErrors: factory.NewCounterVec(prometheus.CounterOpts{ + Name: "session_errors_total", + Help: "Number of SMTP session errors by type", + }, []string{"error_type"}), + sessionMessageCount: factory.NewCounter(prometheus.CounterOpts{ + Name: "session_messages_total", + Help: "Number of messages processed per session", + }), } } @@ -655,3 +777,388 @@ func TestMetricsServer_HealthEndpointNew(t *testing.T) { // The health endpoint code is covered by other tests t.Skip("Skipping health endpoint test due to server setup complexity") } + +// Test new webhook metrics +func TestMetricsCollector_TrackWebhookRequestSize(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackWebhookRequestSize(1024, 2048) + collector.TrackWebhookRequestSize(0, 0) + collector.TrackWebhookRequestSize(999999999, 999999999) +} + +func TestMetricsCollector_TrackWebhookRetry(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackWebhookRetry(true) + collector.TrackWebhookRetry(false) +} + +func TestMetricsCollector_TrackWebhookError(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackWebhookError("network") + collector.TrackWebhookError("timeout") + collector.TrackWebhookError("4xx") + collector.TrackWebhookError("5xx") +} + +// Test authentication duration metrics +func TestMetricsCollector_TrackSPFDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSPFDuration(100 * time.Millisecond) + collector.TrackSPFDuration(1 * time.Second) + collector.TrackSPFDuration(5 * time.Second) +} + +func TestMetricsCollector_TrackDKIMDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackDKIMDuration(50 * time.Millisecond) + collector.TrackDKIMDuration(500 * time.Millisecond) + collector.TrackDKIMDuration(2 * time.Second) +} + +func TestMetricsCollector_TrackDMARCDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackDMARCDuration(75 * time.Millisecond) + collector.TrackDMARCDuration(750 * time.Millisecond) + collector.TrackDMARCDuration(3 * time.Second) +} + +// Test domain-based authentication metrics +func TestMetricsCollector_TrackSPFResultByDomain(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSPFResultByDomain("example.com", "pass") + collector.TrackSPFResultByDomain("example.com", "fail") + collector.TrackSPFResultByDomain("test.org", "temperror") + collector.TrackSPFResultByDomain("", "none") +} + +func TestMetricsCollector_TrackDKIMResultByDomain(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackDKIMResultByDomain("example.com", "valid") + collector.TrackDKIMResultByDomain("example.com", "invalid") + collector.TrackDKIMResultByDomain("test.org", "valid") + collector.TrackDKIMResultByDomain("", "invalid") +} + +func TestMetricsCollector_TrackDMARCResultByDomain(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackDMARCResultByDomain("example.com", "pass") + collector.TrackDMARCResultByDomain("example.com", "fail") + collector.TrackDMARCResultByDomain("test.org", "none") + collector.TrackDMARCResultByDomain("", "pass") +} + +// Test authentication error metrics +func TestMetricsCollector_TrackAuthError(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackAuthError("spf", "dns") + collector.TrackAuthError("spf", "timeout") + collector.TrackAuthError("dkim", "verification") + collector.TrackAuthError("dkim", "parse") + collector.TrackAuthError("dmarc", "dns") + collector.TrackAuthError("dmarc", "timeout") +} + +// Test message processing stage metrics +func TestMetricsCollector_TrackMessageReadDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackMessageReadDuration(10 * time.Millisecond) + collector.TrackMessageReadDuration(100 * time.Millisecond) + collector.TrackMessageReadDuration(1 * time.Second) +} + +func TestMetricsCollector_TrackMessageParseDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackMessageParseDuration(5 * time.Millisecond) + collector.TrackMessageParseDuration(50 * time.Millisecond) + collector.TrackMessageParseDuration(500 * time.Millisecond) +} + +func TestMetricsCollector_TrackMessageAuthDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackMessageAuthDuration(20 * time.Millisecond) + collector.TrackMessageAuthDuration(200 * time.Millisecond) + collector.TrackMessageAuthDuration(2 * time.Second) +} + +func TestMetricsCollector_TrackMessageWebhookDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackMessageWebhookDuration(30 * time.Millisecond) + collector.TrackMessageWebhookDuration(300 * time.Millisecond) + collector.TrackMessageWebhookDuration(3 * time.Second) +} + +func TestMetricsCollector_TrackMessageTotalProcessingDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackMessageTotalProcessingDuration(100 * time.Millisecond) + collector.TrackMessageTotalProcessingDuration(1 * time.Second) + collector.TrackMessageTotalProcessingDuration(10 * time.Second) +} + +// Test email characteristics metrics +func TestMetricsCollector_TrackEmailAttachmentCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackEmailAttachmentCount(0) + collector.TrackEmailAttachmentCount(1) + collector.TrackEmailAttachmentCount(5) + collector.TrackEmailAttachmentCount(100) +} + +func TestMetricsCollector_TrackEmailAttachmentSize(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackEmailAttachmentSize(0) + collector.TrackEmailAttachmentSize(1024) + collector.TrackEmailAttachmentSize(1024 * 1024) + collector.TrackEmailAttachmentSize(1024 * 1024 * 100) +} + +func TestMetricsCollector_TrackEmailBodyType(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackEmailBodyType("text") + collector.TrackEmailBodyType("html") + collector.TrackEmailBodyType("multipart") + collector.TrackEmailBodyType("") +} + +func TestMetricsCollector_TrackEmailHeaderCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackEmailHeaderCount(0) + collector.TrackEmailHeaderCount(5) + collector.TrackEmailHeaderCount(20) + collector.TrackEmailHeaderCount(200) +} + +func TestMetricsCollector_TrackEmailRecipientCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackEmailRecipientCount(0) + collector.TrackEmailRecipientCount(1) + collector.TrackEmailRecipientCount(5) + collector.TrackEmailRecipientCount(100) +} + +// Test session lifecycle metrics +func TestMetricsCollector_TrackSessionDuration(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSessionDuration(1 * time.Second) + collector.TrackSessionDuration(30 * time.Second) + collector.TrackSessionDuration(5 * time.Minute) +} + +func TestMetricsCollector_TrackSessionCommandCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSessionCommandCount("MAIL") + collector.TrackSessionCommandCount("RCPT") + collector.TrackSessionCommandCount("DATA") + collector.TrackSessionCommandCount("RSET") + collector.TrackSessionCommandCount("QUIT") +} + +func TestMetricsCollector_TrackSessionErrorCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSessionErrorCount("syntax_error") + collector.TrackSessionErrorCount("too_many_recipients") + collector.TrackSessionErrorCount("message_processing") + collector.TrackSessionErrorCount("no_recipients") +} + +func TestMetricsCollector_TrackSessionMessageCount(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + collector.TrackSessionMessageCount() + collector.TrackSessionMessageCount() + collector.TrackSessionMessageCount() +} + +// Test NoOpMetricsCollector for all new methods +func TestNoOpMetricsCollector_NewMethods(t *testing.T) { + collector := &NoOpMetricsCollector{} + + // Test all new webhook methods + collector.TrackWebhookRequestSize(1024, 2048) + collector.TrackWebhookRetry(true) + collector.TrackWebhookError("network") + + // Test all new authentication methods + collector.TrackSPFDuration(100 * time.Millisecond) + collector.TrackDKIMDuration(50 * time.Millisecond) + collector.TrackDMARCDuration(75 * time.Millisecond) + collector.TrackSPFResultByDomain("example.com", "pass") + collector.TrackDKIMResultByDomain("example.com", "valid") + collector.TrackDMARCResultByDomain("example.com", "pass") + collector.TrackAuthError("spf", "dns") + + // Test all new processing stage methods + collector.TrackMessageReadDuration(10 * time.Millisecond) + collector.TrackMessageParseDuration(5 * time.Millisecond) + collector.TrackMessageAuthDuration(20 * time.Millisecond) + collector.TrackMessageWebhookDuration(30 * time.Millisecond) + collector.TrackMessageTotalProcessingDuration(100 * time.Millisecond) + + // Test all new email characteristics methods + collector.TrackEmailAttachmentCount(5) + collector.TrackEmailAttachmentSize(1024) + collector.TrackEmailBodyType("html") + collector.TrackEmailHeaderCount(20) + collector.TrackEmailRecipientCount(3) + + // Test all new session lifecycle methods + collector.TrackSessionDuration(30 * time.Second) + collector.TrackSessionCommandCount("MAIL") + collector.TrackSessionErrorCount("syntax_error") + collector.TrackSessionMessageCount() +} + +// Test edge cases for new metrics +func TestMetricsCollector_NewMetricsEdgeCases(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + // Test with zero and negative values + collector.TrackWebhookRequestSize(0, 0) + collector.TrackWebhookRequestSize(-1, -1) + collector.TrackEmailAttachmentCount(0) + collector.TrackEmailAttachmentCount(-1) + collector.TrackEmailAttachmentSize(0) + collector.TrackEmailAttachmentSize(-1) + collector.TrackEmailHeaderCount(0) + collector.TrackEmailHeaderCount(-1) + collector.TrackEmailRecipientCount(0) + collector.TrackEmailRecipientCount(-1) + + // Test with very large values + collector.TrackWebhookRequestSize(999999999, 999999999) + collector.TrackEmailAttachmentSize(999999999) + collector.TrackEmailHeaderCount(999) + collector.TrackEmailRecipientCount(999) + + // Test with empty strings + collector.TrackWebhookError("") + collector.TrackEmailBodyType("") + collector.TrackSessionCommandCount("") + collector.TrackSessionErrorCount("") + collector.TrackSPFResultByDomain("", "") + collector.TrackDKIMResultByDomain("", "") + collector.TrackDMARCResultByDomain("", "") + collector.TrackAuthError("", "") + + // Test with very long strings + longString := string(make([]byte, 1000)) + collector.TrackWebhookError(longString) + collector.TrackEmailBodyType(longString) + collector.TrackSessionCommandCount(longString) + collector.TrackSessionErrorCount(longString) + collector.TrackSPFResultByDomain(longString, longString) + collector.TrackDKIMResultByDomain(longString, longString) + collector.TrackDMARCResultByDomain(longString, longString) + collector.TrackAuthError(longString, longString) + + // Test with very long durations + collector.TrackSPFDuration(24 * time.Hour) + collector.TrackDKIMDuration(24 * time.Hour) + collector.TrackDMARCDuration(24 * time.Hour) + collector.TrackMessageReadDuration(24 * time.Hour) + collector.TrackMessageParseDuration(24 * time.Hour) + collector.TrackMessageAuthDuration(24 * time.Hour) + collector.TrackMessageWebhookDuration(24 * time.Hour) + collector.TrackMessageTotalProcessingDuration(24 * time.Hour) + collector.TrackSessionDuration(24 * time.Hour) +} + +// Test concurrent access to new metrics +func TestMetricsCollector_NewMetricsConcurrentAccess(t *testing.T) { + registry := prometheus.NewRegistry() + collector := createTestMetricsCollector(registry) + + done := make(chan bool, 10) + + for i := 0; i < 10; i++ { + go func() { + // Test webhook metrics + collector.TrackWebhookRequestSize(1024, 2048) + collector.TrackWebhookRetry(true) + collector.TrackWebhookError("network") + + // Test authentication metrics + collector.TrackSPFDuration(100 * time.Millisecond) + collector.TrackDKIMDuration(50 * time.Millisecond) + collector.TrackDMARCDuration(75 * time.Millisecond) + collector.TrackSPFResultByDomain("example.com", "pass") + collector.TrackDKIMResultByDomain("example.com", "valid") + collector.TrackDMARCResultByDomain("example.com", "pass") + collector.TrackAuthError("spf", "dns") + + // Test processing stage metrics + collector.TrackMessageReadDuration(10 * time.Millisecond) + collector.TrackMessageParseDuration(5 * time.Millisecond) + collector.TrackMessageAuthDuration(20 * time.Millisecond) + collector.TrackMessageWebhookDuration(30 * time.Millisecond) + collector.TrackMessageTotalProcessingDuration(100 * time.Millisecond) + + // Test email characteristics metrics + collector.TrackEmailAttachmentCount(5) + collector.TrackEmailAttachmentSize(1024) + collector.TrackEmailBodyType("html") + collector.TrackEmailHeaderCount(20) + collector.TrackEmailRecipientCount(3) + + // Test session lifecycle metrics + collector.TrackSessionDuration(30 * time.Second) + collector.TrackSessionCommandCount("MAIL") + collector.TrackSessionErrorCount("syntax_error") + collector.TrackSessionMessageCount() + + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/internal/processor/processor.go b/internal/processor/processor.go index 77438f7..f624504 100644 --- a/internal/processor/processor.go +++ b/internal/processor/processor.go @@ -54,13 +54,19 @@ func NewMessageProcessor( func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, from string, to []string, sessionID string, remoteIP net.IP) error { start := time.Now() + // Create session-aware logger with session_id context + sessionLogger := observability.NewSessionLogger(mp.logger, sessionID, "") + // Read the entire message + readStart := time.Now() var buf bytes.Buffer size, err := io.Copy(&buf, r) + readDuration := time.Since(readStart) + mp.metrics.TrackMessageReadDuration(readDuration) + if err != nil { - mp.logger.Error(). + sessionLogger.Error(). Err(err). - Str("session_id", sessionID). Msg("Failed to read message") mp.sentry.CaptureError(err, map[string]interface{}{ "session_id": sessionID, @@ -75,11 +81,14 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro rawMessage := buf.String() // Parse the email + parseStart := time.Now() email, err := letters.ParseEmail(bytes.NewReader(buf.Bytes())) + parseDuration := time.Since(parseStart) + mp.metrics.TrackMessageParseDuration(parseDuration) + if err != nil { - mp.logger.Error(). + sessionLogger.Error(). Err(err). - Str("session_id", sessionID). Int64("size", size). Msg("Failed to parse message") mp.sentry.CaptureError(err, map[string]interface{}{ @@ -93,8 +102,7 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro return fmt.Errorf("failed to parse message: %w", err) } - mp.logger.Info(). - Str("session_id", sessionID). + sessionLogger.Info(). Str("from", from). Strs("to", to). Str("subject", email.Headers.Subject). @@ -102,6 +110,9 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro Int64("size", size). Msg("Message parsed successfully") + // Track email characteristics + mp.trackEmailCharacteristics(&email) + // Create webhook payload payload := &domain.WebhookPayload{ Timestamp: email.Headers.Date, @@ -119,11 +130,14 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro // Perform authentication checks if mp.config.EnableSPF || mp.config.EnableDKIM || mp.config.EnableDMARC { - authResults, err := mp.performAuthenticationChecks(ctx, rawMessage, email.Headers, from, remoteIP) + authStart := time.Now() + authResults, err := mp.performAuthenticationChecks(ctx, rawMessage, email.Headers, from, remoteIP, sessionLogger) + authDuration := time.Since(authStart) + mp.metrics.TrackMessageAuthDuration(authDuration) + if err != nil { - mp.logger.Warn(). + sessionLogger.Warn(). Err(err). - Str("session_id", sessionID). Msg("Authentication checks failed") // Continue processing even if auth checks fail } else { @@ -132,13 +146,17 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro } // Send to webhook + webhookStart := time.Now() err = mp.webhookSender.SendWebhook(ctx, payload) + webhookDuration := time.Since(webhookStart) + mp.metrics.TrackMessageWebhookDuration(webhookDuration) duration := time.Since(start) + mp.metrics.TrackMessageTotalProcessingDuration(duration) + if err != nil { - mp.logger.Error(). + sessionLogger.Error(). Err(err). - Str("session_id", sessionID). Str("message_id", string(email.Headers.MessageID)). Dur("duration", duration). Msg("Failed to process message") @@ -146,8 +164,7 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro return err } - mp.logger.Info(). - Str("session_id", sessionID). + sessionLogger.Info(). Str("message_id", string(email.Headers.MessageID)). Dur("duration", duration). Msg("Message processed successfully") @@ -157,10 +174,10 @@ func (mp *MessageProcessor) ProcessMessage(ctx context.Context, r io.Reader, fro } // performAuthenticationChecks runs all enabled authentication checks -func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, rawMessage string, headers letters.Headers, from string, remoteIP net.IP) (domain.AuthenticationResults, error) { +func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, rawMessage string, headers letters.Headers, from string, remoteIP net.IP, sessionLogger observability.Logger) (domain.AuthenticationResults, error) { results := domain.AuthenticationResults{} - mp.logger.Debug(). + sessionLogger.Debug(). Bool("spf_enabled", mp.config.EnableSPF). Bool("dkim_enabled", mp.config.EnableDKIM). Bool("dmarc_enabled", mp.config.EnableDMARC). @@ -171,9 +188,9 @@ func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, raw // SPF Check if mp.config.EnableSPF { - spfResult, err := mp.authChecker.CheckSPF(ctx, remoteIP, domain, from) + spfResult, err := mp.authChecker.CheckSPF(ctx, remoteIP, domain, from, sessionLogger, mp.metrics) if err != nil { - mp.logger.Warn(). + sessionLogger.Warn(). Err(err). Msg("SPF check failed") mp.sentry.CaptureError(err, map[string]interface{}{ @@ -189,9 +206,9 @@ func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, raw // DKIM Check if mp.config.EnableDKIM { - dkimResult, err := mp.authChecker.CheckDKIM(ctx, rawMessage) + dkimResult, err := mp.authChecker.CheckDKIM(ctx, rawMessage, sessionLogger, mp.metrics) if err != nil { - mp.logger.Warn(). + sessionLogger.Warn(). Err(err). Msg("DKIM check failed") mp.sentry.CaptureError(err, map[string]interface{}{ @@ -206,9 +223,9 @@ func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, raw // DMARC Check - pass previously computed results to avoid redundant checks if mp.config.EnableDMARC { - dmarcResult, err := mp.authChecker.CheckDMARC(ctx, rawMessage, headers, results.SPF, results.DKIM) + dmarcResult, err := mp.authChecker.CheckDMARC(ctx, rawMessage, headers, results.SPF, results.DKIM, sessionLogger, mp.metrics) if err != nil { - mp.logger.Warn(). + sessionLogger.Warn(). Err(err). Msg("DMARC check failed") mp.sentry.CaptureError(err, map[string]interface{}{ @@ -221,7 +238,7 @@ func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, raw } } - mp.logger.Info(). + sessionLogger.Info(). Interface("spf", results.SPF). Interface("dkim", results.DKIM). Interface("dmarc", results.DMARC). @@ -229,3 +246,36 @@ func (mp *MessageProcessor) performAuthenticationChecks(ctx context.Context, raw return results, nil } + +// trackEmailCharacteristics tracks various characteristics of the parsed email +func (mp *MessageProcessor) trackEmailCharacteristics(email *letters.Email) { + // Track attachment count and size + attachmentCount := len(email.AttachedFiles) + mp.metrics.TrackEmailAttachmentCount(attachmentCount) + + var totalAttachmentSize int64 + for _, attachment := range email.AttachedFiles { + totalAttachmentSize += int64(len(attachment.Data)) + } + if totalAttachmentSize > 0 { + mp.metrics.TrackEmailAttachmentSize(totalAttachmentSize) + } + + // Track body type + bodyType := "text" + if email.HTML != "" { + bodyType = "html" + } + if len(email.AttachedFiles) > 0 { + bodyType = "multipart" + } + mp.metrics.TrackEmailBodyType(bodyType) + + // Track header count + headerCount := len(email.Headers.ExtraHeaders) + mp.metrics.TrackEmailHeaderCount(headerCount) + + // Track recipient count (To + CC + BCC) + recipientCount := len(email.Headers.To) + len(email.Headers.Cc) + len(email.Headers.Bcc) + mp.metrics.TrackEmailRecipientCount(recipientCount) +} diff --git a/internal/processor/processor_test.go b/internal/processor/processor_test.go index 14bdbf2..8647d91 100644 --- a/internal/processor/processor_test.go +++ b/internal/processor/processor_test.go @@ -25,15 +25,15 @@ type mockAuthChecker struct { dmarcErr error } -func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.SPFResult, error) { return m.spfResult, m.spfErr } -func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DKIMResult, error) { return m.dkimResult, m.dkimErr } -func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { +func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DMARCResult, error) { return m.dmarcResult, m.dmarcErr } @@ -54,6 +54,30 @@ func (m *mockMetricsCollector) TrackSPFResult(result string) func (m *mockMetricsCollector) TrackDKIMResult(valid bool) {} func (m *mockMetricsCollector) TrackDMARCResult(result, policy string) {} func (m *mockMetricsCollector) TrackWebhookRequest(statusCode int, duration time.Duration) {} +func (m *mockMetricsCollector) TrackWebhookRequestSize(requestSize, responseSize int64) {} +func (m *mockMetricsCollector) TrackWebhookRetry(success bool) {} +func (m *mockMetricsCollector) TrackWebhookError(errorType string) {} +func (m *mockMetricsCollector) TrackSPFDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDKIMDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDMARCDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSPFResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDKIMResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDMARCResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackAuthError(checkType, errorType string) {} +func (m *mockMetricsCollector) TrackMessageReadDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageParseDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageAuthDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageWebhookDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageTotalProcessingDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackEmailAttachmentCount(count int) {} +func (m *mockMetricsCollector) TrackEmailAttachmentSize(size int64) {} +func (m *mockMetricsCollector) TrackEmailBodyType(bodyType string) {} +func (m *mockMetricsCollector) TrackEmailHeaderCount(count int) {} +func (m *mockMetricsCollector) TrackEmailRecipientCount(count int) {} +func (m *mockMetricsCollector) TrackSessionDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSessionCommandCount(command string) {} +func (m *mockMetricsCollector) TrackSessionErrorCount(errorType string) {} +func (m *mockMetricsCollector) TrackSessionMessageCount() {} type mockSentryClient struct{} diff --git a/internal/processor/webhook.go b/internal/processor/webhook.go index e4630c4..9ac00b9 100644 --- a/internal/processor/webhook.go +++ b/internal/processor/webhook.go @@ -5,12 +5,26 @@ import ( "context" "encoding/json" "fmt" + "inboundparse/internal/observability" "net/http" "time" - "github.com/rs/zerolog/log" + "github.com/failsafe-go/failsafe-go" + "github.com/failsafe-go/failsafe-go/circuitbreaker" + "github.com/failsafe-go/failsafe-go/ratelimiter" + "github.com/failsafe-go/failsafe-go/retrypolicy" ) +// httpError represents an HTTP error response +type httpError struct { + statusCode int + message string +} + +func (e *httpError) Error() string { + return e.message +} + // WebhookSender interface for sending webhook requests type WebhookSender interface { SendWebhook(ctx context.Context, payload interface{}) error @@ -22,21 +36,87 @@ type WebhookConfig struct { Username string Password string Timeout time.Duration + + // Retry configuration + MaxRetries int + RetryDelay time.Duration + MaxRetryDelay time.Duration + RetryMultiplier float64 + + // Rate limiting configuration + RateLimitPerSecond int + RateLimitBurst int } // webhookSender implements WebhookSender interface type webhookSender struct { - config WebhookConfig - client *http.Client + config WebhookConfig + client *http.Client + logger observability.Logger + metrics observability.MetricsCollector + retryPolicy retrypolicy.RetryPolicy[any] + rateLimiter ratelimiter.RateLimiter[any] + circuitBreaker circuitbreaker.CircuitBreaker[any] } // NewWebhookSender creates a new webhook sender -func NewWebhookSender(config WebhookConfig) WebhookSender { +func NewWebhookSender(config WebhookConfig, logger observability.Logger, metrics observability.MetricsCollector) WebhookSender { + // Set default values if not configured + if config.MaxRetries == 0 { + config.MaxRetries = 3 + } + if config.RetryDelay == 0 { + config.RetryDelay = 1 * time.Second + } + if config.MaxRetryDelay == 0 { + config.MaxRetryDelay = 30 * time.Second + } + if config.RetryMultiplier == 0 { + config.RetryMultiplier = 2.0 + } + if config.RateLimitPerSecond == 0 { + config.RateLimitPerSecond = 10 + } + if config.RateLimitBurst == 0 { + config.RateLimitBurst = 20 + } + + // Create retry policy + retryPolicy := retrypolicy.NewBuilder[any](). + WithMaxRetries(config.MaxRetries). + WithDelay(config.RetryDelay). + Build() + + // Create rate limiter + rateLimit := config.RateLimitPerSecond + if rateLimit < 0 { + rateLimit = 0 + } + // Ensure rate limit is within valid uint range + if rateLimit > int(^uint(0)>>1) { + rateLimit = int(^uint(0) >> 1) + } + // Safe conversion to uint with explicit bounds checking + var rateLimitUint uint + if rateLimit >= 0 { + rateLimitUint = uint(rateLimit) + } + rateLimiter := ratelimiter.NewSmooth[any](rateLimitUint, time.Second) + + // Create circuit breaker + circuitBreaker := circuitbreaker.NewBuilder[any](). + WithFailureThreshold(5). + WithSuccessThreshold(3). + Build() + return &webhookSender{ - config: config, - client: &http.Client{ - Timeout: config.Timeout, - }, + config: config, + client: &http.Client{Timeout: config.Timeout}, + logger: logger, + metrics: metrics, + retryPolicy: retryPolicy, + rateLimiter: rateLimiter, + circuitBreaker: circuitBreaker, } } @@ -46,18 +126,27 @@ func (w *webhookSender) SendWebhook(ctx context.Context, payload interface{}) er jsonData, err := json.Marshal(payload) if err != nil { - log.Error(). + w.logger.Error(). Err(err). Msg("Failed to marshal webhook payload") + w.metrics.TrackWebhookError("marshal_error") return fmt.Errorf("failed to marshal payload: %w", err) } + requestSize := int64(len(jsonData)) + w.logger.Debug(). + Int64("request_size", requestSize). + Str("url", w.config.URL). + Msg("Sending webhook request") + + // Create the HTTP request req, err := http.NewRequestWithContext(ctx, "POST", w.config.URL, bytes.NewBuffer(jsonData)) if err != nil { - log.Error(). + w.logger.Error(). Err(err). - Str("webhook_url", w.config.URL). + Str("url", w.config.URL). Msg("Failed to create webhook request") + w.metrics.TrackWebhookError("request_creation_error") return fmt.Errorf("failed to create request: %w", err) } @@ -69,39 +158,84 @@ func (w *webhookSender) SendWebhook(ctx context.Context, payload interface{}) er req.SetBasicAuth(w.config.Username, w.config.Password) } - resp, err := w.client.Do(req) + // Execute with failsafe (retry policy, rate limiter, circuit breaker) + var resp *http.Response + executor := failsafe.With(w.retryPolicy, w.rateLimiter, w.circuitBreaker) + + err = executor.Run(func() error { + // Make the HTTP request + httpResp, httpErr := w.client.Do(req) + if httpErr != nil { + return httpErr + } + + // Check for HTTP error status codes + if httpResp.StatusCode >= 400 { + if closeErr := httpResp.Body.Close(); closeErr != nil { + w.logger.Warn(). + Err(closeErr). + Msg("Failed to close response body after error") + } + errorMsg := fmt.Sprintf("webhook returned error status: %d", httpResp.StatusCode) + return &httpError{ + statusCode: httpResp.StatusCode, + message: errorMsg, + } + } + + resp = httpResp + return nil + }) + duration := time.Since(start) + // Handle execution errors if err != nil { - log.Error(). + w.logger.Error(). Err(err). - Str("webhook_url", w.config.URL). + Str("url", w.config.URL). Dur("duration", duration). - Msg("Webhook request failed") + Msg("Webhook request failed after retries") + + // Determine error type for metrics + errorType := "network_error" + if ctx.Err() == context.DeadlineExceeded { + errorType = "timeout_error" + } + w.metrics.TrackWebhookError(errorType) return fmt.Errorf("webhook request failed: %w", err) } + + // Handle HTTP error responses + if resp == nil { + return fmt.Errorf("webhook request failed: no response received") + } + defer func() { if closeErr := resp.Body.Close(); closeErr != nil { - log.Warn(). + w.logger.Warn(). Err(closeErr). Msg("Failed to close response body") } }() - if resp.StatusCode >= 400 { - log.Error(). - Str("webhook_url", w.config.URL). - Int("status_code", resp.StatusCode). - Dur("duration", duration). - Msg("Webhook returned error status") - return fmt.Errorf("webhook returned error status: %d", resp.StatusCode) + // Track metrics + w.metrics.TrackWebhookRequest(resp.StatusCode, duration) + + // Track request/response size + responseSize := resp.ContentLength + if responseSize < 0 { + responseSize = 0 // Content-Length header not present } + w.metrics.TrackWebhookRequestSize(requestSize, responseSize) - log.Info(). - Str("webhook_url", w.config.URL). + // Log response details + w.logger.Info(). Int("status_code", resp.StatusCode). Dur("duration", duration). - Msg("Successfully sent email to webhook") + Int64("request_size", requestSize). + Int64("response_size", responseSize). + Msg("Webhook request completed") return nil } diff --git a/internal/processor/webhook_test.go b/internal/processor/webhook_test.go index 5332210..916ba28 100644 --- a/internal/processor/webhook_test.go +++ b/internal/processor/webhook_test.go @@ -34,7 +34,7 @@ func TestWebhookSender_SendWebhook_Success(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -68,7 +68,7 @@ func TestWebhookSender_SendWebhook_WithAuth(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -91,7 +91,7 @@ func TestWebhookSender_SendWebhook_ServerError(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -115,7 +115,7 @@ func TestWebhookSender_SendWebhook_Timeout(t *testing.T) { Timeout: 100 * time.Millisecond, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -133,7 +133,7 @@ func TestWebhookSender_SendWebhook_JSONMarshalError(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) // Create a payload that will cause JSON marshal to fail payload := map[string]interface{}{ @@ -152,7 +152,7 @@ func TestWebhookSender_SendWebhook_RequestCreationError(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) // Create a payload that will cause request creation to fail payload := map[string]interface{}{ @@ -161,7 +161,7 @@ func TestWebhookSender_SendWebhook_RequestCreationError(t *testing.T) { // Test with invalid URL config.URL = "invalid-url" - sender = NewWebhookSender(config) + sender = NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) err := sender.SendWebhook(context.Background(), payload) if err == nil { @@ -181,7 +181,7 @@ func TestWebhookSender_SendWebhook_ResponseBodyCloseError(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -204,7 +204,7 @@ func TestWebhookSender_SendWebhook_ErrorStatus(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -227,7 +227,7 @@ func TestWebhookSender_SendWebhook_NotFound(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -250,7 +250,7 @@ func TestWebhookSender_SendWebhook_Unauthorized(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -273,7 +273,7 @@ func TestWebhookSender_SendWebhook_Forbidden(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -296,7 +296,7 @@ func TestWebhookSender_SendWebhook_InternalServerError(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -319,7 +319,7 @@ func TestWebhookSender_SendWebhook_BadGateway(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -342,7 +342,7 @@ func TestWebhookSender_SendWebhook_ServiceUnavailable(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -365,7 +365,7 @@ func TestWebhookSender_SendWebhook_GatewayTimeout(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -383,7 +383,7 @@ func TestWebhookSender_SendWebhook_NetworkError(t *testing.T) { Timeout: 1 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -407,7 +407,7 @@ func TestWebhookSender_SendWebhook_ContextCancellation(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -433,7 +433,7 @@ func TestWebhookSender_SendWebhook_WithCustomHeaders(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{ "test": "data", @@ -455,7 +455,7 @@ func TestWebhookSender_SendWebhook_WithEmptyPayload(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) payload := map[string]interface{}{} @@ -475,7 +475,7 @@ func TestWebhookSender_SendWebhook_WithNilPayload(t *testing.T) { Timeout: 5 * time.Second, } - sender := NewWebhookSender(config) + sender := NewWebhookSender(config, &mockLogger{}, &mockMetricsCollector{}) var payload interface{} = nil diff --git a/internal/smtp/backend.go b/internal/smtp/backend.go index f4c2720..4448aba 100644 --- a/internal/smtp/backend.go +++ b/internal/smtp/backend.go @@ -7,6 +7,7 @@ import ( "inboundparse/internal/processor" "io" "net" + "time" "github.com/emersion/go-smtp" "github.com/google/uuid" @@ -51,19 +52,22 @@ func (b *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) { // Track connection metrics cleanup := b.metrics.TrackConnection() - b.logger.Debug(). - Str("session_id", sessionID). - Str("remote_addr", remote). + // Create session-specific logger with context + sessionLogger := observability.NewSessionLogger(b.logger, sessionID, remote) + + sessionLogger.Debug(). Msg("New SMTP session started") return &Session{ config: b.config, messageProcessor: b.messageProcessor, metrics: b.metrics, - logger: b.logger, + logger: sessionLogger, conn: c, sessionID: sessionID, cleanup: cleanup, + sessionStart: time.Now(), + messageCount: 0, }, nil } @@ -78,6 +82,8 @@ type Session struct { to []string sessionID string cleanup func() // cleanup function for connection tracking + sessionStart time.Time + messageCount int } // AuthPlain handles PLAIN authentication (we'll accept any or none) @@ -93,9 +99,12 @@ func (s *Session) AuthPlain(username, password string) error { // Mail is called when the MAIL FROM command is received func (s *Session) Mail(from string, opts *smtp.MailOptions) error { s.logger.Debug(). - Str("session_id", s.sessionID). Str("from", from). Msg("MAIL FROM received") + + // Track MAIL command + s.metrics.TrackSessionCommandCount("MAIL") + s.from = from return nil } @@ -103,12 +112,15 @@ func (s *Session) Mail(from string, opts *smtp.MailOptions) error { // Rcpt is called when a RCPT TO command is received func (s *Session) Rcpt(to string, opts *smtp.RcptOptions) error { s.logger.Debug(). - Str("session_id", s.sessionID). Str("to", to). Msg("RCPT TO received") + // Track RCPT command + s.metrics.TrackSessionCommandCount("RCPT") + // Check max recipients limit if s.config.MaxRecipients > 0 && len(s.to) >= s.config.MaxRecipients { + s.metrics.TrackSessionErrorCount("too_many_recipients") return fmt.Errorf("too many recipients") } @@ -119,17 +131,20 @@ func (s *Session) Rcpt(to string, opts *smtp.RcptOptions) error { // Data is called when the DATA command is received func (s *Session) Data(r io.Reader) error { s.logger.Info(). - Str("session_id", s.sessionID). Str("from", s.from). Strs("to", s.to). Msg("Receiving message data") + // Track DATA command + s.metrics.TrackSessionCommandCount("DATA") + // Track active session sessionCleanup := s.metrics.TrackSession() defer sessionCleanup() // Check if there are recipients if len(s.to) == 0 { + s.metrics.TrackSessionErrorCount("no_recipients") return fmt.Errorf("no recipients") } @@ -142,16 +157,30 @@ func (s *Session) Data(r io.Reader) error { } } - // Process the email message + // Process the email message with session-aware logger ctx := context.Background() // TODO: Add proper context with timeout - return s.messageProcessor.ProcessMessage(ctx, r, s.from, s.to, s.sessionID, remoteIP) + err := s.messageProcessor.ProcessMessage(ctx, r, s.from, s.to, s.sessionID, remoteIP) + + // Track message count regardless of success/failure + s.messageCount++ + s.metrics.TrackSessionMessageCount() + + // Track processing errors + if err != nil { + s.metrics.TrackSessionErrorCount("message_processing") + } + + return err } // Reset is called when the RSET command is received func (s *Session) Reset() { s.logger.Debug(). - Str("session_id", s.sessionID). Msg("Session reset") + + // Track RSET command + s.metrics.TrackSessionCommandCount("RSET") + s.from = "" s.to = nil } @@ -159,9 +188,18 @@ func (s *Session) Reset() { // Logout is called when the session is closed func (s *Session) Logout() error { s.logger.Debug(). - Str("session_id", s.sessionID). Msg("Session logout") + // Track session duration + sessionDuration := time.Since(s.sessionStart) + s.metrics.TrackSessionDuration(sessionDuration) + + // Log session summary + s.logger.Info(). + Dur("session_duration", sessionDuration). + Int("message_count", s.messageCount). + Msg("Session ended") + // Clean up connection tracking if s.cleanup != nil { s.cleanup() diff --git a/internal/smtp/backend_test.go b/internal/smtp/backend_test.go index 718348f..f37324a 100644 --- a/internal/smtp/backend_test.go +++ b/internal/smtp/backend_test.go @@ -39,15 +39,15 @@ func createMockMessageProcessor() *processor.MessageProcessor { // Mock implementations for testing type mockAuthChecker struct{} -func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string) (*domain.SPFResult, error) { +func (m *mockAuthChecker) CheckSPF(ctx context.Context, remoteIP net.IP, domainName, from string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.SPFResult, error) { return nil, nil } -func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string) (*domain.DKIMResult, error) { +func (m *mockAuthChecker) CheckDKIM(ctx context.Context, rawMessage string, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DKIMResult, error) { return nil, nil } -func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult) (*domain.DMARCResult, error) { +func (m *mockAuthChecker) CheckDMARC(ctx context.Context, rawMessage string, headers letters.Headers, spfResult *domain.SPFResult, dkimResult *domain.DKIMResult, logger observability.Logger, metrics observability.MetricsCollector) (*domain.DMARCResult, error) { return nil, nil } @@ -80,6 +80,30 @@ func (m *mockMetricsCollector) TrackSPFResult(result string) func (m *mockMetricsCollector) TrackDKIMResult(valid bool) {} func (m *mockMetricsCollector) TrackDMARCResult(result, policy string) {} func (m *mockMetricsCollector) TrackWebhookRequest(statusCode int, duration time.Duration) {} +func (m *mockMetricsCollector) TrackWebhookRequestSize(requestSize, responseSize int64) {} +func (m *mockMetricsCollector) TrackWebhookRetry(success bool) {} +func (m *mockMetricsCollector) TrackWebhookError(errorType string) {} +func (m *mockMetricsCollector) TrackSPFDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDKIMDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackDMARCDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSPFResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDKIMResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackDMARCResultByDomain(domain, result string) {} +func (m *mockMetricsCollector) TrackAuthError(checkType, errorType string) {} +func (m *mockMetricsCollector) TrackMessageReadDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageParseDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageAuthDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageWebhookDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackMessageTotalProcessingDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackEmailAttachmentCount(count int) {} +func (m *mockMetricsCollector) TrackEmailAttachmentSize(size int64) {} +func (m *mockMetricsCollector) TrackEmailBodyType(bodyType string) {} +func (m *mockMetricsCollector) TrackEmailHeaderCount(count int) {} +func (m *mockMetricsCollector) TrackEmailRecipientCount(count int) {} +func (m *mockMetricsCollector) TrackSessionDuration(duration time.Duration) {} +func (m *mockMetricsCollector) TrackSessionCommandCount(command string) {} +func (m *mockMetricsCollector) TrackSessionErrorCount(errorType string) {} +func (m *mockMetricsCollector) TrackSessionMessageCount() {} type mockLogger struct{}