Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions F2LLM/configs/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"model_path": "models/qwen3-4b",
"experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs",
"train_data_path": "training_data/data_tokenized_qwen",
"model_path": "models/bert",
"experiment_id": "bert+lr.8e-6+bs.16x32+context.1024+1epochs",
"train_data_path": "data_tokenized_bert",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
Expand Down
72 changes: 72 additions & 0 deletions F2LLM/model_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

import torch
from transformers import AutoModel, AutoTokenizer

class BertEmbedder:
def __init__(self,
model_path,
max_seq_length=512,
args=None,
pool_strategy="cls"
):
self.args = args
self.dtype = torch.bfloat16
self.device = None
self.encoder = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype
)
if hasattr(self.encoder.config, "use_cache"):
self.encoder.config.use_cache = False

self.tokenizer = AutoTokenizer.from_pretrained(model_path)

if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
elif self.tokenizer.sep_token is not None:
self.tokenizer.pad_token = self.tokenizer.sep_token
elif self.tokenizer.cls_token is not None:
self.tokenizer.pad_token = self.tokenizer.cls_token

self.max_seq_length = max_seq_length
self.pool_strategy = pool_strategy

def set_device(self):
self.device = self.encoder.device

def _pool(self, last_hidden_state, attention_mask):
# last\_hidden\_state: [bs, seq, d], attention\_mask: [bs, seq]
if self.pool_strategy == "mean":
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [bs, seq, 1]
summed = (last_hidden_state * mask).sum(dim=1) # [bs, d]
denom = mask.sum(dim=1).clamp_min(1e-6) # [bs, 1]
return summed / denom
return last_hidden_state[:, 0, :] # [CLS]

def forward(self, batch):
bs = batch['bs']
num_hard_neg = int((len(batch['input_ids']) - 2 * bs) / bs)

outputs = self.encoder(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
last_hidden_state = outputs.last_hidden_state
attn = batch['attention_mask']

query_emb = self._pool(last_hidden_state[:bs], attn[:bs]).unsqueeze(1) # [bs, 1, d]
passage_emb = self._pool(last_hidden_state[bs:2*bs], attn[bs:2*bs]).unsqueeze(1) # [bs, 1, d]

if num_hard_neg == 0:
neg_emb = None
else:
neg_all = self._pool(last_hidden_state[2*bs:], attn[2*bs:]) # [bs*num\_hard\_neg, d]
neg_emb = neg_all.view(bs, num_hard_neg, -1) # [bs, num\_hard\_neg, d]

return {
'query_passage_features': query_emb,
'passage_passage_features': passage_emb,
'negative_passage_features': neg_emb
}
57 changes: 57 additions & 0 deletions F2LLM/tokenize_data_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from multiprocessing import Pool
import numpy as np
import pandas as pd
import os
from transformers import AutoTokenizer
from tqdm.auto import tqdm


tokenizer = AutoTokenizer.from_pretrained('models/qwen3-0.6b')
max_seq_length = 1023

output_dir = 'data_tokenized_qwen'
os.makedirs(output_dir, exist_ok=True)


def process_sent(sentence):

tokenizer_outputs = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=False)

return np.array(tokenizer_outputs.input_ids + [tokenizer.eos_token_id])


def process_sent_batch(s):
return s.apply(process_sent)

def parallelize(data, func, num_of_processes=8):
indices = np.array_split(data.index, num_of_processes)
data_split = [data.iloc[idx] for idx in indices]
with Pool(num_of_processes) as pool:
data = pd.concat(pool.map(func, data_split))
return data


root_dir = 'datasets'

for ds_name in tqdm(sorted(parquet_files)):
print(ds_name, flush=True)

df = pd.read_parquet(f"{root_dir}/{ds_name}")
df['query_input_ids'] = parallelize(df['query'], process_sent_batch, 62)

num_neg = 24 if 'negative_2' in df.keys() else 1

ls = df.passage.to_list()
for i in range(1, num_neg+1):
ls += df[f'negative_{i}'].to_list()
ls = list(set(ls))
df_tmp = pd.DataFrame({'text': ls})
df_tmp['input_ids'] = parallelize(df_tmp['text'], process_sent_batch, 62)
df_tmp = df_tmp.set_index('text')

df['passage_input_ids'] = df.passage.map(df_tmp.input_ids)

for i in range(1, num_neg+1):
df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids)

df.to_parquet(f'{output_dir}/{ds_name}', index=False)