From eb0245702d9d37024a4307bea4ca85e2fac716f1 Mon Sep 17 00:00:00 2001 From: NotYuSheng Date: Wed, 20 May 2026 08:09:07 +0800 Subject: [PATCH] fix: overhaul conversation extraction and add optional LLM validation - Fixes incorrect time gap reference (was using wrong message index) - Collects all consecutive messages from same sender as one turn (no more dropped messages) - Splits chats into distinct conversations using a configurable silence gap threshold - Groups full back-and-forth into multi-turn samples instead of isolated pairs - Updates convert_to_sharegpt.py for the new multi-turn format - Adds validator.py: optional LLM pass that scores each sample for coherence and quality, controlled via DIALOGSMITH_LLM_VALIDATE env var - Adds .env.example documenting all new configuration options Closes #1, #2, #3, #4, #5, #6 Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 12 +++ scripts/convert_to_sharegpt.py | 33 ++++-- scripts/telegram_extract.py | 192 +++++++++++++++++++++++---------- scripts/validator.py | 143 ++++++++++++++++++++++++ 4 files changed, 310 insertions(+), 70 deletions(-) create mode 100644 .env.example create mode 100644 scripts/validator.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1179511 --- /dev/null +++ b/.env.example @@ -0,0 +1,12 @@ +# ── LLM Validation ──────────────────────────────────────────────────────────── +# Validates extracted conversation samples for coherence and quality before +# writing the dataset. Enabled by default when ANTHROPIC_API_KEY is set. +# Set to false to skip validation entirely (faster, no API calls). +DIALOGSMITH_LLM_VALIDATE=true + +# Model used for validation scoring (defaults to claude-haiku-4-5-20251001). +# A fast, cheap model is recommended here — the validator runs once per sample. +DIALOGSMITH_LLM_MODEL=claude-haiku-4-5-20251001 + +# Your Anthropic API key. Required when DIALOGSMITH_LLM_VALIDATE=true. +ANTHROPIC_API_KEY=your_api_key_here diff --git a/scripts/convert_to_sharegpt.py b/scripts/convert_to_sharegpt.py index fe7e684..9cc3963 100644 --- a/scripts/convert_to_sharegpt.py +++ b/scripts/convert_to_sharegpt.py @@ -3,25 +3,36 @@ input_path = "./data/chat_dataset.jsonl" output_path = "./data/chat_sharegpt.json" +ROLE_MAP = { + "user": "human", + "assistant": "gpt", +} + output_data = [] with open(input_path, "r", encoding="utf-8") as infile: for line in infile: sample = json.loads(line) - prompt = sample.get("prompt", "").strip() - response = sample.get("response", "").strip() + turns = sample.get("conversations", []) + + if not turns: + continue + + conversations = [] + for turn in turns: + role = ROLE_MAP.get(turn.get("role", ""), turn.get("role", "")) + text = turn.get("text", "").strip() + if role and text: + conversations.append({"from": role, "value": text}) - if not prompt or not response: - continue # Skip blank entries + # Must have at least one human and one gpt turn + roles_present = {t["from"] for t in conversations} + if "human" not in roles_present or "gpt" not in roles_present: + continue - output_data.append({ - "conversations": [ - {"from": "user", "value": prompt}, - {"from": "assistant", "value": response} - ] - }) + output_data.append({"conversations": conversations}) with open(output_path, "w", encoding="utf-8") as outfile: json.dump(output_data, outfile, indent=2, ensure_ascii=False) -print(f"Converted {len(output_data)} valid samples to ShareGPT format.") +print(f"Converted {len(output_data)} valid conversation samples to ShareGPT format.") diff --git a/scripts/telegram_extract.py b/scripts/telegram_extract.py index 7ec0308..f685d0e 100644 --- a/scripts/telegram_extract.py +++ b/scripts/telegram_extract.py @@ -1,9 +1,11 @@ import json # CONFIGURATION -RESULT_PATH = "./data/result.json" # Path to your exported result.json file +RESULT_PATH = "./data/result.json" # Path to your exported result.json file OUTPUT_PATH = "./data/chat_dataset.jsonl" -TIME_GAP_THRESHOLD = 30 # Seconds between chained messages +MESSAGE_CHAIN_THRESHOLD = 30 # Max seconds between chained messages from same sender +CONVERSATION_GAP_THRESHOLD = 3600 # Seconds of silence that starts a new conversation + def get_user_name(data): personal_info = data.get("personal_information", {}) @@ -11,82 +13,151 @@ def get_user_name(data): last = personal_info.get("last_name", "") return f"{first} {last}".strip() + def load_all_messages_from_result(data): all_messages = [] chat_list = data.get("chats", {}).get("list", []) - for chat in chat_list: messages = chat.get("messages", []) - if not messages: + if messages: + all_messages.append(messages) + return all_messages # List of message lists (per chat) + + +def get_text(msg): + """Extract plain text from a message, handling both string and entity list formats.""" + text = msg.get("text", "") + if isinstance(text, str): + return text.strip() + if isinstance(text, list): + return "".join( + t["text"] if isinstance(t, dict) else t + for t in msg.get("text_entities", text) + ).strip() + return "" + + +def is_valid_message(msg): + return msg.get("type") == "message" and bool(get_text(msg)) + + +def collect_turn(messages, start_idx, sender, chain_threshold): + """ + Collect all consecutive messages from `sender` starting at `start_idx`, + chaining them if within `chain_threshold` seconds of the previous message + in the chain. + + Returns (texts: list[str], next_idx: int, last_unixtime: int) + """ + texts = [] + last_unixtime = None + j = start_idx + + while j < len(messages): + msg = messages[j] + + if not is_valid_message(msg): + j += 1 continue - all_messages.append(messages) - return all_messages # List of message lists (per chat) + if msg.get("from") != sender: + break + + unixtime = int(msg["date_unixtime"]) + + if last_unixtime is not None and (unixtime - last_unixtime) > chain_threshold: + break + + texts.append(get_text(msg)) + last_unixtime = unixtime + j += 1 + + return texts, j, last_unixtime + + +def split_into_conversations(messages, gap_threshold): + """ + Split a flat message list into sub-lists representing distinct conversations, + based on silence gaps between valid messages. + """ + conversations = [] + current = [] + last_unixtime = None + + for msg in messages: + if not is_valid_message(msg): + continue + + unixtime = int(msg["date_unixtime"]) + + if last_unixtime is not None and (unixtime - last_unixtime) > gap_threshold: + if current: + conversations.append(current) + current = [] + + current.append(msg) + last_unixtime = unixtime + + if current: + conversations.append(current) + + return conversations + def format_conversations(message_groups, your_name): + """ + For each chat, split messages into conversations, then walk each conversation + collecting alternating turns into multi-turn samples. + + Each sample is a list of {"role": ..., "text": ...} dicts. + Consecutive messages from the same sender are concatenated into one turn. + """ samples = [] for messages in message_groups: - i = 0 - while i < len(messages) - 1: - msg = messages[i] - - if msg.get("type") != "message" or not msg.get("text"): - i += 1 - continue - - sender = msg.get("from") - if sender == your_name: - i += 1 - continue # Only process messages from others as prompts - - # Format the prompt - if isinstance(msg["text"], str): - prompt = f"{sender}: {msg['text']}" - else: - prompt = f"{sender}: {''.join([t['text'] for t in msg.get('text_entities', [])])}" - - # Gather your consecutive responses - response = [] - j = i + 1 - while j < len(messages): - next_msg = messages[j] - if next_msg.get("type") != "message" or not next_msg.get("text"): - j += 1 - continue - - if next_msg.get("from") != your_name: - break - - time_diff = int(next_msg["date_unixtime"]) - int(messages[j - 1]["date_unixtime"]) - if time_diff > TIME_GAP_THRESHOLD: - break - - text = next_msg["text"] - if isinstance(text, str): - response.append(text) - elif isinstance(text, list): - response.append(''.join([t["text"] for t in next_msg.get("text_entities", [])])) - - j += 1 - - if response: - samples.append({ - "prompt": prompt.strip(), - "response": "\n".join(response).strip() - }) - - i = j # Move pointer forward + conversations = split_into_conversations(messages, CONVERSATION_GAP_THRESHOLD) + + for conversation in conversations: + turns = [] + i = 0 + + while i < len(conversation): + msg = conversation[i] + sender = msg.get("from") + role = "assistant" if sender == your_name else "user" + + texts, next_i, _ = collect_turn( + conversation, i, sender, MESSAGE_CHAIN_THRESHOLD + ) + + if texts: + turn_text = "\n".join(texts) + # Merge with previous turn if same role (edge case: gap exceeded mid-block) + if turns and turns[-1]["role"] == role: + turns[-1]["text"] += "\n" + turn_text + else: + turns.append({"role": role, "text": turn_text}) + + i = next_i + + # Only keep conversations that have at least one user + one assistant turn + roles = [t["role"] for t in turns] + if "user" in roles and "assistant" in roles: + samples.append(turns) return samples + def save_dataset(samples, out_path): with open(out_path, "w", encoding="utf-8") as f: for sample in samples: - json.dump(sample, f, ensure_ascii=False) + json.dump({"conversations": sample}, f, ensure_ascii=False) f.write("\n") + if __name__ == "__main__": + from validator import validate_samples + print(f"Loading {RESULT_PATH}...") with open(RESULT_PATH, encoding="utf-8") as f: data = json.load(f) @@ -96,9 +167,12 @@ def save_dataset(samples, out_path): message_groups = load_all_messages_from_result(data) - print("Formatting prompt-response pairs...") + print("Formatting multi-turn conversations...") samples = format_conversations(message_groups, your_name) + print(f"Extracted {len(samples)} conversation samples.") + + samples = validate_samples(samples) - print(f"Saving {len(samples)} samples to {OUTPUT_PATH}...") + print(f"Saving {len(samples)} conversation samples to {OUTPUT_PATH}...") save_dataset(samples, OUTPUT_PATH) print("Done.") diff --git a/scripts/validator.py b/scripts/validator.py new file mode 100644 index 0000000..b05a659 --- /dev/null +++ b/scripts/validator.py @@ -0,0 +1,143 @@ +""" +Optional LLM-based conversation quality validator. + +Controlled via environment variables: + DIALOGSMITH_LLM_VALIDATE=true/false (default: true if ANTHROPIC_API_KEY is set) + DIALOGSMITH_LLM_MODEL=... (default: claude-haiku-4-5-20251001) + ANTHROPIC_API_KEY=... + +Each conversation sample is scored on two axes: + - coherence: does this read as a natural, continuous conversation? + - quality: is this a meaningful exchange worth training on? + +Samples that fail either check are excluded from the output. +A summary of filtered samples is printed so the user can audit decisions. +""" + +import json +import os + +VALIDATE_ENV = "DIALOGSMITH_LLM_VALIDATE" +MODEL_ENV = "DIALOGSMITH_LLM_MODEL" +DEFAULT_MODEL = "claude-haiku-4-5-20251001" + +COHERENCE_THRESHOLD = 0.5 # 0–1, below this the conversation is considered incoherent +QUALITY_THRESHOLD = 0.5 # 0–1, below this the sample is considered low-quality + + +def _should_validate(): + val = os.environ.get(VALIDATE_ENV, "").strip().lower() + if val == "false": + return False + if val == "true": + return True + # Default: enable if API key is present + return bool(os.environ.get("ANTHROPIC_API_KEY", "").strip()) + + +def _get_client(): + try: + import anthropic + except ImportError: + raise ImportError( + "The 'anthropic' package is required for LLM validation. " + "Install it with: pip install anthropic" + ) + api_key = os.environ.get("ANTHROPIC_API_KEY", "").strip() + if not api_key: + raise EnvironmentError( + "ANTHROPIC_API_KEY is not set. " + f"Set {VALIDATE_ENV}=false to disable validation." + ) + return anthropic.Anthropic(api_key=api_key) + + +def _format_conversation(turns): + lines = [] + for turn in turns: + role = turn.get("role", "unknown").upper() + text = turn.get("text", "").strip() + lines.append(f"{role}: {text}") + return "\n".join(lines) + + +def _score_sample(client, model, turns): + """ + Ask the LLM to score a conversation sample. + Returns (coherence: float, quality: float, reason: str). + """ + conversation_text = _format_conversation(turns) + + prompt = f"""You are evaluating a conversation sample for use in fine-tuning a language model. + +Rate the following conversation on two dimensions, each from 0.0 to 1.0: + +1. coherence: Does this read as a natural, continuous conversation where each message follows logically from the previous? (0 = completely disjointed, 1 = perfectly coherent) +2. quality: Is this a meaningful, substantive exchange worth training on? Penalise one-word replies, pure greetings, or exchanges with no informational content. (0 = worthless, 1 = highly valuable) + +Respond with ONLY a JSON object in this exact format: +{{"coherence": , "quality": , "reason": ""}} + +Conversation: +{conversation_text}""" + + response = client.messages.create( + model=model, + max_tokens=128, + messages=[{"role": "user", "content": prompt}], + ) + + raw = response.content[0].text.strip() + result = json.loads(raw) + return float(result["coherence"]), float(result["quality"]), result.get("reason", "") + + +def validate_samples(samples): + """ + Validate a list of conversation samples. + + Each sample is a list of {"role": ..., "text": ...} dicts (as produced by telegram_extract.py). + + Returns filtered list of samples that pass validation. + If validation is disabled or unavailable, returns all samples unchanged. + """ + if not _should_validate(): + print("[validator] LLM validation disabled — skipping.") + return samples + + try: + client = _get_client() + except (ImportError, EnvironmentError) as e: + print(f"[validator] WARNING: {e}") + print("[validator] Skipping LLM validation and returning all samples.") + return samples + + model = os.environ.get(MODEL_ENV, DEFAULT_MODEL).strip() + print(f"[validator] Running LLM validation with model: {model}") + + passed = [] + filtered = [] + + for i, turns in enumerate(samples): + try: + coherence, quality, reason = _score_sample(client, model, turns) + except Exception as e: + print(f"[validator] Sample {i}: scoring failed ({e}), keeping sample.") + passed.append(turns) + continue + + if coherence < COHERENCE_THRESHOLD: + filtered.append((i, "incoherent", coherence, quality, reason)) + elif quality < QUALITY_THRESHOLD: + filtered.append((i, "low-quality", coherence, quality, reason)) + else: + passed.append(turns) + + print(f"[validator] {len(passed)} passed, {len(filtered)} filtered out of {len(samples)} total.") + + if filtered: + print("[validator] Filtered samples:") + for idx, reason_type, coh, qual, reason in filtered: + print(f" sample {idx:4d} | {reason_type:12s} | coherence={coh:.2f} quality={qual:.2f} | {reason}") + + return passed