Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cortex-snapshot/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub enum SnapshotError {
Io(#[from] std::io::Error),
#[error("Failed to create snapshot: {0}")]
CreateFailed(String),
#[error("Invalid snapshot storage ID: {0}")]
InvalidId(String),
#[error("Failed to restore snapshot: {0}")]
RestoreFailed(String),
#[error("Git command '{command}' timed out after {timeout_secs}s")]
Expand Down
93 changes: 82 additions & 11 deletions src/cortex-snapshot/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::{Result, RevertPoint, Snapshot, SnapshotError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::path::{Component, Path, PathBuf};
use tokio::fs;
use tracing::{debug, info};

Expand All @@ -27,7 +27,7 @@ impl SnapshotStorage {
/// Save a snapshot.
pub async fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
self.init().await?;
let path = self.snapshot_path(&snapshot.id);
let path = self.snapshot_path(&snapshot.id)?;
let json = serde_json::to_string_pretty(snapshot)
.map_err(|e| SnapshotError::CreateFailed(e.to_string()))?;
fs::write(&path, json).await?;
Expand All @@ -37,7 +37,7 @@ impl SnapshotStorage {

/// Load a snapshot by ID.
pub async fn load_snapshot(&self, id: &str) -> Result<Snapshot> {
let path = self.snapshot_path(id);
let path = self.snapshot_path(id)?;
let json = fs::read_to_string(&path)
.await
.map_err(|_| SnapshotError::NotFound(id.to_string()))?;
Expand All @@ -48,7 +48,7 @@ impl SnapshotStorage {

/// Delete a snapshot.
pub async fn delete_snapshot(&self, id: &str) -> Result<()> {
let path = self.snapshot_path(id);
let path = self.snapshot_path(id)?;
if path.exists() {
fs::remove_file(&path).await?;
debug!("Deleted snapshot: {}", id);
Expand Down Expand Up @@ -94,7 +94,7 @@ impl SnapshotStorage {
history: &[RevertPoint],
) -> Result<()> {
self.init().await?;
let path = self.history_path(session_id);
let path = self.history_path(session_id)?;
let json = serde_json::to_string_pretty(history)
.map_err(|e| SnapshotError::CreateFailed(e.to_string()))?;
fs::write(&path, json).await?;
Expand All @@ -104,7 +104,7 @@ impl SnapshotStorage {

/// Load session revert history.
pub async fn load_revert_history(&self, session_id: &str) -> Result<Vec<RevertPoint>> {
let path = self.history_path(session_id);
let path = self.history_path(session_id)?;
if !path.exists() {
return Ok(Vec::new());
}
Expand Down Expand Up @@ -132,16 +132,34 @@ impl SnapshotStorage {
Ok(removed)
}

fn snapshot_path(&self, id: &str) -> PathBuf {
self.storage_path.join(format!("{}.json", id))
fn snapshot_path(&self, id: &str) -> Result<PathBuf> {
validate_storage_id("snapshot id", id)?;
Ok(self.storage_path.join(format!("{}.json", id)))
}

fn history_path(&self, session_id: &str) -> PathBuf {
self.storage_path
.join(format!("history_{}.json", session_id))
fn history_path(&self, session_id: &str) -> Result<PathBuf> {
validate_storage_id("session id", session_id)?;
Ok(self
.storage_path
.join(format!("history_{}.json", session_id)))
}
}

fn validate_storage_id(kind: &str, id: &str) -> Result<()> {
if id.is_empty()
|| id.chars().any(|c| matches!(c, '/' | '\\' | '\0'))
|| Path::new(id)
.components()
.any(|component| !matches!(component, Component::Normal(_)))
{
return Err(SnapshotError::InvalidId(format!(
"{kind} must be a single filename segment: {id:?}"
)));
}

Ok(())
}

/// Index for fast snapshot lookup.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct SnapshotIndex {
Expand Down Expand Up @@ -203,4 +221,57 @@ mod tests {
assert_eq!(loaded.tree_hash, snapshot.tree_hash);
assert_eq!(loaded.description, snapshot.description);
}

#[tokio::test]
async fn test_snapshot_storage_rejects_traversal_snapshot_id() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());

let mut snapshot = Snapshot::new("test_hash".to_string());
snapshot.id = "../escaped_snapshot".to_string();

let result = storage.save_snapshot(&snapshot).await;
assert!(matches!(result, Err(SnapshotError::InvalidId(_))));
assert!(!temp_dir.path().join("escaped_snapshot.json").exists());
}

#[tokio::test]
async fn test_revert_history_rejects_traversal_session_id() {
let temp_dir = TempDir::new().unwrap();
let storage = SnapshotStorage::new(temp_dir.path());
storage.init().await.unwrap();
fs::create_dir_all(storage.storage_path.join("history_existing"))
.await
.unwrap();

let result = storage
.save_revert_history("existing/../../escaped_history", &[])
.await;

assert!(matches!(result, Err(SnapshotError::InvalidId(_))));
assert!(!temp_dir.path().join("escaped_history.json").exists());
}

#[test]
fn test_storage_ids_allow_plain_uuid_like_segments() {
assert!(validate_storage_id("snapshot id", "550e8400-e29b-41d4-a716-446655440000").is_ok());
assert!(validate_storage_id("session id", "session_name.1").is_ok());
}

#[test]
fn test_storage_ids_reject_path_segments() {
for id in [
"../escape",
"/absolute",
"nested/path",
r"nested\path",
".",
"..",
] {
assert!(matches!(
validate_storage_id("snapshot id", id),
Err(SnapshotError::InvalidId(_))
));
}
}
}