-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
enhancementNew feature or requestNew feature or requestv0.3Advanced: Triton backend, advanced opsAdvanced: Triton backend, advanced ops
Description
RFC: Universal Model Architecture Support
Summary
PyGPUkitを以下の全モデルアーキテクチャに対応させるための基盤設計を提案する。
| Type | Full Name | Example Models | Status |
|---|---|---|---|
| GPT | Generative Pre-trained Transformer | GPT-2, GPT-NeoX | Supported |
| MoE | Mixture of Experts | Mixtral, Qwen3-MoE | Supported |
| VLM | Vision-Language Model | LLaVA, Qwen-VL, InternVL | Not Supported |
| LRM | Large Reasoning Model | DeepSeek-R1, o1, QwQ | Not Supported |
| LAM | Large Action Model | GPT-4 w/tools, Claude | Not Supported |
| HLM | Hierarchical Language Model | Longformer, BigBird | Not Supported |
| LCM | Large Concept Model | Meta SONAR | Not Supported |
Current Architecture
ModelSpec Pattern (Text-only)
ModelSpec (frozen dataclass)
|
+-- Weight patterns (tensor name templates)
+-- Architecture flags (norm_type, activation, use_rope, etc.)
+-- Default hyperparameters
Limitation: ModelSpec assumes single-modality text decoder-only architecture.
Module Structure
src/pygpukit/
llm/ # Text-to-text (decoder-only)
asr/ # Audio-to-text (encoder-decoder)
tts/ # Text-to-audio (encoder + vocoder)
Proposed Architecture
1. Multi-Modal Model Abstraction
@dataclass(frozen=True)
class ModalitySpec:
"""Specification for a single modality encoder/decoder."""
name: str # "vision", "audio", "text"
encoder_type: str | None # "vit", "clip", "whisper", None
decoder_type: str | None # "transformer", "diffusion", None
embedding_dim: int
weight_patterns: dict[str, str] # Tensor name patterns
@dataclass(frozen=True)
class MultiModalModelSpec:
"""Extended ModelSpec for multi-modal architectures."""
name: str
modalities: dict[str, ModalitySpec]
fusion_type: Literal["early", "late", "cross_attention", "interleaved"]
backbone_spec: ModelSpec # Text backbone (existing ModelSpec)2. VLM (Vision-Language Model)
Requirements
| Component | Description |
|---|---|
| Vision Encoder | ViT, CLIP, SigLIP, EVA-CLIP |
| Image Preprocessor | Resize, normalize, patch extraction |
| Projection Layer | Vision features -> LLM hidden dim |
| Fusion | Cross-attention or concatenation |
Proposed Spec
VLM_SPEC = MultiModalModelSpec(
name="vlm",
modalities={
"vision": ModalitySpec(
name="vision",
encoder_type="vit",
decoder_type=None,
embedding_dim=1024,
weight_patterns={
"patch_embed": "vision_tower.embeddings.patch_embedding.weight",
"pos_embed": "vision_tower.embeddings.position_embedding",
"blocks": "vision_tower.encoder.layers.{layer}",
}
),
"text": ModalitySpec(
name="text",
encoder_type=None,
decoder_type="transformer",
embedding_dim=4096,
weight_patterns={} # Uses backbone_spec
),
},
fusion_type="cross_attention",
backbone_spec=LLAMA_SPEC,
)Module Structure
src/pygpukit/
vlm/
vision/
vit.py # Vision Transformer encoder
clip.py # CLIP encoder
siglip.py # SigLIP encoder
preprocessor.py # Image normalization, patching
projector.py # Vision -> text projection
model.py # VLMModel (vision encoder + LLM)
loader.py # load_vlm_from_safetensors()
Target Models
| Model | Vision Encoder | LLM Backbone | Fusion |
|---|---|---|---|
| LLaVA-1.5 | CLIP-ViT-L | LLaMA-2 | Linear projection |
| LLaVA-NeXT | SigLIP | LLaMA-3 | MLP projection |
| Qwen-VL | ViT-bigG | Qwen | Resampler (Q-Former) |
| Qwen2-VL | ViT (variable) | Qwen2 | Dynamic resolution |
| InternVL2 | InternViT-6B | InternLM2 | MLP |
| Phi-3-Vision | CLIP-ViT | Phi-3 | Linear |
3. LRM (Large Reasoning Model)
Requirements
| Component | Description |
|---|---|
| Thinking Tokens | <think>, </think> special tokens |
| Extended Generation | Chain-of-thought before answer |
| Reasoning Parser | Extract thinking vs. answer |
| Streaming Support | Stream thinking tokens separately |
Proposed Extension
@dataclass(frozen=True)
class ReasoningSpec:
"""Specification for reasoning model behavior."""
thinking_start_token: str = "<think>"
thinking_end_token: str = "</think>"
answer_start_token: str | None = None
max_thinking_tokens: int = 8192
streaming_mode: Literal["full", "answer_only", "thinking_only"] = "full"
# Add to ModelSpec
class ModelSpec:
# ... existing fields ...
reasoning: ReasoningSpec | None = None # None = standard model
# Example
DEEPSEEK_R1_SPEC = ModelSpec(
name="deepseek_r1",
# ... standard LLaMA-like weights ...
reasoning=ReasoningSpec(
thinking_start_token="<think>",
thinking_end_token="</think>",
max_thinking_tokens=16384,
),
)Decode Strategy Extension
class DecodeReasoning(DecodeStrategy):
"""Decode strategy for reasoning models."""
def step(self, ...) -> tuple[GPUArray, ReasoningState]:
"""Returns (logits, state) where state tracks thinking/answer phase."""
...
def extract_answer(self, tokens: list[int]) -> list[int]:
"""Extract answer tokens from full reasoning trace."""
...Target Models
| Model | Provider | Thinking Tokens | Notes |
|---|---|---|---|
| DeepSeek-R1 | DeepSeek | <think>/</think> |
Open weights |
| QwQ | Alibaba | Custom | Qwen-based |
| o1/o3 | OpenAI | Hidden | API only |
| Skywork-o1 | Skywork | <think>/</think> |
Open weights |
4. LAM (Large Action Model)
Requirements
| Component | Description |
|---|---|
| Tool Schema | JSON Schema for function definitions |
| Tool Parser | Extract tool calls from generation |
| Action Executor | (Optional) Execute actions |
| Multi-turn Support | Tool call -> result -> continue |
Proposed Extension
@dataclass(frozen=True)
class ToolCallSpec:
"""Specification for tool/function calling."""
tool_call_start: str = "<tool_call>"
tool_call_end: str = "</tool_call>"
tool_result_start: str = "<tool_result>"
tool_result_end: str = "</tool_result>"
schema_format: Literal["json_schema", "openai", "anthropic"] = "json_schema"
parallel_calls: bool = False # Support multiple tool calls per turn
class ModelSpec:
# ... existing fields ...
tool_calling: ToolCallSpec | None = None # None = no tool supportModule Structure
src/pygpukit/
llm/
tools/
schema.py # ToolDefinition, ParameterSchema
parser.py # Parse tool calls from generation
executor.py # (Optional) Execute tool calls
model.py # Add tool_call support to CausalTransformerModel
Target Models
| Model | Tool Format | Notes |
|---|---|---|
| Qwen2.5 | Custom XML | Native tool calling |
| LLaMA-3.1 | JSON | Built-in tool format |
| Mistral | Custom | Function calling |
| Hermes-2 | JSON | Fine-tuned for tools |
5. HLM (Hierarchical Language Model)
Requirements
| Component | Description |
|---|---|
| Local Attention | Window-based attention (e.g., 512 tokens) |
| Global Attention | Sparse global tokens |
| Hierarchical Encoding | Document -> Paragraph -> Sentence |
| Long Context | 16K-128K+ token support |
Proposed Extension
@dataclass(frozen=True)
class HierarchicalSpec:
"""Specification for hierarchical attention patterns."""
local_window: int = 512
global_stride: int = 512 # Every Nth token attends globally
num_global_tokens: int = 0 # Prepended global tokens (like [CLS])
attention_pattern: Literal["sliding_window", "longformer", "bigbird"] = "sliding_window"
class ModelSpec:
# ... existing fields ...
hierarchical: HierarchicalSpec | None = None # None = full attentionAttention Pattern Implementation
class SlidingWindowAttention(Attention):
"""Sliding window attention for long sequences."""
def __init__(self, config: TransformerConfig, window_size: int):
super().__init__(config)
self.window_size = window_size
def forward(self, hidden: GPUArray, ...) -> GPUArray:
# Only attend to [i - window_size/2, i + window_size/2]
...
class LongformerAttention(Attention):
"""Longformer-style local + global attention."""
def forward(self, hidden: GPUArray, global_mask: GPUArray, ...) -> GPUArray:
# Local sliding window + global tokens attend everywhere
...Target Models
| Model | Attention Type | Max Length |
|---|---|---|
| Longformer | Local + Global | 4096 |
| BigBird | Local + Global + Random | 4096 |
| Mistral | Sliding Window | 32K |
| LLaMA-3 (Long) | Full attention | 128K |
| Qwen2.5-Long | Full attention | 128K |
6. LCM (Large Concept Model)
Requirements
| Component | Description |
|---|---|
| SONAR Encoder | Text/Audio -> Concept space |
| SONAR Decoder | Concept space -> Text/Audio |
| Concept Space | Language-agnostic representation |
| Multi-modal Alignment | Text/Audio/Vision in same space |
Module Structure
src/pygpukit/
lcm/
sonar/
encoder.py # SONAR encoder (text/audio -> concepts)
decoder.py # SONAR decoder (concepts -> text)
tokenizer.py # Concept tokenizer
model.py # LCMModel
loader.py # load_lcm_from_safetensors()
Proposed Spec
@dataclass(frozen=True)
class ConceptSpec:
"""Specification for concept-based models."""
concept_dim: int = 1024
encoder_type: Literal["sonar", "laser"] = "sonar"
decoder_type: Literal["sonar", "transformer"] = "sonar"
languages: list[str] | None = None # None = multilingual
LCM_SPEC = MultiModalModelSpec(
name="lcm",
modalities={
"concept": ModalitySpec(
name="concept",
encoder_type="sonar",
decoder_type="sonar",
embedding_dim=1024,
weight_patterns={
"encoder": "sonar.encoder",
"decoder": "sonar.decoder",
}
),
},
fusion_type="early",
backbone_spec=None, # No text backbone, concept-native
)Implementation Phases
Phase 1: Foundation (v0.3.0)
-
Extend ModelSpec
- Add
reasoning: ReasoningSpec | None - Add
tool_calling: ToolCallSpec | None - Add
hierarchical: HierarchicalSpec | None
- Add
-
Introduce MultiModalModelSpec
- Define
ModalitySpecfor encoder/decoder - Implement modality registry
- Define
-
Refactor Module Structure
- Create
vlm/,lcm/modules - Abstract common patterns
- Create
Phase 2: VLM Support (v0.3.x)
- Vision encoders (ViT, CLIP, SigLIP)
- Image preprocessing pipeline
- Projection layers
- LLaVA-1.5, Qwen-VL support
Phase 3: LRM/LAM Support (v0.4.x)
- Reasoning decode strategy
- Tool calling parser
- DeepSeek-R1, Qwen2.5-tools support
Phase 4: HLM/LCM Support (v0.5.x)
- Sliding window attention
- SONAR encoder/decoder
- Long-context optimization
Design Principles
-
Configuration Over Subclassing
- Extend
ModelSpecwith optional specs - No explosion of model classes
- Extend
-
Modality Separation
- Each modality (vision, audio, text, concept) in its own module
- Clean interfaces between modalities
-
Backward Compatibility
- Existing
ModelSpecworks unchanged - New specs are optional fields
- Existing
-
Explicit Over Implicit
- No automatic detection of model capabilities
- Users specify what they want
Open Questions
- VLM Resolution: Fixed resolution vs. dynamic (Qwen2-VL style)?
- LRM Streaming: How to stream thinking tokens separately?
- LAM Execution: Should PyGPUkit execute tool calls or just parse?
- HLM KV Cache: How to handle sparse attention KV cache?
- LCM Multilingual: Full SONAR language support or subset?
References
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestv0.3Advanced: Triton backend, advanced opsAdvanced: Triton backend, advanced ops