From c9e4f8ad52306d29ef7333aadfff93a8f276b599 Mon Sep 17 00:00:00 2001 From: heya5 Date: Thu, 26 Dec 2024 21:56:44 +0800 Subject: [PATCH 1/2] Disable return all hidden states to reduce gpu memory use --- esm/layers/transformer_stack.py | 5 ++++- esm/models/esmc.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/esm/layers/transformer_stack.py b/esm/layers/transformer_stack.py index 0b587819..bef1706b 100644 --- a/esm/layers/transformer_stack.py +++ b/esm/layers/transformer_stack.py @@ -37,8 +37,10 @@ def __init__( ffn_type: str = "swiglu", # swiglu | gelu expansion_ratio: float = 8 / 3, use_flash_attn: bool = False, + return_hidden_states: bool = False, ): super().__init__() + self.return_hidden_states = return_hidden_states self.blocks = nn.ModuleList( [ UnifiedTransformerBlock( @@ -90,5 +92,6 @@ def forward( hiddens = [] for block in self.blocks: x = block(x, sequence_id, affine, affine_mask, chain_id) - hiddens.append(x) + if self.return_hidden_states: + hiddens.append(x) return self.norm(x), x, hiddens diff --git a/esm/models/esmc.py b/esm/models/esmc.py index 0d3438f6..020407e2 100644 --- a/esm/models/esmc.py +++ b/esm/models/esmc.py @@ -58,11 +58,13 @@ def __init__( n_layers: int, tokenizer: EsmSequenceTokenizer, use_flash_attn: bool = True, + return_hidden_states: bool = False, ): super().__init__() self.embed = nn.Embedding(64, d_model) self._use_flash_attn = is_flash_attn_available and use_flash_attn + self.return_hidden_states = return_hidden_states self.transformer = TransformerStack( d_model, n_heads, @@ -70,6 +72,7 @@ def __init__( n_layers, n_layers_geom=0, use_flash_attn=self._use_flash_attn, + return_hidden_states=self.return_hidden_states ) self.sequence_head = RegressionHead(d_model, 64) @@ -164,7 +167,8 @@ def forward( ] # Stack hidden states into a [n_layers, B, L, D] matrix. - hiddens = torch.stack(hiddens, dim=0) # type: ignore + if len(hiddens): + hiddens = torch.stack(hiddens, dim=0) # type: ignore sequence_logits = self.sequence_head(x) output = ESMCOutput( From 05412044b16687fd1b73722d9b78d996d1fcac3e Mon Sep 17 00:00:00 2001 From: heya5 Date: Mon, 30 Dec 2024 15:38:34 +0800 Subject: [PATCH 2/2] Use layer normalized last_hidden_states as embeddings --- esm/models/esm3.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/esm/models/esm3.py b/esm/models/esm3.py index cbe02ddd..78034189 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -165,7 +165,8 @@ def __init__(self, d_model: int): self.function_head = RegressionHead(d_model, 260 * 8) self.residue_head = RegressionHead(d_model, 1478) - def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: + def forward(self, x: torch.Tensor, last_hidden_state: torch.Tensor) -> ESMOutput: + embeddings = x.clone() sequence_logits = self.sequence_head(x) structure_logits = self.structure_head(x) secondary_structure_logits = self.ss8_head(x) @@ -182,7 +183,7 @@ def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: sasa_logits=sasa_logits, function_logits=function_logits, residue_logits=residue_logits, - embeddings=embed, + embeddings=embeddings, ) @@ -376,10 +377,10 @@ def forward( function_tokens, residue_annotation_tokens, ) - x, embedding, _ = self.transformer( + x, last_hidden_states, _ = self.transformer( x, sequence_id, affine, affine_mask, chain_id ) - return self.output_heads(x, embedding) + return self.output_heads(x, last_hidden_states) # The following methods are for the ESM3InferenceClient interface def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: