-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdenoise.py
More file actions
181 lines (155 loc) · 6.72 KB
/
denoise.py
File metadata and controls
181 lines (155 loc) · 6.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import json
import re
import os
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import ast
from typing import List, Tuple, Dict
import astor
class Denoiser:
"""Attention-based code denoiser. Only block granularity and top-k aggregation are supported."""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.model.eval()
def merge_consecutive_blank_lines(self, text: str) -> str:
"""Merge consecutive blank lines into a single blank line."""
lines = text.split('\n')
cleaned = []
prev_blank = False
for line in lines:
if line.strip() == "":
if not prev_blank:
cleaned.append("")
prev_blank = True
else:
cleaned.append(line)
prev_blank = False
return "\n".join(cleaned)
def aggregate_attention_topk(self, indices: List[int], attn_weights: List[float]) -> float:
"""Aggregate attention weights using top-k (top 10%)."""
if not indices:
return -1.0
values = [attn_weights[i] for i in indices]
k = max(1, int(len(values) * 0.1))
topk_vals = sorted(values, reverse=True)[:k]
return sum(topk_vals) / k
def denoise_code_by_block(self, code: str, query: str, ratio: float) -> str:
"""Block-level denoising with top-k aggregation."""
original_lines = code.split('\n')
# Split code into blocks by blank lines
blocks = []
current_block = []
for idx, line in enumerate(original_lines):
if line.strip() == "":
if current_block:
blocks.append(current_block)
current_block = []
else:
current_block.append(idx)
if current_block:
blocks.append(current_block)
if not blocks:
return code
input_text = "\n".join(original_lines) + "\n" + query.strip()
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
attentions = torch.stack(outputs.attentions)
last_layer_attn = attentions[-1].mean(dim=0)[0]
generated_index = inputs["input_ids"].shape[1] - 1
attn_weights = last_layer_attn[generated_index, :-1].cpu().numpy()
# Map lines to token indices
line_token_indices = []
start_idx = 0
for line in original_lines:
tokens = self.tokenizer.tokenize(line + "\n")
indices = list(range(start_idx, start_idx + len(tokens)))
line_token_indices.append(indices)
start_idx += len(tokens)
line_scores = [self.aggregate_attention_topk(indices, attn_weights) for indices in line_token_indices]
block_lines = [list(b) for b in blocks]
total_tokens = sum(len(line_token_indices[i]) for b in block_lines for i in b)
kept_token_count = total_tokens
kept_lines = set(i for b in block_lines for i in b)
def calc_block_score(block):
scores = [line_scores[i] for i in block if i in kept_lines and line_scores[i] >= 0]
return sum(scores) / len(scores) if scores else 0.0
# Iteratively drop the worst line within the currently worst-scored block
while kept_token_count / total_tokens > ratio:
best_block = None
best_score = float('inf')
for i, block in enumerate(block_lines):
valid_lines = [j for j in block if j in kept_lines and line_scores[j] >= 0]
if not valid_lines:
continue
score = calc_block_score(valid_lines)
if score < best_score:
best_score = score
best_block = i
if best_block is None:
break
block = block_lines[best_block]
worst_line = min(
(j for j in block if j in kept_lines and line_scores[j] >= 0),
key=lambda x: line_scores[x],
default=None
)
if worst_line is None:
break
kept_lines.remove(worst_line)
kept_token_count -= len(line_token_indices[worst_line])
denoised_lines = []
for i, line in enumerate(original_lines):
if line.strip() == "":
denoised_lines.append("")
elif i in kept_lines:
denoised_lines.append(line)
return self.merge_consecutive_blank_lines("\n".join(denoised_lines))
def extract_contexts_and_query(self, text: str) -> Tuple[List[str], str]:
"""Extract code context blocks and the query part from a prompt-like text."""
pattern = (
r"(# the below code fragment can be found in:\n"
r"# .+?\n"
r"# -{50,}\n"
r"([\s\S]*?)"
r"# -{50,}\n)"
)
matches = list(re.finditer(pattern, text))
contexts = []
for match in matches:
full_block = match.group(1)
raw_code = match.group(2)
header_lines = full_block.splitlines()[:3]
context = "\n".join(header_lines) + "\n" + raw_code
contexts.append(context)
if matches:
end_of_last = matches[-1].end()
query = text[end_of_last:].strip()
else:
query = text.strip()
return contexts, query
def build_prompt(self, query: str, denoised_context: List[Dict]) -> str:
"""Build the final prompt with denoised context blocks."""
seperator = '# ' + '-' * 50
prepend_context = "# Here are some relevant code fragments from other files of the repo:\n"
prepend_context += seperator + '\n'
prepend_blocks = []
for context in denoised_context:
code = context['code'].splitlines()
code_line_comment = [line for line in code]
context1 = '\n'.join(code_line_comment) + '\n' + seperator
prepend_blocks.append(context1)
prepend_context += '\n'.join(prepend_blocks)
return prepend_context + '\n' + query
def denoise_contexts(self, contexts: List[str], query: str, ratio: float = 0.9) -> List[Dict]:
"""Denoise a list of code context strings at block-level using top-k aggregation."""
denoised_context = []
for candidate in contexts:
code = candidate
denoised_code = self.denoise_code_by_block(code, query, ratio)
context_item = {'code': denoised_code}
denoised_context.append(context_item)
return denoised_context