Skip to content

Inconsistent behaviour between singlue gpu and distributed implementation  #1

@OscarYau525

Description

@OscarYau525

Dear authors,
I want to check if the following distributed code matches the design of SogCLR.

The distributed part of dynamic_contrastive_loss() in bulider.py might be inconsistent with its non-distributed counterpart, because:

  1. When distributed, all_gather_layer only backpropagate through the locally computed encodings.
  2. Each gpu compute loss using logits_ab_aa and logits_ba_bb, therefore the off-diagonal inner products of encodings does not have its gradient fully computed. All other gpus should compute the same part of logits_ab_aa so that all gradients are computed, i.e., replace logits_ab_aa with inner product of hidden_large.

I suggest the following implementation for correct distributed behaviour:

def dynamic_contrastive_loss(self, hidden1, hidden2, index=None, gamma=0.9, distributed=True):
    # Get (normalized) hidden1 and hidden2.
    hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1)
    batch_size = hidden1.shape[0]
    
    # Gather hidden1/hidden2 across replicas and create local labels.
    if distributed:  
        hidden1_large = torch.cat(all_gather_layer.apply(hidden1), dim=0) # why concat_all_gather()
        hidden2_large =  torch.cat(all_gather_layer.apply(hidden2), dim=0)
        enlarged_batch_size = hidden1_large.shape[0]

        labels_idx = torch.arange(enlarged_batch_size, dtype=torch.long)

        labels = F.one_hot(labels_idx, enlarged_batch_size*2).to(self.device) 
        batch_size = enlarged_batch_size
    else:
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size * 2).to(self.device) 

    """each agent should compute the whole logits matrix, because u_i is different across the rows."""

    logits_aa = torch.matmul(hidden1_large, hidden1_large.T) # (b * world_size, b * world_size)
    logits_bb = torch.matmul(hidden2_large, hidden2_large.T)
    logits_ab = torch.matmul(hidden1_large, hidden2_large.T)
    logits_ba = torch.matmul(hidden2_large, hidden1_large.T)

    #  SogCLR
    neg_mask = 1-labels
    logits_ab_aa = torch.cat([logits_ab, logits_aa ], 1) # neg. pairs inner product, (b * world_size, 2 * b * world_size)
    logits_ba_bb = torch.cat([logits_ba, logits_bb ], 1)
    
    neg_logits1 = torch.exp(logits_ab_aa /self.T)*neg_mask   #(B, 2B)
    neg_logits2 = torch.exp(logits_ba_bb /self.T)*neg_mask

    neg_logits1[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM
    neg_logits2[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM

    if distributed:
        index = concat_all_gather(index)

    # u init    
    if self.u[index.cpu()].sum() == 0:
        gamma = 1
        
    u1 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
    u2 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))

    self.u[index.cpu()] = (u1.detach().cpu() + u2.detach().cpu())/2 

    p_neg_weights1 = (neg_logits1/u1).detach()
    p_neg_weights2 = (neg_logits2/u2).detach()

    def softmax_cross_entropy_with_logits(labels, logits, weights):
        expsum_neg_logits = torch.sum(weights*logits, dim=1, keepdim=True)/(2*(batch_size-1))
        normalized_logits = logits - expsum_neg_logits
        return -torch.sum(labels * normalized_logits, dim=1)

    loss_a = softmax_cross_entropy_with_logits(labels, logits_ab_aa, p_neg_weights1)
    loss_b = softmax_cross_entropy_with_logits(labels, logits_ba_bb, p_neg_weights2)
    loss = (loss_a + loss_b).mean()

    return loss

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions