-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathd2code_pipeline.py
More file actions
122 lines (102 loc) · 5.01 KB
/
d2code_pipeline.py
File metadata and controls
122 lines (102 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import json
from tqdm import tqdm
from typing import List, Dict, Any, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from .utils import load_model, load_jsonl, dump_jsonl
from .diversify import DiversifySelector
from .denoise import Denoiser
class D2CodePipeline:
"""D²-Code main pipeline: integrate diversification and denoising."""
def __init__(self, model_name_or_path: str, device: Optional[str] = None):
"""Initialize pipeline and load model/tokenizer."""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model_path = f"/path/to/llms/{model_name_or_path}"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
output_attentions=True
).to(self.device)
self.diversify_selector = DiversifySelector()
self.denoiser = Denoiser(self.model, self.tokenizer)
def process_contexts(self, contexts: List[Dict], query: str,
diversify_config: Dict[str, Any] = None,
denoise_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""Diversify contexts first, then denoise a top-k subset."""
# Phase 1: diversify (penalty-based reward)
print("Phase 1: Diversifying contexts using penalty-based reward function...")
self.diversify_selector.alpha = diversify_config['alpha']
self.diversify_selector.cluster_size = diversify_config['cluster_size']
diversified_contexts = self.diversify_selector.context_selection(contexts)
# Phase 2: denoise only top-k diversified contexts
print(f"Phase 2: Denoising top-{denoise_config['topk']} contexts using block-level top-k aggregation...")
topk_contexts = diversified_contexts[:denoise_config['topk']]
# Extract raw code strings
context_texts = []
for ctx in topk_contexts:
if 'context' in ctx:
context_texts.append(ctx['context'])
elif 'code' in ctx:
context_texts.append(ctx['code'])
else:
context_texts.append(str(ctx))
denoised_contexts = self.denoiser.denoise_contexts(
context_texts, query, ratio=denoise_config['ratio']
)
final_prompt = self.denoiser.build_prompt(query, denoised_contexts)
return {
'original_contexts': contexts,
'diversified_contexts': diversified_contexts,
'topk_contexts': topk_contexts,
'denoised_contexts': denoised_contexts,
'final_prompt': final_prompt,
'query': query
}
def process_batch(self, data_list: List[Dict],
diversify_config: Dict[str, Any] = None,
denoise_config: Dict[str, Any] = None,
output_file: Optional[str] = None) -> List[Dict]:
"""Process a batch of items."""
results = []
for item in tqdm(data_list, desc="Processing batch"):
try:
contexts = item.get('contexts', item.get('top_k_context', []))
query = item.get('query', item.get('prompt', ''))
if not contexts or not query:
print(f"Skipping item due to missing data: {item}")
continue
result = self.process_contexts(contexts, query, diversify_config, denoise_config)
result['metadata'] = item.get('metadata', {})
result['task_id'] = item.get('task_id', item.get('metadata', {}).get('task_id', ''))
results.append(result)
except Exception as e:
print(f"Error processing item: {e}")
continue
if output_file:
dump_jsonl(results, output_file)
print(f"Results saved to: {output_file}")
return results
def process_from_file(self, input_file: str, output_file: str,
diversify_config: Dict[str, Any] = None,
denoise_config: Dict[str, Any] = None) -> List[Dict]:
"""Load items from a JSONL file, process, and optionally save results."""
print(f"Loading data from: {input_file}")
data_list = load_jsonl(input_file)
print(f"Processing {len(data_list)} items...")
return self.process_batch(data_list, diversify_config, denoise_config, output_file)
def main():
"""Example entrypoint."""
model_name = ''
pipeline = D2CodePipeline(model_name)
input_file = f"/path/to/initial-retrieved-context"
output_file = f"/path/to/optimized-context"
print("Processing with default settings...")
try:
pipeline.process_from_file(input_file, output_file)
except Exception as e:
print(f"Error processing: {e}")
if __name__ == "__main__":
main()