diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index c33a572..95b55de 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -5,11 +5,14 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from contextlib import contextmanager from megatron.core import mpu from packaging import version from peft import PeftModel from peft.utils import ModulesToSaveWrapper from tqdm import tqdm +from transformers import PreTrainedModel +from transformers.utils import ContextManagers from typing import Callable, List, Optional, Union from mcore_bridge.tuners import LoraParallelLinear @@ -1744,6 +1747,43 @@ def save_weights( saver.finalize() dist.barrier() # Ensure all weights are saved completely + @contextmanager + def _patch_hf_initialize_weight(self): + + _origin_initialize_weight = PreTrainedModel._initialize_weights + + def _initialize_weight(self, *args, **kwargs): + return + + PreTrainedModel._initialize_weights = _initialize_weight + try: + yield + finally: + PreTrainedModel._initialize_weights = _origin_initialize_weight + + @contextmanager + def _patch_device_meta(self, model_cls): + __origin_init__ = model_cls.__init__ + + def __init__(self, *args, **kwargs): + with torch.device('meta'): + __origin_init__(self, *args, **kwargs) + + model_cls.__init__ = __init__ + + try: + yield + finally: + model_cls.__init__ = __origin_init__ + + def _get_meta_model_context(self, ignore_init_model_cls=None): + ignore_init_model_cls = ignore_init_model_cls or [] + if not isinstance(ignore_init_model_cls, list): + ignore_init_model_cls = [ignore_init_model_cls] + context_list = [self._patch_device_meta(model_cls) for model_cls in ignore_init_model_cls] + context_list.append(self._patch_hf_initialize_weight()) + return ContextManagers(context_list) + class MultimodalGPTBridge(GPTBridge): hf_layers_prefix = 'model.language_model.layers' diff --git a/src/mcore_bridge/model/mm_gpts/internvl.py b/src/mcore_bridge/model/mm_gpts/internvl.py index aa142c0..1f3bedc 100644 --- a/src/mcore_bridge/model/mm_gpts/internvl.py +++ b/src/mcore_bridge/model/mm_gpts/internvl.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import importlib import torch from torch import nn -from transformers import AutoModel, PretrainedConfig +from transformers import AutoModel, AutoTokenizer, PretrainedConfig from transformers.dynamic_module_utils import get_class_from_dynamic_module from mcore_bridge.bridge import GPTBridge, MultimodalGPTBridge @@ -18,6 +19,23 @@ class InternvlBridge(GPTBridge): hf_lm_head_key = 'language_model.lm_head.weight' hf_score_key = 'language_model.score.weight' + def get_hf_meta_model(self): + model_cls = [] + class_names = ['Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM', 'GptOssForCausalLM'] + module = importlib.import_module('transformers') + for cls_name in class_names: + try: + model_cls.append(getattr(module, cls_name)) + except (ImportError, AttributeError): + pass + contexts = self._get_meta_model_context(model_cls) + hf_config = self.config.hf_config + model_cls = get_class_from_dynamic_module('modeling_internvl_chat.InternVLChatModel', hf_config.name_or_path) + with contexts: + model = model_cls(hf_config) + model._auto_class = 'AutoModelForCausalLM' + return model + class InternvlVit(HuggingFaceVit): module_mapping = {'vision_model': 'vision_model', 'mlp1': 'mlp1'} @@ -33,7 +51,6 @@ def prepare_attn_impl(self): self.hf_config.vision_config.use_flash_attn = use_flash_attn def prepare_model(self, hf_config: PretrainedConfig): - from transformers import AutoProcessor llm_model_type = self.config.llm_model_type if llm_model_type not in ['qwen2', 'qwen3', 'qwen3_moe', 'gpt_oss']: raise ValueError(f'{llm_model_type} is not supported for internvl_chat model') @@ -52,7 +69,7 @@ def prepare_model(self, hf_config: PretrainedConfig): self.select_layer = hf_config.select_layer self.downsample_ratio = hf_config.downsample_ratio self.ps_version = hf_config.ps_version - self.processor = AutoProcessor.from_pretrained(hf_config.name_or_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(hf_config.name_or_path, trust_remote_code=True) def get_inputs_embeds(self, inputs_embeds, **kwargs): input_ids = kwargs['input_ids'] @@ -63,7 +80,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): inputs_embeds = inputs_embeds + vit_embeds.mean() * 0. else: vit_embeds = self.extract_feature(pixel_values.to(self.vision_model.dtype)) - selected = (input_ids == self.processor.encode('', add_special_tokens=False)[0]) + selected = (input_ids == self.tokenizer.encode('', add_special_tokens=False)[0]) inputs_embeds = inputs_embeds.clone() inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1]).to(dtype=inputs_embeds.dtype) return inputs_embeds