Skip to content

Commit 0b8fbad

Browse files
committed
last cleanup for the day
1 parent 3937aa2 commit 0b8fbad

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.1"
3+
version = "1.27.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_vq.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from itertools import zip_longest
77

88
import torch
9-
from torch import nn, Tensor, arange, cat
10-
from torch.nn import Module, ModuleList
119
import torch.nn.functional as F
10+
from torch import nn, Tensor, tensor, arange, cat
11+
from torch.nn import Module, ModuleList
12+
1213
import torch.distributed as dist
1314
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
1415

@@ -172,6 +173,7 @@ def __init__(
172173
mlp_kwargs: dict = dict(),
173174
beam_size = None,
174175
eval_beam_size = None,
176+
beam_score_quantizer_weights: list[float] | None = None,
175177
**vq_kwargs
176178
):
177179
super().__init__()
@@ -236,6 +238,13 @@ def __init__(
236238
self.beam_size = default(beam_size, eval_beam_size)
237239
self.eval_beam_size = eval_beam_size
238240

241+
# able to assign a different weight for the scoring at each quantizer layer
242+
243+
beam_score_quantizer_weights = default(beam_score_quantizer_weights, [1.] * num_quantizers)
244+
assert len(beam_score_quantizer_weights) == num_quantizers
245+
246+
self.register_buffer('beam_score_weights', tensor(beam_score_quantizer_weights), persistent = False)
247+
239248
# setting up the MLPs for implicit neural codebooks
240249

241250
self.mlps = None
@@ -427,6 +436,8 @@ def forward(
427436

428437
for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)):
429438

439+
is_last_step = (quantizer_index == (len(self.layers) - 1)) if not should_quantize_dropout else quantizer_index == rand_quantize_dropout_index
440+
430441
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
431442
all_indices = pad_at_dim(all_indices, (0, 1), value = -1, dim = -1)
432443
all_losses = pad_at_dim(all_losses, (0, 1), value = 0, dim = -1)
@@ -470,7 +481,8 @@ def forward(
470481

471482
if is_beam_search:
472483

473-
search_scores = einx.add('... j, ... j k -> ... (j k)', search_scores, -loss)
484+
score_weight = self.beam_score_weights[quantizer_index]
485+
search_scores = einx.add('... j, ... j k -> ... (j k)', search_scores, -loss * score_weight)
474486

475487
residual = rearrange(residual, '... j d -> ... j 1 d')
476488
quantized_out = rearrange(quantized_out, '... j d -> ... j 1 d')
@@ -506,8 +518,10 @@ def forward(
506518

507519
# handle sort and selection of highest beam size
508520

509-
if search_scores.shape[-1] > beam_size:
510-
search_scores, select_indices = search_scores.topk(beam_size, dim = -1)
521+
layer_beam_size = beam_size if not is_last_step else 1
522+
523+
if search_scores.shape[-1] > layer_beam_size:
524+
search_scores, select_indices = search_scores.topk(layer_beam_size, dim = -1)
511525

512526
residual = batch_select(residual, select_indices, '* k d')
513527
quantized_out = batch_select(quantized_out, select_indices, '* k d')
@@ -526,12 +540,6 @@ def forward(
526540
# handle beam search
527541

528542
if is_beam_search:
529-
top_index = search_scores.argmax(dim = -1, keepdim = True)
530-
531-
quantized_out = batch_select(quantized_out, top_index, '* k d')
532-
all_indices = batch_select(all_indices, top_index, '* k l')
533-
all_losses = batch_select(all_losses, top_index, '* k l')
534-
all_residuals = batch_select(all_residuals, top_index, '* k l d')
535543

536544
quantized_out, all_indices, all_losses, all_residuals = [t[..., 0, :] for t in (quantized_out, all_indices, all_losses, all_residuals)]
537545

0 commit comments

Comments
 (0)