Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions doctr/models/layout/lw_detr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,19 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int
exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
prob = exp / exp.sum(axis=-1, keepdims=True)

prob_fg = prob[:, :-1] # exclude background
scores = prob_fg.max(axis=-1)
labels = prob_fg.argmax(axis=-1)
scores = prob.max(axis=-1)
labels = prob.argmax(axis=-1)

# treat background as invalid prediction
bg = self.num_classes - 1
valid = labels != bg

scores = scores * valid

# Keep only topk predictions before NMS
if self.topk is not None and len(scores) > self.topk:
idxs = np.argsort(scores)[::-1][: self.topk]
idxs = np.argpartition(-scores, self.topk)[: self.topk]
idxs = idxs[np.argsort(-scores[idxs])]
else:
idxs = np.arange(len(scores))

Expand Down
51 changes: 31 additions & 20 deletions doctr/models/layout/lw_detr/layers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
Expand Down Expand Up @@ -409,26 +409,28 @@
"""
scale = 2 * math.pi
dim = hidden_size // 2
# Keep dim_t in float32 for numerical precision; cast output to match caller dtype
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
x_embed = pos_tensor[:, :, 0].float() * scale
y_embed = pos_tensor[:, :, 1].float() * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
w_embed = pos_tensor[:, :, 2].float() * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)

h_embed = pos_tensor[:, :, 3] * scale
h_embed = pos_tensor[:, :, 3].float() * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)

pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
# Cast back to the caller's dtype (supports bfloat16 / float16 AMP)
return pos.to(pos_tensor.dtype)


Expand Down Expand Up @@ -498,7 +500,7 @@
tensor containing the valid ratios for each level of the input feature maps

Returns:
reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4)
reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6)
tensor containing the reference point inputs for the decoder layers,
which are the normalized center coordinates,
width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps
Expand Down Expand Up @@ -535,25 +537,38 @@
return reference_points_inputs, query_pos

def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
reference_points = reference_points.to(deltas.device)
cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
"""Refine bounding boxes by applying the predicted deltas to the reference points.

Check notice on line 540 in doctr/models/layout/lw_detr/layers/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/layers/pytorch.py#L540

Missing dashed underline after section ('Returns') (D407)

Check notice on line 540 in doctr/models/layout/lw_detr/layers/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/layers/pytorch.py#L540

Section name should end with a newline ('Returns', not 'Returns:') (D406)
The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format.
The refined boxes are computed as follows:

# Clamp deltas to prevent exp() from shooting to Infinity during early training
wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4]
cx' = cx + delta_cx * w
cy' = cy + delta_cy * h
w' = w * exp(delta_w)
h' = h * exp(delta_h)
sinθ' = sinθ * cosΔ + cosθ * sinΔ
cosθ' = cosθ * cosΔ - sinθ * sinΔ

Args:
reference_points: (N, S, 6) tensor containing the reference points
deltas: (N, S, 6) tensor containing the predicted deltas

# Add eps=1e-6 to avoid division-by-zero NaN creation
Returns:
refined_boxes: (N, S, 6) tensor containing the refined bounding boxes
"""
reference_points = reference_points.to(deltas.device)
cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
# size
wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4]
# rotation
delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6)
sin_delta = delta_rot[..., 0:1]
cos_delta = delta_rot[..., 1:2]
sin_ref = reference_points[..., 4:5]
cos_ref = reference_points[..., 5:6]

# compose rotations
sin_new = sin_ref * cos_delta + cos_ref * sin_delta
cos_new = cos_ref * cos_delta - sin_ref * sin_delta

# Add eps=1e-6 here too
rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6)

return torch.cat((cxcy, wh, rot), dim=-1)

def forward(
Expand Down Expand Up @@ -590,11 +605,7 @@
if self.bbox_embed is not None:
delta = self.bbox_embed(hidden_states_norm)

reference_points = self.refine_boxes(
reference_points.squeeze(2),
delta,
)

reference_points = self.refine_boxes(reference_points, delta)
intermediate_reference_points.append(reference_points)

reference_points_inputs, query_pos = self.get_reference(
Expand Down
Loading
Loading