Skip to content

ryusudol/Centered-Kernel-Alignment

Repository files navigation

pytorch-cka

PyPI Python CI PyPI Downloads

The Fastest, Memory-efficient Python Library for CKA with Built-in Visualization

A bar chart with benchmark results in dark mode

3000% faster CKA computation across all layers of two distinct ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs

  • ⚡️ Much faster than the most popular CKA library
  • 📦 Memory-efficient minibatch CKA computation
  • 🎨 Customizable visualizations: heatmaps and line charts
  • 🧠 Supports HuggingFace models, DataParallel, and DDP
  • 🐳 Installable via pip or docker
  • 🛠️ Modern pyproject.toml packaging
  • 🤝 Python 3.10–3.14 compatibility

📦 Installation

Requires Python >= 3.10.

# Using pip
pip install pytorch-cka

# Using uv
uv add pytorch-cka

# Using docker
docker pull ghcr.io/ryusudol/pytorch-cka:latest

# From source
git clone https://github.com/ryusudol/Centered-Kernel-Alignment
cd pytorch-cka
uv sync  # or: pip install -e .

👟 Quick Start

Basic Usage

from torch.utils.data import DataLoader
from cka import CKA

pretrained_model = ...  # e.g. pretrained ResNet-18
fine_tuned_model = ...  # e.g. fine-tuned ResNet-18

layers = ["layer1", "layer2", "layer3", "fc"]

dataloader = DataLoader(..., batch_size=128)

cka = CKA(
    model1=pretrained_model,
    model2=fine_tuned_model,
    model1_name="ResNet-18 (pretrained)",
    model2_name="ResNet-18 (fine-tuned)",
    model1_layers=layers,
    model2_layers=layers,
    device="cuda"
)

# Most convenient usage (auto context manager)
cka_matrix = cka(dataloader)
cka_result = cka.export(cka_matrix)

# Or explicit control
with cka:
    cka_matrix = cka.compare(dataloader)
    cka_result = cka.export(cka_matrix)

Visualization

Heatmap

from cka import plot_cka_heatmap

fig, ax = plot_cka_heatmap(
    cka_matrix,
    layers1=layers,
    layers2=layers,
    model1_name="ResNet-18 (pretrained)",
    model2_name="ResNet-18 (random init)",
    annot=False,          # Show values in cells
    cmap="inferno",       # Colormap
    mask_upper=False,     # Mask upper triangle (symmetric matrices)
)
Self-comparison heatmap Cross-model comparison heatmap
Self-comparison Cross-model

Trend Plot

from cka import plot_cka_trend

# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
fig, ax = plot_cka_trend(
    diagonal,
    labels=["Self-similarity"],
    xlabel="Layer",
    ylabel="CKA Score",
)
Cross model CKA scores trends Multiple trends comparison
Cross Model CKA Scores Trends Multiple Trends

📚 References

  1. Kornblith, Simon, et al. "Similarity of Neural Network Representations Revisited." ICML 2019.

  2. Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do Wide and Deep Networks Learn the Same Things?" arXiv 2020. (Minibatch CKA)

  3. Wang, Tinghua, Xiaolu Dai, and Yuze Liu. "Learning with Hilbert-Schmidt Independence Criterion: A Review." Knowledge-Based Systems 2021.

  4. Hwang, Doyeon, et al. "Tracing Representation Progression: Analyzing and Enhancing Layer-Wise Similarity." arXiv 2024.

  5. Davari, MohammadReza, et al. "Reliability of CKA as a Similarity Measure in Deep Learning." ICLR 2023.

  6. Deng, Yuqi, et al. "Manifold Approximation leads to Robust Kernel Alignment." arXiv 2025.

  7. Lee, Jeeyoon, et al. "Path to Intelligence: Measuring Similarity between Human Brain and Large Language Model Beyond Language Task." arXiv 2025.

Related Projects

📝 License

MIT License

About

The Fastest, Memory-efficient Python Library for CKA with Built-in Visualization

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors 3

  •  
  •  
  •