Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 184 additions & 13 deletions src/cortex-snapshot/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::{Result, RevertPoint, Snapshot, SnapshotError};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::fs;
Expand Down Expand Up @@ -37,8 +38,8 @@ impl SnapshotStorage {

/// Load a snapshot by ID.
pub async fn load_snapshot(&self, id: &str) -> Result<Snapshot> {
let path = self.snapshot_path(id);
let json = fs::read_to_string(&path)
let json = self
.read_first_existing(self.snapshot_paths(id))
.await
.map_err(|_| SnapshotError::NotFound(id.to_string()))?;
let snapshot: Snapshot =
Expand All @@ -48,10 +49,11 @@ impl SnapshotStorage {

/// Delete a snapshot.
pub async fn delete_snapshot(&self, id: &str) -> Result<()> {
let path = self.snapshot_path(id);
if path.exists() {
fs::remove_file(&path).await?;
debug!("Deleted snapshot: {}", id);
for path in self.snapshot_paths(id) {
if path.exists() {
fs::remove_file(&path).await?;
debug!("Deleted snapshot: {}", id);
}
}
Ok(())
}
Expand Down Expand Up @@ -104,11 +106,15 @@ impl SnapshotStorage {

/// Load session revert history.
pub async fn load_revert_history(&self, session_id: &str) -> Result<Vec<RevertPoint>> {
let path = self.history_path(session_id);
if !path.exists() {
return Ok(Vec::new());
}
let json = fs::read_to_string(&path).await?;
let json = match self
.read_first_existing(self.history_paths(session_id))
.await
{
Ok(json) => json,
Err(_) => {
return Ok(Vec::new());
}
};
let history: Vec<RevertPoint> =
serde_json::from_str(&json).map_err(|e| SnapshotError::NotFound(e.to_string()))?;
Ok(history)
Expand All @@ -132,13 +138,76 @@ impl SnapshotStorage {
Ok(removed)
}

async fn read_first_existing(&self, paths: Vec<PathBuf>) -> std::io::Result<String> {
for path in paths {
if path.exists() {
return fs::read_to_string(&path).await;
}
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"file not found",
))
}

fn snapshot_paths(&self, id: &str) -> Vec<PathBuf> {
let mut paths = vec![self.snapshot_path(id)];
if let Some(path) = self.legacy_snapshot_path(id) {
paths.push(path);
}
paths
}

fn history_paths(&self, session_id: &str) -> Vec<PathBuf> {
let mut paths = vec![self.history_path(session_id)];
if let Some(path) = self.legacy_history_path(session_id) {
paths.push(path);
}
paths
}

fn snapshot_path(&self, id: &str) -> PathBuf {
self.storage_path.join(format!("{}.json", id))
self.storage_path
.join(format!("snapshot_{}.json", Self::hashed_id(id)))
}

fn history_path(&self, session_id: &str) -> PathBuf {
self.storage_path
.join(format!("history_{}.json", session_id))
.join(format!("history_{}.json", Self::hashed_id(session_id)))
}

fn legacy_snapshot_path(&self, id: &str) -> Option<PathBuf> {
if Self::is_safe_legacy_id(id) {
Some(self.storage_path.join(format!("{}.json", id)))
} else {
None
}
}

fn legacy_history_path(&self, session_id: &str) -> Option<PathBuf> {
if Self::is_safe_legacy_id(session_id) {
Some(
self.storage_path
.join(format!("history_{}.json", session_id)),
)
} else {
None
}
}

fn hashed_id(id: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(id.as_bytes());
hex::encode(hasher.finalize())
}

fn is_safe_legacy_id(id: &str) -> bool {
if id.is_empty() {
return false;
}

id.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
}
}

Expand Down Expand Up @@ -203,4 +272,106 @@ mod tests {
assert_eq!(loaded.tree_hash, snapshot.tree_hash);
assert_eq!(loaded.description, snapshot.description);
}

#[test]
fn snapshot_paths_hash_untrusted_ids_into_storage_dir() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let path = storage.snapshot_path("../escape/owned");

assert_eq!(path.parent(), Some(storage.storage_path.as_path()));

let filename = path.file_name().unwrap().to_string_lossy();
assert!(filename.starts_with("snapshot_"));
assert!(filename.ends_with(".json"));
assert!(!filename.contains(".."));
assert!(!filename.contains('/'));
assert!(!filename.contains('\\'));
}

#[test]
fn history_paths_hash_untrusted_ids_into_storage_dir() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let path = storage.history_path("session/../../escape/owned");

assert_eq!(path.parent(), Some(storage.storage_path.as_path()));

let filename = path.file_name().unwrap().to_string_lossy();
assert!(filename.starts_with("history_"));
assert!(filename.ends_with(".json"));
assert!(!filename.contains(".."));
assert!(!filename.contains('/'));
assert!(!filename.contains('\\'));
}

#[tokio::test]
async fn save_snapshot_with_path_components_does_not_escape_storage_dir() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let escape_dir = temp_dir.path().join("escape");
fs::create_dir_all(&escape_dir).await.unwrap();

let mut snapshot = Snapshot::new("test_hash".to_string());
snapshot.id = "../escape/owned".to_string();

storage.save_snapshot(&snapshot).await.unwrap();

assert!(!escape_dir.join("owned.json").exists());
assert!(storage.snapshot_path(&snapshot.id).exists());

let loaded = storage.load_snapshot(&snapshot.id).await.unwrap();
assert_eq!(loaded.id, snapshot.id);
}

#[tokio::test]
async fn load_snapshot_with_path_components_ignores_legacy_traversal_path() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let escape_dir = temp_dir.path().join("escape");
fs::create_dir_all(&escape_dir).await.unwrap();

let mut snapshot = Snapshot::new("test_hash".to_string());
snapshot.id = "../escape/owned".to_string();
let json = serde_json::to_string(&snapshot).unwrap();
fs::write(escape_dir.join("owned.json"), json)
.await
.unwrap();

assert!(storage.load_snapshot(&snapshot.id).await.is_err());
}

#[tokio::test]
async fn save_history_with_path_components_does_not_escape_storage_dir() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let escape_dir = temp_dir.path().join("escape");
fs::create_dir_all(&escape_dir).await.unwrap();

let session_id = "session/../../escape/owned";

storage.save_revert_history(session_id, &[]).await.unwrap();

assert!(!escape_dir.join("owned.json").exists());
assert!(storage.history_path(session_id).exists());

let loaded = storage.load_revert_history(session_id).await.unwrap();
assert!(loaded.is_empty());
}

#[tokio::test]
async fn load_history_with_path_components_ignores_legacy_traversal_path() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
let escape_dir = temp_dir.path().join("escape");
fs::create_dir_all(&escape_dir).await.unwrap();

let session_id = "session/../../escape/owned";
fs::write(escape_dir.join("owned.json"), "[]")
.await
.unwrap();

let loaded = storage.load_revert_history(session_id).await.unwrap();
assert!(loaded.is_empty());
}
}