66from itertools import zip_longest
77
88import torch
9- from torch import nn , Tensor , arange , cat
10- from torch .nn import Module , ModuleList
119import torch .nn .functional as F
10+ from torch import nn , Tensor , tensor , arange , cat
11+ from torch .nn import Module , ModuleList
12+
1213import torch .distributed as dist
1314from vector_quantize_pytorch .vector_quantize_pytorch import VectorQuantize
1415
@@ -172,6 +173,7 @@ def __init__(
172173 mlp_kwargs : dict = dict (),
173174 beam_size = None ,
174175 eval_beam_size = None ,
176+ beam_score_quantizer_weights : list [float ] | None = None ,
175177 ** vq_kwargs
176178 ):
177179 super ().__init__ ()
@@ -236,6 +238,13 @@ def __init__(
236238 self .beam_size = default (beam_size , eval_beam_size )
237239 self .eval_beam_size = eval_beam_size
238240
241+ # able to assign a different weight for the scoring at each quantizer layer
242+
243+ beam_score_quantizer_weights = default (beam_score_quantizer_weights , [1. ] * num_quantizers )
244+ assert len (beam_score_quantizer_weights ) == num_quantizers
245+
246+ self .register_buffer ('beam_score_weights' , tensor (beam_score_quantizer_weights ), persistent = False )
247+
239248 # setting up the MLPs for implicit neural codebooks
240249
241250 self .mlps = None
@@ -427,6 +436,8 @@ def forward(
427436
428437 for quantizer_index , (vq , maybe_mlp ) in enumerate (zip (self .layers , maybe_code_transforms )):
429438
439+ is_last_step = (quantizer_index == (len (self .layers ) - 1 )) if not should_quantize_dropout else quantizer_index == rand_quantize_dropout_index
440+
430441 if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index :
431442 all_indices = pad_at_dim (all_indices , (0 , 1 ), value = - 1 , dim = - 1 )
432443 all_losses = pad_at_dim (all_losses , (0 , 1 ), value = 0 , dim = - 1 )
@@ -470,7 +481,8 @@ def forward(
470481
471482 if is_beam_search :
472483
473- search_scores = einx .add ('... j, ... j k -> ... (j k)' , search_scores , - loss )
484+ score_weight = self .beam_score_weights [quantizer_index ]
485+ search_scores = einx .add ('... j, ... j k -> ... (j k)' , search_scores , - loss * score_weight )
474486
475487 residual = rearrange (residual , '... j d -> ... j 1 d' )
476488 quantized_out = rearrange (quantized_out , '... j d -> ... j 1 d' )
@@ -506,8 +518,10 @@ def forward(
506518
507519 # handle sort and selection of highest beam size
508520
509- if search_scores .shape [- 1 ] > beam_size :
510- search_scores , select_indices = search_scores .topk (beam_size , dim = - 1 )
521+ layer_beam_size = beam_size if not is_last_step else 1
522+
523+ if search_scores .shape [- 1 ] > layer_beam_size :
524+ search_scores , select_indices = search_scores .topk (layer_beam_size , dim = - 1 )
511525
512526 residual = batch_select (residual , select_indices , '* k d' )
513527 quantized_out = batch_select (quantized_out , select_indices , '* k d' )
@@ -526,12 +540,6 @@ def forward(
526540 # handle beam search
527541
528542 if is_beam_search :
529- top_index = search_scores .argmax (dim = - 1 , keepdim = True )
530-
531- quantized_out = batch_select (quantized_out , top_index , '* k d' )
532- all_indices = batch_select (all_indices , top_index , '* k l' )
533- all_losses = batch_select (all_losses , top_index , '* k l' )
534- all_residuals = batch_select (all_residuals , top_index , '* k l d' )
535543
536544 quantized_out , all_indices , all_losses , all_residuals = [t [..., 0 , :] for t in (quantized_out , all_indices , all_losses , all_residuals )]
537545
0 commit comments