Skip to content

RFC: Universal Model Architecture Support (GPT/MoE/VLM/LRM/LAM/HLM/LCM) #176

@m96-chan

Description

@m96-chan

RFC: Universal Model Architecture Support

Summary

PyGPUkitを以下の全モデルアーキテクチャに対応させるための基盤設計を提案する。

Type Full Name Example Models Status
GPT Generative Pre-trained Transformer GPT-2, GPT-NeoX Supported
MoE Mixture of Experts Mixtral, Qwen3-MoE Supported
VLM Vision-Language Model LLaVA, Qwen-VL, InternVL Not Supported
LRM Large Reasoning Model DeepSeek-R1, o1, QwQ Not Supported
LAM Large Action Model GPT-4 w/tools, Claude Not Supported
HLM Hierarchical Language Model Longformer, BigBird Not Supported
LCM Large Concept Model Meta SONAR Not Supported

Current Architecture

ModelSpec Pattern (Text-only)

ModelSpec (frozen dataclass)
    |
    +-- Weight patterns (tensor name templates)
    +-- Architecture flags (norm_type, activation, use_rope, etc.)
    +-- Default hyperparameters

Limitation: ModelSpec assumes single-modality text decoder-only architecture.

Module Structure

src/pygpukit/
    llm/     # Text-to-text (decoder-only)
    asr/     # Audio-to-text (encoder-decoder)
    tts/     # Text-to-audio (encoder + vocoder)

Proposed Architecture

1. Multi-Modal Model Abstraction

@dataclass(frozen=True)
class ModalitySpec:
    """Specification for a single modality encoder/decoder."""
    name: str                          # "vision", "audio", "text"
    encoder_type: str | None           # "vit", "clip", "whisper", None
    decoder_type: str | None           # "transformer", "diffusion", None
    embedding_dim: int
    weight_patterns: dict[str, str]    # Tensor name patterns

@dataclass(frozen=True)
class MultiModalModelSpec:
    """Extended ModelSpec for multi-modal architectures."""
    name: str
    modalities: dict[str, ModalitySpec]
    fusion_type: Literal["early", "late", "cross_attention", "interleaved"]
    backbone_spec: ModelSpec           # Text backbone (existing ModelSpec)

2. VLM (Vision-Language Model)

Requirements

Component Description
Vision Encoder ViT, CLIP, SigLIP, EVA-CLIP
Image Preprocessor Resize, normalize, patch extraction
Projection Layer Vision features -> LLM hidden dim
Fusion Cross-attention or concatenation

Proposed Spec

VLM_SPEC = MultiModalModelSpec(
    name="vlm",
    modalities={
        "vision": ModalitySpec(
            name="vision",
            encoder_type="vit",
            decoder_type=None,
            embedding_dim=1024,
            weight_patterns={
                "patch_embed": "vision_tower.embeddings.patch_embedding.weight",
                "pos_embed": "vision_tower.embeddings.position_embedding",
                "blocks": "vision_tower.encoder.layers.{layer}",
            }
        ),
        "text": ModalitySpec(
            name="text",
            encoder_type=None,
            decoder_type="transformer",
            embedding_dim=4096,
            weight_patterns={}  # Uses backbone_spec
        ),
    },
    fusion_type="cross_attention",
    backbone_spec=LLAMA_SPEC,
)

Module Structure

src/pygpukit/
    vlm/
        vision/
            vit.py           # Vision Transformer encoder
            clip.py          # CLIP encoder
            siglip.py        # SigLIP encoder
        preprocessor.py      # Image normalization, patching
        projector.py         # Vision -> text projection
        model.py             # VLMModel (vision encoder + LLM)
        loader.py            # load_vlm_from_safetensors()

Target Models

Model Vision Encoder LLM Backbone Fusion
LLaVA-1.5 CLIP-ViT-L LLaMA-2 Linear projection
LLaVA-NeXT SigLIP LLaMA-3 MLP projection
Qwen-VL ViT-bigG Qwen Resampler (Q-Former)
Qwen2-VL ViT (variable) Qwen2 Dynamic resolution
InternVL2 InternViT-6B InternLM2 MLP
Phi-3-Vision CLIP-ViT Phi-3 Linear

3. LRM (Large Reasoning Model)

Requirements

Component Description
Thinking Tokens <think>, </think> special tokens
Extended Generation Chain-of-thought before answer
Reasoning Parser Extract thinking vs. answer
Streaming Support Stream thinking tokens separately

Proposed Extension

@dataclass(frozen=True)
class ReasoningSpec:
    """Specification for reasoning model behavior."""
    thinking_start_token: str = "<think>"
    thinking_end_token: str = "</think>"
    answer_start_token: str | None = None
    max_thinking_tokens: int = 8192
    streaming_mode: Literal["full", "answer_only", "thinking_only"] = "full"

# Add to ModelSpec
class ModelSpec:
    # ... existing fields ...
    reasoning: ReasoningSpec | None = None  # None = standard model

# Example
DEEPSEEK_R1_SPEC = ModelSpec(
    name="deepseek_r1",
    # ... standard LLaMA-like weights ...
    reasoning=ReasoningSpec(
        thinking_start_token="<think>",
        thinking_end_token="</think>",
        max_thinking_tokens=16384,
    ),
)

Decode Strategy Extension

class DecodeReasoning(DecodeStrategy):
    """Decode strategy for reasoning models."""

    def step(self, ...) -> tuple[GPUArray, ReasoningState]:
        """Returns (logits, state) where state tracks thinking/answer phase."""
        ...

    def extract_answer(self, tokens: list[int]) -> list[int]:
        """Extract answer tokens from full reasoning trace."""
        ...

Target Models

Model Provider Thinking Tokens Notes
DeepSeek-R1 DeepSeek <think>/</think> Open weights
QwQ Alibaba Custom Qwen-based
o1/o3 OpenAI Hidden API only
Skywork-o1 Skywork <think>/</think> Open weights

4. LAM (Large Action Model)

Requirements

Component Description
Tool Schema JSON Schema for function definitions
Tool Parser Extract tool calls from generation
Action Executor (Optional) Execute actions
Multi-turn Support Tool call -> result -> continue

Proposed Extension

@dataclass(frozen=True)
class ToolCallSpec:
    """Specification for tool/function calling."""
    tool_call_start: str = "<tool_call>"
    tool_call_end: str = "</tool_call>"
    tool_result_start: str = "<tool_result>"
    tool_result_end: str = "</tool_result>"
    schema_format: Literal["json_schema", "openai", "anthropic"] = "json_schema"
    parallel_calls: bool = False  # Support multiple tool calls per turn

class ModelSpec:
    # ... existing fields ...
    tool_calling: ToolCallSpec | None = None  # None = no tool support

Module Structure

src/pygpukit/
    llm/
        tools/
            schema.py        # ToolDefinition, ParameterSchema
            parser.py        # Parse tool calls from generation
            executor.py      # (Optional) Execute tool calls
        model.py             # Add tool_call support to CausalTransformerModel

Target Models

Model Tool Format Notes
Qwen2.5 Custom XML Native tool calling
LLaMA-3.1 JSON Built-in tool format
Mistral Custom Function calling
Hermes-2 JSON Fine-tuned for tools

5. HLM (Hierarchical Language Model)

Requirements

Component Description
Local Attention Window-based attention (e.g., 512 tokens)
Global Attention Sparse global tokens
Hierarchical Encoding Document -> Paragraph -> Sentence
Long Context 16K-128K+ token support

Proposed Extension

@dataclass(frozen=True)
class HierarchicalSpec:
    """Specification for hierarchical attention patterns."""
    local_window: int = 512
    global_stride: int = 512  # Every Nth token attends globally
    num_global_tokens: int = 0  # Prepended global tokens (like [CLS])
    attention_pattern: Literal["sliding_window", "longformer", "bigbird"] = "sliding_window"

class ModelSpec:
    # ... existing fields ...
    hierarchical: HierarchicalSpec | None = None  # None = full attention

Attention Pattern Implementation

class SlidingWindowAttention(Attention):
    """Sliding window attention for long sequences."""

    def __init__(self, config: TransformerConfig, window_size: int):
        super().__init__(config)
        self.window_size = window_size

    def forward(self, hidden: GPUArray, ...) -> GPUArray:
        # Only attend to [i - window_size/2, i + window_size/2]
        ...

class LongformerAttention(Attention):
    """Longformer-style local + global attention."""

    def forward(self, hidden: GPUArray, global_mask: GPUArray, ...) -> GPUArray:
        # Local sliding window + global tokens attend everywhere
        ...

Target Models

Model Attention Type Max Length
Longformer Local + Global 4096
BigBird Local + Global + Random 4096
Mistral Sliding Window 32K
LLaMA-3 (Long) Full attention 128K
Qwen2.5-Long Full attention 128K

6. LCM (Large Concept Model)

Requirements

Component Description
SONAR Encoder Text/Audio -> Concept space
SONAR Decoder Concept space -> Text/Audio
Concept Space Language-agnostic representation
Multi-modal Alignment Text/Audio/Vision in same space

Module Structure

src/pygpukit/
    lcm/
        sonar/
            encoder.py       # SONAR encoder (text/audio -> concepts)
            decoder.py       # SONAR decoder (concepts -> text)
            tokenizer.py     # Concept tokenizer
        model.py             # LCMModel
        loader.py            # load_lcm_from_safetensors()

Proposed Spec

@dataclass(frozen=True)
class ConceptSpec:
    """Specification for concept-based models."""
    concept_dim: int = 1024
    encoder_type: Literal["sonar", "laser"] = "sonar"
    decoder_type: Literal["sonar", "transformer"] = "sonar"
    languages: list[str] | None = None  # None = multilingual

LCM_SPEC = MultiModalModelSpec(
    name="lcm",
    modalities={
        "concept": ModalitySpec(
            name="concept",
            encoder_type="sonar",
            decoder_type="sonar",
            embedding_dim=1024,
            weight_patterns={
                "encoder": "sonar.encoder",
                "decoder": "sonar.decoder",
            }
        ),
    },
    fusion_type="early",
    backbone_spec=None,  # No text backbone, concept-native
)

Implementation Phases

Phase 1: Foundation (v0.3.0)

  1. Extend ModelSpec

    • Add reasoning: ReasoningSpec | None
    • Add tool_calling: ToolCallSpec | None
    • Add hierarchical: HierarchicalSpec | None
  2. Introduce MultiModalModelSpec

    • Define ModalitySpec for encoder/decoder
    • Implement modality registry
  3. Refactor Module Structure

    • Create vlm/, lcm/ modules
    • Abstract common patterns

Phase 2: VLM Support (v0.3.x)

  1. Vision encoders (ViT, CLIP, SigLIP)
  2. Image preprocessing pipeline
  3. Projection layers
  4. LLaVA-1.5, Qwen-VL support

Phase 3: LRM/LAM Support (v0.4.x)

  1. Reasoning decode strategy
  2. Tool calling parser
  3. DeepSeek-R1, Qwen2.5-tools support

Phase 4: HLM/LCM Support (v0.5.x)

  1. Sliding window attention
  2. SONAR encoder/decoder
  3. Long-context optimization

Design Principles

  1. Configuration Over Subclassing

    • Extend ModelSpec with optional specs
    • No explosion of model classes
  2. Modality Separation

    • Each modality (vision, audio, text, concept) in its own module
    • Clean interfaces between modalities
  3. Backward Compatibility

    • Existing ModelSpec works unchanged
    • New specs are optional fields
  4. Explicit Over Implicit

    • No automatic detection of model capabilities
    • Users specify what they want

Open Questions

  1. VLM Resolution: Fixed resolution vs. dynamic (Qwen2-VL style)?
  2. LRM Streaming: How to stream thinking tokens separately?
  3. LAM Execution: Should PyGPUkit execute tool calls or just parse?
  4. HLM KV Cache: How to handle sparse attention KV cache?
  5. LCM Multilingual: Full SONAR language support or subset?

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestv0.3Advanced: Triton backend, advanced ops

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions