diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 29db5d0..12cd97d 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd( rotary_interleaved: bool = False, multi_latent_attention: bool = False, # not use mscale: float = 1.0, + **kwargs, ) -> torch.Tensor: """Apply rotary positional embedding to input tensor T. diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 2de8e51..392ff34 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -608,6 +608,7 @@ def _apply_rotary_pos_emb_thd( multi_latent_attention: bool = False, mscale: float = 1.0, cp_group: torch.distributed.ProcessGroup = None, + **kwargs, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. @@ -629,7 +630,8 @@ def _apply_rotary_pos_emb_thd( use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() if not use_batched_rope: logger.warning_once('Using non-batched RoPE, which may affect performance.') - kwargs = {'cp_group': cp_group} if mcore_013 else {} + if mcore_013: + kwargs['cp_group'] = cp_group return _origin_apply_rotary_pos_emb_thd( t, cu_seqlens, @@ -646,6 +648,7 @@ def _apply_rotary_pos_emb_thd( rotary_interleaved=rotary_interleaved, multi_latent_attention=multi_latent_attention, mscale=mscale, + **kwargs, ).squeeze(1) rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd diff --git a/src/mcore_bridge/version.py b/src/mcore_bridge/version.py index 3594f30..f4f1f4b 100644 --- a/src/mcore_bridge/version.py +++ b/src/mcore_bridge/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.0.1.dev0' +__version__ = '1.1.0.dev0' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future __release_datetime__ = '2099-12-31 23:59:59'