diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 2aa59e7d18..8c461c9fc2 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -147,13 +147,19 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - prob_fg = prob[:, :-1] # exclude background - scores = prob_fg.max(axis=-1) - labels = prob_fg.argmax(axis=-1) + scores = prob.max(axis=-1) + labels = prob.argmax(axis=-1) + + # treat background as invalid prediction + bg = self.num_classes - 1 + valid = labels != bg + + scores = scores * valid # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: - idxs = np.argsort(scores)[::-1][: self.topk] + idxs = np.argpartition(-scores, self.topk)[: self.topk] + idxs = idxs[np.argsort(-scores[idxs])] else: idxs = np.arange(len(scores)) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 08f9ea9fb4..e9f319e664 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -250,7 +250,7 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask - value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 @@ -409,26 +409,28 @@ def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 25 """ scale = 2 * math.pi dim = hidden_size // 2 + # Keep dim_t in float32 for numerical precision; cast output to match caller dtype dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) - x_embed = pos_tensor[:, :, 0] * scale - y_embed = pos_tensor[:, :, 1] * scale + x_embed = pos_tensor[:, :, 0].float() * scale + y_embed = pos_tensor[:, :, 1].float() * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 4: - w_embed = pos_tensor[:, :, 2] * scale + w_embed = pos_tensor[:, :, 2].float() * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) - h_embed = pos_tensor[:, :, 3] * scale + h_embed = pos_tensor[:, :, 3].float() * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + # Cast back to the caller's dtype (supports bfloat16 / float16 AMP) return pos.to(pos_tensor.dtype) @@ -498,7 +500,7 @@ def get_reference( tensor containing the valid ratios for each level of the input feature maps Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) + reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6) tensor containing the reference point inputs for the decoder layers, which are the normalized center coordinates, width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps @@ -535,25 +537,38 @@ def get_reference( return reference_points_inputs, query_pos def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas - # Add eps=1e-6 to avoid division-by-zero NaN creation + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # rotation delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] cos_ref = reference_points[..., 5:6] - + # compose rotations sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta - - # Add eps=1e-6 here too rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) def forward( @@ -590,11 +605,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) - + reference_points = self.refine_boxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 848d44280e..d760504378 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -10,6 +10,7 @@ import numpy as np import torch +from scipy.optimize import linear_sum_assignment from torch import nn from torch.nn import functional as F @@ -156,8 +157,8 @@ def __init__( score_thresh: float = 0.3, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 130, - group_detr: int = 1, + num_queries: int = 300, + group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, @@ -185,8 +186,10 @@ def __init__( self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) # Initialize angle to (sin=0, cos=1) with torch.no_grad(): - self.reference_point_embed.weight[:, 4] = 0.0 # sinθ - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) + self.reference_point_embed.weight[:, 2:4].fill_(0.1) + self.reference_point_embed.weight[:, 4].zero_() + self.reference_point_embed.weight[:, 5].fill_(1.0) self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -241,7 +244,8 @@ def __init__( if hasattr(m, "bias") and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) + if m is not self.reference_point_embed: + nn.init.normal_(m.weight, std=0.02) elif isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) @@ -252,27 +256,26 @@ def __init__( .view(m.n_heads, 1, 1, 2) .repeat(1, m.n_levels, m.n_points, 1) ) - for i in range(m.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(m.value_proj.weight) nn.init.zeros_(m.value_proj.bias) - nn.init.xavier_uniform_(m.output_proj.weight) nn.init.zeros_(m.output_proj.bias) - if isinstance(m, nn.Linear) and m.out_features == self.num_classes: - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) if m.bias is not None: - nn.init.constant_(m.bias, bias_value) + with torch.no_grad(): + # Focal-loss prior: foreground starts with low confidence (~0.01), + # preventing background from dominating gradients at the start of training. + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(m.bias, 0.0) + m.bias[:-1].fill_(bias_value) if isinstance(m, LWDETRHead): last = m.layers[-1] if isinstance(last, nn.Linear): @@ -310,22 +313,19 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> refined_boxes: (N, S, 6) tensor containing the refined bounding boxes """ reference_points = reference_points.to(deltas.device) - # center cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] # rotation delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] cos_ref = reference_points[..., 5:6] - # compose rotations sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: @@ -404,13 +404,13 @@ def gen_encoder_output_proposals( spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid - invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) # assign each pixel as an object query object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, float(0)) + object_query = object_query.masked_fill(invalid_mask, 0.0) + return object_query, output_proposals, invalid_mask def forward( @@ -468,6 +468,7 @@ def forward( topk = self.num_queries topk_coords_logits_list: list[torch.Tensor] = [] + topk_content_list: list[torch.Tensor] = [] # encoder predictions for auxiliary losses all_group_enc_logits: list[torch.Tensor] = [] @@ -487,17 +488,25 @@ def forward( all_group_enc_coords.append(group_enc_outputs_coord) - group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] + scores = group_enc_outputs_class_masked[..., :-1].max(-1).values + group_topk_proposals = torch.topk(scores, topk, dim=1)[1] group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_content = torch.gather( + group_object_query, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), + ) + topk_content_list.append(group_topk_content) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + reference_points = self.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( @@ -506,11 +515,11 @@ def forward( spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) - pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + pred_boxes = intermediate_reference_points[-1] out: dict[str, Any] = {} @@ -543,11 +552,10 @@ def _postprocess(logits, boxes): main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) loss = main_loss / group_detr - # Auxiliary losses from intermediate decoder layers + # Auxiliary losses from intermediate decoder layers (group DETR) for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) + aux_boxes = intermediate_reference_points[i + 1] split_aux_logits = aux_logits.chunk(group_detr, dim=1) split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) @@ -555,48 +563,48 @@ def _postprocess(logits, boxes): aux_loss: float | torch.Tensor = 0.0 for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += 0.5 * (aux_loss / group_detr) + loss += aux_loss / group_detr # Auxiliary losses for encoder proposals enc_loss: float | torch.Tensor = 0.0 for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += 0.1 * (enc_loss / group_detr) + loss += enc_loss / group_detr out["loss"] = loss return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: - """ - Compute the loss for LW-DETR. The loss consists of three components: - classification loss, box regression loss, and rotation loss. - The classification loss is a cross-entropy loss between the predicted class logits and the target classes. - The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, - computed only on the positive samples. - The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, - averaged over the positive samples. - The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, - we select the top-k queries with the lowest cost - (combination of classification cost, box regression cost, and rotation cost). + """Compute the loss between predicted logits and boxes and target labels and boxes. Args: - logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) + logits: (N, S, C) tensor containing the predicted class logits for each query + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length N, where each element is a dict with keys "labels" and "boxes", + containing the ground truth labels and boxes for each image in the batch. + The boxes are in (cx, cy, w, h, sinθ, cosθ) format. Returns: - loss: the computed loss value + A scalar tensor containing the computed loss. """ - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters - (mean and covariance). + def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + to Gaussian distributions (mean and covariance). + The mean is simply (cx, cy), and the covariance is computed from the width, height, and rotation angle. + + Args: + boxes: (N, S, 6) tensor containing the rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A tuple of (mean, covariance) where: + - mean is a (N, S, 2) tensor containing the mean (cx, cy) of the Gaussian distributions + - covariance is a (N, S, 2, 2) tensor containing the covariance matrices of the Gaussian distributions """ cxcy = boxes[..., :2] @@ -614,11 +622,13 @@ def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch. dim=-2, ) + # Variance for a box half-width/half-height: σ² = (w/2)² + # Using w²/12 (uniform distribution) produces ~8x smaller variance, + # which collapses Bhattacharyya distance to the clamp ceiling and kills gradients. sx = (w / 2) ** 2 sy = (h / 2) ** 2 S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) - S[..., 0, 0] = sx S[..., 1, 1] = sy @@ -626,15 +636,26 @@ def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch. return cxcy, covariance def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted boxes and target boxes.""" - mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) + """Compute the ProbIoU loss between predicted and target boxes, + where boxes are represented as Gaussian distributions. + The ProbIoU loss is defined as 1 - exp(-Bhattacharyya distance), + where the Bhattacharyya distance is computed between the two Gaussian distributions. + + Args: + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + tgt_boxes: (N, S, 6) tensor containing the target boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A (N, S) tensor containing the ProbIoU loss for each pair of predicted and target boxes + """ + mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) + mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) delta = (mu1 - mu2).unsqueeze(-1) sigma = (sigma1 + sigma2) * 0.5 eps = 1e-6 eye = torch.eye(2, device=sigma.device) * eps + sigma_safe = sigma + eye sigma1_safe = sigma1 + eye sigma2_safe = sigma2 + eye @@ -649,6 +670,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + bhattacharyya = torch.clamp(bhattacharyya, min=0.0, max=10.0) probiou = torch.exp(-bhattacharyya) return 1 - probiou @@ -657,100 +679,73 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te total_cls = torch.tensor(0.0, device=device) total_box = torch.tensor(0.0, device=device) - total_rot = torch.tensor(0.0, device=device) + + num_matched_total = 0 for b in range(B): pred_logits = logits[b] pred_boxes_b = pred_boxes[b] - tgt_boxes = torch.as_tensor( - targets[b]["boxes"], - device=device, - dtype=pred_boxes.dtype, - ) - tgt_cls = torch.as_tensor( - targets[b]["labels"], - device=device, - dtype=torch.long, - ) + boxes = targets[b]["boxes"] + + if len(boxes) == 0: + # Penalize the model for any foreground boxes it guessed on this empty image + background_idx = self.num_classes - 1 + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + total_cls += F.cross_entropy(pred_logits, target_classes) + continue + + tgt_boxes = torch.as_tensor(boxes, device=device, dtype=pred_boxes.dtype) + tgt_cls = torch.as_tensor(targets[b]["labels"], device=device, dtype=torch.long) - num_gt = len(tgt_cls) + if tgt_boxes.ndim == 1: + tgt_boxes = tgt_boxes.unsqueeze(0) pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) with torch.no_grad(): - cls_prob = pred_logits.sigmoid() - alpha = 0.25 - gamma = 2.0 + out_logprob = pred_logits.log_softmax(-1) - neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) + cost_cls = -out_logprob[:, tgt_cls] + cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) + cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) - pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) + total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot - cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] - cost_l1 = torch.cdist( - pred_boxes_b[:, :4], - tgt_boxes[:, :4], - p=1, - ) - cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot - matching_matrix = torch.zeros( - (Q, num_gt), - dtype=torch.bool, - device=device, - ) + cost_np = total_cost.detach().cpu().numpy() + row_ind, col_ind = linear_sum_assignment(cost_np) - center_dist = torch.cdist( - pred_boxes_b[:, :2], - tgt_boxes[:, :2], - p=2, - ) + pos_idx = torch.as_tensor(row_ind, device=device) + gt_idx = torch.as_tensor(col_ind, device=device) - iou_like = torch.exp(-center_dist) - dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) + background_idx = self.num_classes - 1 - for gt_idx in range(num_gt): - _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) - matching_matrix[candidate_idx, gt_idx] = True + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + target_classes[pos_idx] = tgt_cls[gt_idx] - # resolve duplicate matches - multiple_match_mask = matching_matrix.sum(1) > 1 + cls_weights = torch.ones(self.num_classes, device=device) + cls_weights[background_idx] = 0.1 - if multiple_match_mask.any(): - duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) - min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) - # Set all matches to False for the duplicate indices, - # then set the match with the lowest cost to True - matching_matrix[duplicate_idx] = False - matching_matrix[duplicate_idx, min_cost_idx] = True + total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) - pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) + if pos_idx.numel() == 0: + continue - target_classes = torch.zeros((Q,), dtype=torch.long, device=device) + num_matched_total += pos_idx.numel() - # background = 0 - target_classes[pos_idx] = tgt_cls[gt_indices] + pred_sel = pred_boxes_b[pos_idx] + tgt_sel = tgt_boxes[gt_idx] - total_cls += F.cross_entropy(pred_logits, target_classes) + l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) + probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() + total_box += 5.0 * l1_loss + 2.0 * probiou_loss - if len(pos_idx) == 0: - continue + num_matched_total = max(num_matched_total, 1) + loss_cls = total_cls / B + loss_box = total_box / num_matched_total - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_indices] - # L1 loss on (cx, cy, w, h) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) - # ProbIoU loss on the whole box (including rotation) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() - total_box += 2.0 * l1_loss + 0.5 * probiou_loss - # Rotation loss - cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() - rot_loss = (1 - cos_sim).mean() - total_rot += 0.5 * rot_loss - # Average the loss over the batch - return (total_cls + total_box + total_rot) / B + return loss_cls + loss_box def _lw_detr(