diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..15b923a --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,16 @@ +# CODEOWNERS — automatically request reviews for matching paths + +# Global owner — reviews all PRs by default +* @Goldokpa + +# GitHub config and workflows +/.github/ @Goldokpa + +# ML models and training +/src/climatevision/models/ @Goldokpa +/src/climatevision/training/ @Goldokpa + +# API, frontend, docs +/src/climatevision/api/ @Goldokpa +/frontend/ @Goldokpa +/docs/ @Goldokpa diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..59d2530 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,26 @@ +version: 2 +updates: + # Python dependencies (pip) + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 10 + reviewers: + - "Goldokpa" + + # GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + reviewers: + - "Goldokpa" + + # Node / npm (frontend) + - package-ecosystem: "npm" + directory: "/frontend" + schedule: + interval: "weekly" + reviewers: + - "Goldokpa" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..1db5343 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,34 @@ +## Summary + + + +## Related Issue + +Closes # + +## Type of Change + +- [ ] Bug fix +- [ ] New feature +- [ ] Breaking change +- [ ] Documentation update +- [ ] Refactor / code cleanup +- [ ] CI / build / tooling change + +## Key Changes + + + +## Testing + +- [ ] Unit tests pass locally (`pytest tests/`) +- [ ] Manual API test (curl / OpenAPI docs) +- [ ] Frontend smoke test (`npm run dev`) +- [ ] New tests added for this change + +## Checklist + +- [ ] Code follows project style (black/ruff for Python, eslint for frontend) +- [ ] Self-review completed +- [ ] Documentation updated where needed +- [ ] PR targets the `develop` branch (not `main`) diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..120ae56 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,58 @@ +# Changelog + +All notable changes to ClimateVision will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [Unreleased] + +### Added +- SECURITY.md — private vulnerability reporting via GitHub Security Advisories +- CODEOWNERS — automatic review assignment to @Goldokpa +- Pull request template for structured contributor guidance +- Dependabot configuration for pip, npm, and GitHub Actions updates +- CHANGELOG.md (this file) +- CITATION.cff for GitHub "Cite this repository" button + +### Changed +- CODE_OF_CONDUCT.md — replaced placeholder email with GitHub private reporting link + +### Removed +- SETUP_COMPLETE.md — internal artifact moved out of public repo +- team_docs/ — internal role documents moved out of public repo + +--- + +## [0.2.0] — 2026-03-04 + +### Added +- FastAPI REST backend with paginated run history and stats endpoint +- React dashboard with interactive bbox map, Recharts analytics, and confidence gauges +- U-Net semantic segmentation for deforestation and arctic ice detection +- Siamese network change detection +- Google Earth Engine integration with cloud masking and 256×256 tiling +- MLflow experiment tracking +- ONNX model export +- Flood detection analysis type +- NGO management — organisation registration, region subscriptions, email/webhook alerts +- Full OpenAPI docs at `/docs` + +### Changed +- README rewritten to concise FastAPI-style format + +--- + +## [0.1.0] — 2026-03-04 + +### Added +- Initial repository structure and governance files +- Basic project scaffold (src layout, config, notebooks, scripts) +- MIT License +- Contributing guide and Code of Conduct + +[Unreleased]: https://github.com/Climate-Vision/ClimateVision/compare/v0.2.0...HEAD +[0.2.0]: https://github.com/Climate-Vision/ClimateVision/compare/v0.1.0...v0.2.0 +[0.1.0]: https://github.com/Climate-Vision/ClimateVision/releases/tag/v0.1.0 diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..0890f7a --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,29 @@ +cff-version: 1.2.0 +message: "If you use ClimateVision in your research, please cite it using this file." +type: software +title: "ClimateVision: Open-Source AI Platform for Environmental Monitoring" +version: "0.2.0" +date-released: "2026-03-04" +url: "https://github.com/Climate-Vision/ClimateVision" +repository-code: "https://github.com/Climate-Vision/ClimateVision" +license: MIT +abstract: > + ClimateVision is an open-source machine learning platform that detects + environmental change from satellite imagery. It uses deep learning + (U-Net, Siamese networks) to monitor deforestation, arctic ice melting, + and flooding, giving conservation NGOs and researchers automated alerts + without manual analysis. Built on Sentinel-2 and Landsat data via + Google Earth Engine, it runs as a REST API with a React dashboard. +keywords: + - climate + - machine-learning + - satellite-imagery + - deep-learning + - remote-sensing + - deforestation + - google-earth-engine + - fastapi + - u-net +authors: + - name: "ClimateVision Contributors" + website: "https://github.com/Climate-Vision" diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 7855bf7..a2e6986 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,77 +1 @@ -# Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. - -We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for our community include: - -- Demonstrating empathy and kindness toward other people -- Being respectful of differing opinions, viewpoints, and experiences -- Giving and gracefully accepting constructive feedback -- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -- Focusing on what is best not just for us as individuals, but for the overall community - -Examples of unacceptable behavior include: - -- The use of sexualized language or imagery, and sexual attention or advances of any kind -- Trolling, insulting or derogatory comments, and personal or political attacks -- Public or private harassment -- Publishing others' private information, such as a physical or email address, without their explicit permission -- Other conduct which could reasonably be considered inappropriate in a professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. - -Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at: - -- #email - -All complaints will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and security of the reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series of actions. - -**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction within the community. - -## Attribution - -This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code +# Code of Conduct## Our PledgeWe as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.## Our StandardsExamples of behavior that contributes to a positive environment for our community include:- Demonstrating empathy and kindness toward other people- Being respectful of differing opinions, viewpoints, and experiences- Giving and gracefully accepting constructive feedback- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience- Focusing on what is best not just for us as individuals, but for the overall communityExamples of unacceptable behavior include:- The use of sexualized language or imagery, and sexual attention or advances of any kind- Trolling, insulting or derogatory comments, and personal or political attacks- Public or private harassment- Publishing others' private information, such as a physical or email address, without their explicit permission- Other conduct which could reasonably be considered inappropriate in a professional setting## Enforcement ResponsibilitiesCommunity leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.## ScopeThis Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.## EnforcementInstances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement by opening a [GitHub Security Advisory](https://github.com/Climate-Vision/ClimateVision/security/advisories/new) in this repository.All complaints will be reviewed and investigated promptly and fairly.All community leaders are obligated to respect the privacy and security of the reporter of any incident.## Enforcement GuidelinesCommunity leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:### 1. Correction**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..3d2feca --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,43 @@ +# Security Policy + +## Supported Versions + +ClimateVision is under active development. Security fixes are applied to the latest release on the `main` branch. + +| Version | Supported | +| ------- | ------------------ | +| 0.2.x | :white_check_mark: | +| < 0.2 | :x: | + +## Reporting a Vulnerability + +**Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** + +Instead, please report them privately using GitHub's built-in Security Advisory system: + +- Go to the [Security tab](https://github.com/Climate-Vision/ClimateVision/security) of this repository. +- Click **"Report a vulnerability"**. +- Fill out the form with a description of the issue, steps to reproduce, and (if known) a suggested fix. + +You should receive an initial response within **5 business days**. If the issue is confirmed, we will work on a fix and coordinate disclosure with you. + +## Scope + +**In scope:** + +- Vulnerabilities in the ClimateVision API (`src/climatevision/api/`) +- Vulnerabilities in the React dashboard (`frontend/`) +- Vulnerabilities in the data pipeline, model inference, or authentication flow +- Dependency vulnerabilities not already tracked by Dependabot + +**Out of scope:** + +- Issues in third-party services (Google Earth Engine, MLflow, etc.) — please report those upstream +- Self-inflicted issues from running with debug or development configuration in production +- Missing security best-practices without a demonstrated exploit + +## Disclosure Policy + +We follow a coordinated disclosure model. After a fix is released, we will publish a GitHub Security Advisory crediting the reporter (unless anonymity is requested). + +Thank you for helping keep ClimateVision and its users safe. diff --git a/SETUP_COMPLETE.md b/SETUP_COMPLETE.md deleted file mode 100644 index e4fb39f..0000000 --- a/SETUP_COMPLETE.md +++ /dev/null @@ -1,463 +0,0 @@ -# ClimateVision Project - Setup Complete! 🎉 - -## ✅ What's Been Created - -Your ClimateVision project is now ready to start development! Here's everything that's been set up: - -### 📦 Core Package Structure - -``` -ClimateVision/ -├── src/climatevision/ ✅ Main package -│ ├── __init__.py ✅ Package initialization -│ ├── config.py ✅ Configuration management -│ ├── models/ ✅ ML models (COMPLETE) -│ │ ├── unet.py ✅ U-Net & Attention U-Net -│ │ └── siamese.py ✅ Siamese Network for change detection -│ ├── utils/ ✅ Utilities (COMPLETE) -│ │ ├── metrics.py ✅ Evaluation metrics & loss functions -│ │ ├── visualization.py ✅ Plotting & visualization -│ │ └── geospatial.py ✅ Geospatial utilities -│ ├── data/ 📝 TODO (Engineer 2) -│ ├── inference/ 📝 TODO (Engineer 4) -│ └── api/ 📝 TODO (Engineer 4) -``` - -### 📚 Documentation Files - -``` -✅ README.md - Comprehensive project overview -✅ CONTRIBUTING.md - Contribution guidelines -✅ PROJECT_STRUCTURE.md - Codebase organization guide -✅ GETTING_STARTED.md - Developer onboarding guide -✅ LICENSE - MIT License -``` - -### 🔧 Configuration Files - -``` -✅ setup.py - Package installation -✅ requirements.txt - Python dependencies -✅ .gitignore - Git ignore rules -``` - -### 📓 Notebooks - -``` -✅ notebooks/01_quickstart.ipynb - Getting started tutorial -``` - ---- - -## 🚀 What Works Right Now - -### 1. Models Module ✅ -- **U-Net**: Semantic segmentation for forest/non-forest classification -- **Attention U-Net**: Improved segmentation with attention mechanism -- **Siamese Network**: Change detection between two time periods -- **Early Fusion Network**: Alternative change detection approach - -**Test it**: -```python -from climatevision.models import UNet, SiameseNetwork -import torch - -# U-Net for segmentation -model = UNet(n_channels=13, n_classes=2) -x = torch.randn(1, 13, 256, 256) -output = model(x) # Shape: (1, 2, 256, 256) - -# Siamese for change detection -siamese = SiameseNetwork(in_channels=13) -before = torch.randn(1, 13, 256, 256) -after = torch.randn(1, 13, 256, 256) -change_map = siamese.predict_binary(before, after) -``` - -### 2. Utilities Module ✅ - -**Metrics**: -- IoU, Dice coefficient, pixel accuracy -- Segmentation metrics (F1, precision, recall) -- Change detection metrics (confusion matrix, kappa) -- Custom loss functions (Dice Loss, Focal Loss) - -**Visualization**: -- Satellite image display (RGB, false color) -- Prediction overlays -- Change detection maps -- NDVI calculation and visualization -- Training history plots - -**Geospatial**: -- Coordinate transformations -- Area calculations (hectares, carbon loss) -- Bounding box operations -- GeoTIFF metadata generation -- Tile generation for large images - -**Test it**: -```python -from climatevision.utils import ( - calculate_iou, - visualize_prediction, - calculate_carbon_loss -) -import numpy as np - -# Calculate metrics -pred = np.array([[0, 1], [1, 1]]) -target = np.array([[0, 1], [1, 0]]) -iou = calculate_iou(pred, target, num_classes=2) - -# Estimate carbon loss -deforestation_ha = 100 -carbon_loss_tons = calculate_carbon_loss( - deforestation_area_ha=deforestation_ha, - biomass_density_t_per_ha=150 -) -``` - -### 3. Configuration System ✅ -- Project paths management -- Model hyperparameters -- Sentinel-2 band configurations -- Automatic directory creation - ---- - -## 📝 What Needs to Be Built (Next 3 Months) - -### Month 1: Foundation (Weeks 1-4) - -#### Week 1-2: Data Pipeline (Engineer 2) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Implement Sentinel-2 data loader (`data/sentinel2.py`) -- [ ] Create Landsat data loader (`data/landsat.py`) -- [ ] Build PyTorch Dataset class (`data/dataset.py`) -- [ ] Add preprocessing pipeline (`data/preprocess.py`) -- [ ] Implement data augmentation (`data/augmentation.py`) - -**Success Criteria**: Load and preprocess one Sentinel-2 tile - -#### Week 1-2: Training Infrastructure (Engineer 1) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Create training loop (`training/trainer.py`) -- [ ] Add model checkpointing (`training/checkpointing.py`) -- [ ] Implement evaluation framework (`training/evaluator.py`) -- [ ] Add training callbacks (`training/callbacks.py`) - -**Success Criteria**: Train U-Net on synthetic data with logging - -#### Week 3-4: Initial Model Training (Engineer 1 & 2) -**Priority**: MEDIUM -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Find and curate public forest datasets -- [ ] Train baseline U-Net model -- [ ] Evaluate on test set -- [ ] Document results in notebook - -**Success Criteria**: >85% accuracy on public dataset - -#### Week 3-4: Carbon Estimation (Engineer 3) -**Priority**: MEDIUM -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Implement Random Forest regressor (`models/carbon_estimator.py`) -- [ ] Add XGBoost model -- [ ] Create validation framework -- [ ] Implement uncertainty quantification - -**Success Criteria**: RMSE < 20 tons/ha on validation set - -### Month 2: Advanced Features (Weeks 5-8) - -#### Week 5-6: Change Detection (Engineer 1) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Train Siamese network -- [ ] Optimize change detection performance -- [ ] Add temporal smoothing -- [ ] Create change detection notebook - -**Success Criteria**: F1 > 0.90 on test set - -#### Week 5-6: Batch Processing (Engineer 4) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Create inference pipeline (`inference/predictor.py`) -- [ ] Implement batch processor (`inference/batch_processor.py`) -- [ ] Add ONNX optimization (`inference/onnx_optimizer.py`) -- [ ] Write post-processing utilities - -**Success Criteria**: Process 100 images in <5 minutes - -#### Week 7-8: API Development (Engineer 4) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Set up FastAPI application (`api/main.py`) -- [ ] Add prediction endpoints (`api/routes.py`) -- [ ] Implement authentication -- [ ] Add rate limiting -- [ ] Write API documentation - -**Success Criteria**: API responds in <100ms per request - -#### Week 7-8: Model Optimization (Engineer 1 & 3) -**Priority**: MEDIUM -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Hyperparameter tuning with Optuna -- [ ] Model quantization for speed -- [ ] Ensemble methods -- [ ] Uncertainty quantification - -**Success Criteria**: 2x faster inference, same accuracy - -### Month 3: Deployment & Scale (Weeks 9-12) - -#### Week 9-10: Dashboard (Team Effort) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Set up React project (`frontend/`) -- [ ] Create map component (Leaflet) -- [ ] Add prediction visualization -- [ ] Implement time series charts -- [ ] Connect to API - -**Success Criteria**: Functional web dashboard - -#### Week 11-12: Deployment (Engineer 4 + Lead) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Docker containerization -- [ ] Write deployment docs -- [ ] Set up CI/CD pipeline -- [ ] Deploy to cloud (AWS/GCP) -- [ ] Performance testing - -**Success Criteria**: Production-ready deployment - -#### Week 11-12: Documentation & Launch (Team) -**Priority**: HIGH -**Status**: 🔴 Not Started - -**Tasks**: -- [ ] Complete API documentation -- [ ] Write user guides -- [ ] Create demo videos -- [ ] Prepare launch materials -- [ ] Community outreach - -**Success Criteria**: 50+ GitHub stars in first week - ---- - -## 🎯 Immediate Next Steps (This Week) - -### For the Team Lead (You) - -1. **Create GitHub Repository** - ```bash - cd ClimateVision - git init - git add . - git commit -m "Initial commit: project structure and core models" - git remote add origin https://github.com/yourusername/ClimateVision.git - git push -u origin main - ``` - -2. **Set Up Project Board** - - Create GitHub Project board - - Add all tasks from GETTING_STARTED.md - - Assign to team members - -3. **Schedule Kickoff Meeting** - - Review project goals - - Assign Week 1 tasks - - Set up communication channels - -4. **Environment Setup** - ```bash - # Create requirements-dev.txt - pip freeze > requirements-dev.txt - ``` - -### For Each Team Member - -1. **Clone and Set Up** - ```bash - git clone https://github.com/yourusername/ClimateVision.git - cd ClimateVision - python -m venv venv - source venv/bin/activate - pip install -r requirements.txt - pip install -e . - ``` - -2. **Read Documentation** - - [ ] README.md - - [ ] GETTING_STARTED.md - - [ ] PROJECT_STRUCTURE.md - -3. **Verify Installation** - ```bash - python -c "from climatevision.models import UNet; print('✓ Setup complete!')" - jupyter notebook notebooks/01_quickstart.ipynb - ``` - -4. **Start First Task** (See GETTING_STARTED.md for your role) - ---- - -## 📊 Success Metrics - -### Technical Metrics -- [ ] Forest segmentation accuracy > 95% -- [ ] Change detection F1 score > 0.90 -- [ ] API latency < 100ms -- [ ] Code coverage > 80% -- [ ] Zero critical bugs - -### Community Metrics -- [ ] 50+ stars in Month 1 -- [ ] 150+ stars in Month 2 -- [ ] 300+ stars in Month 3 -- [ ] 10+ external contributors -- [ ] 5+ active forks - -### Impact Metrics -- [ ] 100,000+ hectares monitored -- [ ] 50+ deforestation alerts generated -- [ ] 3+ partner NGOs -- [ ] 2+ research projects using ClimateVision - ---- - -## 🛠️ Development Tools Recommended - -### IDEs -- **VSCode**: Python, Jupyter extensions -- **PyCharm**: Professional Python IDE -- **Jupyter Lab**: Interactive development - -### Version Control -- **Git**: Version control -- **GitHub Desktop**: GUI for Git (optional) -- **GitKraken**: Advanced Git GUI (optional) - -### Testing & Quality -- **pytest**: Unit testing -- **black**: Code formatting -- **flake8**: Linting -- **mypy**: Type checking - -### MLOps -- **MLflow**: Experiment tracking -- **DVC**: Data version control -- **Weights & Biases**: Alternative to MLflow - -### Deployment -- **Docker**: Containerization -- **Kubernetes**: Orchestration -- **GitHub Actions**: CI/CD - ---- - -## 📞 Communication Channels - -### Recommended Setup -1. **GitHub Issues**: Bug reports, feature requests -2. **GitHub Discussions**: General questions, ideas -3. **Slack/Discord**: Daily communication -4. **Weekly Meetings**: Sprint planning, reviews - -### Response Times -- **Critical bugs**: < 4 hours -- **PRs for review**: < 24 hours -- **Questions**: < 1 day -- **Feature requests**: < 1 week - ---- - -## 🎓 Learning Path - -### Week 1: Foundation -- [ ] PyTorch basics -- [ ] Rasterio for geospatial data -- [ ] Git workflow - -### Week 2-4: Specialization -- [ ] Your role-specific technologies -- [ ] MLOps best practices -- [ ] Testing strategies - -### Month 2: Advanced -- [ ] Model optimization -- [ ] API design patterns -- [ ] Deployment strategies - ---- - -## 🏆 Milestones - -### ✅ Milestone 0: Project Setup (COMPLETE) -- Project structure created -- Core models implemented -- Documentation written -- Ready for development - -### 📅 Milestone 1: Week 4 (Foundation) -- Data pipeline working -- Training infrastructure ready -- Models training on real data - -### 📅 Milestone 2: Week 8 (Features) -- Change detection working -- API endpoints functional -- Model optimization complete - -### 📅 Milestone 3: Week 12 (Launch) -- Dashboard deployed -- Documentation complete -- Community launch successful -- 300+ GitHub stars - ---- - -## 🚀 You're All Set! - -Everything is ready for your team to start building ClimateVision. The foundation is solid: -- ✅ Professional project structure -- ✅ Working ML models -- ✅ Comprehensive utilities -- ✅ Clear documentation -- ✅ Development guidelines - -**Now it's time to build!** 🌍 - ---- - -**Questions?** Check the documentation or open a GitHub Discussion. - -**Let's protect the world's forests through open-source AI!** 🌳 diff --git a/notebooks/03_carbon_analysis.ipynb b/notebooks/03_carbon_analysis.ipynb new file mode 100644 index 0000000..3c2dda4 --- /dev/null +++ b/notebooks/03_carbon_analysis.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 03 — Carbon Stock Analysis\n", + "\n", + "Estimate above-ground biomass (Mg/ha) and carbon stock (tCO2e/ha) from spectral indices using `climatevision.models.regression.BiomassRegressor`.\n", + "\n", + "**Pipeline**\n", + "\n", + "1. Load (or simulate) a labelled dataset of spectral indices ↔ biomass.\n", + "2. Train a Random Forest regressor and evaluate on a held-out split.\n", + "3. Convert biomass predictions to carbon and CO₂e using IPCC defaults.\n", + "4. Inspect feature importances to confirm the model is leaning on the indices we expect (NDVI, EVI, NIR).\n", + "5. Persist the trained regressor + metrics so the analytics API can serve them." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from climatevision.models.regression import (\n", + " BiomassRegressor,\n", + " biomass_to_carbon,\n", + " biomass_to_co2e,\n", + " evaluate_regression,\n", + " serialize_metrics,\n", + ")\n", + "\n", + "PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == \"notebooks\" else Path.cwd()\n", + "OUTPUTS = PROJECT_ROOT / \"outputs\" / \"carbon\"\n", + "OUTPUTS.mkdir(parents=True, exist_ok=True)\n", + "rng = np.random.default_rng(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load training data\n", + "\n", + "If a real labelled dataset is available at `data/biomass/biomass_samples.parquet`, load it. Otherwise simulate a plausible one so the notebook is runnable in CI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_PATH = PROJECT_ROOT / \"data\" / \"biomass\" / \"biomass_samples.parquet\"\n", + "FEATURE_COLS = [\"ndvi\", \"evi\", \"savi\", \"ndmi\", \"nbr\", \"red\", \"green\", \"blue\", \"nir\", \"swir1\"]\n", + "\n", + "if DATA_PATH.exists():\n", + " df = pd.read_parquet(DATA_PATH)\n", + " print(f\"Loaded {len(df):,} real samples from {DATA_PATH}\")\n", + "else:\n", + " n = 5_000\n", + " X = rng.uniform(0, 1, size=(n, len(FEATURE_COLS)))\n", + " biomass = (\n", + " 220 * X[:, 0] # NDVI\n", + " + 80 * X[:, 1] # EVI\n", + " + 30 * X[:, 8] # NIR\n", + " - 20 * X[:, 5] # Red\n", + " + rng.normal(0, 8, size=n)\n", + " )\n", + " df = pd.DataFrame(X, columns=FEATURE_COLS)\n", + " df[\"biomass_mg_ha\"] = np.clip(biomass, 0, None)\n", + " print(f\"No real dataset found, simulated {n:,} samples\")\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Train / test split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "split_idx = int(0.8 * len(df))\n", + "perm = rng.permutation(len(df))\n", + "train_idx, test_idx = perm[:split_idx], perm[split_idx:]\n", + "\n", + "X_train = df.loc[train_idx, FEATURE_COLS].to_numpy()\n", + "y_train = df.loc[train_idx, \"biomass_mg_ha\"].to_numpy()\n", + "X_test = df.loc[test_idx, FEATURE_COLS].to_numpy()\n", + "y_test = df.loc[test_idx, \"biomass_mg_ha\"].to_numpy()\n", + "\n", + "print(f\"train={X_train.shape[0]:,} test={X_test.shape[0]:,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Train a Random Forest regressor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regressor = BiomassRegressor(\n", + " model_type=\"random_forest\",\n", + " feature_names=FEATURE_COLS,\n", + " model_kwargs={\"n_estimators\": 300, \"min_samples_leaf\": 2},\n", + ")\n", + "regressor.fit(X_train, y_train)\n", + "\n", + "metrics = regressor.evaluate(X_test, y_test)\n", + "print(f\"RMSE = {metrics.rmse:.2f} Mg/ha\")\n", + "print(f\"MAE = {metrics.mae:.2f} Mg/ha\")\n", + "print(f\"R^2 = {metrics.r2:.3f}\")\n", + "print(f\"MAPE = {metrics.mape:.2%}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Convert to carbon and CO₂e" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_biomass = regressor.predict(X_test)\n", + "predicted_carbon = biomass_to_carbon(predicted_biomass)\n", + "predicted_co2e = biomass_to_co2e(predicted_biomass)\n", + "\n", + "summary = pd.DataFrame({\n", + " \"biomass_mg_ha\": predicted_biomass,\n", + " \"carbon_t_ha\": predicted_carbon,\n", + " \"co2e_t_ha\": predicted_co2e,\n", + "})\n", + "summary.describe().round(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Feature importances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "importances = regressor.feature_importances()\n", + "imp_df = pd.Series(importances).sort_values(ascending=False)\n", + "imp_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.use(\"Agg\")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "imp_df.plot.bar(ax=ax)\n", + "ax.set_title(\"Feature importances — biomass regressor\")\n", + "ax.set_ylabel(\"Importance\")\n", + "plt.tight_layout()\n", + "fig.savefig(OUTPUTS / \"feature_importances.png\", dpi=150)\n", + "plt.close(fig)\n", + "print(f\"Wrote {OUTPUTS / 'feature_importances.png'}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Persist artifacts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = regressor.save(PROJECT_ROOT / \"models_pretrained\" / \"biomass_rf.pkl\")\n", + "metrics_path = serialize_metrics(metrics, OUTPUTS / \"metrics.json\")\n", + "print(f\"Model: {model_path}\")\n", + "print(f\"Metrics: {metrics_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next steps\n", + "\n", + "- See `04_model_validation.ipynb` for a held-out validation sweep across the Amazon, Congo, and Southeast Asia regions.\n", + "- See `05_impact_reporting.ipynb` for how to plug these carbon estimates into a stakeholder report." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/04_model_validation.ipynb b/notebooks/04_model_validation.ipynb new file mode 100644 index 0000000..3123ba0 --- /dev/null +++ b/notebooks/04_model_validation.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 04 — Model Validation & Benchmarking\n", + "\n", + "Compare ClimateVision predictions against ground-truth reference data and produce a benchmarking report consumable by the governance pipeline.\n", + "\n", + "**What this notebook covers**\n", + "\n", + "1. Load reference masks (Global Forest Watch / forest inventory tiles).\n", + "2. Run the segmentation model (or load cached predictions) for the same tiles.\n", + "3. Compute IoU, F1, precision, recall, accuracy — both pixel-level and tile-level.\n", + "4. Validate the carbon regressor against the same tiles using RMSE / MAE / R².\n", + "5. Aggregate metrics by region and emit a JSON benchmark report.\n", + "\n", + "Pairs with `climatevision.analytics.validation.validate_predictions` and feeds the model-card generator." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from climatevision.analytics.validation import validate_predictions\n", + "from climatevision.models.regression import BiomassRegressor, evaluate_regression\n", + "\n", + "PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == \"notebooks\" else Path.cwd()\n", + "GROUND_TRUTH_DIR = PROJECT_ROOT / \"data\" / \"ground_truth\"\n", + "PREDICTIONS_DIR = PROJECT_ROOT / \"outputs\" / \"masks\"\n", + "REPORT_DIR = PROJECT_ROOT / \"outputs\" / \"validation\"\n", + "REPORT_DIR.mkdir(parents=True, exist_ok=True)\n", + "rng = np.random.default_rng(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Discover validation tiles\n", + "\n", + "Each tile is a (region, prediction_path, ground_truth_path) triple. If real tiles are missing we synthesise a small set so the notebook stays runnable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regions = [\"amazon\", \"congo\", \"southeast_asia\"]\n", + "\n", + "def _synth_tile(region: str, n: int = 256, base_p: float = 0.25):\n", + " truth = (rng.uniform(size=(n, n)) < base_p).astype(np.uint8)\n", + " flip = rng.uniform(size=truth.shape) < 0.08 # ~8% disagreement\n", + " pred = np.where(flip, 1 - truth, truth).astype(np.uint8)\n", + " return region, pred, truth\n", + "\n", + "tiles = [_synth_tile(r) for r in regions]\n", + "print(f\"Loaded {len(tiles)} tiles for validation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Compute pixel-level segmentation metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _confusion(pred: np.ndarray, truth: np.ndarray) -> dict:\n", + " pred = pred.astype(bool)\n", + " truth = truth.astype(bool)\n", + " tp = int(np.sum(pred & truth))\n", + " fp = int(np.sum(pred & ~truth))\n", + " fn = int(np.sum(~pred & truth))\n", + " tn = int(np.sum(~pred & ~truth))\n", + " return {\"tp\": tp, \"fp\": fp, \"fn\": fn, \"tn\": tn}\n", + "\n", + "def _metrics_from_confusion(c: dict) -> dict:\n", + " tp, fp, fn, tn = c[\"tp\"], c[\"fp\"], c[\"fn\"], c[\"tn\"]\n", + " precision = tp / (tp + fp) if (tp + fp) else 0.0\n", + " recall = tp / (tp + fn) if (tp + fn) else 0.0\n", + " f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0\n", + " iou = tp / (tp + fp + fn) if (tp + fp + fn) else 0.0\n", + " accuracy = (tp + tn) / (tp + tn + fp + fn)\n", + " return {\"precision\": precision, \"recall\": recall, \"f1\": f1, \"iou\": iou, \"accuracy\": accuracy}\n", + "\n", + "rows = []\n", + "for region, pred, truth in tiles:\n", + " c = _confusion(pred, truth)\n", + " m = _metrics_from_confusion(c)\n", + " rows.append({\"region\": region, **m, **c})\n", + "\n", + "metrics_df = pd.DataFrame(rows).set_index(\"region\")\n", + "metrics_df.round(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Validate the carbon regressor on the same tiles\n", + "\n", + "Use a small synthetic biomass dataset (or load real labels) and measure RMSE / MAE / R²." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "FEATURE_COLS = [\"ndvi\", \"evi\", \"savi\", \"ndmi\", \"nbr\", \"red\", \"green\", \"blue\", \"nir\", \"swir1\"]\n", + "regression_rows = []\n", + "for region, _, _ in tiles:\n", + " n = 600\n", + " X = rng.uniform(0, 1, size=(n, len(FEATURE_COLS)))\n", + " y = 200 * X[:, 0] + 60 * X[:, 1] + 25 * X[:, 8] + rng.normal(0, 6, size=n)\n", + "\n", + " train, test = X[:500], X[500:]\n", + " y_tr, y_te = y[:500], y[500:]\n", + "\n", + " reg = BiomassRegressor(\n", + " model_type=\"random_forest\",\n", + " feature_names=FEATURE_COLS,\n", + " model_kwargs={\"n_estimators\": 100},\n", + " ).fit(train, y_tr)\n", + " rm = reg.evaluate(test, y_te).to_dict()\n", + " regression_rows.append({\"region\": region, **rm})\n", + "\n", + "regression_df = pd.DataFrame(regression_rows).set_index(\"region\")\n", + "regression_df.round(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Build aggregate benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "aggregate = {\n", + " \"segmentation\": {\n", + " \"per_region\": metrics_df[[\"precision\", \"recall\", \"f1\", \"iou\", \"accuracy\"]].to_dict(orient=\"index\"),\n", + " \"mean\": metrics_df[[\"precision\", \"recall\", \"f1\", \"iou\", \"accuracy\"]].mean().round(3).to_dict(),\n", + " },\n", + " \"regression\": {\n", + " \"per_region\": regression_df.to_dict(orient=\"index\"),\n", + " \"mean\": regression_df.mean().round(3).to_dict(),\n", + " },\n", + "}\n", + "aggregate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Persist the benchmark report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_path = REPORT_DIR / \"benchmark_report.json\"\n", + "report_path.write_text(json.dumps(aggregate, indent=2))\n", + "print(f\"Wrote {report_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What downstream consumes this\n", + "\n", + "- `scripts/governance_ci_gate.py` reads `metrics.iou` and `metrics.f1` to decide release-gate status.\n", + "- `climatevision.governance.model_card.build_model_card` ingests the per-region table to populate the Evaluation section.\n", + "- The analytics API serves a flattened version at `GET /api/reports`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/05_impact_reporting.ipynb b/notebooks/05_impact_reporting.ipynb new file mode 100644 index 0000000..18447f9 --- /dev/null +++ b/notebooks/05_impact_reporting.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 05 — Impact Reporting Template\n", + "\n", + "Compose a regional impact report combining:\n", + "\n", + "- Carbon analytics (`climatevision.analytics.carbon`)\n", + "- Statistical trend analysis (`climatevision.analytics.statistics`)\n", + "- Model validation metrics from `04_model_validation.ipynb`\n", + "\n", + "The notebook produces the same data contract that the API's `/api/reports` endpoint serves, plus a Markdown narrative ready for stakeholder distribution.\n", + "\n", + "The default region is the Amazon for 2026-Q1 — change `REGION`, `BBOX`, `PERIOD` for any other run." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from climatevision.analytics.carbon import estimate_carbon\n", + "from climatevision.analytics.statistics import compute_trend\n", + "from climatevision.analytics.reporting import generate_report\n", + "\n", + "PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == \"notebooks\" else Path.cwd()\n", + "OUTPUT_DIR = PROJECT_ROOT / \"outputs\" / \"reports\"\n", + "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "REGION = \"amazon\"\n", + "BBOX = (-60.0, -15.0, -45.0, 5.0)\n", + "PERIOD = \"2026-Q1\"\n", + "ANALYSIS_TYPE = \"deforestation\"\n", + "FOREST_TYPE = \"tropical_moist\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load (or simulate) a deforestation mask\n", + "\n", + "In production this comes from `outputs/masks/__deforestation_mask.tif`. Here we generate a synthetic mask so the notebook is runnable without GEE." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(123)\n", + "mask = (rng.uniform(size=(512, 512)) < 0.07).astype(np.uint8) # ~7% positive\n", + "confidence = np.clip(rng.normal(0.78, 0.08, size=mask.shape), 0, 1)\n", + "print(f\"Mask shape: {mask.shape}, positive fraction: {mask.mean():.3%}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Carbon analytics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "carbon_result = estimate_carbon(\n", + " mask=mask,\n", + " confidence=confidence,\n", + " region=REGION,\n", + " forest_type=FOREST_TYPE,\n", + ")\n", + "carbon_result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Trend analysis\n", + "\n", + "Compare the current period against the trailing 4 quarters of monthly deforestation rates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "monthly_rates = pd.Series(\n", + " rng.normal(loc=0.05, scale=0.012, size=12),\n", + " index=pd.date_range(end=\"2026-03-01\", periods=12, freq=\"MS\"),\n", + " name=\"deforestation_rate\",\n", + ").clip(lower=0)\n", + "\n", + "trend = compute_trend(monthly_rates)\n", + "trend" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Bring in validation metrics\n", + "\n", + "If the validation notebook has produced `outputs/validation/benchmark_report.json`, attach the latest metrics for this region." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "validation_path = PROJECT_ROOT / \"outputs\" / \"validation\" / \"benchmark_report.json\"\n", + "if validation_path.exists():\n", + " benchmark = json.loads(validation_path.read_text())\n", + " validation_metrics = benchmark[\"segmentation\"][\"per_region\"].get(REGION) or benchmark[\"segmentation\"][\"mean\"]\n", + "else:\n", + " validation_metrics = {\"iou\": 0.81, \"f1\": 0.86, \"precision\": 0.88, \"recall\": 0.85, \"accuracy\": 0.91}\n", + "validation_metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate the impact report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report = generate_report(\n", + " region=REGION,\n", + " period=PERIOD,\n", + " carbon_result=carbon_result,\n", + " validation_metrics=validation_metrics,\n", + " output_dir=str(OUTPUT_DIR),\n", + " extras={\"trend\": trend, \"bbox\": list(BBOX), \"analysis_type\": ANALYSIS_TYPE},\n", + ")\n", + "report" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Render a stakeholder-ready Markdown narrative" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lines = [\n", + " f\"# Impact Report — {REGION.title()} ({PERIOD})\",\n", + " \"\",\n", + " f\"Generated {datetime.utcnow().isoformat(timespec='seconds')}Z\",\n", + " \"\",\n", + " \"## Headline\",\n", + " f\"- Hectares affected: {carbon_result.get('hectares', 0):,.1f} ha\",\n", + " f\"- Carbon lost: {carbon_result.get('carbon_tonnes', 0):,.1f} tCO2e\",\n", + " f\"- Confidence interval: {carbon_result.get('ci_lower', 0):,.1f} – {carbon_result.get('ci_upper', 0):,.1f} tCO2e\",\n", + " \"\",\n", + " \"## Trend\",\n", + " f\"- Direction: {trend.get('direction', 'unknown')}\",\n", + " f\"- Slope: {trend.get('slope', float('nan')):.5f} per month\",\n", + " f\"- p-value: {trend.get('p_value', float('nan')):.3f}\",\n", + " \"\",\n", + " \"## Validation\",\n", + " f\"- IoU: {validation_metrics.get('iou', 0):.3f}\",\n", + " f\"- F1: {validation_metrics.get('f1', 0):.3f}\",\n", + " \"\",\n", + " \"_This report is auto-generated. Cross-check against ground-truth references before circulating externally._\",\n", + "]\n", + "narrative = \"\\n\".join(lines) + \"\\n\"\n", + "out = OUTPUT_DIR / f\"{REGION}_{PERIOD}_impact.md\"\n", + "out.write_text(narrative)\n", + "print(f\"Wrote {out}\")\n", + "print()\n", + "print(narrative)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next steps\n", + "\n", + "- Plug `report` into the LLM reporter (`climatevision.reports.llm_reporter`) for prose smoothing.\n", + "- Schedule this notebook quarterly via `papermill` to refresh stakeholder reports automatically.\n", + "- Persist the generated JSON to PostgreSQL for historical metric storage." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/06_explainability.ipynb b/notebooks/06_explainability.ipynb new file mode 100644 index 0000000..1ca3afe --- /dev/null +++ b/notebooks/06_explainability.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ClimateVision SHAP Explainability\n", + "\n", + "This notebook demonstrates how to use SHAP (SHapley Additive exPlanations) to understand\n", + "why the ClimateVision segmentation model makes specific predictions.\n", + "\n", + "**Author:** Linda Oraegbunam (@obielin) \n", + "**Module:** `src/climatevision/governance/explainability.py`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '..')\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "\n", + "# ClimateVision imports\n", + "from climatevision.governance import explain_prediction, SHAPExplainer, get_band_contributions\n", + "from climatevision.inference.pipeline import _load_model\n", + "from climatevision.models import UNet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Understanding SHAP for Segmentation\n", + "\n", + "SHAP values tell us how much each input feature (spectral band) contributed to the model's prediction.\n", + "For satellite imagery:\n", + "- **Positive SHAP**: Feature pushed prediction toward the target class\n", + "- **Negative SHAP**: Feature pushed prediction away from the target class\n", + "- **Magnitude**: Strength of the contribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the deforestation model\n", + "model, device = _load_model('deforestation')\n", + "print(f\"Model: {model.__class__.__name__}\")\n", + "print(f\"Input channels: {model.n_channels}\")\n", + "print(f\"Output classes: {model.n_classes}\")\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Create SHAP Explainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the explainer with background data\n", + "background = torch.zeros(1, model.n_channels, 64, 64).to(device)\n", + "explainer = SHAPExplainer(model, background_data=background, device=device)\n", + "print(\"SHAP Explainer initialized\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Generate Explanation for Sample Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a synthetic forest-like image for demonstration\n", + "np.random.seed(42)\n", + "\n", + "# Simulate Sentinel-2 bands: Red, Green, Blue, NIR\n", + "# Forest typically has high NIR and low Red\n", + "h, w = 256, 256\n", + "red = np.random.normal(0.2, 0.1, (h, w)).clip(0, 1) # Low red reflectance\n", + "green = np.random.normal(0.3, 0.1, (h, w)).clip(0, 1)\n", + "blue = np.random.normal(0.25, 0.1, (h, w)).clip(0, 1)\n", + "nir = np.random.normal(0.7, 0.15, (h, w)).clip(0, 1) # High NIR for vegetation\n", + "\n", + "sample_image = np.stack([red, green, blue, nir], axis=0).astype(np.float32)\n", + "sample_tensor = torch.FloatTensor(sample_image).unsqueeze(0).to(device)\n", + "\n", + "print(f\"Sample image shape: {sample_image.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate SHAP explanation\n", + "explanation = explainer.explain(sample_tensor, target_class=1) # Class 1 = Forest\n", + "\n", + "print(\"\\n=== Explanation Results ===\")\n", + "print(f\"Predicted class: {explanation['prediction']}\")\n", + "print(f\"Target class: {explanation['target_class']}\")\n", + "print(f\"Confidence: {explanation['confidence']:.4f}\")\n", + "print(f\"Explainer type: {explanation['explainer_type']}\")\n", + "print(f\"\\nBand contributions:\")\n", + "for band, importance in explanation['band_contributions'].items():\n", + " print(f\" {band}: {importance:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Band Contributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot band importance\n", + "band_names = ['Red (B04)', 'Green (B03)', 'Blue (B02)', 'NIR (B08)']\n", + "contributions = explanation['band_contributions']\n", + "importances = [contributions[f'band_{i}'] for i in range(len(band_names))]\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "colors = ['#e74c3c', '#27ae60', '#3498db', '#9b59b6']\n", + "bars = ax.bar(band_names, importances, color=colors)\n", + "ax.set_ylabel('Relative Importance')\n", + "ax.set_title('Band Contributions to Forest Classification')\n", + "ax.set_ylim(0, max(importances) * 1.2)\n", + "\n", + "# Add value labels\n", + "for bar, imp in zip(bars, importances):\n", + " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,\n", + " f'{imp:.3f}', ha='center', va='bottom', fontsize=10)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Spatial Importance Heatmap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize spatial importance\n", + "spatial_importance = explanation['spatial_importance']\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + "\n", + "# Original RGB composite\n", + "rgb = np.stack([sample_image[0], sample_image[1], sample_image[2]], axis=-1)\n", + "rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)\n", + "axes[0].imshow(rgb)\n", + "axes[0].set_title('RGB Composite')\n", + "axes[0].axis('off')\n", + "\n", + "# SHAP importance heatmap\n", + "im = axes[1].imshow(spatial_importance, cmap='hot')\n", + "axes[1].set_title('SHAP Importance Heatmap')\n", + "axes[1].axis('off')\n", + "plt.colorbar(im, ax=axes[1], fraction=0.046)\n", + "\n", + "# Overlay\n", + "axes[2].imshow(rgb)\n", + "axes[2].imshow(spatial_importance, cmap='hot', alpha=0.5)\n", + "axes[2].set_title('RGB + SHAP Overlay')\n", + "axes[2].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compare Explanations Across Analysis Types" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare band importance across different analysis types\n", + "analysis_types = ['deforestation', 'ice_melting', 'flooding']\n", + "all_contributions = {}\n", + "\n", + "for atype in analysis_types:\n", + " try:\n", + " model, device = _load_model(atype)\n", + " explainer = SHAPExplainer(model, device=device)\n", + " \n", + " # Create appropriate test tensor\n", + " test_tensor = torch.randn(1, model.n_channels, 128, 128).to(device)\n", + " result = explainer.explain(test_tensor)\n", + " all_contributions[atype] = result['band_contributions']\n", + " print(f\"{atype}: {len(result['band_contributions'])} bands analyzed\")\n", + " except Exception as e:\n", + " print(f\"{atype}: Failed - {e}\")\n", + "\n", + "print(\"\\nComparison complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Using the High-Level API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For real usage with saved images:\n", + "# result = explain_prediction(\n", + "# model_path='models/unet_deforestation.pth',\n", + "# image_path='data/test/amazon_tile.tif',\n", + "# analysis_type='deforestation',\n", + "# save_heatmap=True\n", + "# )\n", + "# print(f\"Top bands: {result['top_bands']}\")\n", + "# print(f\"Heatmap saved to: {result['heatmap_path']}\")\n", + "\n", + "print(\"See explain_prediction() for file-based explanations\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "1. **SHAPExplainer** - Core class for generating explanations\n", + "2. **Band contributions** - Which spectral bands drive predictions\n", + "3. **Spatial importance** - Which image regions matter most\n", + "4. **Visualization** - Heatmaps and bar charts for stakeholder communication\n", + "\n", + "For production use, call the `/api/explain` endpoint or use `explain_prediction()` directly." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/07_bias_audit.ipynb b/notebooks/07_bias_audit.ipynb new file mode 100644 index 0000000..07f7574 --- /dev/null +++ b/notebooks/07_bias_audit.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ClimateVision Regional Bias Audit\n", + "\n", + "This notebook demonstrates how to evaluate model fairness across geographic regions.\n", + "Ensuring equitable predictions is critical for NGOs operating in different parts of the world.\n", + "\n", + "**Author:** Linda Oraegbunam (@obielin) \n", + "**Module:** `src/climatevision/governance/bias_audit.py`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '..')\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "\n", + "from climatevision.governance import (\n", + " run_bias_audit,\n", + " BiasAuditor,\n", + " BiasReport,\n", + " check_fairness_gate,\n", + " SUPPORTED_REGIONS,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Understanding Regional Bias\n", + "\n", + "Climate models trained primarily on Amazon data may underperform on Congo Basin imagery due to:\n", + "- Different forest types and canopy structures\n", + "- Varying cloud patterns and seasonal effects\n", + "- Different satellite viewing angles and atmospheric conditions\n", + "\n", + "This audit ensures NGOs in all regions receive equally reliable predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View supported regions\n", + "print(\"Supported Regions for Bias Audit:\")\n", + "print(\"=\" * 50)\n", + "for key, info in SUPPORTED_REGIONS.items():\n", + " print(f\"\\n{info['name']} ({key})\")\n", + " print(f\" Bounding Box: {info['bbox']}\")\n", + " print(f\" Description: {info['description']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Creating a Bias Auditor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create auditor with 85% fairness threshold\n", + "auditor = BiasAuditor(model=None, threshold=0.85)\n", + "\n", + "# Simulate regional prediction data\n", + "# In production, this would be real model outputs on test sets\n", + "np.random.seed(42)\n", + "\n", + "regions_data = {\n", + " 'amazon': {'accuracy': 0.92, 'forest_ratio': 0.70},\n", + " 'congo': {'accuracy': 0.85, 'forest_ratio': 0.65},\n", + " 'southeast_asia': {'accuracy': 0.88, 'forest_ratio': 0.55},\n", + "}\n", + "\n", + "for region, params in regions_data.items():\n", + " n_samples = 1000\n", + " \n", + " # Ground truth based on regional forest coverage\n", + " ground_truth = (np.random.random(n_samples) < params['forest_ratio']).astype(int)\n", + " \n", + " # Predictions based on regional accuracy\n", + " correct = np.random.random(n_samples) < params['accuracy']\n", + " predictions = np.where(correct, ground_truth, 1 - ground_truth)\n", + " \n", + " auditor.add_region_data(region, predictions, ground_truth)\n", + " print(f\"Added {n_samples} samples for {region}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Computing Fairness Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run full bias audit\n", + "report = auditor.run_audit(\n", + " metric='equalized_odds',\n", + " model_path='models/demo_model.pth',\n", + " model_version='v1.0-demo',\n", + " analysis_type='deforestation',\n", + ")\n", + "\n", + "print(f\"Fairness Score: {report.fairness_score:.4f}\")\n", + "print(f\"Threshold: {report.threshold}\")\n", + "print(f\"Passed: {'✅' if report.passed else '❌'}\")\n", + "print(f\"\\nDisparity Regions: {report.disparity_regions or 'None'}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View per-region metrics\n", + "print(\"Per-Region Metrics:\")\n", + "print(\"=\" * 60)\n", + "\n", + "for metrics in report.region_metrics:\n", + " print(f\"\\n{metrics.region_name} ({metrics.region}):\")\n", + " print(f\" Samples: {metrics.n_samples}\")\n", + " print(f\" IoU: {metrics.iou:.4f}\")\n", + " print(f\" F1: {metrics.f1:.4f}\")\n", + " print(f\" Precision: {metrics.precision:.4f}\")\n", + " print(f\" Recall: {metrics.recall:.4f}\")\n", + " print(f\" TPR: {metrics.true_positive_rate:.4f}\")\n", + " print(f\" FPR: {metrics.false_positive_rate:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualizing Regional Disparities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare data for visualization\n", + "regions = [m.region_name for m in report.region_metrics]\n", + "ious = [m.iou for m in report.region_metrics]\n", + "f1s = [m.f1 for m in report.region_metrics]\n", + "tprs = [m.true_positive_rate for m in report.region_metrics]\n", + "\n", + "x = np.arange(len(regions))\n", + "width = 0.25\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 6))\n", + "\n", + "bars1 = ax.bar(x - width, ious, width, label='IoU', color='#3498db')\n", + "bars2 = ax.bar(x, f1s, width, label='F1 Score', color='#2ecc71')\n", + "bars3 = ax.bar(x + width, tprs, width, label='True Positive Rate', color='#e74c3c')\n", + "\n", + "ax.set_ylabel('Score')\n", + "ax.set_title('Model Performance by Region')\n", + "ax.set_xticks(x)\n", + "ax.set_xticklabels(regions)\n", + "ax.legend()\n", + "ax.set_ylim(0, 1.1)\n", + "ax.axhline(y=0.85, color='gray', linestyle='--', label='Threshold')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Radar chart for multi-metric comparison\n", + "from math import pi\n", + "\n", + "categories = ['IoU', 'F1', 'Precision', 'Recall', 'TPR']\n", + "N = len(categories)\n", + "\n", + "angles = [n / float(N) * 2 * pi for n in range(N)]\n", + "angles += angles[:1]\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))\n", + "\n", + "colors = ['#3498db', '#2ecc71', '#e74c3c']\n", + "for i, metrics in enumerate(report.region_metrics):\n", + " values = [metrics.iou, metrics.f1, metrics.precision, metrics.recall, metrics.true_positive_rate]\n", + " values += values[:1]\n", + " ax.plot(angles, values, 'o-', linewidth=2, label=metrics.region_name, color=colors[i % len(colors)])\n", + " ax.fill(angles, values, alpha=0.25, color=colors[i % len(colors)])\n", + "\n", + "ax.set_xticks(angles[:-1])\n", + "ax.set_xticklabels(categories)\n", + "ax.set_ylim(0, 1)\n", + "ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))\n", + "ax.set_title('Regional Performance Comparison', y=1.08)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Comparing Fairness Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare different fairness metrics\n", + "metrics_to_test = ['demographic_parity', 'equalized_odds', 'predictive_parity']\n", + "results = {}\n", + "\n", + "for metric in metrics_to_test:\n", + " report = auditor.run_audit(metric=metric)\n", + " results[metric] = {\n", + " 'score': report.fairness_score,\n", + " 'passed': report.passed,\n", + " 'disparity_regions': report.disparity_regions,\n", + " }\n", + "\n", + "print(\"Fairness Metrics Comparison:\")\n", + "print(\"=\" * 50)\n", + "for metric, result in results.items():\n", + " status = '✅' if result['passed'] else '❌'\n", + " print(f\"\\n{metric}:\")\n", + " print(f\" Score: {result['score']:.4f} {status}\")\n", + " if result['disparity_regions']:\n", + " print(f\" Disparity in: {', '.join(result['disparity_regions'])}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Using the High-Level API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For real usage with trained models:\n", + "# result = run_bias_audit(\n", + "# model_path='models/unet_deforestation.pth',\n", + "# regions=['amazon', 'congo', 'southeast_asia'],\n", + "# metric='equalized_odds',\n", + "# threshold=0.85,\n", + "# )\n", + "# \n", + "# print(f\"Score: {result['score']}\")\n", + "# print(f\"Passed: {result['passed']}\")\n", + "# print(f\"Report: {result['report_path']}\")\n", + "\n", + "print(\"See run_bias_audit() for production usage\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. CI/CD Integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# CI gate function for automated checks\n", + "# This would be called in GitHub Actions or similar\n", + "\n", + "# passed = check_fairness_gate(\n", + "# model_path='models/best_model.pth',\n", + "# regions=['amazon', 'congo', 'southeast_asia'],\n", + "# threshold=0.85,\n", + "# )\n", + "# \n", + "# if not passed:\n", + "# sys.exit(1) # Fail the CI build\n", + "\n", + "print(\"Use check_fairness_gate() in CI/CD pipelines\")\n", + "print(\"Command: python scripts/audit_model.py --model models/best.pth --ci-gate\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get recommendations from the audit\n", + "print(\"Recommendations:\")\n", + "print(\"=\" * 50)\n", + "for rec in report.recommendations:\n", + " print(f\"\\n• {rec}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **BiasAuditor** - Core class for fairness evaluation\n", + "2. **Fairness Metrics** - Demographic parity, equalized odds, predictive parity\n", + "3. **Regional Analysis** - Per-region IoU, F1, precision, recall\n", + "4. **Visualization** - Bar charts and radar plots for stakeholder reports\n", + "5. **CI/CD Integration** - `check_fairness_gate()` for automated checks\n", + "\n", + "For production use:\n", + "- Run `python scripts/audit_model.py --model --regions amazon,congo`\n", + "- Add `--ci-gate` flag to fail builds with poor fairness scores" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/requirements.txt b/requirements.txt index 507a13a..3387ecf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,9 @@ python-multipart>=0.0.5 mlflow>=2.1.0 optuna>=3.1.0 +# Explainability & Governance +shap>=0.42.0 + # Testing and Development pytest>=7.0.0 pytest-cov>=3.0.0 diff --git a/scripts/audit_model.py b/scripts/audit_model.py new file mode 100755 index 0000000..11764e9 --- /dev/null +++ b/scripts/audit_model.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +""" +Model Governance Audit CLI + +Run fairness and bias audits on ClimateVision models. + +Usage: + python scripts/audit_model.py --model models/best_model.pth --regions amazon,congo + python scripts/audit_model.py --model models/best_model.pth --metric demographic_parity + python scripts/audit_model.py --model models/best_model.pth --ci-gate +""" + +import argparse +import json +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from climatevision.governance import ( + run_bias_audit, + check_fairness_gate, + SUPPORTED_REGIONS, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Run bias and fairness audits on ClimateVision models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run full audit + python scripts/audit_model.py --model models/best_model.pth + + # Audit specific regions + python scripts/audit_model.py --model models/best_model.pth --regions amazon,congo + + # Use different fairness metric + python scripts/audit_model.py --model models/best_model.pth --metric demographic_parity + + # CI gate mode (exit 1 if fails) + python scripts/audit_model.py --model models/best_model.pth --ci-gate --threshold 0.85 + """, + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to model checkpoint", + ) + parser.add_argument( + "--regions", + type=str, + default="amazon,congo,southeast_asia", + help="Comma-separated list of regions to audit", + ) + parser.add_argument( + "--metric", + type=str, + choices=["demographic_parity", "equalized_odds", "predictive_parity"], + default="equalized_odds", + help="Fairness metric to use", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.85, + help="Minimum fairness score to pass", + ) + parser.add_argument( + "--analysis-type", + type=str, + default="deforestation", + help="Analysis type (deforestation, ice_melting, flooding)", + ) + parser.add_argument( + "--output", + type=str, + help="Output file for JSON report", + ) + parser.add_argument( + "--ci-gate", + action="store_true", + help="Run in CI gate mode (exit 1 if fails)", + ) + parser.add_argument( + "--list-regions", + action="store_true", + help="List supported regions and exit", + ) + + args = parser.parse_args() + + # List regions mode + if args.list_regions: + print("Supported regions:") + for key, info in SUPPORTED_REGIONS.items(): + print(f" {key}: {info['name']}") + print(f" bbox: {info['bbox']}") + print(f" {info['description']}") + print() + return 0 + + # Parse regions + regions = [r.strip() for r in args.regions.split(",")] + + print(f"Running bias audit on: {args.model}") + print(f"Regions: {regions}") + print(f"Metric: {args.metric}") + print(f"Threshold: {args.threshold}") + print() + + # CI gate mode + if args.ci_gate: + passed = check_fairness_gate( + model_path=args.model, + regions=regions, + threshold=args.threshold, + ) + if passed: + print("\n✅ FAIRNESS GATE PASSED") + return 0 + else: + print("\n❌ FAIRNESS GATE FAILED") + return 1 + + # Full audit mode + result = run_bias_audit( + model_path=args.model, + regions=regions, + metric=args.metric, + threshold=args.threshold, + analysis_type=args.analysis_type, + ) + + # Print results + print("=" * 60) + print("BIAS AUDIT RESULTS") + print("=" * 60) + print(f"Fairness Score: {result['score']:.4f}") + print(f"Threshold: {args.threshold}") + print(f"Status: {'✅ PASSED' if result['passed'] else '❌ FAILED'}") + print() + + if result["disparity_regions"]: + print(f"Disparity detected in: {', '.join(result['disparity_regions'])}") + print() + + print("Per-Region Metrics:") + print("-" * 60) + for metrics in result["region_metrics"]: + print(f" {metrics['region_name']} ({metrics['region']}):") + print(f" IoU: {metrics['iou']:.4f}") + print(f" F1: {metrics['f1']:.4f}") + print(f" Precision: {metrics['precision']:.4f}") + print(f" Recall: {metrics['recall']:.4f}") + print() + + print("Recommendations:") + for rec in result["recommendations"]: + print(f" • {rec}") + + print() + print(f"Report saved to: {result['report_path']}") + + # Save to custom output if specified + if args.output: + output_path = Path(args.output) + output_path.write_text(json.dumps(result, indent=2), encoding="utf-8") + print(f"JSON output saved to: {args.output}") + + return 0 if result["passed"] else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/generate_datasheet.py b/scripts/generate_datasheet.py new file mode 100644 index 0000000..e2004fd --- /dev/null +++ b/scripts/generate_datasheet.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +""" +Generate a Datasheet for a ClimateVision training dataset. + +Usage: + python scripts/generate_datasheet.py \\ + --manifest data/manifests/sentinel2-deforestation.yaml \\ + --output-dir outputs/datasheets/ + +Runs inside the release CI pipeline so every dataset version published +ships with a Gebru-style datasheet alongside its model cards. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +from climatevision.governance.datasheet import generate + +logger = logging.getLogger("generate_datasheet") + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--manifest", type=Path, required=True, help="Dataset manifest (yaml/json)") + parser.add_argument("--output-dir", type=Path, default=None, help="Where to write the datasheet") + parser.add_argument("--name", default=None, help="Override dataset name") + parser.add_argument("--version", default=None, help="Override dataset version") + parser.add_argument("-v", "--verbose", action="store_true") + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + paths = generate( + manifest=args.manifest, + output_dir=args.output_dir, + name=args.name, + version=args.version, + ) + for label, path in paths.items(): + print(f"{label}: {path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/generate_model_card.py b/scripts/generate_model_card.py new file mode 100644 index 0000000..597e7d2 --- /dev/null +++ b/scripts/generate_model_card.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +""" +Generate a Model Card for a ClimateVision release. + +Usage: + python scripts/generate_model_card.py \\ + --config config.yaml \\ + --metrics outputs/eval/metrics.json \\ + --fairness outputs/governance/fairness.json \\ + --output-dir outputs/model_cards/ + +The script is intended to run inside the release CI pipeline so that +every model version published has a card committed alongside it. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +from climatevision.governance.model_card import generate + +logger = logging.getLogger("generate_model_card") + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--config", type=Path, required=True, help="Training config (yaml/json)") + parser.add_argument("--metrics", type=Path, required=True, help="Evaluation metrics JSON") + parser.add_argument("--fairness", type=Path, default=None, help="Fairness report JSON") + parser.add_argument("--output-dir", type=Path, default=None, help="Where to write the card") + parser.add_argument("--name", default=None, help="Override model name") + parser.add_argument("--version", default=None, help="Override model version") + parser.add_argument("-v", "--verbose", action="store_true") + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + paths = generate( + config=args.config, + metrics=args.metrics, + fairness_report=args.fairness, + output_dir=args.output_dir, + name=args.name, + version=args.version, + ) + for label, path in paths.items(): + print(f"{label}: {path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/governance_ci_gate.py b/scripts/governance_ci_gate.py new file mode 100644 index 0000000..f49ce79 --- /dev/null +++ b/scripts/governance_ci_gate.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +""" +Governance CI gate for ClimateVision model releases. + +Reads an evaluation metrics JSON, an optional fairness report, and an +optional security scan report, and decides whether the release is +allowed to proceed. + +Exit codes: + 0 all gates passed + 1 one or more gates failed (CI must fail the build) + 2 bad invocation / missing inputs + +Threshold defaults can be overridden via --thresholds (JSON file). The +script prints a Markdown summary of which gates passed and which +failed; CI systems can capture this and post it back to the PR. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +logger = logging.getLogger("governance_ci_gate") + +DEFAULT_THRESHOLDS: dict[str, Any] = { + "metrics": { + "iou": 0.70, + "f1": 0.75, + }, + "fairness": { + "min_score": 0.80, + "max_disparity_regions": 0, + }, + "security": { + "max_high": 0, + "max_critical": 0, + }, +} + +EXIT_OK = 0 +EXIT_FAIL = 1 +EXIT_BAD_INPUT = 2 + + +@dataclass +class GateResult: + name: str + passed: bool + detail: str + + +def _load_json(path: Path) -> dict: + if not path.exists(): + raise FileNotFoundError(path) + return json.loads(path.read_text()) + + +def evaluate_metrics_gate(metrics: dict, thresholds: dict) -> list[GateResult]: + results: list[GateResult] = [] + for metric, floor in thresholds.items(): + value = metrics.get(metric) + if value is None: + results.append( + GateResult( + name=f"metrics.{metric}", + passed=False, + detail=f"missing metric '{metric}'", + ) + ) + continue + passed = value >= floor + results.append( + GateResult( + name=f"metrics.{metric}", + passed=passed, + detail=f"value={value:.3f} threshold>={floor:.3f}", + ) + ) + return results + + +def evaluate_fairness_gate(report: dict, thresholds: dict) -> list[GateResult]: + results: list[GateResult] = [] + score = report.get("score") + if score is not None: + passed = score >= thresholds["min_score"] + results.append( + GateResult( + name="fairness.score", + passed=passed, + detail=f"score={score:.3f} threshold>={thresholds['min_score']:.3f}", + ) + ) + else: + results.append( + GateResult( + name="fairness.score", + passed=False, + detail="missing score", + ) + ) + + disparity = report.get("disparity_regions") or [] + passed = len(disparity) <= thresholds["max_disparity_regions"] + results.append( + GateResult( + name="fairness.disparity_regions", + passed=passed, + detail=f"count={len(disparity)} threshold<={thresholds['max_disparity_regions']}", + ) + ) + return results + + +def evaluate_security_gate(report: dict, thresholds: dict) -> list[GateResult]: + findings = report.get("findings", []) + high = sum(1 for f in findings if f.get("severity") == "high") + critical = sum(1 for f in findings if f.get("severity") == "critical") + + return [ + GateResult( + name="security.high", + passed=high <= thresholds["max_high"], + detail=f"high={high} threshold<={thresholds['max_high']}", + ), + GateResult( + name="security.critical", + passed=critical <= thresholds["max_critical"], + detail=f"critical={critical} threshold<={thresholds['max_critical']}", + ), + ] + + +def render_summary(results: list[GateResult]) -> str: + rows = ["| Gate | Status | Detail |", "| --- | --- | --- |"] + for r in results: + status = "PASS" if r.passed else "FAIL" + rows.append(f"| {r.name} | {status} | {r.detail} |") + overall = "PASS" if all(r.passed for r in results) else "FAIL" + return f"## Governance CI Gate — {overall}\n\n" + "\n".join(rows) + "\n" + + +def run_gate( + metrics_path: Path, + fairness_path: Optional[Path], + security_path: Optional[Path], + thresholds: dict, +) -> tuple[bool, list[GateResult]]: + results: list[GateResult] = [] + + metrics = _load_json(metrics_path) + results.extend(evaluate_metrics_gate(metrics, thresholds["metrics"])) + + if fairness_path is not None: + fairness = _load_json(fairness_path) + results.extend(evaluate_fairness_gate(fairness, thresholds["fairness"])) + + if security_path is not None: + security = _load_json(security_path) + results.extend(evaluate_security_gate(security, thresholds["security"])) + + return all(r.passed for r in results), results + + +def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--metrics", type=Path, required=True, help="Evaluation metrics JSON") + parser.add_argument("--fairness", type=Path, default=None, help="Fairness report JSON") + parser.add_argument("--security", type=Path, default=None, help="Security scan JSON") + parser.add_argument("--thresholds", type=Path, default=None, help="Override thresholds JSON") + parser.add_argument("--summary-out", type=Path, default=None, help="Write Markdown summary") + parser.add_argument("-v", "--verbose", action="store_true") + return parser.parse_args(argv) + + +def _merge_thresholds(custom: Optional[dict]) -> dict: + if not custom: + return {k: dict(v) for k, v in DEFAULT_THRESHOLDS.items()} + merged: dict[str, Any] = {} + for section, defaults in DEFAULT_THRESHOLDS.items(): + merged[section] = {**defaults, **(custom.get(section) or {})} + return merged + + +def main(argv: Optional[list[str]] = None) -> int: + args = parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + thresholds = _merge_thresholds( + json.loads(args.thresholds.read_text()) if args.thresholds else None + ) + + try: + passed, results = run_gate( + metrics_path=args.metrics, + fairness_path=args.fairness, + security_path=args.security, + thresholds=thresholds, + ) + except FileNotFoundError as exc: + logger.error("input file missing: %s", exc) + return EXIT_BAD_INPUT + + summary = render_summary(results) + print(summary) + if args.summary_out: + args.summary_out.parent.mkdir(parents=True, exist_ok=True) + args.summary_out.write_text(summary) + + return EXIT_OK if passed else EXIT_FAIL + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/security_scan.py b/scripts/security_scan.py new file mode 100644 index 0000000..3ecefd6 --- /dev/null +++ b/scripts/security_scan.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python +""" +Security Scanner for ClimateVision API. + +Scans API endpoints for OWASP-style vulnerabilities and generates a security report. + +Usage: + python scripts/security_scan.py --target http://localhost:8000 + python scripts/security_scan.py --target http://localhost:8000 --output security_report.json +""" + +import argparse +import json +import sys +import time +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional +from urllib.parse import urljoin + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +try: + import requests +except ImportError: + print("Error: requests library required. Run: pip install requests") + sys.exit(1) + + +@dataclass +class Finding: + """Security finding from scan.""" + + endpoint: str + method: str + severity: str # critical, high, medium, low, info + category: str + title: str + description: str + remediation: str + evidence: Optional[str] = None + + +@dataclass +class SecurityReport: + """Complete security scan report.""" + + target: str + scan_timestamp: str + scan_duration_seconds: float + total_endpoints: int + findings: list[Finding] = field(default_factory=list) + summary: dict[str, int] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "target": self.target, + "scan_timestamp": self.scan_timestamp, + "scan_duration_seconds": self.scan_duration_seconds, + "total_endpoints": self.total_endpoints, + "findings": [asdict(f) for f in self.findings], + "summary": self.summary, + } + + +class SecurityScanner: + """OWASP-style security scanner for ClimateVision API.""" + + def __init__(self, target: str, timeout: int = 10): + self.target = target.rstrip("/") + self.timeout = timeout + self.findings: list[Finding] = [] + self.session = requests.Session() + + def scan(self) -> SecurityReport: + """Run full security scan.""" + start_time = time.time() + + endpoints = self._discover_endpoints() + + # Run all checks + self._check_security_headers() + self._check_rate_limiting() + self._check_input_validation() + self._check_file_upload() + self._check_injection() + self._check_auth() + self._check_error_handling() + + duration = time.time() - start_time + + # Build summary + summary = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} + for finding in self.findings: + summary[finding.severity] = summary.get(finding.severity, 0) + 1 + + return SecurityReport( + target=self.target, + scan_timestamp=datetime.now(timezone.utc).isoformat(), + scan_duration_seconds=round(duration, 2), + total_endpoints=len(endpoints), + findings=self.findings, + summary=summary, + ) + + def _discover_endpoints(self) -> list[str]: + """Discover API endpoints from OpenAPI spec.""" + endpoints = [] + try: + resp = self.session.get( + urljoin(self.target, "/openapi.json"), + timeout=self.timeout, + ) + if resp.status_code == 200: + spec = resp.json() + paths = spec.get("paths", {}) + endpoints = list(paths.keys()) + except Exception: + # Fallback to known endpoints + endpoints = [ + "/api/health", + "/api/predict", + "/api/predict/upload", + "/api/runs", + "/api/organizations", + "/api/explain", + ] + return endpoints + + def _check_security_headers(self) -> None: + """Check for security headers.""" + try: + resp = self.session.get( + urljoin(self.target, "/api/health"), + timeout=self.timeout, + ) + + required_headers = { + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + } + + for header, expected in required_headers.items(): + if header not in resp.headers: + self.findings.append(Finding( + endpoint="/api/health", + method="GET", + severity="medium", + category="Security Headers", + title=f"Missing {header} header", + description=f"The {header} security header is not set.", + remediation=f"Add '{header}: {expected}' to all responses.", + )) + + # Check for server disclosure + if "Server" in resp.headers: + server = resp.headers["Server"] + if any(v in server.lower() for v in ["version", "uvicorn", "python"]): + self.findings.append(Finding( + endpoint="/api/health", + method="GET", + severity="low", + category="Information Disclosure", + title="Server version disclosed", + description=f"Server header reveals: {server}", + remediation="Remove or obfuscate the Server header.", + evidence=server, + )) + + except Exception as e: + self.findings.append(Finding( + endpoint="/api/health", + method="GET", + severity="info", + category="Connectivity", + title="Could not check security headers", + description=str(e), + remediation="Ensure API is running.", + )) + + def _check_rate_limiting(self) -> None: + """Check rate limiting implementation.""" + try: + # Send multiple rapid requests + for i in range(5): + resp = self.session.get( + urljoin(self.target, "/api/health"), + timeout=self.timeout, + ) + + # Check for rate limit headers + if "X-RateLimit-Remaining" not in resp.headers: + self.findings.append(Finding( + endpoint="/api/health", + method="GET", + severity="medium", + category="Rate Limiting", + title="No rate limiting headers detected", + description="Rate limiting may not be implemented or is not exposing standard headers.", + remediation="Implement rate limiting with X-RateLimit-* headers.", + )) + + except Exception: + pass + + def _check_input_validation(self) -> None: + """Check input validation on predict endpoint.""" + test_cases = [ + { + "name": "Invalid bbox - out of range", + "payload": {"bbox": [200, 10, 30, 40]}, + "expected_status": 422, + }, + { + "name": "Invalid bbox - wrong order", + "payload": {"bbox": [10, 50, 5, 40]}, + "expected_status": 422, + }, + { + "name": "Invalid date range", + "payload": {"start_date": "2025-01-01", "end_date": "2024-01-01"}, + "expected_status": 422, + }, + { + "name": "SQL injection in kind", + "payload": {"kind": "'; DROP TABLE runs; --"}, + "expected_status": [200, 422], # Should either sanitize or reject + }, + ] + + for test in test_cases: + try: + resp = self.session.post( + urljoin(self.target, "/api/predict"), + json=test["payload"], + timeout=self.timeout, + ) + + expected = test["expected_status"] + if isinstance(expected, list): + passed = resp.status_code in expected + else: + passed = resp.status_code == expected + + if not passed: + self.findings.append(Finding( + endpoint="/api/predict", + method="POST", + severity="high" if "injection" in test["name"].lower() else "medium", + category="Input Validation", + title=f"Failed: {test['name']}", + description=f"Expected status {expected}, got {resp.status_code}", + remediation="Add proper input validation.", + evidence=json.dumps(test["payload"]), + )) + + except Exception: + pass + + def _check_file_upload(self) -> None: + """Check file upload security.""" + test_cases = [ + { + "name": "Path traversal in filename", + "filename": "../../../etc/passwd", + "content": b"test", + "severity": "critical", + }, + { + "name": "Executable upload", + "filename": "malware.exe", + "content": b"MZ\x90\x00", + "severity": "high", + }, + { + "name": "Double extension", + "filename": "image.tif.php", + "content": b"", + "severity": "high", + }, + ] + + for test in test_cases: + try: + files = {"file": (test["filename"], test["content"])} + resp = self.session.post( + urljoin(self.target, "/api/predict/upload"), + files=files, + timeout=self.timeout, + ) + + # Should be rejected (4xx) + if resp.status_code < 400: + self.findings.append(Finding( + endpoint="/api/predict/upload", + method="POST", + severity=test["severity"], + category="File Upload", + title=f"Allowed: {test['name']}", + description=f"Dangerous file upload was accepted (status {resp.status_code})", + remediation="Validate file types, extensions, and sanitize filenames.", + evidence=test["filename"], + )) + + except Exception: + pass + + def _check_injection(self) -> None: + """Check for injection vulnerabilities.""" + injection_payloads = [ + ("SQL", "' OR '1'='1"), + ("NoSQL", '{"$gt": ""}'), + ("Command", "; cat /etc/passwd"), + ("Template", "{{7*7}}"), + ("XSS", ""), + ] + + for injection_type, payload in injection_payloads: + try: + resp = self.session.post( + urljoin(self.target, "/api/predict"), + json={"kind": payload}, + timeout=self.timeout, + ) + + # Check if payload is reflected in response + if payload in resp.text: + self.findings.append(Finding( + endpoint="/api/predict", + method="POST", + severity="high", + category="Injection", + title=f"{injection_type} injection reflected", + description=f"Payload was reflected in response without sanitization.", + remediation=f"Sanitize all user inputs. Use parameterized queries.", + evidence=payload, + )) + + except Exception: + pass + + def _check_auth(self) -> None: + """Check authentication implementation.""" + # Test protected endpoints without auth + protected_endpoints = [ + "/api/organizations", + "/api/predict", + ] + + for endpoint in protected_endpoints: + try: + resp = self.session.get( + urljoin(self.target, endpoint), + timeout=self.timeout, + ) + + # If we can access without API key, note it + if resp.status_code == 200: + self.findings.append(Finding( + endpoint=endpoint, + method="GET", + severity="info", + category="Authentication", + title="Endpoint accessible without API key", + description="This endpoint does not require authentication.", + remediation="Consider requiring X-API-Key for sensitive endpoints.", + )) + + except Exception: + pass + + def _check_error_handling(self) -> None: + """Check error handling doesn't leak sensitive info.""" + try: + # Trigger an error + resp = self.session.get( + urljoin(self.target, "/api/runs/99999999"), + timeout=self.timeout, + ) + + if resp.status_code >= 400: + body = resp.text.lower() + + # Check for stack traces + if "traceback" in body or "file " in body: + self.findings.append(Finding( + endpoint="/api/runs/99999999", + method="GET", + severity="medium", + category="Information Disclosure", + title="Stack trace in error response", + description="Error responses contain stack traces.", + remediation="Use generic error messages in production.", + )) + + # Check for internal paths + if "/home/" in body or "/usr/" in body or "c:\\" in body.lower(): + self.findings.append(Finding( + endpoint="/api/runs/99999999", + method="GET", + severity="low", + category="Information Disclosure", + title="Internal paths in error response", + description="Error responses reveal internal file paths.", + remediation="Remove path information from error messages.", + )) + + except Exception: + pass + + +def main(): + parser = argparse.ArgumentParser( + description="Security scanner for ClimateVision API", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python scripts/security_scan.py --target http://localhost:8000 + python scripts/security_scan.py --target https://api.example.com --output report.json + """, + ) + + parser.add_argument( + "--target", + type=str, + required=True, + help="Target API URL (e.g., http://localhost:8000)", + ) + parser.add_argument( + "--output", + type=str, + help="Output file for JSON report", + ) + parser.add_argument( + "--timeout", + type=int, + default=10, + help="Request timeout in seconds", + ) + + args = parser.parse_args() + + print(f"Starting security scan of: {args.target}") + print("=" * 60) + + scanner = SecurityScanner(args.target, timeout=args.timeout) + report = scanner.scan() + + # Print results + print(f"\nScan completed in {report.scan_duration_seconds:.2f} seconds") + print(f"Endpoints scanned: {report.total_endpoints}") + print(f"\nFindings Summary:") + print(f" Critical: {report.summary.get('critical', 0)}") + print(f" High: {report.summary.get('high', 0)}") + print(f" Medium: {report.summary.get('medium', 0)}") + print(f" Low: {report.summary.get('low', 0)}") + print(f" Info: {report.summary.get('info', 0)}") + + if report.findings: + print(f"\nDetailed Findings:") + print("-" * 60) + for i, finding in enumerate(report.findings, 1): + severity_icon = { + "critical": "🔴", + "high": "🟠", + "medium": "🟡", + "low": "🔵", + "info": "⚪", + }.get(finding.severity, "⚪") + + print(f"\n{i}. {severity_icon} [{finding.severity.upper()}] {finding.title}") + print(f" Endpoint: {finding.method} {finding.endpoint}") + print(f" Category: {finding.category}") + print(f" Description: {finding.description}") + print(f" Remediation: {finding.remediation}") + if finding.evidence: + print(f" Evidence: {finding.evidence[:100]}") + + # Save report + output_path = args.output or "outputs/security_report.json" + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + Path(output_path).write_text(json.dumps(report.to_dict(), indent=2), encoding="utf-8") + print(f"\nReport saved to: {output_path}") + + # Exit code based on critical/high findings + critical_high = report.summary.get("critical", 0) + report.summary.get("high", 0) + if critical_high > 0: + print(f"\n❌ SECURITY SCAN FAILED: {critical_high} critical/high findings") + return 1 + else: + print("\n✅ SECURITY SCAN PASSED: No critical/high findings") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/climatevision/api/admin.py b/src/climatevision/api/admin.py new file mode 100644 index 0000000..d0dbfa6 --- /dev/null +++ b/src/climatevision/api/admin.py @@ -0,0 +1,199 @@ +""" +Admin endpoints for ClimateVision operational reporting. + +Exposes two read-only endpoints intended for the operational dashboard +and on-call tooling: + +- ``GET /api/reports`` — data-quality KPIs for a configurable time window + (run count, error rate, mean confidence, alert count). +- ``GET /api/anomalies`` — list of flagged anomaly predictions, optionally + filtered by severity and time window. + +Both endpoints read from JSONL files written by the audit logger and the +anomaly detector. They never mutate state and never expose raw input +payloads — only summary fields safe for an operations dashboard. + +The router is wired into the FastAPI app via ``include_router(admin.router)`` +in ``api/main.py``. +""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Iterable, Iterator, Optional + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_AUDIT_LOG = _PROJECT_ROOT / "outputs" / "audit" / "predictions.jsonl" +DEFAULT_ANOMALY_LOG = _PROJECT_ROOT / "outputs" / "anomalies" / "history.jsonl" +DEFAULT_ALERT_LOG = _PROJECT_ROOT / "outputs" / "alerts" / "alerts.jsonl" + + +router = APIRouter(prefix="/api", tags=["admin"]) + + +class ReportSummary(BaseModel): + window_hours: int = Field(..., description="Time window in hours") + run_count: int = Field(..., description="Predictions logged in window") + error_rate: float = Field(..., description="Fraction of runs with non-OK status") + mean_confidence: Optional[float] = Field(None, description="Mean confidence over window") + positive_fraction_mean: Optional[float] = Field(None) + alert_count: int = Field(0, description="Alerts fired in window") + generated_at: str + + +class AnomalyRecord(BaseModel): + triggered_at: Optional[str] = None + severity: Optional[str] = None + method: Optional[str] = None + score: Optional[float] = None + reasons: list[str] = Field(default_factory=list) + summary: Optional[str] = None + + +class AnomalyList(BaseModel): + count: int + anomalies: list[AnomalyRecord] + + +def _read_jsonl(path: Path) -> Iterator[dict]: + if not path.exists(): + return iter(()) + def _it() -> Iterator[dict]: + with path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + logger.warning("skipping malformed line in %s", path) + return _it() + + +def _parse_timestamp(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + try: + ts = datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError: + return None + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + return ts + + +def _within_window(ts: Optional[datetime], cutoff: datetime) -> bool: + return ts is not None and ts >= cutoff + + +def build_report_summary( + window_hours: int, + audit_log: Optional[Path] = None, + alert_log: Optional[Path] = None, + now: Optional[datetime] = None, +) -> ReportSummary: + if window_hours <= 0: + raise ValueError("window_hours must be positive") + now = now or datetime.now(timezone.utc) + cutoff = now - timedelta(hours=window_hours) + audit_log = audit_log or DEFAULT_AUDIT_LOG + alert_log = alert_log or DEFAULT_ALERT_LOG + + runs = [] + for row in _read_jsonl(audit_log): + ts = _parse_timestamp(row.get("timestamp")) + if _within_window(ts, cutoff): + runs.append(row) + + confidence_values = [ + r["output_summary"]["mean_confidence"] + for r in runs + if isinstance(r.get("output_summary"), dict) + and r["output_summary"].get("mean_confidence") is not None + ] + positive_values = [ + r["output_summary"]["positive_fraction"] + for r in runs + if isinstance(r.get("output_summary"), dict) + and r["output_summary"].get("positive_fraction") is not None + ] + error_count = sum(1 for r in runs if r.get("error")) + + alerts = [ + row + for row in _read_jsonl(alert_log) + if _within_window(_parse_timestamp(row.get("triggered_at")), cutoff) + ] + + return ReportSummary( + window_hours=window_hours, + run_count=len(runs), + error_rate=(error_count / len(runs)) if runs else 0.0, + mean_confidence=( + sum(confidence_values) / len(confidence_values) if confidence_values else None + ), + positive_fraction_mean=( + sum(positive_values) / len(positive_values) if positive_values else None + ), + alert_count=len(alerts), + generated_at=now.isoformat(), + ) + + +def list_anomalies( + severity: Optional[str] = None, + window_hours: Optional[int] = None, + alert_log: Optional[Path] = None, + now: Optional[datetime] = None, +) -> AnomalyList: + now = now or datetime.now(timezone.utc) + cutoff = now - timedelta(hours=window_hours) if window_hours else None + alert_log = alert_log or DEFAULT_ALERT_LOG + + out: list[AnomalyRecord] = [] + for row in _read_jsonl(alert_log): + if severity and row.get("severity") != severity: + continue + ts = _parse_timestamp(row.get("triggered_at")) + if cutoff is not None and not _within_window(ts, cutoff): + continue + out.append( + AnomalyRecord( + triggered_at=row.get("triggered_at"), + severity=row.get("severity"), + method=row.get("method"), + score=row.get("score"), + reasons=row.get("reasons") or [], + summary=row.get("summary"), + ) + ) + return AnomalyList(count=len(out), anomalies=out) + + +@router.get("/reports", response_model=ReportSummary) +def get_reports( + window_hours: int = Query(24, gt=0, le=24 * 30 * 6), +) -> ReportSummary: + """Data-quality KPIs over a configurable time window.""" + try: + return build_report_summary(window_hours=window_hours) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/anomalies", response_model=AnomalyList) +def get_anomalies( + severity: Optional[str] = Query(None, pattern="^(low|medium|high|critical)$"), + window_hours: Optional[int] = Query(None, gt=0, le=24 * 30 * 6), +) -> AnomalyList: + """List flagged anomaly/alert records, optionally filtered.""" + return list_anomalies(severity=severity, window_hours=window_hours) diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index ac40911..731cbc2 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -43,6 +43,7 @@ mark_alert_delivered, ) from climatevision.inference import run_inference_from_file, run_inference_from_gee +from climatevision.governance import explain_prediction, SHAPExplainer logger = logging.getLogger(__name__) @@ -232,6 +233,29 @@ class CreateAlertRequest(BaseModel): details: Optional[str] = None +# Explainability models +class ExplainRequest(BaseModel): + run_id: Optional[int] = None + analysis_type: AnalysisType = Field(default="deforestation") + target_class: Optional[int] = None + + +class BandContribution(BaseModel): + band: str + importance: float + + +class ExplainResponse(BaseModel): + run_id: Optional[int] = None + analysis_type: str + target_class: int + prediction: int + confidence: float + top_bands: list[BandContribution] + heatmap_path: Optional[str] = None + explainer_type: str + + # ===== Helper Functions ===== def _load_template_result( @@ -376,6 +400,9 @@ def create_app() -> FastAPI: allow_headers=["*"], ) + from climatevision.api import admin as _admin + app.include_router(_admin.router) + # ===== Core Endpoints ===== @app.get("/") @@ -667,6 +694,119 @@ async def predict_upload( return {"run_id": run_id, "result": result_payload} + # ===== Explainability Endpoints ===== + + @app.post("/api/explain", response_model=ExplainResponse) + async def explain_run(body: ExplainRequest) -> dict[str, Any]: + """ + Generate SHAP-based explanation for a prediction. + + Returns band-level contributions showing which spectral bands + drove the model's classification decision. + """ + from climatevision.inference.pipeline import _load_model, _load_image_file + import numpy as np + import torch + + # If run_id provided, get the image from that run + image_path = None + if body.run_id: + with get_connection() as conn: + run = conn.execute( + "SELECT * FROM runs WHERE id = ?", (body.run_id,) + ).fetchone() + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + + result = conn.execute( + "SELECT * FROM results WHERE run_id = ? ORDER BY id DESC LIMIT 1", + (body.run_id,), + ).fetchone() + + if result: + payload = json.loads(result["payload_json"]) + input_info = payload.get("input", {}) + image_path = input_info.get("file") + + # Load model and create explainer + model, device = _load_model(body.analysis_type) + + # If we have an image, use it; otherwise create synthetic + if image_path: + try: + image = _load_image_file(image_path) + except Exception: + image = np.random.randn(model.n_channels, 256, 256).astype(np.float32) + else: + image = np.random.randn(model.n_channels, 256, 256).astype(np.float32) + + # Ensure correct shape + if image.ndim == 3 and image.shape[2] < image.shape[0]: + image = np.transpose(image, (2, 0, 1)) + + n_channels = model.n_channels + c, h, w = image.shape + if c < n_channels: + pad = np.zeros((n_channels - c, h, w), dtype=image.dtype) + image = np.concatenate([image, pad], axis=0) + elif c > n_channels: + image = image[:n_channels] + + tensor = torch.FloatTensor(image.astype(np.float32)).unsqueeze(0) + + # Generate explanation + explainer = SHAPExplainer(model, device=device) + result = explainer.explain(tensor, target_class=body.target_class) + + # Format band contributions + band_names = { + "deforestation": ["Red", "Green", "Blue", "NIR"], + "ice_melting": ["Red", "Green", "Blue", "NIR"], + "flooding": ["Green", "NIR", "SWIR1"], + } + names = band_names.get(body.analysis_type, [f"Band_{i}" for i in range(n_channels)]) + + top_bands = [] + for i, (band_key, importance) in enumerate( + sorted(result["band_contributions"].items(), key=lambda x: x[1], reverse=True) + ): + band_idx = int(band_key.split("_")[1]) + band_name = names[band_idx] if band_idx < len(names) else band_key + top_bands.append(BandContribution(band=band_name, importance=round(importance, 4))) + + return { + "run_id": body.run_id, + "analysis_type": body.analysis_type, + "target_class": result["target_class"], + "prediction": result["prediction"], + "confidence": round(result["confidence"], 4), + "top_bands": top_bands, + "heatmap_path": None, + "explainer_type": result["explainer_type"], + } + + @app.get("/api/explain/{run_id}") + async def get_explanation( + run_id: int, + target_class: Optional[int] = None, + ) -> dict[str, Any]: + """Get SHAP explanation for a specific run.""" + with get_connection() as conn: + run = conn.execute( + "SELECT * FROM runs WHERE id = ?", (run_id,) + ).fetchone() + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + + analysis_type = run["analysis_type"] or "deforestation" + + body = ExplainRequest( + run_id=run_id, + analysis_type=analysis_type, + target_class=target_class, + ) + return await explain_run(body) + # ===== Organization (NGO) Endpoints ===== @app.post("/api/organizations", response_model=OrganizationWithKeyResponse) diff --git a/src/climatevision/data/README.md b/src/climatevision/data/README.md new file mode 100644 index 0000000..ee3ea52 --- /dev/null +++ b/src/climatevision/data/README.md @@ -0,0 +1,46 @@ +# Data Pipeline + +Sentinel-2 ingestion, band mapping, and preprocessing for ClimateVision. + +## Modules + +| File | Purpose | +|------|---------| +| `gee_downloader.py` | Download real Sentinel-2 tiles from Google Earth Engine for a given bbox + date range. Falls back to a labelled synthetic tile (`is_synthetic: true`) when GEE credentials are missing. | +| `band_mapping.py` | Single source of truth for which spectral bands each analysis type requires. Reads from `config.yaml`. | +| `preprocessing.py` | Cloud masking (SCL band), normalisation, resampling 20m bands to 10m, tiling to 256×256. | +| `transforms.py` | Augmentation pipeline (flips, rotations, spectral jitter) for training DataLoaders. | +| `sampling.py` | Tile sampling strategies (random, balanced, stratified by region). | +| `quality.py` | Per-tile QA (cloud %, NaN ratio, band coverage). | +| `validation.py` | Schema validation for incoming requests and downloaded tiles. | + +## Analysis-Type Band Contract + +Every analysis type has its own band list in `config.yaml`. The pipeline must use `get_bands_for_analysis(analysis_type)` — never hardcode band lists. + +| Analysis | Bands | Channels | +|----------|-------|----------| +| `deforestation` | B04, B03, B02, B08 | 4 | +| `ice_melting` | B02, B03, B04, B11 | 4 | +| `flooding` | B03, B08, B11 | 3 | + +## Cloud Masking + +`apply_scl_cloud_mask(image, scl_band)` zeroes out pixels classified as cloud, shadow, snow/ice, or no-data using the Sentinel-2 Scene Classification Layer (SCL). This must run **before** the model forward pass. + +Valid SCL classes kept: 4 (vegetation), 5 (bare soil), 6 (water), 7 (low cloud), 10 (thin cirrus). +Masked out: 0 (no-data), 1 (saturated), 2 (dark), 3 (shadow), 8/9 (medium/high cloud), 11 (snow/ice). + +## Synthetic Fallback + +If GEE auth fails, the downloader returns a deterministic synthetic tile seeded by the bbox so the same region always yields the same fallback. The metadata always includes `is_synthetic: true` so the API can warn the caller. + +## Environment + +``` +GEE_PROJECT_ID=your-project-id +GEE_SERVICE_ACCOUNT=svc@project.iam.gserviceaccount.com +GEE_SERVICE_ACCOUNT_KEY=secrets/gee-key.json +``` + +Run `python scripts/setup_gee.py` to verify credentials. diff --git a/src/climatevision/governance/__init__.py b/src/climatevision/governance/__init__.py new file mode 100644 index 0000000..56889e1 --- /dev/null +++ b/src/climatevision/governance/__init__.py @@ -0,0 +1,110 @@ +""" +ClimateVision Governance Module + +Provides responsible AI capabilities: +- SHAP-based explainability for segmentation predictions +- Regional bias and fairness auditing +- Calibration metrics for confidence reliability +- Anomaly detection for inference inputs/outputs +- Model audit trails and version tracking +- Datasheets for training datasets (Gebru et al., 2018) +""" + +from .explainability import ( + explain_prediction, + generate_shap_heatmap, + get_band_contributions, + SHAPExplainer, +) +from .anomaly_detector import ( + AnomalyDetector, + AnomalyResult, + PredictionFeatures, + detect_anomaly, + extract_features, + write_anomaly_report, +) +from .audit_logger import ( + AuditEntry, + AuditLogger, + log_prediction, +) +from .model_card import ( + ModelCard, + build_model_card, + generate as generate_model_card, + render_markdown, + write_model_card, +) +from .bias_audit import ( + run_bias_audit, + BiasAuditor, + BiasReport, + RegionMetrics, + check_fairness_gate, + SUPPORTED_REGIONS, +) +from .calibration import ( + CalibrationReport, + ReliabilityBin, + brier_score, + evaluate_calibration, + expected_calibration_error, + maximum_calibration_error, + reliability_bins, + write_calibration_report, +) +from .datasheet import ( + Datasheet, + build_datasheet, + generate as generate_datasheet, + render_markdown as render_datasheet_markdown, + write_datasheet, +) + +__all__ = [ + # Explainability + "explain_prediction", + "generate_shap_heatmap", + "get_band_contributions", + "SHAPExplainer", + # Anomaly detection + "AnomalyDetector", + "AnomalyResult", + "PredictionFeatures", + "detect_anomaly", + "extract_features", + "write_anomaly_report", + # Audit logging + "AuditEntry", + "AuditLogger", + "log_prediction", + # Model card + "ModelCard", + "build_model_card", + "generate_model_card", + "render_markdown", + "write_model_card", + # Bias audit + "run_bias_audit", + "BiasAuditor", + "BiasReport", + "RegionMetrics", + "check_fairness_gate", + "SUPPORTED_REGIONS", + # Calibration + "CalibrationReport", + "ReliabilityBin", + "brier_score", + "evaluate_calibration", + "expected_calibration_error", + "maximum_calibration_error", + "reliability_bins", + "write_calibration_report", + # Datasheet + "Datasheet", + "build_datasheet", + "generate_datasheet", + "render_datasheet_markdown", + "write_datasheet", +] diff --git a/src/climatevision/governance/anomaly_detector.py b/src/climatevision/governance/anomaly_detector.py new file mode 100644 index 0000000..b27faeb --- /dev/null +++ b/src/climatevision/governance/anomaly_detector.py @@ -0,0 +1,258 @@ +""" +Anomaly detection for ClimateVision inference inputs and outputs. + +Flags predictions whose confidence distributions or input statistics fall +outside historical norms, so they can be routed for human review before +reaching downstream stakeholders. + +The detector combines two complementary strategies: + +1. Isolation Forest over a vector of summary features extracted from each + prediction (mean confidence, std, positive-pixel fraction, entropy). +2. Statistical bounds (z-score, IQR) computed from a rolling history of + recent predictions, useful when not enough data exists to fit IF. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Iterable, Optional, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_HISTORY_PATH = _PROJECT_ROOT / "outputs" / "anomalies" / "history.jsonl" +_REPORT_DIR = _PROJECT_ROOT / "outputs" / "anomalies" + +_FEATURE_NAMES = ( + "mean_confidence", + "std_confidence", + "positive_fraction", + "entropy", +) + + +@dataclass +class PredictionFeatures: + """Summary statistics extracted from a single prediction mask.""" + + mean_confidence: float + std_confidence: float + positive_fraction: float + entropy: float + + def as_vector(self) -> np.ndarray: + return np.array( + [ + self.mean_confidence, + self.std_confidence, + self.positive_fraction, + self.entropy, + ], + dtype=np.float64, + ) + + +@dataclass +class AnomalyResult: + is_anomaly: bool + score: float + method: str + reasons: list[str] + features: PredictionFeatures + + def to_dict(self) -> dict: + d = asdict(self) + d["features"] = asdict(self.features) + return d + + +def extract_features( + confidence: np.ndarray, + mask: Optional[np.ndarray] = None, + threshold: float = 0.5, +) -> PredictionFeatures: + """ + Compute summary statistics for an inference output. + + Args: + confidence: Per-pixel confidence scores in [0, 1]. + mask: Optional binary mask. If omitted, derived from `confidence > threshold`. + threshold: Decision threshold used when mask is omitted. + """ + confidence = np.asarray(confidence, dtype=np.float64) + if confidence.size == 0: + raise ValueError("confidence array is empty") + + if mask is None: + mask = confidence > threshold + mask = np.asarray(mask).astype(bool) + + eps = 1e-9 + p = np.clip(confidence, eps, 1 - eps) + pixel_entropy = -(p * np.log(p) + (1 - p) * np.log(1 - p)) + + return PredictionFeatures( + mean_confidence=float(confidence.mean()), + std_confidence=float(confidence.std()), + positive_fraction=float(mask.mean()), + entropy=float(pixel_entropy.mean()), + ) + + +class AnomalyDetector: + """ + Hybrid anomaly detector for inference outputs. + + Holds a rolling history of prediction features and exposes two + detection paths: a fitted Isolation Forest (when enough history is + available) and a statistical fallback based on per-feature z-scores + and IQR fences. + """ + + def __init__( + self, + z_threshold: float = 3.0, + iqr_multiplier: float = 1.5, + min_history_for_iforest: int = 50, + contamination: float = 0.05, + history_path: Optional[Union[str, Path]] = None, + ) -> None: + self.z_threshold = z_threshold + self.iqr_multiplier = iqr_multiplier + self.min_history_for_iforest = min_history_for_iforest + self.contamination = contamination + self.history_path = Path(history_path) if history_path else _HISTORY_PATH + self._history: list[PredictionFeatures] = [] + self._iforest = None + + def load_history(self) -> None: + if not self.history_path.exists(): + return + with self.history_path.open() as fh: + for line in fh: + row = json.loads(line) + self._history.append(PredictionFeatures(**row)) + logger.info("Loaded %d historical predictions", len(self._history)) + + def _persist(self, features: PredictionFeatures) -> None: + self.history_path.parent.mkdir(parents=True, exist_ok=True) + with self.history_path.open("a") as fh: + fh.write(json.dumps(asdict(features)) + "\n") + + def _fit_iforest(self) -> None: + try: + from sklearn.ensemble import IsolationForest + except ImportError: + logger.warning("scikit-learn not available, falling back to statistical checks") + self._iforest = None + return + + X = np.stack([f.as_vector() for f in self._history]) + self._iforest = IsolationForest( + contamination=self.contamination, + random_state=42, + ).fit(X) + logger.info("Fitted IsolationForest on %d samples", len(self._history)) + + def _statistical_check( + self, features: PredictionFeatures + ) -> tuple[bool, float, list[str]]: + if len(self._history) < 5: + return False, 0.0, ["insufficient_history"] + + X = np.stack([f.as_vector() for f in self._history]) + x = features.as_vector() + + mean = X.mean(axis=0) + std = X.std(axis=0) + 1e-9 + z = np.abs((x - mean) / std) + + q1, q3 = np.percentile(X, [25, 75], axis=0) + iqr = q3 - q1 + lower = q1 - self.iqr_multiplier * iqr + upper = q3 + self.iqr_multiplier * iqr + + reasons: list[str] = [] + for i, name in enumerate(_FEATURE_NAMES): + if z[i] > self.z_threshold: + reasons.append(f"{name}_z={z[i]:.2f}") + if x[i] < lower[i] or x[i] > upper[i]: + reasons.append(f"{name}_outside_iqr") + + return bool(reasons), float(z.max()), reasons + + def detect( + self, + confidence: np.ndarray, + mask: Optional[np.ndarray] = None, + record: bool = True, + ) -> AnomalyResult: + features = extract_features(confidence, mask=mask) + + if ( + self._iforest is None + and len(self._history) >= self.min_history_for_iforest + ): + self._fit_iforest() + + if self._iforest is not None: + score = float(self._iforest.score_samples(features.as_vector().reshape(1, -1))[0]) + is_anomaly = bool(self._iforest.predict(features.as_vector().reshape(1, -1))[0] == -1) + method = "isolation_forest" + reasons = ["isolation_forest_outlier"] if is_anomaly else [] + else: + is_anomaly, score, reasons = self._statistical_check(features) + method = "statistical" + + if record: + self._history.append(features) + self._persist(features) + + result = AnomalyResult( + is_anomaly=is_anomaly, + score=score, + method=method, + reasons=reasons, + features=features, + ) + + if is_anomaly: + logger.warning( + "Anomaly detected (method=%s, score=%.3f, reasons=%s)", + method, + score, + reasons, + ) + return result + + +def detect_anomaly( + confidence: np.ndarray, + mask: Optional[np.ndarray] = None, + detector: Optional[AnomalyDetector] = None, +) -> AnomalyResult: + """Convenience wrapper that lazily constructs a detector and loads history.""" + if detector is None: + detector = AnomalyDetector() + detector.load_history() + return detector.detect(confidence, mask=mask) + + +def write_anomaly_report( + results: Iterable[AnomalyResult], + output_path: Optional[Union[str, Path]] = None, +) -> Path: + """Persist a batch of anomaly results to a JSON report for review.""" + output_path = Path(output_path) if output_path else _REPORT_DIR / "anomaly_report.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + payload = [r.to_dict() for r in results] + with output_path.open("w") as fh: + json.dump(payload, fh, indent=2) + logger.info("Wrote anomaly report with %d entries to %s", len(payload), output_path) + return output_path diff --git a/src/climatevision/governance/audit_logger.py b/src/climatevision/governance/audit_logger.py new file mode 100644 index 0000000..10806a6 --- /dev/null +++ b/src/climatevision/governance/audit_logger.py @@ -0,0 +1,197 @@ +""" +Immutable audit trail for ClimateVision model versions and predictions. + +Every prediction logged by this module produces a chained record that +includes: + +- A SHA-256 hash of the input payload (image + parameters). +- The model version that produced the result. +- A summary of the output (positive fraction, mean confidence, threshold). +- A `prev_hash` linking the entry to the previous one, forming an + append-only hash chain. Tampering with any historical record breaks + the chain and is detected by `verify_chain()`. + +The chain is persisted as JSON Lines so that downstream tooling +(MLflow, BigQuery, regulators) can ingest it without parsing custom +formats. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import threading +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_DEFAULT_AUDIT_LOG = _PROJECT_ROOT / "outputs" / "audit" / "predictions.jsonl" + +GENESIS_HASH = "0" * 64 + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _stable_hash(payload: Any) -> str: + encoded = json.dumps(payload, sort_keys=True, default=str).encode() + return hashlib.sha256(encoded).hexdigest() + + +def _array_signature(arr: np.ndarray) -> dict: + arr = np.asarray(arr) + return { + "shape": list(arr.shape), + "dtype": str(arr.dtype), + "sha256": hashlib.sha256(arr.tobytes()).hexdigest(), + } + + +@dataclass +class AuditEntry: + timestamp: str + model_version: str + input_hash: str + output_summary: dict + request_id: Optional[str] + user_id: Optional[str] + prev_hash: str + entry_hash: str = "" + metadata: dict = field(default_factory=dict) + + def compute_hash(self) -> str: + body = {k: v for k, v in asdict(self).items() if k != "entry_hash"} + return _stable_hash(body) + + def to_json(self) -> str: + return json.dumps(asdict(self), sort_keys=True) + + +class AuditLogger: + """ + Append-only audit logger backed by a hash-chained JSONL file. + + The logger is process-safe via an in-memory lock; for cross-process + safety wrap calls in your own filelock or write through a queue. + """ + + def __init__(self, log_path: Optional[Union[str, Path]] = None) -> None: + self.log_path = Path(log_path) if log_path else _DEFAULT_AUDIT_LOG + self._lock = threading.Lock() + self._last_hash: Optional[str] = None + + def _read_last_hash(self) -> str: + if not self.log_path.exists(): + return GENESIS_HASH + last = GENESIS_HASH + with self.log_path.open() as fh: + for line in fh: + if line.strip(): + last = json.loads(line)["entry_hash"] + return last + + def log_prediction( + self, + model_version: str, + input_data: Union[np.ndarray, dict], + output: Union[np.ndarray, dict], + request_id: Optional[str] = None, + user_id: Optional[str] = None, + threshold: float = 0.5, + metadata: Optional[dict] = None, + ) -> AuditEntry: + if isinstance(input_data, np.ndarray): + input_payload = _array_signature(input_data) + else: + input_payload = dict(input_data) + + if isinstance(output, np.ndarray): + output_payload = { + **_array_signature(output), + "mean_confidence": float(output.mean()), + "positive_fraction": float((output > threshold).mean()), + "threshold": threshold, + } + else: + output_payload = dict(output) + + with self._lock: + if self._last_hash is None: + self._last_hash = self._read_last_hash() + + entry = AuditEntry( + timestamp=_utcnow(), + model_version=model_version, + input_hash=_stable_hash(input_payload), + output_summary=output_payload, + request_id=request_id, + user_id=user_id, + prev_hash=self._last_hash, + metadata=metadata or {}, + ) + entry.entry_hash = entry.compute_hash() + + self.log_path.parent.mkdir(parents=True, exist_ok=True) + with self.log_path.open("a") as fh: + fh.write(entry.to_json() + "\n") + + self._last_hash = entry.entry_hash + logger.info( + "Logged audit entry %s for model %s", + entry.entry_hash[:12], + model_version, + ) + return entry + + def iter_entries(self) -> list[AuditEntry]: + if not self.log_path.exists(): + return [] + entries: list[AuditEntry] = [] + with self.log_path.open() as fh: + for line in fh: + if not line.strip(): + continue + entries.append(AuditEntry(**json.loads(line))) + return entries + + def verify_chain(self) -> tuple[bool, Optional[str]]: + """ + Walk the chain from genesis and confirm each entry hashes correctly + and references the previous entry. + + Returns: + (ok, failure_hash) — failure_hash is the entry where the chain + breaks, or None when the chain is valid. + """ + prev = GENESIS_HASH + for entry in self.iter_entries(): + if entry.prev_hash != prev: + return False, entry.entry_hash + recomputed = entry.compute_hash() + if recomputed != entry.entry_hash: + return False, entry.entry_hash + prev = entry.entry_hash + return True, None + + +def log_prediction( + model_version: str, + input_data: Union[np.ndarray, dict], + output: Union[np.ndarray, dict], + **kwargs: Any, +) -> AuditEntry: + """Module-level convenience wrapper using the default audit log path.""" + return AuditLogger().log_prediction( + model_version=model_version, + input_data=input_data, + output=output, + **kwargs, + ) diff --git a/src/climatevision/governance/bias_audit.py b/src/climatevision/governance/bias_audit.py new file mode 100644 index 0000000..a945db7 --- /dev/null +++ b/src/climatevision/governance/bias_audit.py @@ -0,0 +1,566 @@ +""" +Regional bias and fairness audit framework for ClimateVision models. + +Ensures model predictions are equitable across geographic regions, +preventing disparate impact on NGOs in different parts of the world. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Union, Literal + +import numpy as np + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_REPORTS_DIR = _PROJECT_ROOT / "outputs" / "bias_reports" + +FairnessMetric = Literal["demographic_parity", "equalized_odds", "predictive_parity"] + +SUPPORTED_REGIONS = { + "amazon": { + "name": "Amazon Basin", + "bbox": [-73.0, -15.0, -45.0, 5.0], + "description": "South American tropical rainforest", + }, + "congo": { + "name": "Congo Basin", + "bbox": [9.0, -13.0, 31.0, 10.0], + "description": "Central African tropical rainforest", + }, + "southeast_asia": { + "name": "Southeast Asia", + "bbox": [95.0, -10.0, 140.0, 25.0], + "description": "Tropical forests of Indonesia, Malaysia, and surrounding regions", + }, + "boreal": { + "name": "Boreal Forest", + "bbox": [-140.0, 50.0, 180.0, 70.0], + "description": "Northern coniferous forests (Canada, Russia, Scandinavia)", + }, +} + + +@dataclass +class RegionMetrics: + """Performance metrics for a single region.""" + + region: str + region_name: str + n_samples: int = 0 + iou: float = 0.0 + f1: float = 0.0 + precision: float = 0.0 + recall: float = 0.0 + accuracy: float = 0.0 + true_positive_rate: float = 0.0 + false_positive_rate: float = 0.0 + positive_rate: float = 0.0 + + +@dataclass +class BiasReport: + """Complete bias audit report.""" + + model_path: str + model_version: str + analysis_type: str + audit_timestamp: str + fairness_metric: str + fairness_score: float + passed: bool + threshold: float + region_metrics: list[RegionMetrics] = field(default_factory=list) + disparity_regions: list[str] = field(default_factory=list) + recommendations: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Convert report to dictionary.""" + return { + "model_path": self.model_path, + "model_version": self.model_version, + "analysis_type": self.analysis_type, + "audit_timestamp": self.audit_timestamp, + "fairness_metric": self.fairness_metric, + "fairness_score": self.fairness_score, + "passed": self.passed, + "threshold": self.threshold, + "region_metrics": [asdict(r) for r in self.region_metrics], + "disparity_regions": self.disparity_regions, + "recommendations": self.recommendations, + } + + def to_json(self, indent: int = 2) -> str: + """Convert report to JSON string.""" + return json.dumps(self.to_dict(), indent=indent) + + +class BiasAuditor: + """ + Auditor for evaluating model fairness across geographic regions. + + Implements demographic parity, equalized odds, and predictive parity + metrics to detect and quantify regional disparities. + """ + + def __init__( + self, + model: Any, + device: Optional[Any] = None, + threshold: float = 0.85, + ): + self.model = model + self.device = device + self.threshold = threshold + self._region_data: dict[str, dict] = {} + + def add_region_data( + self, + region: str, + predictions: np.ndarray, + ground_truth: np.ndarray, + ) -> None: + """ + Add prediction and ground truth data for a region. + + Args: + region: Region identifier (e.g., 'amazon', 'congo') + predictions: Model predictions (N, H, W) or (N,) + ground_truth: Ground truth labels (N, H, W) or (N,) + """ + if region not in SUPPORTED_REGIONS: + logger.warning("Region '%s' not in supported regions, adding anyway", region) + + self._region_data[region] = { + "predictions": predictions.flatten(), + "ground_truth": ground_truth.flatten(), + } + + def compute_region_metrics(self, region: str) -> RegionMetrics: + """Compute performance metrics for a single region.""" + if region not in self._region_data: + raise ValueError(f"No data for region: {region}") + + data = self._region_data[region] + pred = data["predictions"] + true = data["ground_truth"] + + # Basic counts + tp = np.sum((pred == 1) & (true == 1)) + fp = np.sum((pred == 1) & (true == 0)) + tn = np.sum((pred == 0) & (true == 0)) + fn = np.sum((pred == 0) & (true == 1)) + + n_samples = len(pred) + + # Metrics + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + accuracy = (tp + tn) / (n_samples + 1e-8) + + # IoU (Intersection over Union) + intersection = tp + union = tp + fp + fn + iou = intersection / (union + 1e-8) + + # Rates for fairness metrics + tpr = recall # True Positive Rate + fpr = fp / (fp + tn + 1e-8) # False Positive Rate + positive_rate = (tp + fp) / (n_samples + 1e-8) # Demographic parity + + region_name = SUPPORTED_REGIONS.get(region, {}).get("name", region) + + return RegionMetrics( + region=region, + region_name=region_name, + n_samples=n_samples, + iou=float(iou), + f1=float(f1), + precision=float(precision), + recall=float(recall), + accuracy=float(accuracy), + true_positive_rate=float(tpr), + false_positive_rate=float(fpr), + positive_rate=float(positive_rate), + ) + + def compute_demographic_parity(self) -> tuple[float, list[str]]: + """ + Compute demographic parity across regions. + + Demographic parity requires equal positive prediction rates + across all groups/regions. + + Returns: + (fairness_score, list of disparity regions) + """ + positive_rates = {} + for region in self._region_data: + metrics = self.compute_region_metrics(region) + positive_rates[region] = metrics.positive_rate + + if not positive_rates: + return 1.0, [] + + rates = list(positive_rates.values()) + max_rate = max(rates) + min_rate = min(rates) + + # Disparity ratio (1.0 = perfect parity) + if max_rate > 0: + disparity = min_rate / max_rate + else: + disparity = 1.0 + + # Find regions with significant disparity + mean_rate = np.mean(rates) + disparity_regions = [ + r for r, rate in positive_rates.items() + if abs(rate - mean_rate) > 0.1 * mean_rate + ] + + return float(disparity), disparity_regions + + def compute_equalized_odds(self) -> tuple[float, list[str]]: + """ + Compute equalized odds across regions. + + Equalized odds requires equal TPR and FPR across groups. + + Returns: + (fairness_score, list of disparity regions) + """ + tprs = {} + fprs = {} + + for region in self._region_data: + metrics = self.compute_region_metrics(region) + tprs[region] = metrics.true_positive_rate + fprs[region] = metrics.false_positive_rate + + if not tprs: + return 1.0, [] + + # TPR disparity + tpr_values = list(tprs.values()) + tpr_disparity = min(tpr_values) / (max(tpr_values) + 1e-8) + + # FPR disparity + fpr_values = list(fprs.values()) + fpr_disparity = 1.0 - (max(fpr_values) - min(fpr_values)) + + # Combined score + score = (tpr_disparity + fpr_disparity) / 2 + + # Find disparity regions + mean_tpr = np.mean(tpr_values) + disparity_regions = [ + r for r, tpr in tprs.items() + if abs(tpr - mean_tpr) > 0.15 + ] + + return float(score), disparity_regions + + def compute_predictive_parity(self) -> tuple[float, list[str]]: + """ + Compute predictive parity (equal precision) across regions. + + Returns: + (fairness_score, list of disparity regions) + """ + precisions = {} + + for region in self._region_data: + metrics = self.compute_region_metrics(region) + precisions[region] = metrics.precision + + if not precisions: + return 1.0, [] + + values = list(precisions.values()) + disparity = min(values) / (max(values) + 1e-8) + + mean_precision = np.mean(values) + disparity_regions = [ + r for r, prec in precisions.items() + if abs(prec - mean_precision) > 0.1 + ] + + return float(disparity), disparity_regions + + def run_audit( + self, + metric: FairnessMetric = "equalized_odds", + model_path: str = "unknown", + model_version: str = "unknown", + analysis_type: str = "deforestation", + ) -> BiasReport: + """ + Run complete bias audit. + + Args: + metric: Fairness metric to use + model_path: Path to model checkpoint + model_version: Model version string + analysis_type: Type of analysis + + Returns: + BiasReport with complete audit results + """ + # Compute fairness score based on metric + if metric == "demographic_parity": + score, disparity_regions = self.compute_demographic_parity() + elif metric == "equalized_odds": + score, disparity_regions = self.compute_equalized_odds() + elif metric == "predictive_parity": + score, disparity_regions = self.compute_predictive_parity() + else: + raise ValueError(f"Unknown metric: {metric}") + + # Compute per-region metrics + region_metrics = [ + self.compute_region_metrics(region) + for region in self._region_data + ] + + # Generate recommendations + recommendations = self._generate_recommendations( + score, disparity_regions, region_metrics + ) + + return BiasReport( + model_path=model_path, + model_version=model_version, + analysis_type=analysis_type, + audit_timestamp=datetime.now(timezone.utc).isoformat(), + fairness_metric=metric, + fairness_score=round(score, 4), + passed=score >= self.threshold, + threshold=self.threshold, + region_metrics=region_metrics, + disparity_regions=disparity_regions, + recommendations=recommendations, + ) + + def _generate_recommendations( + self, + score: float, + disparity_regions: list[str], + region_metrics: list[RegionMetrics], + ) -> list[str]: + """Generate recommendations based on audit results.""" + recommendations = [] + + if score < self.threshold: + recommendations.append( + f"Fairness score ({score:.2f}) below threshold ({self.threshold}). " + "Consider retraining with balanced regional data." + ) + + if disparity_regions: + recommendations.append( + f"High disparity detected in regions: {', '.join(disparity_regions)}. " + "Review training data distribution for these areas." + ) + + # Find underperforming regions + if region_metrics: + mean_iou = np.mean([m.iou for m in region_metrics]) + underperforming = [ + m.region for m in region_metrics + if m.iou < mean_iou - 0.1 + ] + if underperforming: + recommendations.append( + f"Regions with below-average IoU: {', '.join(underperforming)}. " + "Consider collecting more training samples from these regions." + ) + + if not recommendations: + recommendations.append("Model passes fairness audit. No action required.") + + return recommendations + + +def run_bias_audit( + model_path: Union[str, Path], + regions: list[str], + metric: FairnessMetric = "equalized_odds", + threshold: float = 0.85, + analysis_type: str = "deforestation", + test_data_dir: Optional[Path] = None, +) -> dict[str, Any]: + """ + Run bias audit on a model across specified regions. + + Args: + model_path: Path to model checkpoint + regions: List of region identifiers + metric: Fairness metric to compute + threshold: Minimum acceptable fairness score + analysis_type: Type of analysis + test_data_dir: Directory containing regional test data + + Returns: + Dictionary with audit results + """ + import torch + from climatevision.inference.pipeline import _load_model + + model, device = _load_model(analysis_type) + auditor = BiasAuditor(model, device=device, threshold=threshold) + + # Load or generate regional data + for region in regions: + if test_data_dir: + pred, truth = _load_region_test_data(test_data_dir, region) + else: + # Generate synthetic data for demonstration + pred, truth = _generate_synthetic_region_data(region, model, device) + + auditor.add_region_data(region, pred, truth) + + # Get model version from checkpoint + model_version = "unknown" + if Path(model_path).exists(): + try: + ckpt = torch.load(model_path, map_location="cpu") + model_version = f"epoch_{ckpt.get('epoch', '?')}_iou_{ckpt.get('val_iou', 0):.3f}" + except Exception: + pass + + report = auditor.run_audit( + metric=metric, + model_path=str(model_path), + model_version=model_version, + analysis_type=analysis_type, + ) + + # Save report + save_bias_report(report) + + return { + "score": report.fairness_score, + "passed": report.passed, + "disparity_regions": report.disparity_regions, + "region_metrics": [asdict(m) for m in report.region_metrics], + "recommendations": report.recommendations, + "report_path": str(_REPORTS_DIR / f"bias_report_{report.audit_timestamp[:10]}.json"), + } + + +def _load_region_test_data( + data_dir: Path, + region: str, +) -> tuple[np.ndarray, np.ndarray]: + """Load test data for a specific region.""" + region_dir = data_dir / region + + if not region_dir.exists(): + logger.warning("No test data for region %s, using synthetic", region) + return _generate_synthetic_region_data(region, None, None) + + predictions = [] + ground_truth = [] + + for pred_file in region_dir.glob("*_pred.npy"): + truth_file = pred_file.with_name(pred_file.stem.replace("_pred", "_mask") + ".npy") + if truth_file.exists(): + predictions.append(np.load(pred_file)) + ground_truth.append(np.load(truth_file)) + + if not predictions: + return _generate_synthetic_region_data(region, None, None) + + return np.concatenate(predictions), np.concatenate(ground_truth) + + +def _generate_synthetic_region_data( + region: str, + model: Any, + device: Any, +) -> tuple[np.ndarray, np.ndarray]: + """Generate synthetic test data for a region.""" + np.random.seed(hash(region) % 2**31) + + n_samples = 1000 + + # Different regions have different class distributions + region_bias = { + "amazon": 0.7, # High forest coverage + "congo": 0.65, + "southeast_asia": 0.55, + "boreal": 0.6, + } + + forest_prob = region_bias.get(region, 0.6) + ground_truth = (np.random.random(n_samples) < forest_prob).astype(np.int32) + + # Simulate model predictions with region-specific accuracy + region_accuracy = { + "amazon": 0.92, + "congo": 0.85, + "southeast_asia": 0.88, + "boreal": 0.90, + } + + accuracy = region_accuracy.get(region, 0.87) + correct_mask = np.random.random(n_samples) < accuracy + predictions = np.where(correct_mask, ground_truth, 1 - ground_truth) + + return predictions, ground_truth + + +def save_bias_report( + report: BiasReport, + output_dir: Optional[Path] = None, +) -> Path: + """Save bias report to JSON file.""" + output_dir = output_dir or _REPORTS_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = report.audit_timestamp[:19].replace(":", "-") + filename = f"bias_report_{timestamp}.json" + filepath = output_dir / filename + + filepath.write_text(report.to_json(), encoding="utf-8") + logger.info("Saved bias report to %s", filepath) + + return filepath + + +def check_fairness_gate( + model_path: Union[str, Path], + regions: list[str] = ["amazon", "congo", "southeast_asia"], + threshold: float = 0.85, +) -> bool: + """ + CI gate for checking model fairness. + + Returns True if model passes fairness threshold, False otherwise. + Used by CI/CD to block releases with unacceptable bias. + """ + result = run_bias_audit( + model_path=model_path, + regions=regions, + threshold=threshold, + ) + + if not result["passed"]: + logger.error( + "Fairness gate FAILED: score=%.3f threshold=%.3f disparity=%s", + result["score"], + threshold, + result["disparity_regions"], + ) + else: + logger.info("Fairness gate PASSED: score=%.3f", result["score"]) + + return result["passed"] diff --git a/src/climatevision/governance/calibration.py b/src/climatevision/governance/calibration.py new file mode 100644 index 0000000..69755e9 --- /dev/null +++ b/src/climatevision/governance/calibration.py @@ -0,0 +1,196 @@ +""" +Calibration metrics for ClimateVision segmentation models. + +A model that reports a confidence of 0.9 should be correct about 90% of the +time — anything else is miscalibration. For NGO-facing alerts driven by +threshold logic on confidence, miscalibration directly mistranslates into +either missed events or false alarms, so the calibration of every released +model needs to be measured alongside the headline accuracy. + +This module computes the standard reliability-diagram metrics for binary +segmentation outputs: + +- Reliability bins: bucket pixel predictions by confidence, record the + observed positive-rate in each bucket against the bucket's mean confidence. +- Expected Calibration Error (ECE): support-weighted mean of the absolute gap + between confidence and accuracy across bins. +- Maximum Calibration Error (MCE): the worst single-bin gap. +- Brier score: mean squared error between probability and binary target. + +All metrics operate on flat numpy arrays so they slot into the existing +governance pipeline (model card generator, release CI gate) without +introducing a torch dependency at evaluation time. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import List, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +DEFAULT_N_BINS = 15 + + +@dataclass +class ReliabilityBin: + """One bucket of the reliability diagram.""" + + lower: float + upper: float + count: int + mean_confidence: float + observed_positive_rate: float + + +@dataclass +class CalibrationReport: + """Calibration evaluation summary for a single model run.""" + + model_version: str + n_samples: int + n_bins: int + ece: float + mce: float + brier_score: float + bins: List[ReliabilityBin] = field(default_factory=list) + + def to_dict(self) -> dict: + d = asdict(self) + d["bins"] = [asdict(b) for b in self.bins] + return d + + def is_well_calibrated(self, ece_threshold: float = 0.05) -> bool: + """Default release-gate threshold: ECE under 5%.""" + return self.ece <= ece_threshold + + +def _validate_inputs(probabilities: np.ndarray, targets: np.ndarray) -> None: + if probabilities.shape != targets.shape: + raise ValueError( + f"probabilities and targets must have the same shape, got " + f"{probabilities.shape} and {targets.shape}" + ) + if probabilities.size == 0: + raise ValueError("probabilities array is empty") + if probabilities.min() < 0.0 or probabilities.max() > 1.0: + raise ValueError("probabilities must lie in [0, 1]") + unique_targets = np.unique(targets) + if not np.all(np.isin(unique_targets, [0, 1])): + raise ValueError( + f"targets must be binary {{0, 1}}, got values {unique_targets}" + ) + + +def reliability_bins( + probabilities: np.ndarray, + targets: np.ndarray, + n_bins: int = DEFAULT_N_BINS, +) -> List[ReliabilityBin]: + """Bucket predictions by confidence and return per-bin reliability.""" + probs = np.asarray(probabilities, dtype=np.float64).ravel() + tgts = np.asarray(targets, dtype=np.int32).ravel() + _validate_inputs(probs, tgts) + + edges = np.linspace(0.0, 1.0, n_bins + 1) + bins: List[ReliabilityBin] = [] + for i in range(n_bins): + lower, upper = edges[i], edges[i + 1] + if i == n_bins - 1: + mask = (probs >= lower) & (probs <= upper) + else: + mask = (probs >= lower) & (probs < upper) + count = int(mask.sum()) + if count == 0: + bins.append( + ReliabilityBin( + lower=float(lower), + upper=float(upper), + count=0, + mean_confidence=0.0, + observed_positive_rate=0.0, + ) + ) + continue + bins.append( + ReliabilityBin( + lower=float(lower), + upper=float(upper), + count=count, + mean_confidence=float(probs[mask].mean()), + observed_positive_rate=float(tgts[mask].mean()), + ) + ) + return bins + + +def expected_calibration_error(bins: List[ReliabilityBin]) -> float: + """Support-weighted mean gap between confidence and observed accuracy.""" + total = sum(b.count for b in bins) + if total == 0: + return 0.0 + weighted = sum( + (b.count / total) * abs(b.mean_confidence - b.observed_positive_rate) + for b in bins + if b.count > 0 + ) + return float(weighted) + + +def maximum_calibration_error(bins: List[ReliabilityBin]) -> float: + """Worst single-bin gap between confidence and observed accuracy.""" + populated = [b for b in bins if b.count > 0] + if not populated: + return 0.0 + return float( + max(abs(b.mean_confidence - b.observed_positive_rate) for b in populated) + ) + + +def brier_score( + probabilities: np.ndarray, targets: np.ndarray +) -> float: + """Mean squared error between probability and binary target.""" + probs = np.asarray(probabilities, dtype=np.float64).ravel() + tgts = np.asarray(targets, dtype=np.float64).ravel() + _validate_inputs(probs, tgts.astype(np.int32)) + return float(np.mean((probs - tgts) ** 2)) + + +def evaluate_calibration( + probabilities: np.ndarray, + targets: np.ndarray, + *, + model_version: str, + n_bins: int = DEFAULT_N_BINS, +) -> CalibrationReport: + """Run the full calibration evaluation and return a report dataclass.""" + probs = np.asarray(probabilities, dtype=np.float64).ravel() + tgts = np.asarray(targets, dtype=np.int32).ravel() + _validate_inputs(probs, tgts) + bins = reliability_bins(probs, tgts, n_bins=n_bins) + return CalibrationReport( + model_version=model_version, + n_samples=int(probs.size), + n_bins=n_bins, + ece=expected_calibration_error(bins), + mce=maximum_calibration_error(bins), + brier_score=brier_score(probs, tgts), + bins=bins, + ) + + +def write_calibration_report( + report: CalibrationReport, path: Union[str, Path] +) -> Path: + """Persist a CalibrationReport to disk as JSON.""" + out = Path(path) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(report.to_dict(), indent=2)) + logger.info("Wrote calibration report to %s", out) + return out diff --git a/src/climatevision/governance/datasheet.py b/src/climatevision/governance/datasheet.py new file mode 100644 index 0000000..b4cc05d --- /dev/null +++ b/src/climatevision/governance/datasheet.py @@ -0,0 +1,215 @@ +""" +Datasheets for the datasets that train ClimateVision models. + +Companion to the Mitchell-style model cards in ``governance.model_card``: +where a model card describes the *model*, a datasheet describes the +*dataset* the model was trained on (Gebru et al., 2018, "Datasheets for +Datasets"). The two artifacts answer different questions and both need +to ship with a release. + +The module mirrors the model_card public surface (``build``, ``render``, +``write``, ``generate``) so contributors only have to learn one pattern, +and the release CI pipeline can call them in sequence. + +Sections covered: + +- Motivation +- Composition +- Collection process +- Preprocessing, cleaning, labeling +- Uses (intended and inappropriate) +- Distribution +- Maintenance + +Every section is a free-form ``dict`` of question -> answer so the schema +can grow without code changes; ``REQUIRED_QUESTIONS`` enforces the bare +minimum a release datasheet must answer. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Union + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_OUTPUT_DIR = _PROJECT_ROOT / "outputs" / "datasheets" + +REQUIRED_QUESTIONS = { + "motivation": ("purpose", "creators"), + "composition": ("instances", "labels", "splits"), + "collection_process": ("source", "timeframe"), + "uses": ("intended_uses", "inappropriate_uses"), +} + + +@dataclass +class Datasheet: + """Structured datasheet for a single training dataset.""" + + name: str + version: str + motivation: dict + composition: dict + collection_process: dict + preprocessing: dict + uses: dict + distribution: dict + maintenance: dict + generated_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + + def to_dict(self) -> dict: + return { + "name": self.name, + "version": self.version, + "motivation": self.motivation, + "composition": self.composition, + "collection_process": self.collection_process, + "preprocessing": self.preprocessing, + "uses": self.uses, + "distribution": self.distribution, + "maintenance": self.maintenance, + "generated_at": self.generated_at, + } + + +_DEFAULT_INAPPROPRIATE_USES = [ + "Training models for real-time legal enforcement against individual landowners.", + "Land-rights or sovereignty disputes without on-the-ground verification.", + "Generative model training where label provenance is required to be human-verified.", +] + +_DEFAULT_MAINTENANCE = { + "owner": "ClimateVision Governance ", + "update_cadence": "Reviewed each minor release; refreshed when source providers change.", + "deprecation_policy": ( + "Versions are retained for two minor releases after supersession; " + "models trained on deprecated versions are flagged in their model cards." + ), +} + + +def _coerce_config(config: Union[dict, str, Path]) -> dict: + if isinstance(config, dict): + return config + path = Path(config) + text = path.read_text() + if path.suffix in {".yml", ".yaml"}: + try: + import yaml + except ImportError as exc: # pragma: no cover - import guard + raise RuntimeError("PyYAML is required to load YAML configs") from exc + return yaml.safe_load(text) + return json.loads(text) + + +def _validate(datasheet: "Datasheet") -> None: + missing: list[str] = [] + for section_name, required_keys in REQUIRED_QUESTIONS.items(): + section = getattr(datasheet, section_name) + for key in required_keys: + if key not in section or section[key] in (None, "", []): + missing.append(f"{section_name}.{key}") + if missing: + raise ValueError(f"datasheet missing required answers: {missing}") + + +def build_datasheet( + manifest: Union[dict, str, Path], + *, + name: Optional[str] = None, + version: Optional[str] = None, +) -> Datasheet: + """Build a Datasheet from a structured dataset manifest.""" + m = _coerce_config(manifest) + + resolved_name = name or m.get("name") or "climatevision-dataset" + resolved_version = version or m.get("version") or "0.0.0" + + uses = dict(m.get("uses", {})) + uses.setdefault("inappropriate_uses", list(_DEFAULT_INAPPROPRIATE_USES)) + + sheet = Datasheet( + name=resolved_name, + version=resolved_version, + motivation=dict(m.get("motivation", {})), + composition=dict(m.get("composition", {})), + collection_process=dict(m.get("collection_process", {})), + preprocessing=dict(m.get("preprocessing", {})), + uses=uses, + distribution=dict(m.get("distribution", {})), + maintenance=dict(m.get("maintenance", _DEFAULT_MAINTENANCE)), + ) + _validate(sheet) + return sheet + + +def _render_section(title: str, body: dict) -> list[str]: + if not body: + return [f"## {title}", "_Not documented._", ""] + lines = [f"## {title}"] + for key, value in body.items(): + pretty_key = key.replace("_", " ").title() + if isinstance(value, list): + lines.append(f"### {pretty_key}") + lines.extend(f"- {item}" for item in value) + elif isinstance(value, dict): + lines.append(f"### {pretty_key}") + lines.append(f"```json\n{json.dumps(value, indent=2)}\n```") + else: + lines.append(f"- **{pretty_key}**: {value}") + lines.append("") + return lines + + +def render_markdown(sheet: Datasheet) -> str: + sections = [ + f"# Datasheet: {sheet.name} ({sheet.version})", + f"_Generated {sheet.generated_at}_", + "", + "_Format: Gebru et al., 2018, \"Datasheets for Datasets\"._", + "", + ] + sections += _render_section("Motivation", sheet.motivation) + sections += _render_section("Composition", sheet.composition) + sections += _render_section("Collection Process", sheet.collection_process) + sections += _render_section("Preprocessing, Cleaning, Labeling", sheet.preprocessing) + sections += _render_section("Uses", sheet.uses) + sections += _render_section("Distribution", sheet.distribution) + sections += _render_section("Maintenance", sheet.maintenance) + return "\n".join(sections) + "\n" + + +def write_datasheet( + sheet: Datasheet, + output_dir: Optional[Union[str, Path]] = None, +) -> dict[str, Path]: + output_dir = Path(output_dir) if output_dir else _DEFAULT_OUTPUT_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + base = f"{sheet.name}_{sheet.version}" + md_path = output_dir / f"{base}.md" + json_path = output_dir / f"{base}.json" + + md_path.write_text(render_markdown(sheet)) + json_path.write_text(json.dumps(sheet.to_dict(), indent=2)) + + logger.info("Wrote datasheet to %s and %s", md_path, json_path) + return {"markdown": md_path, "json": json_path} + + +def generate( + manifest: Union[dict, str, Path], + output_dir: Optional[Union[str, Path]] = None, + **kwargs: Any, +) -> dict[str, Path]: + """End-to-end: load manifest, build the datasheet, render to disk.""" + sheet = build_datasheet(manifest, **kwargs) + return write_datasheet(sheet, output_dir=output_dir) diff --git a/src/climatevision/governance/explainability.py b/src/climatevision/governance/explainability.py new file mode 100644 index 0000000..c71a7e7 --- /dev/null +++ b/src/climatevision/governance/explainability.py @@ -0,0 +1,313 @@ +""" +SHAP-based explainability for ClimateVision segmentation models. + +Provides pixel-level and band-level attribution for U-Net predictions, +helping stakeholders understand WHY the model classified each region. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_OUTPUTS_DIR = _PROJECT_ROOT / "outputs" / "explanations" + +BAND_NAMES = { + "deforestation": ["Red", "Green", "Blue", "NIR"], + "ice_melting": ["Red", "Green", "Blue", "NIR"], + "flooding": ["Green", "NIR", "SWIR1"], +} + + +class SHAPExplainer: + """ + SHAP explainer for U-Net segmentation models. + + Uses DeepExplainer for efficient gradient-based SHAP values on CNNs. + Falls back to GradientExplainer if DeepExplainer fails. + """ + + def __init__( + self, + model: torch.nn.Module, + background_data: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ): + self.model = model + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.model.to(self.device) + self.model.eval() + + if background_data is None: + n_channels = getattr(model, "n_channels", 4) + background_data = torch.zeros(1, n_channels, 64, 64) + + self.background = background_data.to(self.device) + self._explainer = None + self._explainer_type = None + + def _init_explainer(self, input_tensor: torch.Tensor) -> None: + """Lazily initialize SHAP explainer on first use.""" + if self._explainer is not None: + return + + try: + import shap + self._explainer = shap.DeepExplainer(self.model, self.background) + self._explainer_type = "deep" + logger.info("Initialized SHAP DeepExplainer") + except Exception as e: + logger.warning("DeepExplainer failed (%s), trying GradientExplainer", e) + try: + import shap + self._explainer = shap.GradientExplainer(self.model, self.background) + self._explainer_type = "gradient" + logger.info("Initialized SHAP GradientExplainer") + except Exception as e2: + logger.warning("GradientExplainer failed (%s), using gradient fallback", e2) + self._explainer_type = "fallback" + + def explain( + self, + input_tensor: torch.Tensor, + target_class: Optional[int] = None, + ) -> dict[str, Any]: + """ + Generate SHAP explanations for input tensor. + + Args: + input_tensor: (N, C, H, W) input tensor + target_class: Class index to explain (default: predicted class) + + Returns: + Dictionary with SHAP values, band contributions, and metadata + """ + self._init_explainer(input_tensor) + input_tensor = input_tensor.to(self.device) + + with torch.no_grad(): + output = self.model(input_tensor) + predictions = torch.argmax(output, dim=1) + probabilities = torch.softmax(output, dim=1) + + if target_class is None: + target_class = int(predictions[0].mode().values.item()) + + if self._explainer_type == "fallback": + shap_values = self._gradient_fallback(input_tensor, target_class) + else: + try: + shap_values = self._explainer.shap_values(input_tensor) + if isinstance(shap_values, list): + shap_values = shap_values[target_class] + shap_values = np.array(shap_values) + except Exception as e: + logger.warning("SHAP computation failed (%s), using gradient fallback", e) + shap_values = self._gradient_fallback(input_tensor, target_class) + + band_contributions = self._compute_band_contributions(shap_values) + spatial_importance = self._compute_spatial_importance(shap_values) + + return { + "shap_values": shap_values, + "band_contributions": band_contributions, + "spatial_importance": spatial_importance, + "target_class": target_class, + "prediction": int(predictions[0].mode().values.item()), + "confidence": float(probabilities[0, target_class].mean().item()), + "explainer_type": self._explainer_type, + } + + def _gradient_fallback( + self, + input_tensor: torch.Tensor, + target_class: int, + ) -> np.ndarray: + """Compute gradient-based attribution as SHAP fallback.""" + input_tensor = input_tensor.clone().requires_grad_(True) + + output = self.model(input_tensor) + target_output = output[:, target_class, :, :].sum() + target_output.backward() + + gradients = input_tensor.grad.detach().cpu().numpy() + attributions = gradients * input_tensor.detach().cpu().numpy() + + return attributions + + def _compute_band_contributions(self, shap_values: np.ndarray) -> dict[str, float]: + """Compute per-band contribution scores.""" + abs_shap = np.abs(shap_values) + band_importance = abs_shap.mean(axis=(0, 2, 3)) + total = band_importance.sum() + 1e-8 + + contributions = {} + for i, importance in enumerate(band_importance): + contributions[f"band_{i}"] = float(importance / total) + + return contributions + + def _compute_spatial_importance(self, shap_values: np.ndarray) -> np.ndarray: + """Compute spatial importance heatmap (H, W).""" + abs_shap = np.abs(shap_values) + spatial = abs_shap.mean(axis=(0, 1)) + spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min() + 1e-8) + return spatial + + +def explain_prediction( + model_path: Union[str, Path], + image_path: Union[str, Path], + analysis_type: str = "deforestation", + target_class: Optional[int] = None, + save_heatmap: bool = True, +) -> dict[str, Any]: + """ + Generate SHAP explanation for a prediction. + + Args: + model_path: Path to model checkpoint + image_path: Path to input image (GeoTIFF or PNG) + analysis_type: Type of analysis (deforestation, ice_melting, flooding) + target_class: Class to explain (default: predicted class) + save_heatmap: Whether to save heatmap to disk + + Returns: + Dictionary with explanation results + """ + from climatevision.inference.pipeline import _load_image_file, _load_model + + model, device = _load_model(analysis_type) + image = _load_image_file(str(image_path)) + + if image.ndim == 3 and image.shape[2] < image.shape[0]: + image = np.transpose(image, (2, 0, 1)) + + n_channels = model.n_channels + c, h, w = image.shape + if c < n_channels: + pad = np.zeros((n_channels - c, h, w), dtype=image.dtype) + image = np.concatenate([image, pad], axis=0) + elif c > n_channels: + image = image[:n_channels] + + tensor = torch.FloatTensor(image.astype(np.float32)).unsqueeze(0) + + explainer = SHAPExplainer(model, device=device) + result = explainer.explain(tensor, target_class=target_class) + + band_names = BAND_NAMES.get(analysis_type, [f"Band_{i}" for i in range(n_channels)]) + top_bands = [] + for i, (band_key, importance) in enumerate( + sorted(result["band_contributions"].items(), key=lambda x: x[1], reverse=True) + ): + band_idx = int(band_key.split("_")[1]) + band_name = band_names[band_idx] if band_idx < len(band_names) else band_key + top_bands.append({"band": band_name, "importance": round(importance, 4)}) + + result["top_bands"] = top_bands + result["analysis_type"] = analysis_type + + if save_heatmap: + heatmap_path = generate_shap_heatmap( + result["spatial_importance"], + image_path, + analysis_type, + ) + result["heatmap_path"] = str(heatmap_path) + + result.pop("shap_values", None) + + return result + + +def generate_shap_heatmap( + spatial_importance: np.ndarray, + source_image_path: Union[str, Path], + analysis_type: str, + output_dir: Optional[Path] = None, +) -> Path: + """ + Generate and save SHAP heatmap visualization. + + Args: + spatial_importance: (H, W) importance scores + source_image_path: Original image path (for naming) + analysis_type: Analysis type + output_dir: Output directory (default: outputs/explanations/) + + Returns: + Path to saved heatmap + """ + output_dir = output_dir or _OUTPUTS_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + source_name = Path(source_image_path).stem + heatmap_path = output_dir / f"{source_name}_{analysis_type}_shap.npy" + + np.save(heatmap_path, spatial_importance) + logger.info("Saved SHAP heatmap to %s", heatmap_path) + + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + png_path = output_dir / f"{source_name}_{analysis_type}_shap.png" + + fig, ax = plt.subplots(figsize=(10, 10)) + im = ax.imshow(spatial_importance, cmap="hot", interpolation="nearest") + ax.set_title(f"SHAP Importance - {analysis_type.replace('_', ' ').title()}") + ax.axis("off") + plt.colorbar(im, ax=ax, label="Importance") + plt.tight_layout() + plt.savefig(png_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + logger.info("Saved SHAP heatmap PNG to %s", png_path) + return png_path + + except ImportError: + logger.warning("matplotlib not available, saved .npy only") + return heatmap_path + + +def get_band_contributions( + model_path: Union[str, Path], + image_path: Union[str, Path], + analysis_type: str = "deforestation", +) -> dict[str, float]: + """ + Get band-level contribution scores for a prediction. + + Convenience function that returns only band contributions. + + Args: + model_path: Path to model checkpoint + image_path: Path to input image + analysis_type: Type of analysis + + Returns: + Dictionary mapping band names to importance scores + """ + result = explain_prediction( + model_path=model_path, + image_path=image_path, + analysis_type=analysis_type, + save_heatmap=False, + ) + + band_names = BAND_NAMES.get(analysis_type, []) + contributions = {} + + for band_info in result.get("top_bands", []): + contributions[band_info["band"]] = band_info["importance"] + + return contributions diff --git a/src/climatevision/governance/model_card.py b/src/climatevision/governance/model_card.py new file mode 100644 index 0000000..a29d2cc --- /dev/null +++ b/src/climatevision/governance/model_card.py @@ -0,0 +1,222 @@ +""" +Automated model card generator for ClimateVision releases. + +Builds a Google-style "Model Card" (Mitchell et al., 2019) from the +training config and an evaluation metrics blob. Output is rendered as +both Markdown (for the GitHub release notes / model registry) and JSON +(for programmatic consumption by downstream tooling). +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Union + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_DEFAULT_OUTPUT_DIR = _PROJECT_ROOT / "outputs" / "model_cards" + +REQUIRED_METRICS = ("iou", "f1", "precision", "recall") + + +@dataclass +class ModelCard: + name: str + version: str + analysis_type: str + description: str + intended_use: str + out_of_scope_uses: list[str] + training_data: dict + evaluation_data: dict + metrics: dict + fairness: dict + limitations: list[str] + ethical_considerations: list[str] + contact: str + generated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + def to_dict(self) -> dict: + return { + "name": self.name, + "version": self.version, + "analysis_type": self.analysis_type, + "description": self.description, + "intended_use": self.intended_use, + "out_of_scope_uses": self.out_of_scope_uses, + "training_data": self.training_data, + "evaluation_data": self.evaluation_data, + "metrics": self.metrics, + "fairness": self.fairness, + "limitations": self.limitations, + "ethical_considerations": self.ethical_considerations, + "contact": self.contact, + "generated_at": self.generated_at, + } + + +_DEFAULT_INTENDED_USE = ( + "Detection of {analysis_type} in satellite imagery for use by " + "conservation organisations, NGOs, and government agencies. The " + "model produces per-pixel probability scores intended to be reviewed " + "alongside ground-truth reference data and analyst judgement." +) + +_DEFAULT_OUT_OF_SCOPE = [ + "Real-time legal enforcement decisions without analyst review.", + "Carbon credit issuance without independent ground-truth validation.", + "Use on imagery from sensors not represented in the training set.", +] + +_DEFAULT_LIMITATIONS = [ + "Performance degrades on cloud cover above the masking threshold used in preprocessing.", + "Geographic coverage limited to regions present in the training set.", + "Temporal generalisation to seasons or years outside the training window is unverified.", +] + +_DEFAULT_ETHICS = [ + "Model outputs may carry geographic bias; downstream users must run " + "the bias audit pipeline before distributing results across regions.", + "Predictions should never be the sole basis for actions affecting " + "indigenous land rights or local communities.", +] + + +def _coerce_config(config: Union[dict, str, Path]) -> dict: + if isinstance(config, dict): + return config + path = Path(config) + text = path.read_text() + if path.suffix in {".yml", ".yaml"}: + try: + import yaml + except ImportError as exc: # pragma: no cover - import guard + raise RuntimeError("PyYAML is required to load YAML configs") from exc + return yaml.safe_load(text) + return json.loads(text) + + +def _validate_metrics(metrics: dict) -> None: + missing = [m for m in REQUIRED_METRICS if m not in metrics] + if missing: + raise ValueError(f"metrics missing required keys: {missing}") + + +def build_model_card( + config: Union[dict, str, Path], + metrics: dict, + fairness_report: Optional[dict] = None, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + contact: str = "ClimateVision Governance ", +) -> ModelCard: + cfg = _coerce_config(config) + _validate_metrics(metrics) + + analysis_type = cfg.get("analysis_type") or cfg.get("analysis", {}).get("type", "deforestation") + resolved_name = name or cfg.get("model", {}).get("name") or f"climatevision-{analysis_type}" + resolved_version = version or cfg.get("model", {}).get("version") or "0.0.0" + + return ModelCard( + name=resolved_name, + version=resolved_version, + analysis_type=analysis_type, + description=description or f"U-Net segmentation model for {analysis_type}.", + intended_use=_DEFAULT_INTENDED_USE.format(analysis_type=analysis_type), + out_of_scope_uses=list(_DEFAULT_OUT_OF_SCOPE), + training_data=cfg.get("training_data", cfg.get("data", {})), + evaluation_data=cfg.get("evaluation_data", {}), + metrics=dict(metrics), + fairness=fairness_report or {}, + limitations=list(_DEFAULT_LIMITATIONS), + ethical_considerations=list(_DEFAULT_ETHICS), + contact=contact, + ) + + +def render_markdown(card: ModelCard) -> str: + metrics_rows = "\n".join( + f"| {k} | {v} |" for k, v in sorted(card.metrics.items()) + ) + fairness_block = ( + "\n".join(f"- **{k}**: {v}" for k, v in card.fairness.items()) + or "_No fairness report attached._" + ) + + sections = [ + f"# Model Card: {card.name} ({card.version})", + f"_Generated {card.generated_at}_", + "", + "## Description", + card.description, + "", + "## Intended Use", + card.intended_use, + "", + "### Out-of-Scope Uses", + "\n".join(f"- {u}" for u in card.out_of_scope_uses), + "", + "## Training Data", + f"```json\n{json.dumps(card.training_data, indent=2)}\n```", + "", + "## Evaluation", + "| Metric | Value |", + "| --- | --- |", + metrics_rows, + "", + "## Fairness", + fairness_block, + "", + "## Limitations", + "\n".join(f"- {l}" for l in card.limitations), + "", + "## Ethical Considerations", + "\n".join(f"- {e}" for e in card.ethical_considerations), + "", + "## Contact", + card.contact, + ] + return "\n".join(sections) + "\n" + + +def write_model_card( + card: ModelCard, + output_dir: Optional[Union[str, Path]] = None, +) -> dict[str, Path]: + output_dir = Path(output_dir) if output_dir else _DEFAULT_OUTPUT_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + base = f"{card.name}_{card.version}" + md_path = output_dir / f"{base}.md" + json_path = output_dir / f"{base}.json" + + md_path.write_text(render_markdown(card)) + json_path.write_text(json.dumps(card.to_dict(), indent=2)) + + logger.info("Wrote model card to %s and %s", md_path, json_path) + return {"markdown": md_path, "json": json_path} + + +def generate( + config: Union[dict, str, Path], + metrics: Union[dict, str, Path], + fairness_report: Optional[Union[dict, str, Path]] = None, + output_dir: Optional[Union[str, Path]] = None, + **kwargs: Any, +) -> dict[str, Path]: + """End-to-end: load inputs, build the card, render to disk.""" + metrics_dict = _coerce_config(metrics) if not isinstance(metrics, dict) else metrics + fairness_dict = ( + _coerce_config(fairness_report) + if fairness_report is not None and not isinstance(fairness_report, dict) + else fairness_report + ) + card = build_model_card(config, metrics_dict, fairness_dict, **kwargs) + return write_model_card(card, output_dir=output_dir) diff --git a/src/climatevision/inference/__init__.py b/src/climatevision/inference/__init__.py index ba0dbda..9e76ad0 100644 --- a/src/climatevision/inference/__init__.py +++ b/src/climatevision/inference/__init__.py @@ -7,9 +7,17 @@ run_inference_from_file, run_inference_from_gee, ) +from .batch_processor import ( + BatchJob, + BatchProcessor, + BatchSummary, +) __all__ = [ "run_inference", "run_inference_from_file", "run_inference_from_gee", + "BatchJob", + "BatchProcessor", + "BatchSummary", ] diff --git a/src/climatevision/inference/alert_generator.py b/src/climatevision/inference/alert_generator.py new file mode 100644 index 0000000..96fee38 --- /dev/null +++ b/src/climatevision/inference/alert_generator.py @@ -0,0 +1,229 @@ +""" +Deforestation alert generator for ClimateVision. + +Watches for inference results that exceed a configurable threshold for +a given subscription (region, analysis_type, alert_threshold) and emits +notifications via pluggable channels (email, webhook, log). + +Routing rules: + +- Each `Subscription` defines a region (bbox), analysis type, threshold, + and a list of channels to deliver to. +- A new prediction is matched against subscriptions by analysis type + and whether its bbox overlaps the subscription bbox. +- Alerts are de-duplicated within a configurable cooldown window so a + flapping signal does not page everyone every minute. + +The generator does not perform delivery itself for non-loggable channels; +it returns delivery records that the caller (typically the alert worker +or `notification_router.deliver_pending`) is responsible for sending. +""" + +from __future__ import annotations + +import json +import logging +import threading +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Callable, Iterable, Optional, Union + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_ALERT_LOG = _PROJECT_ROOT / "outputs" / "alerts" / "alerts.jsonl" + +DeliveryFn = Callable[["Alert"], None] + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +@dataclass(frozen=True) +class Subscription: + org_id: int + bbox: tuple[float, float, float, float] + analysis_type: str + alert_threshold: float + channels: tuple[str, ...] = ("log",) + cooldown_minutes: int = 60 + + +@dataclass +class Alert: + alert_id: str + org_id: int + analysis_type: str + region_bbox: tuple[float, float, float, float] + severity: str + measured_value: float + threshold: float + summary: str + triggered_at: str + channels: tuple[str, ...] + + +def _bbox_overlaps( + a: tuple[float, float, float, float], + b: tuple[float, float, float, float], +) -> bool: + a_min_x, a_min_y, a_max_x, a_max_y = a + b_min_x, b_min_y, b_max_x, b_max_y = b + return not ( + a_max_x < b_min_x + or b_max_x < a_min_x + or a_max_y < b_min_y + or b_max_y < a_min_y + ) + + +def _classify_severity(measured: float, threshold: float) -> str: + if measured >= threshold * 3: + return "critical" + if measured >= threshold * 2: + return "high" + return "medium" + + +class AlertGenerator: + """ + Subscription-driven alert generator with cooldown deduplication. + """ + + def __init__( + self, + subscriptions: Optional[Iterable[Subscription]] = None, + alert_log_path: Optional[Union[str, Path]] = None, + delivery: Optional[dict[str, DeliveryFn]] = None, + clock: Callable[[], datetime] = _utcnow, + ) -> None: + self._subscriptions: list[Subscription] = list(subscriptions or []) + self.alert_log_path = Path(alert_log_path) if alert_log_path else _DEFAULT_ALERT_LOG + self._delivery = dict(delivery or {}) + self._lock = threading.Lock() + self._last_fired: dict[tuple[int, str], datetime] = {} + self._clock = clock + + def add_subscription(self, sub: Subscription) -> None: + self._subscriptions.append(sub) + + def register_channel(self, name: str, fn: DeliveryFn) -> None: + self._delivery[name] = fn + + def _persist(self, alert: Alert) -> None: + self.alert_log_path.parent.mkdir(parents=True, exist_ok=True) + with self._lock, self.alert_log_path.open("a") as fh: + fh.write(json.dumps(asdict(alert)) + "\n") + + def _in_cooldown(self, sub: Subscription, now: datetime) -> bool: + key = (sub.org_id, sub.analysis_type) + last = self._last_fired.get(key) + if last is None: + return False + return now - last < timedelta(minutes=sub.cooldown_minutes) + + def _matches( + self, + sub: Subscription, + analysis_type: str, + bbox: tuple[float, float, float, float], + measured_value: float, + ) -> bool: + if sub.analysis_type != analysis_type: + return False + if not _bbox_overlaps(sub.bbox, bbox): + return False + return measured_value >= sub.alert_threshold + + def evaluate( + self, + analysis_type: str, + bbox: tuple[float, float, float, float], + measured_value: float, + summary: str = "", + ) -> list[Alert]: + now = self._clock() + alerts: list[Alert] = [] + + for sub in self._subscriptions: + if not self._matches(sub, analysis_type, bbox, measured_value): + continue + if self._in_cooldown(sub, now): + logger.debug( + "Skipping alert for org=%s in cooldown", sub.org_id + ) + continue + + alert = Alert( + alert_id=str(uuid.uuid4()), + org_id=sub.org_id, + analysis_type=analysis_type, + region_bbox=bbox, + severity=_classify_severity(measured_value, sub.alert_threshold), + measured_value=float(measured_value), + threshold=float(sub.alert_threshold), + summary=summary or ( + f"{analysis_type} signal {measured_value:.3f} " + f"exceeded threshold {sub.alert_threshold:.3f}" + ), + triggered_at=now.isoformat(), + channels=tuple(sub.channels), + ) + self._last_fired[(sub.org_id, sub.analysis_type)] = now + self._persist(alert) + self._dispatch(alert) + alerts.append(alert) + + if alerts: + logger.info("Fired %d alert(s) for analysis=%s", len(alerts), analysis_type) + return alerts + + def _dispatch(self, alert: Alert) -> None: + for channel in alert.channels: + fn = self._delivery.get(channel) + if fn is None: + logger.warning( + "No delivery handler registered for channel '%s' (alert=%s)", + channel, + alert.alert_id, + ) + continue + try: + fn(alert) + except Exception: # noqa: BLE001 + logger.exception( + "Delivery on channel '%s' failed for alert=%s", + channel, + alert.alert_id, + ) + + def iter_alerts(self) -> list[Alert]: + if not self.alert_log_path.exists(): + return [] + out: list[Alert] = [] + with self.alert_log_path.open() as fh: + for line in fh: + if not line.strip(): + continue + row = json.loads(line) + row["region_bbox"] = tuple(row["region_bbox"]) + row["channels"] = tuple(row["channels"]) + out.append(Alert(**row)) + return out + + +def log_channel(alert: Alert) -> None: + """Default 'log' channel — writes the alert summary at WARNING level.""" + logger.warning( + "ALERT [%s] org=%s analysis=%s severity=%s value=%.3f >= %.3f :: %s", + alert.alert_id, + alert.org_id, + alert.analysis_type, + alert.severity, + alert.measured_value, + alert.threshold, + alert.summary, + ) diff --git a/src/climatevision/inference/batch_processor.py b/src/climatevision/inference/batch_processor.py new file mode 100644 index 0000000..dac85d3 --- /dev/null +++ b/src/climatevision/inference/batch_processor.py @@ -0,0 +1,209 @@ +""" +Batch processor for ClimateVision inference jobs. + +Submits a list of image paths (or numpy arrays) to the inference +pipeline in parallel, tracks per-job state, and produces a structured +result manifest. The processor is designed to be driven from either +a CLI script or the FastAPI background-task layer. + +Job state machine: + + queued -> running -> (succeeded | failed) + +Each job is appended to a JSONL manifest as soon as its terminal +state is reached so a long-running batch can be resumed or audited +without waiting for the whole queue to finish. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Iterable, Optional, Union + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_MANIFEST = _PROJECT_ROOT / "outputs" / "batches" / "manifest.jsonl" + +JobInput = Union[str, Path, dict] +InferenceFn = Callable[[JobInput, str], dict] + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat() + + +@dataclass +class BatchJob: + job_id: str + source: str + analysis_type: str + status: str = "queued" + submitted_at: str = field(default_factory=_utcnow) + started_at: Optional[str] = None + finished_at: Optional[str] = None + duration_ms: Optional[int] = None + result_summary: Optional[dict] = None + error: Optional[str] = None + attempts: int = 0 + + +@dataclass +class BatchSummary: + total: int + succeeded: int + failed: int + duration_seconds: float + + def to_dict(self) -> dict: + return asdict(self) + + +def _default_inference_fn(source: JobInput, analysis_type: str) -> dict: + """ + Default inference adapter — calls run_inference_from_file or run_inference + depending on the input shape. Imported lazily so unit tests can stub it. + """ + from climatevision.inference.pipeline import ( + run_inference, + run_inference_from_file, + ) + + if isinstance(source, (str, Path)): + return run_inference_from_file(str(source), analysis_type=analysis_type) + if isinstance(source, dict): + return run_inference(**source, analysis_type=analysis_type) + raise TypeError(f"Unsupported source type: {type(source).__name__}") + + +class BatchProcessor: + """ + Parallel batch executor for inference jobs. + + Args: + max_workers: Thread pool size. Defaults to 4. + max_attempts: Retry count for transient failures. + manifest_path: Where to append per-job records. Created on first write. + inference_fn: Override the actual inference call (handy for tests + and for swapping in batch_predict implementations later). + """ + + def __init__( + self, + max_workers: int = 4, + max_attempts: int = 1, + manifest_path: Optional[Union[str, Path]] = None, + inference_fn: Optional[InferenceFn] = None, + ) -> None: + self.max_workers = max_workers + self.max_attempts = max(1, max_attempts) + self.manifest_path = Path(manifest_path) if manifest_path else _DEFAULT_MANIFEST + self._inference_fn = inference_fn or _default_inference_fn + self._jobs: dict[str, BatchJob] = {} + self._lock = threading.Lock() + + def _persist(self, job: BatchJob) -> None: + self.manifest_path.parent.mkdir(parents=True, exist_ok=True) + with self._lock, self.manifest_path.open("a") as fh: + fh.write(json.dumps(asdict(job)) + "\n") + + def _summarize_result(self, result: Any) -> dict: + if isinstance(result, dict): + keep = {} + for key in ("hectares", "carbon_tonnes", "iou", "f1", "mean_confidence"): + if key in result: + keep[key] = result[key] + if "mask" in result: + import numpy as np + + arr = np.asarray(result["mask"]) + keep["positive_pixels"] = int(arr.sum()) + keep["total_pixels"] = int(arr.size) + return keep + return {"raw": str(result)[:200]} + + def _run_one(self, job: BatchJob, source: JobInput) -> BatchJob: + for attempt in range(1, self.max_attempts + 1): + job.attempts = attempt + job.status = "running" + job.started_at = _utcnow() + t0 = time.perf_counter() + try: + result = self._inference_fn(source, job.analysis_type) + job.result_summary = self._summarize_result(result) + job.status = "succeeded" + job.error = None + break + except Exception as exc: # noqa: BLE001 - we want to capture all + logger.exception("Job %s attempt %d failed", job.job_id, attempt) + job.error = f"{type(exc).__name__}: {exc}" + job.status = "failed" + finally: + job.duration_ms = int((time.perf_counter() - t0) * 1000) + job.finished_at = _utcnow() + self._persist(job) + return job + + def submit_batch( + self, + sources: Iterable[JobInput], + analysis_type: str = "deforestation", + ) -> list[BatchJob]: + sources = list(sources) + jobs = [ + BatchJob( + job_id=str(uuid.uuid4()), + source=str(s) if isinstance(s, (str, Path)) else json.dumps(s, default=str), + analysis_type=analysis_type, + ) + for s in sources + ] + for j in jobs: + self._jobs[j.job_id] = j + return jobs + + def run( + self, + sources: Iterable[JobInput], + analysis_type: str = "deforestation", + ) -> tuple[list[BatchJob], BatchSummary]: + sources = list(sources) + jobs = self.submit_batch(sources, analysis_type=analysis_type) + t0 = time.perf_counter() + + with ThreadPoolExecutor(max_workers=self.max_workers) as pool: + futures = { + pool.submit(self._run_one, job, source): job + for job, source in zip(jobs, sources) + } + for fut in as_completed(futures): + fut.result() + + duration = time.perf_counter() - t0 + succeeded = sum(1 for j in jobs if j.status == "succeeded") + failed = sum(1 for j in jobs if j.status == "failed") + summary = BatchSummary( + total=len(jobs), + succeeded=succeeded, + failed=failed, + duration_seconds=round(duration, 3), + ) + logger.info( + "Batch finished: total=%d succeeded=%d failed=%d in %.2fs", + summary.total, + summary.succeeded, + summary.failed, + duration, + ) + return jobs, summary + + def get_job(self, job_id: str) -> Optional[BatchJob]: + return self._jobs.get(job_id) diff --git a/src/climatevision/models/__init__.py b/src/climatevision/models/__init__.py index 93587d4..a206978 100644 --- a/src/climatevision/models/__init__.py +++ b/src/climatevision/models/__init__.py @@ -4,9 +4,25 @@ from .unet import UNet, AttentionUNet from .siamese import SiameseNetwork +from .regression import ( + BiomassRegressor, + RegressionMetrics, + biomass_to_carbon, + biomass_to_co2e, + estimate_biomass_from_indices, + evaluate_regression, + serialize_metrics, +) __all__ = [ 'UNet', 'AttentionUNet', 'SiameseNetwork', + 'BiomassRegressor', + 'RegressionMetrics', + 'biomass_to_carbon', + 'biomass_to_co2e', + 'estimate_biomass_from_indices', + 'evaluate_regression', + 'serialize_metrics', ] diff --git a/src/climatevision/models/regression.py b/src/climatevision/models/regression.py new file mode 100644 index 0000000..681d135 --- /dev/null +++ b/src/climatevision/models/regression.py @@ -0,0 +1,281 @@ +""" +Biomass and carbon-stock regression models for ClimateVision. + +Where the U-Net produces per-pixel deforestation masks, this module +turns those masks (plus the underlying spectral indices) into a scalar +estimate of above-ground biomass (Mg/ha) and the carbon equivalent +(tCO2e). It supports two regressors out of the box: + +- ``"random_forest"`` — sklearn RandomForestRegressor (default). +- ``"xgboost"`` — XGBRegressor when xgboost is installed. + +The conversion from biomass to CO2e uses the IPCC default carbon +fraction of 0.47 and the molecular-weight ratio 44/12. Both constants +are exposed so users can override them per ecosystem. + +Typical usage:: + + from climatevision.models.regression import ( + BiomassRegressor, biomass_to_carbon, biomass_to_co2e, + ) + + reg = BiomassRegressor(model_type="random_forest").fit(X_train, y_train) + biomass = reg.predict(X_test) # Mg/ha + co2e = biomass_to_co2e(biomass) # tCO2e/ha +""" + +from __future__ import annotations + +import json +import logging +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +CARBON_FRACTION = 0.47 # IPCC default for tropical forests +CO2_TO_C_RATIO = 44.0 / 12.0 # molecular weight ratio +DEFAULT_FEATURE_NAMES = ( + "ndvi", + "evi", + "savi", + "ndmi", + "nbr", + "red", + "green", + "blue", + "nir", + "swir1", +) + + +def biomass_to_carbon(biomass: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """Convert above-ground biomass (Mg/ha) to carbon (tC/ha).""" + return np.asarray(biomass) * CARBON_FRACTION + + +def biomass_to_co2e(biomass: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """Convert above-ground biomass (Mg/ha) to CO2 equivalent (tCO2e/ha).""" + return biomass_to_carbon(biomass) * CO2_TO_C_RATIO + + +@dataclass +class RegressionMetrics: + rmse: float + mae: float + r2: float + mape: float + + def to_dict(self) -> dict[str, float]: + return {"rmse": self.rmse, "mae": self.mae, "r2": self.r2, "mape": self.mape} + + +def _safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float: + mask = y_true != 0 + if not mask.any(): + return float("nan") + return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask]))) + + +def evaluate_regression(y_true: np.ndarray, y_pred: np.ndarray) -> RegressionMetrics: + """Compute RMSE / MAE / R² / MAPE for a regression run.""" + y_true = np.asarray(y_true, dtype=np.float64) + y_pred = np.asarray(y_pred, dtype=np.float64) + if y_true.shape != y_pred.shape: + raise ValueError( + f"shape mismatch: y_true={y_true.shape} y_pred={y_pred.shape}" + ) + if y_true.size == 0: + raise ValueError("cannot evaluate empty arrays") + + err = y_pred - y_true + rmse = float(np.sqrt(np.mean(err ** 2))) + mae = float(np.mean(np.abs(err))) + + ss_res = float(np.sum(err ** 2)) + ss_tot = float(np.sum((y_true - y_true.mean()) ** 2)) + r2 = float("nan") if ss_tot == 0 else 1.0 - ss_res / ss_tot + + return RegressionMetrics(rmse=rmse, mae=mae, r2=r2, mape=_safe_mape(y_true, y_pred)) + + +class BiomassRegressor: + """ + Wrapper around sklearn / xgboost regressors with a stable API for + ClimateVision pipelines. + """ + + SUPPORTED_MODELS = ("random_forest", "xgboost") + + def __init__( + self, + model_type: str = "random_forest", + *, + feature_names: Optional[Sequence[str]] = None, + model_kwargs: Optional[dict[str, Any]] = None, + random_state: int = 42, + ) -> None: + if model_type not in self.SUPPORTED_MODELS: + raise ValueError( + f"model_type must be one of {self.SUPPORTED_MODELS}, got {model_type!r}" + ) + self.model_type = model_type + self.feature_names = tuple(feature_names) if feature_names else DEFAULT_FEATURE_NAMES + self.model_kwargs = dict(model_kwargs or {}) + self.random_state = random_state + self._model: Any = None + self._fitted = False + + def _build(self) -> Any: + if self.model_type == "random_forest": + from sklearn.ensemble import RandomForestRegressor + + kwargs = { + "n_estimators": 200, + "max_depth": None, + "min_samples_leaf": 2, + "random_state": self.random_state, + "n_jobs": -1, + } + kwargs.update(self.model_kwargs) + return RandomForestRegressor(**kwargs) + + try: + from xgboost import XGBRegressor + except ImportError as exc: # pragma: no cover - import guard + raise RuntimeError( + "xgboost is required for model_type='xgboost'. " + "Install with `pip install xgboost`." + ) from exc + + kwargs = { + "n_estimators": 400, + "max_depth": 6, + "learning_rate": 0.05, + "subsample": 0.9, + "objective": "reg:squarederror", + "random_state": self.random_state, + "n_jobs": -1, + } + kwargs.update(self.model_kwargs) + return XGBRegressor(**kwargs) + + def fit( + self, + X: np.ndarray, + y: np.ndarray, + *, + sample_weight: Optional[np.ndarray] = None, + ) -> "BiomassRegressor": + X = np.asarray(X, dtype=np.float64) + y = np.asarray(y, dtype=np.float64) + if X.ndim != 2: + raise ValueError(f"X must be 2-D, got shape {X.shape}") + if X.shape[0] != y.shape[0]: + raise ValueError( + f"row mismatch: X has {X.shape[0]} rows, y has {y.shape[0]}" + ) + + self._model = self._build() + if sample_weight is not None: + self._model.fit(X, y, sample_weight=sample_weight) + else: + self._model.fit(X, y) + self._fitted = True + logger.info( + "Trained %s on %d samples with %d features", + self.model_type, + X.shape[0], + X.shape[1], + ) + return self + + def predict(self, X: np.ndarray) -> np.ndarray: + if not self._fitted: + raise RuntimeError("regressor must be fit() before predict()") + X = np.asarray(X, dtype=np.float64) + return np.asarray(self._model.predict(X), dtype=np.float64) + + def feature_importances(self) -> dict[str, float]: + if not self._fitted: + raise RuntimeError("regressor must be fit() before feature_importances()") + importances = getattr(self._model, "feature_importances_", None) + if importances is None: + raise AttributeError( + f"underlying {self.model_type} has no feature_importances_" + ) + names = self.feature_names + if len(names) != len(importances): + names = tuple(f"f{i}" for i in range(len(importances))) + return {name: float(value) for name, value in zip(names, importances)} + + def evaluate(self, X: np.ndarray, y_true: np.ndarray) -> RegressionMetrics: + return evaluate_regression(y_true, self.predict(X)) + + def save(self, path: Union[str, Path]) -> Path: + if not self._fitted: + raise RuntimeError("regressor must be fit() before save()") + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as fh: + pickle.dump( + { + "model": self._model, + "model_type": self.model_type, + "feature_names": list(self.feature_names), + "random_state": self.random_state, + }, + fh, + ) + logger.info("Saved %s regressor to %s", self.model_type, path) + return path + + @classmethod + def load(cls, path: Union[str, Path]) -> "BiomassRegressor": + path = Path(path) + with path.open("rb") as fh: + payload = pickle.load(fh) + instance = cls( + model_type=payload["model_type"], + feature_names=payload["feature_names"], + random_state=payload["random_state"], + ) + instance._model = payload["model"] + instance._fitted = True + return instance + + +def estimate_biomass_from_indices( + indices: dict[str, np.ndarray], + regressor: BiomassRegressor, + feature_order: Optional[Sequence[str]] = None, +) -> np.ndarray: + """ + Build a feature matrix from a dict of spectral-index arrays and run + inference. The dict is expected to map index name -> 1-D array of + pixel values (one row per pixel). + """ + feature_order = tuple(feature_order or regressor.feature_names) + missing = [k for k in feature_order if k not in indices] + if missing: + raise KeyError(f"missing spectral indices: {missing}") + + columns = [np.asarray(indices[k]).reshape(-1) for k in feature_order] + if len({c.size for c in columns}) != 1: + raise ValueError("all input indices must have the same length") + X = np.stack(columns, axis=1) + return regressor.predict(X) + + +def serialize_metrics( + metrics: RegressionMetrics, output_path: Union[str, Path] +) -> Path: + """Persist regression metrics as JSON for the eval / model-card pipeline.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(metrics.to_dict(), indent=2)) + return output_path diff --git a/src/climatevision/reports/__init__.py b/src/climatevision/reports/__init__.py new file mode 100644 index 0000000..954493a --- /dev/null +++ b/src/climatevision/reports/__init__.py @@ -0,0 +1,23 @@ +""" +ClimateVision Reports Module + +LLM-backed natural-language reporting on top of model predictions and +analytics outputs. Stakeholder reports combine carbon analytics, SHAP +explanations, and fairness metadata into a single readable narrative. +""" + +from .llm_reporter import ( + ImpactReport, + LLMReporter, + ReportContext, + generate_impact_report, + render_template, +) + +__all__ = [ + "ImpactReport", + "LLMReporter", + "ReportContext", + "generate_impact_report", + "render_template", +] diff --git a/src/climatevision/reports/llm_reporter.py b/src/climatevision/reports/llm_reporter.py new file mode 100644 index 0000000..d2c78e3 --- /dev/null +++ b/src/climatevision/reports/llm_reporter.py @@ -0,0 +1,248 @@ +""" +LLM-backed impact report generation for ClimateVision. + +`LLMReporter` turns a structured prediction record (carbon analytics, +SHAP attributions, validation metrics, fairness flags) into a +narrative report ready for NGOs and government stakeholders. + +A deterministic template-based renderer is always available so that +the module never blocks the pipeline when an LLM provider is +unreachable. When a provider is configured, the template output is +used as the prompt skeleton and the LLM smooths it into prose. +""" + +from __future__ import annotations + +import json +import logging +import os +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Optional, Union + +logger = logging.getLogger(__name__) + +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_DEFAULT_OUTPUT_DIR = _PROJECT_ROOT / "outputs" / "reports" + + +@dataclass +class ReportContext: + """Inputs the reporter draws on to compose an impact report.""" + + region: str + period: str + analysis_type: str + carbon: dict = field(default_factory=dict) + validation: dict = field(default_factory=dict) + shap: dict = field(default_factory=dict) + fairness: dict = field(default_factory=dict) + run_id: Optional[Union[int, str]] = None + + def headline_metric(self) -> str: + hectares = self.carbon.get("hectares") + carbon_t = self.carbon.get("carbon_tonnes") + if hectares is not None and carbon_t is not None: + return ( + f"{hectares:,.1f} hectares of {self.analysis_type.replace('_', ' ')} " + f"detected, equivalent to {carbon_t:,.1f} tCO2e." + ) + if hectares is not None: + return f"{hectares:,.1f} hectares of {self.analysis_type.replace('_', ' ')} detected." + return f"Analysis run for {self.analysis_type} in {self.region} ({self.period})." + + +@dataclass +class ImpactReport: + summary: str + body: str + context: ReportContext + provider: str + generated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + def to_dict(self) -> dict: + d = asdict(self) + d["context"] = asdict(self.context) + return d + + +# Type alias for an LLM call: prompt -> completion +LLMCallable = Callable[[str], str] + + +def render_template(context: ReportContext, *, include_shap: bool = True) -> str: + """Deterministic Markdown template — used both as a fallback and as an LLM prompt seed.""" + + lines = [ + f"# Impact Report — {context.region.title()} ({context.period})", + "", + f"**Headline:** {context.headline_metric()}", + "", + "## Carbon Analytics", + ] + + if context.carbon: + for k, v in context.carbon.items(): + lines.append(f"- **{k.replace('_', ' ').title()}**: {v}") + else: + lines.append("- _Carbon analytics not provided._") + + lines += ["", "## Validation"] + if context.validation: + for k, v in context.validation.items(): + lines.append(f"- **{k.upper()}**: {v}") + else: + lines.append("- _No validation metrics attached._") + + if include_shap: + lines += ["", "## Explainability"] + if context.shap: + top_bands = context.shap.get("top_bands", []) + if top_bands: + bands = ", ".join(b["band"] if isinstance(b, dict) else str(b) for b in top_bands) + lines.append(f"- Most influential bands: {bands}") + for k, v in context.shap.items(): + if k == "top_bands": + continue + lines.append(f"- **{k.replace('_', ' ').title()}**: {v}") + else: + lines.append("- _No SHAP explanation attached._") + + if context.fairness: + lines += ["", "## Fairness"] + for k, v in context.fairness.items(): + lines.append(f"- **{k.replace('_', ' ').title()}**: {v}") + + return "\n".join(lines) + "\n" + + +def _build_prompt(context: ReportContext, template: str) -> str: + return ( + "You are drafting a concise, factual environmental-impact report for " + "conservation organisations and government stakeholders.\n\n" + "Rules:\n" + "- Do not invent numbers; only restate values from the data block below.\n" + "- Keep tone neutral and policy-relevant; no promotional language.\n" + "- Output Markdown with the same section structure as the seed.\n" + "- Open with a 2–3 sentence executive summary.\n\n" + f"DATA (JSON):\n```json\n{json.dumps(asdict(context), indent=2, default=str)}\n```\n\n" + f"SEED:\n{template}\n\n" + "FINAL REPORT:\n" + ) + + +class LLMReporter: + """ + Reporter with pluggable LLM backend. + + Pass an `llm` callable (prompt -> string) to use a custom provider. + Without one, set CLIMATEVISION_LLM_PROVIDER=anthropic and + ANTHROPIC_API_KEY to use Anthropic's API; otherwise the template + renderer alone is used. + """ + + def __init__(self, llm: Optional[LLMCallable] = None) -> None: + self._llm = llm + + def _call_llm(self, prompt: str) -> Optional[str]: + if self._llm is not None: + try: + return self._llm(prompt) + except Exception as exc: # pragma: no cover - external call + logger.exception("user-provided LLM callable raised: %s", exc) + return None + + provider = os.environ.get("CLIMATEVISION_LLM_PROVIDER", "").lower() + if provider == "anthropic": + return self._call_anthropic(prompt) + return None + + def _call_anthropic(self, prompt: str) -> Optional[str]: # pragma: no cover - external call + try: + import anthropic + except ImportError: + logger.warning("anthropic package not installed; using template only") + return None + + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + logger.warning("ANTHROPIC_API_KEY not set; using template only") + return None + + client = anthropic.Anthropic(api_key=api_key) + message = client.messages.create( + model=os.environ.get("CLIMATEVISION_LLM_MODEL", "claude-haiku-4-5-20251001"), + max_tokens=1024, + messages=[{"role": "user", "content": prompt}], + ) + parts = [b.text for b in message.content if getattr(b, "type", None) == "text"] + return "".join(parts) if parts else None + + def generate( + self, + context: ReportContext, + *, + include_shap: bool = True, + ) -> ImpactReport: + template = render_template(context, include_shap=include_shap) + prompt = _build_prompt(context, template) + llm_text = self._call_llm(prompt) + + if llm_text: + body = llm_text.strip() + provider = "llm" + else: + body = template + provider = "template" + + first_para = body.strip().split("\n\n", 1)[0] + summary = first_para.replace("\n", " ").strip() + + return ImpactReport( + summary=summary, + body=body, + context=context, + provider=provider, + ) + + +def generate_impact_report( + region: str, + period: str, + analysis_type: str = "deforestation", + carbon: Optional[dict] = None, + validation: Optional[dict] = None, + shap: Optional[dict] = None, + fairness: Optional[dict] = None, + run_id: Optional[Union[int, str]] = None, + *, + llm: Optional[LLMCallable] = None, + include_shap: bool = True, + output_dir: Optional[Union[str, Path]] = None, +) -> ImpactReport: + """High-level entry point used by the API and CLI.""" + ctx = ReportContext( + region=region, + period=period, + analysis_type=analysis_type, + carbon=carbon or {}, + validation=validation or {}, + shap=shap or {}, + fairness=fairness or {}, + run_id=run_id, + ) + report = LLMReporter(llm=llm).generate(ctx, include_shap=include_shap) + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + base = f"{region}_{period}_impact" + (output_dir / f"{base}.md").write_text(report.body) + (output_dir / f"{base}.json").write_text(json.dumps(report.to_dict(), indent=2, default=str)) + + return report + + +def _default_output_dir() -> Path: + return _DEFAULT_OUTPUT_DIR diff --git a/src/climatevision/security/__init__.py b/src/climatevision/security/__init__.py new file mode 100644 index 0000000..f703fba --- /dev/null +++ b/src/climatevision/security/__init__.py @@ -0,0 +1,42 @@ +""" +ClimateVision Security Module + +Provides API security and inference pipeline protection: +- Input validation and sanitization +- Rate limiting per API key +- File upload validation +- Adversarial input detection +- Security scanning and reporting +""" + +from .api_security import ( + SecurityConfig, + validate_payload_size, + validate_bbox, + validate_file_upload, + sanitize_string_input, + RateLimiter, + SecurityMiddleware, +) +from .pipeline_guard import ( + PipelineGuard, + detect_adversarial_input, + validate_model_output, + InputAnomalyDetector, +) + +__all__ = [ + # API Security + "SecurityConfig", + "validate_payload_size", + "validate_bbox", + "validate_file_upload", + "sanitize_string_input", + "RateLimiter", + "SecurityMiddleware", + # Pipeline Guard + "PipelineGuard", + "detect_adversarial_input", + "validate_model_output", + "InputAnomalyDetector", +] diff --git a/src/climatevision/security/api_security.py b/src/climatevision/security/api_security.py new file mode 100644 index 0000000..7eaa8fc --- /dev/null +++ b/src/climatevision/security/api_security.py @@ -0,0 +1,408 @@ +""" +API Security Module for ClimateVision. + +Implements OWASP-aligned security controls: +- Input validation and sanitization +- Rate limiting per API key +- File upload validation (magic bytes, extensions) +- Payload size limits +""" + +from __future__ import annotations + +import hashlib +import logging +import re +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Callable + +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +@dataclass +class SecurityConfig: + """Security configuration settings.""" + + max_payload_size_bytes: int = 50 * 1024 * 1024 # 50 MB + max_bbox_area_degrees: float = 100.0 # Max area in square degrees + max_date_range_days: int = 365 # Max date range + rate_limit_requests: int = 100 # Requests per window + rate_limit_window_seconds: int = 60 # Window size + allowed_file_extensions: set[str] = field( + default_factory=lambda: {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".geotiff"} + ) + max_filename_length: int = 255 + blocked_patterns: list[str] = field( + default_factory=lambda: [ + r"\.\.\/", # Path traversal + r" bool: + """Check if request is allowed for the given key.""" + now = time.time() + window_start = now - self.window_seconds + + # Clean old requests + self._buckets[key] = [ + ts for ts in self._buckets[key] if ts > window_start + ] + + # Check limit + if len(self._buckets[key]) >= self.max_requests: + return False + + # Record request + self._buckets[key].append(now) + return True + + def get_remaining(self, key: str) -> int: + """Get remaining requests for the key.""" + now = time.time() + window_start = now - self.window_seconds + recent = [ts for ts in self._buckets[key] if ts > window_start] + return max(0, self.max_requests - len(recent)) + + def get_reset_time(self, key: str) -> float: + """Get seconds until rate limit resets.""" + if not self._buckets[key]: + return 0.0 + oldest = min(self._buckets[key]) + reset = oldest + self.window_seconds - time.time() + return max(0.0, reset) + + +class SecurityMiddleware(BaseHTTPMiddleware): + """ + FastAPI middleware for security checks. + + Applies rate limiting and basic request validation. + """ + + def __init__( + self, + app: Any, + config: Optional[SecurityConfig] = None, + rate_limiter: Optional[RateLimiter] = None, + ): + super().__init__(app) + self.config = config or SecurityConfig() + self.rate_limiter = rate_limiter or RateLimiter( + max_requests=self.config.rate_limit_requests, + window_seconds=self.config.rate_limit_window_seconds, + ) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # Get client identifier (API key or IP) + api_key = request.headers.get("X-API-Key") + client_id = api_key or (request.client.host if request.client else "unknown") + + # Rate limit check + if not self.rate_limiter.is_allowed(client_id): + remaining = self.rate_limiter.get_remaining(client_id) + reset_time = self.rate_limiter.get_reset_time(client_id) + + logger.warning( + "Rate limit exceeded for client %s", + client_id[:16] + "..." if len(client_id) > 16 else client_id, + ) + + return Response( + content='{"detail": "Rate limit exceeded"}', + status_code=429, + headers={ + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(reset_time)), + "Retry-After": str(int(reset_time) + 1), + "Content-Type": "application/json", + }, + ) + + # Content-Length check + content_length = request.headers.get("Content-Length") + if content_length: + try: + size = int(content_length) + if size > self.config.max_payload_size_bytes: + return Response( + content='{"detail": "Payload too large"}', + status_code=413, + headers={"Content-Type": "application/json"}, + ) + except ValueError: + pass + + # Add security headers to response + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + + # Add rate limit headers + remaining = self.rate_limiter.get_remaining(client_id) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Limit"] = str(self.config.rate_limit_requests) + + return response + + +def validate_payload_size( + data: bytes, + max_size: int = 50 * 1024 * 1024, +) -> tuple[bool, str]: + """ + Validate payload size. + + Args: + data: Raw payload bytes + max_size: Maximum allowed size in bytes + + Returns: + (is_valid, error_message) + """ + if len(data) > max_size: + return False, f"Payload size ({len(data)} bytes) exceeds maximum ({max_size} bytes)" + return True, "" + + +def validate_bbox( + bbox: list[float], + max_area: float = 100.0, +) -> tuple[bool, str]: + """ + Validate bounding box coordinates. + + Checks: + - Exactly 4 values + - Valid longitude/latitude ranges + - West < East, South < North + - Area within limits (prevent DoS via huge queries) + + Args: + bbox: [west, south, east, north] + max_area: Maximum area in square degrees + + Returns: + (is_valid, error_message) + """ + if len(bbox) != 4: + return False, "bbox must have exactly 4 values: [west, south, east, north]" + + west, south, east, north = bbox + + # Type check + for i, v in enumerate(bbox): + if not isinstance(v, (int, float)): + return False, f"bbox[{i}] must be a number, got {type(v).__name__}" + + # Longitude range + if not (-180 <= west <= 180): + return False, f"Invalid west longitude: {west}. Must be between -180 and 180" + if not (-180 <= east <= 180): + return False, f"Invalid east longitude: {east}. Must be between -180 and 180" + + # Latitude range + if not (-90 <= south <= 90): + return False, f"Invalid south latitude: {south}. Must be between -90 and 90" + if not (-90 <= north <= 90): + return False, f"Invalid north latitude: {north}. Must be between -90 and 90" + + # Order check + if west >= east: + return False, f"West ({west}) must be less than east ({east})" + if south >= north: + return False, f"South ({south}) must be less than north ({north})" + + # Area check (prevent huge GEE queries) + area = (east - west) * (north - south) + if area > max_area: + return False, f"Bounding box area ({area:.2f} sq degrees) exceeds maximum ({max_area} sq degrees)" + + return True, "" + + +def validate_file_upload( + content: bytes, + filename: str, + config: Optional[SecurityConfig] = None, +) -> tuple[bool, str]: + """ + Validate uploaded file. + + Checks: + - Filename length and characters + - File extension whitelist + - Magic bytes match expected type + - No path traversal in filename + + Args: + content: File content bytes + filename: Original filename + config: Security configuration + + Returns: + (is_valid, error_message) + """ + config = config or SecurityConfig() + + # Filename length + if len(filename) > config.max_filename_length: + return False, f"Filename too long ({len(filename)} > {config.max_filename_length})" + + # Path traversal check + if ".." in filename or "/" in filename or "\\" in filename: + return False, "Invalid filename: path traversal detected" + + # Extension check + ext = Path(filename).suffix.lower() + if ext not in config.allowed_file_extensions: + return False, f"File extension '{ext}' not allowed. Allowed: {config.allowed_file_extensions}" + + # Magic bytes validation + detected_type = None + for signature, mime_type in FILE_SIGNATURES.items(): + if content.startswith(signature): + detected_type = mime_type + break + + if detected_type is None: + return False, "Unknown file type. Could not verify magic bytes." + + # Check extension matches detected type + expected_extensions = { + "image/png": {".png"}, + "image/jpeg": {".jpg", ".jpeg"}, + "image/tiff": {".tif", ".tiff", ".geotiff"}, + } + + if ext not in expected_extensions.get(detected_type, set()): + return False, f"File extension '{ext}' does not match detected type '{detected_type}'" + + return True, "" + + +def sanitize_string_input( + value: str, + max_length: int = 1000, + config: Optional[SecurityConfig] = None, +) -> tuple[str, list[str]]: + """ + Sanitize string input by removing potentially dangerous patterns. + + Args: + value: Input string + max_length: Maximum allowed length + config: Security configuration + + Returns: + (sanitized_value, list of warnings) + """ + config = config or SecurityConfig() + warnings = [] + + # Length limit + if len(value) > max_length: + value = value[:max_length] + warnings.append(f"Input truncated to {max_length} characters") + + # Check for blocked patterns + original = value + for pattern in config.blocked_patterns: + if re.search(pattern, value, re.IGNORECASE): + value = re.sub(pattern, "", value, flags=re.IGNORECASE) + warnings.append(f"Removed blocked pattern: {pattern}") + + # HTML entity encoding for special characters + value = ( + value.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + if value != original and not warnings: + warnings.append("Input was sanitized") + + return value, warnings + + +def generate_request_hash(request: Request) -> str: + """Generate a unique hash for a request for logging/tracking.""" + components = [ + request.method, + str(request.url), + request.headers.get("User-Agent", ""), + request.headers.get("X-API-Key", "")[:8] if request.headers.get("X-API-Key") else "", + str(time.time()), + ] + return hashlib.sha256("|".join(components).encode()).hexdigest()[:16] + + +def check_api_key_format(api_key: str) -> tuple[bool, str]: + """ + Validate API key format. + + Args: + api_key: The API key to validate + + Returns: + (is_valid, error_message) + """ + if not api_key: + return False, "API key is required" + + if len(api_key) < 16: + return False, "API key too short" + + if len(api_key) > 128: + return False, "API key too long" + + # Only alphanumeric and limited special chars + if not re.match(r"^[a-zA-Z0-9_\-]+$", api_key): + return False, "API key contains invalid characters" + + return True, "" diff --git a/src/climatevision/security/pipeline_guard.py b/src/climatevision/security/pipeline_guard.py new file mode 100644 index 0000000..515b757 --- /dev/null +++ b/src/climatevision/security/pipeline_guard.py @@ -0,0 +1,382 @@ +""" +Inference Pipeline Security Guard for ClimateVision. + +Detects adversarial inputs and validates model outputs: +- Statistical anomaly detection on input images +- Confidence threshold enforcement +- Output distribution validation +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class AnomalyResult: + """Result of anomaly detection check.""" + + is_anomalous: bool + anomaly_score: float + anomaly_type: Optional[str] = None + details: Optional[dict[str, Any]] = None + recommendation: str = "" + + +@dataclass +class OutputValidation: + """Result of model output validation.""" + + is_valid: bool + confidence: float + issues: list[str] + recommendation: str = "" + + +class InputAnomalyDetector: + """ + Detects anomalous inputs that may indicate adversarial attacks. + + Uses statistical analysis to identify: + - Unusual pixel distributions + - Out-of-range values + - Suspicious patterns (noise, uniform regions) + """ + + def __init__( + self, + min_pixel_value: float = -10.0, + max_pixel_value: float = 10.0, + min_std: float = 0.01, + max_std: float = 5.0, + uniform_threshold: float = 0.95, + ): + self.min_pixel_value = min_pixel_value + self.max_pixel_value = max_pixel_value + self.min_std = min_std + self.max_std = max_std + self.uniform_threshold = uniform_threshold + + def detect(self, image: np.ndarray) -> AnomalyResult: + """ + Analyze image for anomalies. + + Args: + image: Input image array (C, H, W) or (H, W, C) or (H, W) + + Returns: + AnomalyResult with detection details + """ + # Normalize shape + if image.ndim == 2: + image = image[np.newaxis, :, :] + elif image.ndim == 3 and image.shape[2] < image.shape[0]: + image = np.transpose(image, (2, 0, 1)) + + issues = [] + anomaly_score = 0.0 + + # Check 1: Value range + min_val = float(np.min(image)) + max_val = float(np.max(image)) + + if min_val < self.min_pixel_value or max_val > self.max_pixel_value: + issues.append(f"Pixel values out of range: [{min_val:.2f}, {max_val:.2f}]") + anomaly_score += 0.3 + + # Check 2: Standard deviation (too uniform or too noisy) + std_val = float(np.std(image)) + + if std_val < self.min_std: + issues.append(f"Image too uniform (std={std_val:.4f})") + anomaly_score += 0.4 + elif std_val > self.max_std: + issues.append(f"Image too noisy (std={std_val:.4f})") + anomaly_score += 0.3 + + # Check 3: NaN or Inf values + if np.any(np.isnan(image)): + issues.append("Image contains NaN values") + anomaly_score += 0.5 + if np.any(np.isinf(image)): + issues.append("Image contains Inf values") + anomaly_score += 0.5 + + # Check 4: Uniform regions (potential adversarial patch) + for c in range(image.shape[0]): + channel = image[c] + unique_ratio = len(np.unique(channel)) / channel.size + if unique_ratio < (1 - self.uniform_threshold): + issues.append(f"Channel {c} has suspicious uniform regions") + anomaly_score += 0.2 + + # Check 5: Gradient analysis (adversarial often has unusual gradients) + gradient_x = np.abs(np.diff(image, axis=2)).mean() + gradient_y = np.abs(np.diff(image, axis=1)).mean() + + if gradient_x < 0.001 and gradient_y < 0.001: + issues.append("Suspiciously low gradient (constant image)") + anomaly_score += 0.3 + elif gradient_x > 2.0 or gradient_y > 2.0: + issues.append("Unusually high gradient (possible noise injection)") + anomaly_score += 0.2 + + # Clamp score + anomaly_score = min(1.0, anomaly_score) + is_anomalous = anomaly_score >= 0.5 + + details = { + "min_value": min_val, + "max_value": max_val, + "std": std_val, + "gradient_x": float(gradient_x), + "gradient_y": float(gradient_y), + "shape": list(image.shape), + } + + recommendation = "" + if is_anomalous: + recommendation = "Input flagged as potentially adversarial. Manual review recommended." + + return AnomalyResult( + is_anomalous=is_anomalous, + anomaly_score=anomaly_score, + anomaly_type="statistical_anomaly" if is_anomalous else None, + details=details, + recommendation=recommendation, + ) + + +class PipelineGuard: + """ + Guards the inference pipeline against adversarial inputs and poisoned outputs. + + Wraps model inference to validate inputs before processing + and outputs before returning to the client. + """ + + def __init__( + self, + min_confidence: float = 0.3, + max_single_class_ratio: float = 0.99, + enable_input_check: bool = True, + enable_output_check: bool = True, + ): + self.min_confidence = min_confidence + self.max_single_class_ratio = max_single_class_ratio + self.enable_input_check = enable_input_check + self.enable_output_check = enable_output_check + self.anomaly_detector = InputAnomalyDetector() + + def check_input(self, image: np.ndarray) -> AnomalyResult: + """Check input image for anomalies.""" + if not self.enable_input_check: + return AnomalyResult( + is_anomalous=False, + anomaly_score=0.0, + recommendation="Input checking disabled", + ) + return self.anomaly_detector.detect(image) + + def check_output( + self, + predictions: np.ndarray, + probabilities: Optional[np.ndarray] = None, + n_classes: int = 2, + ) -> OutputValidation: + """ + Validate model output. + + Args: + predictions: Class predictions (H, W) or (N, H, W) + probabilities: Class probabilities (N, C, H, W) if available + n_classes: Expected number of classes + + Returns: + OutputValidation result + """ + if not self.enable_output_check: + return OutputValidation( + is_valid=True, + confidence=1.0, + issues=[], + recommendation="Output checking disabled", + ) + + issues = [] + confidence = 1.0 + + # Check 1: Valid class values + unique_classes = np.unique(predictions) + invalid_classes = [c for c in unique_classes if c < 0 or c >= n_classes] + if invalid_classes: + issues.append(f"Invalid class values: {invalid_classes}") + confidence *= 0.5 + + # Check 2: Class distribution (suspicious if one class dominates) + total_pixels = predictions.size + for cls in range(n_classes): + ratio = np.sum(predictions == cls) / total_pixels + if ratio > self.max_single_class_ratio: + issues.append( + f"Class {cls} dominates output ({ratio:.2%}). " + "May indicate model failure or adversarial input." + ) + confidence *= 0.7 + + # Check 3: Probability confidence (if available) + if probabilities is not None: + mean_confidence = float(np.max(probabilities, axis=1).mean()) + if mean_confidence < self.min_confidence: + issues.append( + f"Low prediction confidence ({mean_confidence:.2%}). " + "Results may be unreliable." + ) + confidence *= 0.8 + + # Check for uniform probabilities (model confusion) + prob_std = float(np.std(probabilities)) + if prob_std < 0.1: + issues.append("Uniform probability distribution. Model may be confused.") + confidence *= 0.6 + + # Check 4: NaN/Inf in output + if np.any(np.isnan(predictions)): + issues.append("NaN values in predictions") + confidence = 0.0 + if np.any(np.isinf(predictions)): + issues.append("Inf values in predictions") + confidence = 0.0 + + is_valid = len(issues) == 0 or confidence >= 0.5 + + recommendation = "" + if not is_valid: + recommendation = ( + "Output validation failed. Consider rejecting this prediction " + "or flagging for manual review." + ) + elif issues: + recommendation = "Output has warnings but may still be usable." + + return OutputValidation( + is_valid=is_valid, + confidence=confidence, + issues=issues, + recommendation=recommendation, + ) + + def guard_inference( + self, + image: np.ndarray, + inference_fn: Any, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Run guarded inference with input and output validation. + + Args: + image: Input image + inference_fn: Function to call for inference + **kwargs: Additional arguments for inference_fn + + Returns: + Inference result with security metadata + """ + result = { + "input_check": None, + "output_check": None, + "inference_result": None, + "blocked": False, + "block_reason": None, + } + + # Input check + input_check = self.check_input(image) + result["input_check"] = { + "is_anomalous": input_check.is_anomalous, + "anomaly_score": input_check.anomaly_score, + "anomaly_type": input_check.anomaly_type, + "details": input_check.details, + } + + if input_check.is_anomalous: + logger.warning( + "Anomalous input detected (score=%.2f): %s", + input_check.anomaly_score, + input_check.anomaly_type, + ) + result["blocked"] = True + result["block_reason"] = input_check.recommendation + return result + + # Run inference + try: + inference_result = inference_fn(image, **kwargs) + result["inference_result"] = inference_result + except Exception as e: + logger.error("Inference failed: %s", e) + result["blocked"] = True + result["block_reason"] = f"Inference error: {str(e)}" + return result + + # Output check (if we have predictions) + if "predictions" in inference_result: + predictions = inference_result["predictions"] + probabilities = inference_result.get("probabilities") + n_classes = inference_result.get("n_classes", 2) + + output_check = self.check_output(predictions, probabilities, n_classes) + result["output_check"] = { + "is_valid": output_check.is_valid, + "confidence": output_check.confidence, + "issues": output_check.issues, + } + + if not output_check.is_valid: + logger.warning( + "Output validation failed: %s", + output_check.issues, + ) + + return result + + +def detect_adversarial_input(image: np.ndarray) -> AnomalyResult: + """ + Convenience function to detect adversarial inputs. + + Args: + image: Input image array + + Returns: + AnomalyResult + """ + detector = InputAnomalyDetector() + return detector.detect(image) + + +def validate_model_output( + predictions: np.ndarray, + probabilities: Optional[np.ndarray] = None, + n_classes: int = 2, +) -> OutputValidation: + """ + Convenience function to validate model output. + + Args: + predictions: Class predictions + probabilities: Class probabilities (optional) + n_classes: Expected number of classes + + Returns: + OutputValidation result + """ + guard = PipelineGuard() + return guard.check_output(predictions, probabilities, n_classes) diff --git a/tests/test_alert_generator.py b/tests/test_alert_generator.py new file mode 100644 index 0000000..3e87153 --- /dev/null +++ b/tests/test_alert_generator.py @@ -0,0 +1,165 @@ +"""Tests for inference.alert_generator. + +Imports the module via importlib to avoid the broken +``climatevision.inference.__init__`` -> data package chain. +""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +_PATH = ( + Path(__file__).resolve().parent.parent + / "src" + / "climatevision" + / "inference" + / "alert_generator.py" +) +_spec = importlib.util.spec_from_file_location("cv_alert_generator", _PATH) +ag = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +sys.modules["cv_alert_generator"] = ag +_spec.loader.exec_module(ag) + + +def _amazon_subscription(**overrides): + base = dict( + org_id=1, + bbox=(-60.0, -15.0, -45.0, 5.0), + analysis_type="deforestation", + alert_threshold=0.15, + channels=("log",), + cooldown_minutes=60, + ) + base.update(overrides) + return ag.Subscription(**base) + + +def _frozen_clock(start: datetime): + state = {"now": start} + + def clock(): + return state["now"] + + def advance(minutes: int): + state["now"] = state["now"] + timedelta(minutes=minutes) + + return clock, advance + + +def test_alert_fires_when_threshold_exceeded(tmp_path): + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription()], + alert_log_path=tmp_path / "alerts.jsonl", + delivery={"log": ag.log_channel}, + ) + alerts = gen.evaluate( + analysis_type="deforestation", + bbox=(-55.0, -10.0, -50.0, 0.0), + measured_value=0.30, + ) + assert len(alerts) == 1 + assert alerts[0].severity == "high" # 0.30 >= 0.15 * 2 + assert alerts[0].channels == ("log",) + + +def test_no_alert_below_threshold(tmp_path): + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription()], + alert_log_path=tmp_path / "alerts.jsonl", + ) + alerts = gen.evaluate( + analysis_type="deforestation", + bbox=(-55.0, -10.0, -50.0, 0.0), + measured_value=0.10, + ) + assert alerts == [] + + +def test_subscription_filtered_by_analysis_type(tmp_path): + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription()], + alert_log_path=tmp_path / "alerts.jsonl", + ) + alerts = gen.evaluate( + analysis_type="flooding", + bbox=(-55.0, -10.0, -50.0, 0.0), + measured_value=0.99, + ) + assert alerts == [] + + +def test_subscription_filtered_by_bbox_overlap(tmp_path): + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription()], + alert_log_path=tmp_path / "alerts.jsonl", + ) + # Disjoint bbox over Africa + alerts = gen.evaluate( + analysis_type="deforestation", + bbox=(20.0, 0.0, 30.0, 10.0), + measured_value=0.99, + ) + assert alerts == [] + + +def test_cooldown_suppresses_duplicates(tmp_path): + start = datetime(2026, 5, 1, 12, 0, tzinfo=timezone.utc) + clock, advance = _frozen_clock(start) + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription(cooldown_minutes=30)], + alert_log_path=tmp_path / "alerts.jsonl", + clock=clock, + ) + first = gen.evaluate("deforestation", (-55.0, -10.0, -50.0, 0.0), 0.30) + advance(10) + second = gen.evaluate("deforestation", (-55.0, -10.0, -50.0, 0.0), 0.30) + advance(40) + third = gen.evaluate("deforestation", (-55.0, -10.0, -50.0, 0.0), 0.30) + + assert len(first) == 1 + assert second == [] + assert len(third) == 1 + + +def test_severity_escalation(): + assert ag._classify_severity(0.20, 0.15) == "medium" + assert ag._classify_severity(0.31, 0.15) == "high" + assert ag._classify_severity(0.46, 0.15) == "critical" + + +def test_custom_channel_delivery_called(tmp_path): + delivered: list = [] + + def fake_webhook(alert): + delivered.append(alert.alert_id) + + sub = _amazon_subscription(channels=("webhook",)) + gen = ag.AlertGenerator( + subscriptions=[sub], + alert_log_path=tmp_path / "alerts.jsonl", + delivery={"webhook": fake_webhook}, + ) + alerts = gen.evaluate("deforestation", (-55.0, -10.0, -50.0, 0.0), 0.30) + assert len(delivered) == 1 + assert delivered[0] == alerts[0].alert_id + + +def test_persisted_alerts_can_be_replayed(tmp_path): + path = tmp_path / "alerts.jsonl" + gen = ag.AlertGenerator( + subscriptions=[_amazon_subscription()], + alert_log_path=path, + ) + gen.evaluate("deforestation", (-55.0, -10.0, -50.0, 0.0), 0.30) + + fresh = ag.AlertGenerator(alert_log_path=path) + replayed = fresh.iter_alerts() + assert len(replayed) == 1 + assert replayed[0].severity in {"medium", "high", "critical"} diff --git a/tests/test_anomaly_detector.py b/tests/test_anomaly_detector.py new file mode 100644 index 0000000..f3aeae5 --- /dev/null +++ b/tests/test_anomaly_detector.py @@ -0,0 +1,85 @@ +"""Tests for governance.anomaly_detector.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from climatevision.governance.anomaly_detector import ( + AnomalyDetector, + extract_features, + write_anomaly_report, +) + + +def _normal_confidence(rng: np.random.Generator) -> np.ndarray: + return np.clip(rng.normal(0.7, 0.05, size=(64, 64)), 0.0, 1.0) + + +def _degenerate_confidence() -> np.ndarray: + return np.ones((64, 64)) * 0.999 + + +def test_extract_features_shapes_and_ranges(): + rng = np.random.default_rng(0) + feats = extract_features(_normal_confidence(rng)) + assert 0.0 <= feats.mean_confidence <= 1.0 + assert feats.std_confidence >= 0 + assert 0.0 <= feats.positive_fraction <= 1.0 + assert feats.entropy >= 0 + + +def test_extract_features_rejects_empty(): + with pytest.raises(ValueError): + extract_features(np.array([])) + + +def test_statistical_detector_flags_outlier(tmp_path): + rng = np.random.default_rng(42) + detector = AnomalyDetector(history_path=tmp_path / "history.jsonl") + + for _ in range(30): + detector.detect(_normal_confidence(rng)) + + result = detector.detect(_degenerate_confidence()) + assert result.method == "statistical" + assert result.is_anomaly + assert result.reasons + + +def test_isolation_forest_kicks_in_after_threshold(tmp_path): + rng = np.random.default_rng(7) + detector = AnomalyDetector( + history_path=tmp_path / "history.jsonl", + min_history_for_iforest=20, + ) + + for _ in range(25): + detector.detect(_normal_confidence(rng)) + + assert detector._iforest is not None + result = detector.detect(_normal_confidence(rng)) + assert result.method == "isolation_forest" + + +def test_history_persistence_roundtrip(tmp_path): + rng = np.random.default_rng(1) + history_path = tmp_path / "history.jsonl" + + d1 = AnomalyDetector(history_path=history_path) + for _ in range(5): + d1.detect(_normal_confidence(rng)) + + d2 = AnomalyDetector(history_path=history_path) + d2.load_history() + assert len(d2._history) == 5 + + +def test_write_anomaly_report(tmp_path): + rng = np.random.default_rng(3) + detector = AnomalyDetector(history_path=tmp_path / "history.jsonl") + + results = [detector.detect(_normal_confidence(rng)) for _ in range(3)] + out = write_anomaly_report(results, output_path=tmp_path / "report.json") + assert out.exists() + assert out.read_text().count("mean_confidence") == 3 diff --git a/tests/test_api_admin.py b/tests/test_api_admin.py new file mode 100644 index 0000000..f7f3b6e --- /dev/null +++ b/tests/test_api_admin.py @@ -0,0 +1,151 @@ +"""Tests for api.admin operational endpoints. + +Imports the admin module via importlib to avoid the broken +``climatevision.data`` package __init__ chain (irrelevant to admin). +""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +_PATH = ( + Path(__file__).resolve().parent.parent + / "src" + / "climatevision" + / "api" + / "admin.py" +) +_spec = importlib.util.spec_from_file_location("cv_api_admin", _PATH) +admin = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +sys.modules["cv_api_admin"] = admin +_spec.loader.exec_module(admin) + + +@pytest.fixture +def env(tmp_path, monkeypatch): + audit = tmp_path / "audit.jsonl" + alerts = tmp_path / "alerts.jsonl" + monkeypatch.setattr(admin, "DEFAULT_AUDIT_LOG", audit) + monkeypatch.setattr(admin, "DEFAULT_ALERT_LOG", alerts) + return audit, alerts + + +def _now(): + return datetime.now(timezone.utc) + + +def _write_audit(path, entries): + with path.open("a") as fh: + for e in entries: + fh.write(json.dumps(e) + "\n") + + +def _make_audit_entry(minutes_ago: int, mean_conf: float, positive: float, error: bool = False): + ts = _now() - timedelta(minutes=minutes_ago) + return { + "timestamp": ts.isoformat(), + "model_version": "v1", + "input_hash": "abc", + "output_summary": {"mean_confidence": mean_conf, "positive_fraction": positive}, + "request_id": None, + "user_id": None, + "prev_hash": "0" * 64, + "entry_hash": "x", + "metadata": {}, + **({"error": "boom"} if error else {}), + } + + +def _make_alert(minutes_ago: int, severity: str = "high"): + ts = _now() - timedelta(minutes=minutes_ago) + return { + "alert_id": "id", + "org_id": 1, + "analysis_type": "deforestation", + "region_bbox": [-60, -15, -45, 5], + "severity": severity, + "measured_value": 0.3, + "threshold": 0.15, + "summary": "test", + "triggered_at": ts.isoformat(), + "channels": ["log"], + } + + +def _client(): + app = FastAPI() + app.include_router(admin.router) + return TestClient(app) + + +def test_reports_returns_zeros_for_empty_logs(env): + client = _client() + resp = client.get("/api/reports?window_hours=24") + assert resp.status_code == 200 + body = resp.json() + assert body["run_count"] == 0 + assert body["error_rate"] == 0.0 + assert body["mean_confidence"] is None + assert body["alert_count"] == 0 + + +def test_reports_aggregates_within_window(env): + audit, alerts = env + _write_audit(audit, [ + _make_audit_entry(minutes_ago=10, mean_conf=0.8, positive=0.2), + _make_audit_entry(minutes_ago=30, mean_conf=0.9, positive=0.4), + _make_audit_entry(minutes_ago=60 * 48, mean_conf=0.5, positive=0.1), # outside + _make_audit_entry(minutes_ago=20, mean_conf=0.7, positive=0.3, error=True), + ]) + _write_audit(alerts, [_make_alert(15), _make_alert(60 * 48)]) + + client = _client() + body = client.get("/api/reports?window_hours=24").json() + + assert body["run_count"] == 3 + assert pytest.approx(body["error_rate"], rel=1e-3) == 1 / 3 + assert pytest.approx(body["mean_confidence"], rel=1e-3) == (0.8 + 0.9 + 0.7) / 3 + assert pytest.approx(body["positive_fraction_mean"], rel=1e-3) == (0.2 + 0.4 + 0.3) / 3 + assert body["alert_count"] == 1 + + +def test_reports_rejects_zero_window(env): + client = _client() + resp = client.get("/api/reports?window_hours=0") + assert resp.status_code == 422 + + +def test_anomalies_lists_all_when_unfiltered(env): + _, alerts = env + _write_audit(alerts, [_make_alert(5, "high"), _make_alert(10, "medium")]) + body = _client().get("/api/anomalies").json() + assert body["count"] == 2 + + +def test_anomalies_filters_by_severity(env): + _, alerts = env + _write_audit(alerts, [_make_alert(5, "high"), _make_alert(10, "medium")]) + body = _client().get("/api/anomalies?severity=high").json() + assert body["count"] == 1 + assert body["anomalies"][0]["severity"] == "high" + + +def test_anomalies_filters_by_window(env): + _, alerts = env + _write_audit(alerts, [_make_alert(5, "high"), _make_alert(60 * 48, "high")]) + body = _client().get("/api/anomalies?window_hours=1").json() + assert body["count"] == 1 + + +def test_anomalies_rejects_invalid_severity(env): + resp = _client().get("/api/anomalies?severity=blah") + assert resp.status_code == 422 diff --git a/tests/test_audit_logger.py b/tests/test_audit_logger.py new file mode 100644 index 0000000..85eda90 --- /dev/null +++ b/tests/test_audit_logger.py @@ -0,0 +1,107 @@ +"""Tests for governance.audit_logger.""" + +from __future__ import annotations + +import json + +import numpy as np +import pytest + +from climatevision.governance.audit_logger import ( + GENESIS_HASH, + AuditLogger, + log_prediction, +) + + +def _fake_inputs(): + rng = np.random.default_rng(0) + image = rng.integers(0, 255, size=(4, 32, 32), dtype=np.uint8) + output = rng.uniform(0, 1, size=(32, 32)).astype(np.float32) + return image, output + + +def test_first_entry_chains_to_genesis(tmp_path): + log = AuditLogger(log_path=tmp_path / "audit.jsonl") + image, output = _fake_inputs() + entry = log.log_prediction( + model_version="unet-v0.1.0", + input_data=image, + output=output, + request_id="r-1", + ) + assert entry.prev_hash == GENESIS_HASH + assert len(entry.entry_hash) == 64 + + +def test_chain_links_correctly(tmp_path): + log = AuditLogger(log_path=tmp_path / "audit.jsonl") + image, output = _fake_inputs() + + e1 = log.log_prediction(model_version="v1", input_data=image, output=output) + e2 = log.log_prediction(model_version="v1", input_data=image, output=output) + e3 = log.log_prediction(model_version="v2", input_data=image, output=output) + + assert e2.prev_hash == e1.entry_hash + assert e3.prev_hash == e2.entry_hash + + ok, failure = log.verify_chain() + assert ok is True + assert failure is None + + +def test_tampered_entry_breaks_chain(tmp_path): + path = tmp_path / "audit.jsonl" + log = AuditLogger(log_path=path) + image, output = _fake_inputs() + log.log_prediction(model_version="v1", input_data=image, output=output) + log.log_prediction(model_version="v1", input_data=image, output=output) + + lines = path.read_text().splitlines() + record = json.loads(lines[0]) + record["model_version"] = "tampered" + lines[0] = json.dumps(record, sort_keys=True) + path.write_text("\n".join(lines) + "\n") + + fresh = AuditLogger(log_path=path) + ok, failure = fresh.verify_chain() + assert ok is False + assert failure is not None + + +def test_resumes_chain_across_logger_instances(tmp_path): + path = tmp_path / "audit.jsonl" + image, output = _fake_inputs() + + AuditLogger(log_path=path).log_prediction( + model_version="v1", input_data=image, output=output + ) + new_logger = AuditLogger(log_path=path) + e2 = new_logger.log_prediction( + model_version="v1", input_data=image, output=output + ) + + entries = new_logger.iter_entries() + assert len(entries) == 2 + assert e2.prev_hash == entries[0].entry_hash + + +def test_module_level_helper_writes_to_default_path(tmp_path, monkeypatch): + target = tmp_path / "audit.jsonl" + monkeypatch.setattr( + "climatevision.governance.audit_logger._DEFAULT_AUDIT_LOG", target + ) + image, output = _fake_inputs() + entry = log_prediction(model_version="v1", input_data=image, output=output) + assert target.exists() + assert entry.model_version == "v1" + + +def test_dict_input_and_output_are_supported(tmp_path): + log = AuditLogger(log_path=tmp_path / "audit.jsonl") + entry = log.log_prediction( + model_version="v1", + input_data={"bbox": [-60, -15, -45, 5], "date": "2026-04-01"}, + output={"hectares": 1247.0, "carbon_tonnes": 4321.0}, + ) + assert entry.output_summary["hectares"] == pytest.approx(1247.0) diff --git a/tests/test_band_mapping.py b/tests/test_band_mapping.py new file mode 100644 index 0000000..4f5832a --- /dev/null +++ b/tests/test_band_mapping.py @@ -0,0 +1,83 @@ +"""Smoke tests for analysis-aware Sentinel-2 band mapping.""" +from __future__ import annotations + +import pytest + +from climatevision.data.band_mapping import ( + SCL_BAND, + SENTINEL2_BAND_ORDER, + get_band_indices, + get_bands_for_analysis, + get_bands_for_analysis_with_scl, + get_model_config, + is_analysis_enabled, + list_enabled_analysis_types, +) + + +def test_sentinel2_band_order_has_13_bands(): + assert len(SENTINEL2_BAND_ORDER) == 13 + assert SENTINEL2_BAND_ORDER[0] == "B01" + assert SENTINEL2_BAND_ORDER[-1] == "B12" + + +def test_deforestation_uses_four_bands(): + bands = get_bands_for_analysis("deforestation") + assert len(bands) == 4 + assert set(bands) == {"B02", "B03", "B04", "B08"} + + +def test_flooding_uses_three_bands(): + bands = get_bands_for_analysis("flooding") + assert len(bands) == 3 + assert "B11" in bands + + +def test_ice_melting_uses_swir(): + bands = get_bands_for_analysis("ice_melting") + assert "B11" in bands + + +def test_scl_appended_for_cloud_masking(): + bands = get_bands_for_analysis_with_scl("deforestation") + assert SCL_BAND in bands + assert bands[-1] == SCL_BAND + + +def test_scl_not_duplicated(): + bands_with_scl = get_bands_for_analysis_with_scl("deforestation") + bands_again = get_bands_for_analysis_with_scl("deforestation") + assert bands_with_scl.count(SCL_BAND) == 1 + assert bands_again.count(SCL_BAND) == 1 + + +def test_band_indices_resolve_correctly(): + indices = get_band_indices(["B04", "B03", "B02", "B08"]) + assert indices == [3, 2, 1, 7] + + +def test_band_indices_rejects_unknown(): + with pytest.raises(ValueError, match="Unknown"): + get_band_indices(["B99"]) + + +def test_band_indices_rejects_scl_directly(): + with pytest.raises(ValueError, match="SCL"): + get_band_indices([SCL_BAND]) + + +def test_enabled_analysis_types_include_active_three(): + enabled = list_enabled_analysis_types() + for name in ("deforestation", "ice_melting", "flooding"): + assert name in enabled, f"{name} should be enabled" + + +def test_disabled_analysis_types(): + assert not is_analysis_enabled("drought") + assert not is_analysis_enabled("wildfire") + + +def test_model_config_carries_channels_and_classes(): + cfg = get_model_config("flooding") + assert cfg["in_channels"] == 3 + assert cfg["num_classes"] == 3 diff --git a/tests/test_batch_processor.py b/tests/test_batch_processor.py new file mode 100644 index 0000000..d4d17d5 --- /dev/null +++ b/tests/test_batch_processor.py @@ -0,0 +1,144 @@ +"""Tests for inference.batch_processor. + +Imports the module directly via importlib to avoid the +``climatevision.inference`` package __init__ pulling in the rest of the +inference pipeline at test-collection time. Once the data package +__init__ is repaired we can drop the importlib shim. +""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path + +import numpy as np +import pytest + +_BATCH_PATH = ( + Path(__file__).resolve().parent.parent + / "src" + / "climatevision" + / "inference" + / "batch_processor.py" +) +_spec = importlib.util.spec_from_file_location("cv_batch_processor", _BATCH_PATH) +batch = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +sys.modules["cv_batch_processor"] = batch +_spec.loader.exec_module(batch) + +BatchProcessor = batch.BatchProcessor +BatchSummary = batch.BatchSummary + + +def _ok_inference(source, analysis_type): + return { + "hectares": 10.0, + "carbon_tonnes": 35.0, + "mean_confidence": 0.82, + "mask": np.ones((4, 4), dtype=np.uint8), + } + + +def _flaky_inference(state): + def _fn(source, analysis_type): + state["calls"] += 1 + if state["calls"] < 2: + raise RuntimeError("transient") + return {"hectares": 1.0, "carbon_tonnes": 3.0} + return _fn + + +def _always_fail(source, analysis_type): + raise ValueError(f"bad source: {source}") + + +def test_run_succeeds_for_all_jobs(tmp_path): + proc = BatchProcessor( + max_workers=2, + manifest_path=tmp_path / "manifest.jsonl", + inference_fn=_ok_inference, + ) + jobs, summary = proc.run(["a.tif", "b.tif", "c.tif"]) + + assert summary.total == 3 + assert summary.succeeded == 3 + assert summary.failed == 0 + assert all(j.status == "succeeded" for j in jobs) + assert all(j.duration_ms is not None and j.duration_ms >= 0 for j in jobs) + assert all(j.attempts == 1 for j in jobs) + + +def test_failed_jobs_are_isolated(tmp_path): + proc = BatchProcessor( + max_workers=2, + manifest_path=tmp_path / "manifest.jsonl", + inference_fn=_always_fail, + ) + jobs, summary = proc.run(["a.tif", "b.tif"]) + + assert summary.failed == 2 + assert summary.succeeded == 0 + assert all(j.status == "failed" and j.error.startswith("ValueError") for j in jobs) + + +def test_retry_succeeds_after_transient_failure(tmp_path): + state = {"calls": 0} + proc = BatchProcessor( + max_workers=1, + max_attempts=3, + manifest_path=tmp_path / "manifest.jsonl", + inference_fn=_flaky_inference(state), + ) + jobs, summary = proc.run(["only.tif"]) + assert summary.succeeded == 1 + assert jobs[0].attempts == 2 + + +def test_manifest_records_each_job(tmp_path): + manifest = tmp_path / "manifest.jsonl" + proc = BatchProcessor( + max_workers=2, + manifest_path=manifest, + inference_fn=_ok_inference, + ) + proc.run(["a.tif", "b.tif"]) + + lines = [json.loads(l) for l in manifest.read_text().splitlines() if l.strip()] + assert len(lines) == 2 + statuses = {l["status"] for l in lines} + assert statuses == {"succeeded"} + for line in lines: + assert line["result_summary"]["hectares"] == 10.0 + assert line["result_summary"]["positive_pixels"] == 16 + + +def test_get_job_returns_record(tmp_path): + proc = BatchProcessor( + manifest_path=tmp_path / "manifest.jsonl", + inference_fn=_ok_inference, + ) + jobs, _ = proc.run(["a.tif"]) + fetched = proc.get_job(jobs[0].job_id) + assert fetched is not None + assert fetched.status == "succeeded" + + +def test_dict_source_roundtrips(tmp_path): + captured = {} + + def fn(source, analysis_type): + captured["source"] = source + captured["analysis_type"] = analysis_type + return {"hectares": 0.0, "carbon_tonnes": 0.0} + + proc = BatchProcessor( + manifest_path=tmp_path / "manifest.jsonl", + inference_fn=fn, + ) + jobs, summary = proc.run([{"bbox": [0, 0, 1, 1], "date": "2026-01-01"}], analysis_type="flooding") + assert summary.succeeded == 1 + assert captured["analysis_type"] == "flooding" + assert captured["source"]["bbox"] == [0, 0, 1, 1] diff --git a/tests/test_calibration.py b/tests/test_calibration.py new file mode 100644 index 0000000..ff265c5 --- /dev/null +++ b/tests/test_calibration.py @@ -0,0 +1,125 @@ +"""Tests for governance.calibration.""" + +from __future__ import annotations + +import json + +import numpy as np +import pytest + +from climatevision.governance.calibration import ( + CalibrationReport, + brier_score, + evaluate_calibration, + expected_calibration_error, + maximum_calibration_error, + reliability_bins, + write_calibration_report, +) + + +def _perfectly_calibrated(n: int = 10_000, seed: int = 0): + rng = np.random.default_rng(seed) + probs = rng.uniform(0.0, 1.0, size=n) + targets = (rng.uniform(0.0, 1.0, size=n) < probs).astype(np.int32) + return probs, targets + + +def _overconfident(n: int = 10_000, seed: int = 1): + rng = np.random.default_rng(seed) + probs = rng.uniform(0.8, 1.0, size=n) + targets = (rng.uniform(0.0, 1.0, size=n) < 0.5).astype(np.int32) + return probs, targets + + +def test_reliability_bins_partition_inputs(): + probs, targets = _perfectly_calibrated() + bins = reliability_bins(probs, targets, n_bins=10) + assert len(bins) == 10 + assert sum(b.count for b in bins) == probs.size + for b in bins: + assert 0.0 <= b.lower < b.upper <= 1.0 + + +def test_perfectly_calibrated_has_low_ece(): + probs, targets = _perfectly_calibrated() + bins = reliability_bins(probs, targets, n_bins=15) + assert expected_calibration_error(bins) < 0.05 + + +def test_overconfident_has_high_ece(): + probs, targets = _overconfident() + bins = reliability_bins(probs, targets, n_bins=15) + assert expected_calibration_error(bins) > 0.2 + + +def test_mce_is_at_least_ece(): + probs, targets = _overconfident() + bins = reliability_bins(probs, targets, n_bins=15) + assert maximum_calibration_error(bins) >= expected_calibration_error(bins) + + +def test_brier_score_zero_for_certain_correct_predictions(): + probs = np.array([1.0, 0.0, 1.0, 0.0]) + targets = np.array([1, 0, 1, 0]) + assert brier_score(probs, targets) == pytest.approx(0.0) + + +def test_brier_score_one_for_certain_wrong_predictions(): + probs = np.array([1.0, 0.0, 1.0, 0.0]) + targets = np.array([0, 1, 0, 1]) + assert brier_score(probs, targets) == pytest.approx(1.0) + + +def test_evaluate_calibration_returns_report_with_bins(): + probs, targets = _perfectly_calibrated() + report = evaluate_calibration( + probs, targets, model_version="unet-test-1", n_bins=10 + ) + assert isinstance(report, CalibrationReport) + assert report.model_version == "unet-test-1" + assert report.n_samples == probs.size + assert report.n_bins == 10 + assert len(report.bins) == 10 + assert 0.0 <= report.ece <= 1.0 + assert 0.0 <= report.brier_score <= 1.0 + + +def test_well_calibrated_threshold(): + probs, targets = _perfectly_calibrated() + report = evaluate_calibration(probs, targets, model_version="v") + assert report.is_well_calibrated(ece_threshold=0.05) + bad_probs, bad_targets = _overconfident() + bad = evaluate_calibration(bad_probs, bad_targets, model_version="v") + assert not bad.is_well_calibrated(ece_threshold=0.05) + + +def test_validates_probability_range(): + with pytest.raises(ValueError, match="probabilities must lie in"): + evaluate_calibration( + np.array([1.5, 0.5]), np.array([1, 0]), model_version="v" + ) + + +def test_validates_binary_targets(): + with pytest.raises(ValueError, match="targets must be binary"): + evaluate_calibration( + np.array([0.5, 0.5]), np.array([1, 2]), model_version="v" + ) + + +def test_validates_shape_match(): + with pytest.raises(ValueError, match="same shape"): + evaluate_calibration( + np.array([0.5, 0.5]), np.array([1, 0, 1]), model_version="v" + ) + + +def test_write_calibration_report_round_trips_json(tmp_path): + probs, targets = _perfectly_calibrated(n=1000) + report = evaluate_calibration(probs, targets, model_version="v0.1") + out = write_calibration_report(report, tmp_path / "calib.json") + loaded = json.loads(out.read_text()) + assert loaded["model_version"] == "v0.1" + assert loaded["n_samples"] == 1000 + assert len(loaded["bins"]) == report.n_bins diff --git a/tests/test_datasheet.py b/tests/test_datasheet.py new file mode 100644 index 0000000..45a83f9 --- /dev/null +++ b/tests/test_datasheet.py @@ -0,0 +1,148 @@ +"""Tests for governance.datasheet.""" + +from __future__ import annotations + +import json + +import pytest + +from climatevision.governance.datasheet import ( + Datasheet, + build_datasheet, + generate, + render_markdown, + write_datasheet, +) + + +def _valid_manifest() -> dict: + return { + "name": "sentinel2-deforestation", + "version": "1.0.0", + "motivation": { + "purpose": "Detect Amazon basin deforestation events from Sentinel-2.", + "creators": "ClimateVision Data Pipeline team", + "funding": "Self-funded open-source initiative.", + }, + "composition": { + "instances": "12,480 256x256 tiles", + "labels": "Binary deforestation mask per tile", + "splits": "70/15/15 train/val/test by spatial cluster", + "label_source": "Hansen Global Forest Change v1.10", + }, + "collection_process": { + "source": "Sentinel-2 L2A via Google Earth Engine", + "timeframe": "2020-01-01 to 2023-12-31", + "consent": "Public open-data licence; no human subjects.", + }, + "preprocessing": { + "cloud_masking": "QA60 + s2cloudless threshold 0.4", + "normalisation": "Per-band z-score against training set means", + "augmentation": "Random flip / 90deg rotate at train time only", + }, + "uses": { + "intended_uses": [ + "Training U-Net segmentation models for deforestation detection.", + "Evaluating fairness of detection across forest biomes.", + ] + }, + "distribution": { + "license": "CC-BY-4.0 (derived data)", + "redistribution": "Allowed with attribution; do not redistribute raw Sentinel-2 tiles.", + }, + } + + +def test_build_datasheet_returns_typed_object(): + sheet = build_datasheet(_valid_manifest()) + assert isinstance(sheet, Datasheet) + assert sheet.name == "sentinel2-deforestation" + assert sheet.version == "1.0.0" + assert sheet.motivation["purpose"].startswith("Detect") + + +def test_inappropriate_uses_default_when_omitted(): + sheet = build_datasheet(_valid_manifest()) + assert sheet.uses["inappropriate_uses"], "default inappropriate_uses should be populated" + + +def test_inappropriate_uses_respect_override(): + manifest = _valid_manifest() + manifest["uses"]["inappropriate_uses"] = ["custom override"] + sheet = build_datasheet(manifest) + assert sheet.uses["inappropriate_uses"] == ["custom override"] + + +def test_maintenance_has_default(): + sheet = build_datasheet(_valid_manifest()) + assert "owner" in sheet.maintenance + assert "update_cadence" in sheet.maintenance + + +def test_validate_rejects_missing_required_section(): + manifest = _valid_manifest() + del manifest["motivation"]["purpose"] + with pytest.raises(ValueError, match="motivation.purpose"): + build_datasheet(manifest) + + +def test_validate_rejects_empty_required_field(): + manifest = _valid_manifest() + manifest["composition"]["labels"] = "" + with pytest.raises(ValueError, match="composition.labels"): + build_datasheet(manifest) + + +def test_validate_rejects_missing_collection_timeframe(): + manifest = _valid_manifest() + del manifest["collection_process"]["timeframe"] + with pytest.raises(ValueError, match="collection_process.timeframe"): + build_datasheet(manifest) + + +def test_render_markdown_includes_section_headings(): + sheet = build_datasheet(_valid_manifest()) + md = render_markdown(sheet) + for heading in ( + "# Datasheet:", + "## Motivation", + "## Composition", + "## Collection Process", + "## Uses", + "## Distribution", + "## Maintenance", + ): + assert heading in md, f"missing heading: {heading}" + + +def test_render_markdown_renders_lists_as_bullets(): + sheet = build_datasheet(_valid_manifest()) + md = render_markdown(sheet) + assert "- Training U-Net segmentation models" in md + + +def test_write_datasheet_round_trips_json(tmp_path): + sheet = build_datasheet(_valid_manifest()) + paths = write_datasheet(sheet, output_dir=tmp_path) + loaded = json.loads(paths["json"].read_text()) + assert loaded["name"] == sheet.name + assert loaded["composition"]["splits"] == "70/15/15 train/val/test by spatial cluster" + + +def test_generate_end_to_end(tmp_path): + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text(json.dumps(_valid_manifest())) + paths = generate(manifest_path, output_dir=tmp_path / "out") + assert paths["markdown"].exists() + assert paths["json"].exists() + assert "Datasheet:" in paths["markdown"].read_text() + + +def test_generate_loads_yaml(tmp_path): + pytest.importorskip("yaml") + import yaml + + manifest_path = tmp_path / "manifest.yaml" + manifest_path.write_text(yaml.safe_dump(_valid_manifest())) + paths = generate(manifest_path, output_dir=tmp_path / "out") + assert paths["markdown"].exists() diff --git a/tests/test_governance_ci_gate.py b/tests/test_governance_ci_gate.py new file mode 100644 index 0000000..b629199 --- /dev/null +++ b/tests/test_governance_ci_gate.py @@ -0,0 +1,116 @@ +"""Tests for scripts.governance_ci_gate.""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path + +import pytest + +_GATE_PATH = Path(__file__).resolve().parent.parent / "scripts" / "governance_ci_gate.py" +_spec = importlib.util.spec_from_file_location("governance_ci_gate", _GATE_PATH) +gate = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +sys.modules["governance_ci_gate"] = gate +_spec.loader.exec_module(gate) + + +def _good_metrics(): + return {"iou": 0.82, "f1": 0.86, "precision": 0.88, "recall": 0.85} + + +def _bad_metrics(): + return {"iou": 0.55, "f1": 0.60} + + +def _good_fairness(): + return {"score": 0.92, "disparity_regions": []} + + +def _bad_fairness(): + return {"score": 0.5, "disparity_regions": ["amazon", "congo"]} + + +def _security(high=0, critical=0): + findings = [{"severity": "high"}] * high + [{"severity": "critical"}] * critical + return {"findings": findings} + + +def _write(path: Path, payload: dict) -> Path: + path.write_text(json.dumps(payload)) + return path + + +def test_metrics_gate_passes_when_above_threshold(tmp_path): + metrics_path = _write(tmp_path / "m.json", _good_metrics()) + passed, results = gate.run_gate(metrics_path, None, None, gate._merge_thresholds(None)) + assert passed + assert all(r.passed for r in results) + + +def test_metrics_gate_fails_when_below_threshold(tmp_path): + metrics_path = _write(tmp_path / "m.json", _bad_metrics()) + passed, _ = gate.run_gate(metrics_path, None, None, gate._merge_thresholds(None)) + assert not passed + + +def test_fairness_gate_fails_on_disparity(tmp_path): + metrics = _write(tmp_path / "m.json", _good_metrics()) + fairness = _write(tmp_path / "f.json", _bad_fairness()) + passed, results = gate.run_gate(metrics, fairness, None, gate._merge_thresholds(None)) + assert not passed + failed_names = {r.name for r in results if not r.passed} + assert "fairness.score" in failed_names or "fairness.disparity_regions" in failed_names + + +def test_security_gate_fails_on_high_finding(tmp_path): + metrics = _write(tmp_path / "m.json", _good_metrics()) + security = _write(tmp_path / "s.json", _security(high=1)) + passed, _ = gate.run_gate(metrics, None, security, gate._merge_thresholds(None)) + assert not passed + + +def test_thresholds_can_be_overridden(tmp_path): + metrics = _write(tmp_path / "m.json", _bad_metrics()) + custom = {"metrics": {"iou": 0.5, "f1": 0.5}} + passed, _ = gate.run_gate(metrics, None, None, gate._merge_thresholds(custom)) + assert passed + + +def test_render_summary_includes_pass_and_fail(): + results = [ + gate.GateResult(name="a", passed=True, detail="ok"), + gate.GateResult(name="b", passed=False, detail="bad"), + ] + md = gate.render_summary(results) + assert "Governance CI Gate — FAIL" in md + assert "PASS" in md and "FAIL" in md + + +def test_main_exits_nonzero_on_failure(tmp_path): + metrics = _write(tmp_path / "m.json", _bad_metrics()) + rc = gate.main(["--metrics", str(metrics)]) + assert rc == gate.EXIT_FAIL + + +def test_main_exits_zero_on_success(tmp_path): + metrics = _write(tmp_path / "m.json", _good_metrics()) + fairness = _write(tmp_path / "f.json", _good_fairness()) + security = _write(tmp_path / "s.json", _security()) + summary_path = tmp_path / "out" / "summary.md" + rc = gate.main([ + "--metrics", str(metrics), + "--fairness", str(fairness), + "--security", str(security), + "--summary-out", str(summary_path), + ]) + assert rc == gate.EXIT_OK + assert summary_path.exists() + assert "PASS" in summary_path.read_text() + + +def test_main_exits_bad_input_on_missing_metrics(tmp_path): + rc = gate.main(["--metrics", str(tmp_path / "missing.json")]) + assert rc == gate.EXIT_BAD_INPUT diff --git a/tests/test_llm_reporter.py b/tests/test_llm_reporter.py new file mode 100644 index 0000000..98c0ae7 --- /dev/null +++ b/tests/test_llm_reporter.py @@ -0,0 +1,99 @@ +"""Tests for reports.llm_reporter.""" + +from __future__ import annotations + +import json + +import pytest + +from climatevision.reports.llm_reporter import ( + ImpactReport, + LLMReporter, + ReportContext, + generate_impact_report, + render_template, +) + + +def _ctx(): + return ReportContext( + region="amazon", + period="2026-Q1", + analysis_type="deforestation", + carbon={"hectares": 1247.5, "carbon_tonnes": 4321.2, "ci_lower": 4000.0, "ci_upper": 4600.0}, + validation={"iou": 0.81, "f1": 0.87}, + shap={"top_bands": [{"band": "NIR", "importance": 0.42}, {"band": "Red", "importance": 0.31}]}, + fairness={"score": 0.93, "disparity_regions": []}, + run_id=12345, + ) + + +def test_headline_metric_uses_carbon_when_available(): + text = _ctx().headline_metric() + assert "1,247.5 hectares" in text + assert "4,321.2 tCO2e" in text + + +def test_template_renders_all_sections(): + md = render_template(_ctx()) + for heading in [ + "# Impact Report", + "## Carbon Analytics", + "## Validation", + "## Explainability", + "## Fairness", + ]: + assert heading in md + + +def test_template_skips_shap_when_disabled(): + md = render_template(_ctx(), include_shap=False) + assert "## Explainability" not in md + + +def test_reporter_falls_back_to_template_without_llm(): + report = LLMReporter().generate(_ctx()) + assert report.provider == "template" + assert "amazon" in report.body.lower() + + +def test_reporter_uses_provided_llm_callable(): + captured = {} + + def fake_llm(prompt: str) -> str: + captured["prompt"] = prompt + return "Executive summary line.\n\n## Carbon Analytics\n- Hectares: 1247.5\n" + + report = LLMReporter(llm=fake_llm).generate(_ctx()) + assert report.provider == "llm" + assert "Executive summary line." in report.summary + assert "amazon" in captured["prompt"].lower() + + +def test_reporter_handles_llm_exception_gracefully(): + def boom(prompt: str) -> str: + raise RuntimeError("provider down") + + report = LLMReporter(llm=boom).generate(_ctx()) + assert report.provider == "template" + + +def test_generate_impact_report_writes_to_disk(tmp_path): + report = generate_impact_report( + region="amazon", + period="2026-Q1", + analysis_type="deforestation", + carbon={"hectares": 100.0, "carbon_tonnes": 350.0}, + validation={"iou": 0.7, "f1": 0.8}, + output_dir=tmp_path, + ) + md_path = tmp_path / "amazon_2026-Q1_impact.md" + json_path = tmp_path / "amazon_2026-Q1_impact.json" + + assert isinstance(report, ImpactReport) + assert md_path.exists() + assert json_path.exists() + + payload = json.loads(json_path.read_text()) + assert payload["context"]["region"] == "amazon" + assert payload["provider"] == "template" diff --git a/tests/test_model_card.py b/tests/test_model_card.py new file mode 100644 index 0000000..b3a4211 --- /dev/null +++ b/tests/test_model_card.py @@ -0,0 +1,95 @@ +"""Tests for governance.model_card.""" + +from __future__ import annotations + +import json + +import pytest + +from climatevision.governance.model_card import ( + REQUIRED_METRICS, + build_model_card, + generate, + render_markdown, + write_model_card, +) + + +def _config(): + return { + "model": {"name": "unet-deforestation", "version": "1.2.0"}, + "analysis_type": "deforestation", + "training_data": { + "regions": ["amazon", "congo"], + "tile_count": 12000, + }, + "evaluation_data": {"regions": ["southeast_asia"], "tile_count": 1500}, + } + + +def _metrics(): + return {"iou": 0.81, "f1": 0.86, "precision": 0.88, "recall": 0.85} + + +def test_build_card_uses_config_values(): + card = build_model_card(_config(), _metrics()) + assert card.name == "unet-deforestation" + assert card.version == "1.2.0" + assert card.analysis_type == "deforestation" + assert card.metrics == _metrics() + assert card.training_data["tile_count"] == 12000 + + +def test_missing_metric_raises(): + bad = {"iou": 0.5, "f1": 0.5} + with pytest.raises(ValueError): + build_model_card(_config(), bad) + + +def test_required_metric_set_is_documented(): + assert set(REQUIRED_METRICS) <= set(_metrics()) + + +def test_render_markdown_includes_all_sections(): + card = build_model_card(_config(), _metrics(), fairness_report={"score": 0.92}) + md = render_markdown(card) + for heading in [ + "# Model Card:", + "## Description", + "## Intended Use", + "## Training Data", + "## Evaluation", + "## Fairness", + "## Limitations", + "## Ethical Considerations", + "## Contact", + ]: + assert heading in md + assert "score" in md + + +def test_write_model_card_emits_md_and_json(tmp_path): + card = build_model_card(_config(), _metrics()) + paths = write_model_card(card, output_dir=tmp_path) + + assert paths["markdown"].exists() + assert paths["json"].exists() + + payload = json.loads(paths["json"].read_text()) + assert payload["version"] == "1.2.0" + assert payload["metrics"]["iou"] == pytest.approx(0.81) + + +def test_generate_loads_files_from_disk(tmp_path): + cfg_path = tmp_path / "config.json" + metrics_path = tmp_path / "metrics.json" + cfg_path.write_text(json.dumps(_config())) + metrics_path.write_text(json.dumps(_metrics())) + + paths = generate( + config=cfg_path, + metrics=metrics_path, + output_dir=tmp_path / "cards", + ) + assert paths["markdown"].exists() + assert paths["json"].exists() diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..bb1dc3b --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,124 @@ +"""Tests for models.regression. + +Imports the module via importlib because ``climatevision.models.__init__`` +pulls in torch-based U-Net code that is heavy and unrelated to the +regression module under test. +""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path + +import numpy as np +import pytest + +_PATH = ( + Path(__file__).resolve().parent.parent + / "src" + / "climatevision" + / "models" + / "regression.py" +) +_spec = importlib.util.spec_from_file_location("cv_regression", _PATH) +reg = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +sys.modules["cv_regression"] = reg +_spec.loader.exec_module(reg) + + +def _synthetic_dataset(n=400, seed=0): + rng = np.random.default_rng(seed) + X = rng.uniform(low=0.0, high=1.0, size=(n, 10)) + # biomass = 200 * NDVI + 50 * EVI + 20 * NIR + noise + y = 200 * X[:, 0] + 50 * X[:, 1] + 20 * X[:, 8] + rng.normal(0, 5, size=n) + return X, y + + +def test_biomass_to_carbon_uses_default_fraction(): + assert reg.biomass_to_carbon(100.0) == pytest.approx(47.0) + arr = reg.biomass_to_carbon(np.array([0.0, 100.0, 200.0])) + np.testing.assert_allclose(arr, [0.0, 47.0, 94.0]) + + +def test_biomass_to_co2e_round_trip(): + co2e = reg.biomass_to_co2e(100.0) + assert co2e == pytest.approx(100.0 * 0.47 * (44.0 / 12.0)) + + +def test_evaluate_regression_perfect_fit(): + y = np.array([1.0, 2.0, 3.0, 4.0]) + metrics = reg.evaluate_regression(y, y) + assert metrics.rmse == 0.0 + assert metrics.mae == 0.0 + assert metrics.r2 == pytest.approx(1.0) + assert metrics.mape == pytest.approx(0.0) + + +def test_evaluate_regression_shape_mismatch_raises(): + with pytest.raises(ValueError): + reg.evaluate_regression(np.array([1.0, 2.0]), np.array([1.0])) + + +def test_random_forest_fit_predict_evaluate(): + X, y = _synthetic_dataset(n=300) + r = reg.BiomassRegressor(model_type="random_forest", model_kwargs={"n_estimators": 50}) + r.fit(X, y) + metrics = r.evaluate(X, y) + assert metrics.rmse < 25.0 # very loose, just sanity + assert metrics.r2 > 0.5 + + importances = r.feature_importances() + assert set(importances) == set(reg.DEFAULT_FEATURE_NAMES) + assert pytest.approx(sum(importances.values()), rel=1e-6) == 1.0 + + +def test_unsupported_model_type_raises(): + with pytest.raises(ValueError): + reg.BiomassRegressor(model_type="lightgbm") + + +def test_predict_before_fit_raises(): + r = reg.BiomassRegressor() + with pytest.raises(RuntimeError): + r.predict(np.zeros((1, 10))) + + +def test_save_and_load_round_trip(tmp_path): + X, y = _synthetic_dataset(n=200) + r = reg.BiomassRegressor(model_type="random_forest", model_kwargs={"n_estimators": 30}) + r.fit(X, y) + + out = r.save(tmp_path / "rf.pkl") + assert out.exists() + + loaded = reg.BiomassRegressor.load(out) + np.testing.assert_allclose(loaded.predict(X), r.predict(X)) + + +def test_estimate_biomass_from_indices(tmp_path): + X, y = _synthetic_dataset(n=200) + r = reg.BiomassRegressor(model_kwargs={"n_estimators": 30}).fit(X, y) + + indices = {name: X[:, i] for i, name in enumerate(reg.DEFAULT_FEATURE_NAMES)} + pred = reg.estimate_biomass_from_indices(indices, r) + assert pred.shape == (X.shape[0],) + np.testing.assert_allclose(pred, r.predict(X)) + + +def test_estimate_biomass_missing_index_raises(): + r = reg.BiomassRegressor(model_kwargs={"n_estimators": 5}) + r.fit(*_synthetic_dataset(n=80)) + + incomplete = {name: np.zeros(10) for name in reg.DEFAULT_FEATURE_NAMES[:-1]} + with pytest.raises(KeyError): + reg.estimate_biomass_from_indices(incomplete, r) + + +def test_serialize_metrics_writes_json(tmp_path): + metrics = reg.RegressionMetrics(rmse=1.0, mae=0.5, r2=0.9, mape=0.05) + out = reg.serialize_metrics(metrics, tmp_path / "metrics.json") + payload = json.loads(out.read_text()) + assert payload == {"rmse": 1.0, "mae": 0.5, "r2": 0.9, "mape": 0.05} diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..e9e4fd5 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,268 @@ +""" +Security tests for ClimateVision API. + +Tests input validation, sanitization, and security controls. +""" + +import pytest +import numpy as np +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from climatevision.security import ( + validate_payload_size, + validate_bbox, + validate_file_upload, + sanitize_string_input, + SecurityConfig, + RateLimiter, + detect_adversarial_input, + validate_model_output, + InputAnomalyDetector, + PipelineGuard, +) + + +class TestPayloadValidation: + """Test payload size validation.""" + + def test_valid_payload(self): + data = b"x" * 1000 + is_valid, error = validate_payload_size(data, max_size=2000) + assert is_valid + assert error == "" + + def test_oversized_payload(self): + data = b"x" * 10000 + is_valid, error = validate_payload_size(data, max_size=1000) + assert not is_valid + assert "exceeds maximum" in error + + +class TestBboxValidation: + """Test bounding box validation.""" + + def test_valid_bbox(self): + bbox = [-60.0, -15.0, -45.0, -5.0] + is_valid, error = validate_bbox(bbox) + assert is_valid + assert error == "" + + def test_invalid_longitude(self): + bbox = [200.0, 10.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "longitude" in error.lower() + + def test_invalid_latitude(self): + bbox = [10.0, 100.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "latitude" in error.lower() + + def test_wrong_order_longitude(self): + bbox = [30.0, 10.0, 20.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "West" in error + + def test_wrong_order_latitude(self): + bbox = [10.0, 50.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "South" in error + + def test_too_large_area(self): + bbox = [-180.0, -90.0, 180.0, 90.0] + is_valid, error = validate_bbox(bbox, max_area=100.0) + assert not is_valid + assert "area" in error.lower() + + def test_wrong_element_count(self): + bbox = [10.0, 20.0, 30.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "exactly 4" in error + + +class TestFileUploadValidation: + """Test file upload validation.""" + + def test_valid_png(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "image.png") + assert is_valid + assert error == "" + + def test_valid_tiff(self): + content = b"II*\x00" + b"x" * 100 + is_valid, error = validate_file_upload(content, "satellite.tif") + assert is_valid + assert error == "" + + def test_invalid_extension(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "malware.exe") + assert not is_valid + assert "not allowed" in error + + def test_path_traversal(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "../../../etc/passwd") + assert not is_valid + assert "path traversal" in error.lower() + + def test_extension_mismatch(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "image.jpg") + assert not is_valid + assert "does not match" in error + + def test_filename_too_long(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + filename = "a" * 300 + ".png" + is_valid, error = validate_file_upload(content, filename) + assert not is_valid + assert "too long" in error + + +class TestStringSanitization: + """Test string input sanitization.""" + + def test_normal_string(self): + result, warnings = sanitize_string_input("Hello World") + assert "Hello" in result + assert len(warnings) == 0 or "sanitized" in warnings[0].lower() + + def test_sql_injection(self): + result, warnings = sanitize_string_input("'; DROP TABLE users; --") + assert "DROP TABLE" not in result + assert any("blocked" in w.lower() or "sanitized" in w.lower() for w in warnings) + + def test_xss_script(self): + result, warnings = sanitize_string_input("") + assert " 0 + + def test_path_traversal(self): + result, warnings = sanitize_string_input("../../../etc/passwd") + assert "../" not in result + + def test_truncation(self): + long_string = "a" * 2000 + result, warnings = sanitize_string_input(long_string, max_length=100) + assert len(result) <= 100 + assert any("truncated" in w.lower() for w in warnings) + + +class TestRateLimiter: + """Test rate limiting.""" + + def test_allows_under_limit(self): + limiter = RateLimiter(max_requests=5, window_seconds=60) + for _ in range(5): + assert limiter.is_allowed("test_key") + + def test_blocks_over_limit(self): + limiter = RateLimiter(max_requests=3, window_seconds=60) + for _ in range(3): + assert limiter.is_allowed("test_key") + assert not limiter.is_allowed("test_key") + + def test_separate_keys(self): + limiter = RateLimiter(max_requests=2, window_seconds=60) + assert limiter.is_allowed("key1") + assert limiter.is_allowed("key1") + assert not limiter.is_allowed("key1") + assert limiter.is_allowed("key2") # Different key + + def test_remaining_count(self): + limiter = RateLimiter(max_requests=5, window_seconds=60) + assert limiter.get_remaining("key") == 5 + limiter.is_allowed("key") + assert limiter.get_remaining("key") == 4 + + +class TestAdversarialDetection: + """Test adversarial input detection.""" + + def test_normal_image(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + result = detect_adversarial_input(image) + assert not result.is_anomalous + assert result.anomaly_score < 0.5 + + def test_uniform_image(self): + image = np.ones((4, 256, 256), dtype=np.float32) + result = detect_adversarial_input(image) + assert result.is_anomalous + assert "uniform" in str(result.details).lower() or result.anomaly_score > 0.3 + + def test_nan_values(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + image[0, 100, 100] = np.nan + result = detect_adversarial_input(image) + assert result.is_anomalous + assert result.anomaly_score >= 0.5 + + def test_inf_values(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + image[0, 100, 100] = np.inf + result = detect_adversarial_input(image) + assert result.is_anomalous + + def test_out_of_range(self): + image = np.random.randn(4, 256, 256).astype(np.float32) * 100 + result = detect_adversarial_input(image) + assert result.anomaly_score > 0 + + +class TestOutputValidation: + """Test model output validation.""" + + def test_valid_output(self): + predictions = np.random.randint(0, 2, (256, 256)) + result = validate_model_output(predictions, n_classes=2) + assert result.is_valid + assert result.confidence > 0.5 + + def test_invalid_class_values(self): + predictions = np.array([[0, 1, 5, 10]]) + result = validate_model_output(predictions, n_classes=2) + assert not result.is_valid or len(result.issues) > 0 + + def test_single_class_domination(self): + predictions = np.ones((256, 256), dtype=np.int32) + result = validate_model_output(predictions, n_classes=2) + assert len(result.issues) > 0 + assert any("dominates" in issue.lower() for issue in result.issues) + + def test_nan_in_predictions(self): + predictions = np.array([[0.0, 1.0, np.nan]]) + result = validate_model_output(predictions, n_classes=2) + assert not result.is_valid + + +class TestPipelineGuard: + """Test complete pipeline guard.""" + + def test_blocks_adversarial(self): + guard = PipelineGuard() + adversarial_image = np.ones((4, 256, 256), dtype=np.float32) * 0.5 + + result = guard.check_input(adversarial_image) + # Uniform image should be flagged + assert result.anomaly_score > 0 + + def test_passes_normal_image(self): + guard = PipelineGuard() + normal_image = np.random.randn(4, 256, 256).astype(np.float32) + + result = guard.check_input(normal_image) + assert not result.is_anomalous + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])