-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
126 lines (103 loc) · 4.25 KB
/
evaluate.py
File metadata and controls
126 lines (103 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""
evaluate.py — Run gasbench-compatible evaluation on trained models.
Evaluates a checkpoint against the local validation set using the same
sn34_score metric that BitMind Subnet 34 uses for scoring.
Usage:
uv run evaluate.py --modality image
uv run evaluate.py --modality image --model clip-vit-l14
uv run evaluate.py --modality video --weights path/to/model.safetensors
uv run evaluate.py --modality audio --model ast
"""
import argparse
import time
from pathlib import Path
import torch
from prepare import (
TARGET_IMAGE_SIZE,
TARGET_VIDEO_SIZE,
NUM_VIDEO_FRAMES,
evaluate_model,
compute_sn34_score,
)
def load_model_from_checkpoint(
modality: str,
model_name: str,
weights_path: Path,
device: str = "cuda",
):
"""Load a model from a safetensors checkpoint."""
from dfresearch.models import get_model
from safetensors.torch import load_file
model = get_model(modality, model_name, num_classes=2, pretrained=False)
state_dict = load_file(str(weights_path))
model.load_state_dict(state_dict)
model.eval()
return model.to(device)
def main():
parser = argparse.ArgumentParser(description="Evaluate deepfake detection model")
parser.add_argument("--modality", required=True, choices=["image", "video", "audio"])
parser.add_argument("--model", default=None, help="Model name (default: auto-detect from modality)")
parser.add_argument("--weights", type=Path, default=None, help="Path to model.safetensors")
parser.add_argument("--batch-size", type=int, default=None)
args = parser.parse_args()
DEFAULTS = {
"image": {"model": "efficientnet-b4", "batch_size": 64},
"video": {"model": "r3d-18", "batch_size": 4},
"audio": {"model": "wav2vec2", "batch_size": 32},
}
model_name = args.model or DEFAULTS[args.modality]["model"]
batch_size = args.batch_size or DEFAULTS[args.modality]["batch_size"]
weights_path = args.weights or Path(f"results/checkpoints/{args.modality}/model.safetensors")
if not weights_path.exists():
print(f"ERROR: Weights not found at {weights_path}")
print(f"Run train_{args.modality}.py first, or specify --weights")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Model: {model_name}")
print(f"Weights: {weights_path}")
# Load model
print("\nLoading model...")
model = load_model_from_checkpoint(args.modality, model_name, weights_path, device)
num_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {num_params / 1e6:.1f}M")
# Load validation data
from dfresearch.data import make_dataloader
print("Loading validation data...")
loader_kwargs = {
"batch_size": batch_size,
"augment_level": 0,
}
if args.modality in ("image", "video"):
target_size = TARGET_IMAGE_SIZE if args.modality == "image" else TARGET_VIDEO_SIZE
loader_kwargs["target_size"] = target_size
if args.modality == "video":
loader_kwargs["num_frames"] = NUM_VIDEO_FRAMES
val_loader = make_dataloader(args.modality, split="val", **loader_kwargs)
print(f"Validation batches: {len(val_loader)}")
# Evaluate
print("\nEvaluating...")
t0 = time.time()
metrics = evaluate_model(model, val_loader, device=device)
eval_time = time.time() - t0
# Results
print(f"\n{'=' * 60}")
print(f"Evaluation Results — {args.modality} / {model_name}")
print(f"{'=' * 60}")
print(f"sn34_score: {metrics['sn34_score']:.6f}")
print(f"accuracy: {metrics['accuracy']:.6f}")
print(f"mcc: {metrics['mcc']:.6f}")
print(f"brier: {metrics['brier']:.6f}")
print(f"mcc_norm: {metrics['mcc_norm']:.6f}")
print(f"brier_norm: {metrics['brier_norm']:.6f}")
print(f"eval_time: {eval_time:.1f}s")
print(f"{'=' * 60}")
# Competition readiness check
passing = metrics["accuracy"] >= 0.80
print(f"\nEntrance exam threshold (>=80% accuracy): {'PASS' if passing else 'FAIL'}")
if not passing:
print(" Your model needs at least 80% accuracy to pass the gasbench entrance exam.")
print(" Keep training or try a different approach.")
if __name__ == "__main__":
main()