Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from contextlib import contextmanager
from megatron.core import mpu
from packaging import version
from peft import PeftModel
from peft.utils import ModulesToSaveWrapper
from tqdm import tqdm
from transformers import PreTrainedModel
from transformers.utils import ContextManagers
from typing import Callable, List, Optional, Union

from mcore_bridge.tuners import LoraParallelLinear
Expand Down Expand Up @@ -1744,6 +1747,43 @@ def save_weights(
saver.finalize()
dist.barrier() # Ensure all weights are saved completely

@contextmanager
def _patch_hf_initialize_weight(self):

_origin_initialize_weight = PreTrainedModel._initialize_weights

def _initialize_weight(self, *args, **kwargs):
return

PreTrainedModel._initialize_weights = _initialize_weight
try:
yield
finally:
PreTrainedModel._initialize_weights = _origin_initialize_weight

@contextmanager
def _patch_device_meta(self, model_cls):
__origin_init__ = model_cls.__init__

def __init__(self, *args, **kwargs):
with torch.device('meta'):
__origin_init__(self, *args, **kwargs)

model_cls.__init__ = __init__

try:
yield
finally:
model_cls.__init__ = __origin_init__

def _get_meta_model_context(self, ignore_init_model_cls=None):
ignore_init_model_cls = ignore_init_model_cls or []
if not isinstance(ignore_init_model_cls, list):
ignore_init_model_cls = [ignore_init_model_cls]
context_list = [self._patch_device_meta(model_cls) for model_cls in ignore_init_model_cls]
context_list.append(self._patch_hf_initialize_weight())
return ContextManagers(context_list)


class MultimodalGPTBridge(GPTBridge):
hf_layers_prefix = 'model.language_model.layers'
Expand Down
25 changes: 21 additions & 4 deletions src/mcore_bridge/model/mm_gpts/internvl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import importlib
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from mcore_bridge.bridge import GPTBridge, MultimodalGPTBridge
Expand All @@ -18,6 +19,23 @@ class InternvlBridge(GPTBridge):
hf_lm_head_key = 'language_model.lm_head.weight'
hf_score_key = 'language_model.score.weight'

def get_hf_meta_model(self):
model_cls = []
class_names = ['Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM', 'GptOssForCausalLM']
module = importlib.import_module('transformers')
for cls_name in class_names:
try:
model_cls.append(getattr(module, cls_name))
except (ImportError, AttributeError):
pass
contexts = self._get_meta_model_context(model_cls)
hf_config = self.config.hf_config
model_cls = get_class_from_dynamic_module('modeling_internvl_chat.InternVLChatModel', hf_config.name_or_path)
with contexts:
model = model_cls(hf_config)
model._auto_class = 'AutoModelForCausalLM'
return model


class InternvlVit(HuggingFaceVit):
module_mapping = {'vision_model': 'vision_model', 'mlp1': 'mlp1'}
Expand All @@ -33,7 +51,6 @@ def prepare_attn_impl(self):
self.hf_config.vision_config.use_flash_attn = use_flash_attn

def prepare_model(self, hf_config: PretrainedConfig):
from transformers import AutoProcessor
llm_model_type = self.config.llm_model_type
if llm_model_type not in ['qwen2', 'qwen3', 'qwen3_moe', 'gpt_oss']:
raise ValueError(f'{llm_model_type} is not supported for internvl_chat model')
Expand All @@ -52,7 +69,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.select_layer = hf_config.select_layer
self.downsample_ratio = hf_config.downsample_ratio
self.ps_version = hf_config.ps_version
self.processor = AutoProcessor.from_pretrained(hf_config.name_or_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(hf_config.name_or_path, trust_remote_code=True)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
input_ids = kwargs['input_ids']
Expand All @@ -63,7 +80,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
inputs_embeds = inputs_embeds + vit_embeds.mean() * 0.
else:
vit_embeds = self.extract_feature(pixel_values.to(self.vision_model.dtype))
selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
inputs_embeds = inputs_embeds.clone()
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1]).to(dtype=inputs_embeds.dtype)
return inputs_embeds
Expand Down
Loading