diff --git a/pyproject.toml b/pyproject.toml index 07a8fd06..62927639 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "addict>=2.4.0", "deepdiff>=6.0.0", "pandas>=1.5.0", + "fireworks-ai>=0.19.12", ] [project.urls] diff --git a/tests/pytest/data/halueval_sample_dataset.jsonl b/tests/pytest/data/halueval_sample_dataset.jsonl new file mode 100644 index 00000000..3411f0bd --- /dev/null +++ b/tests/pytest/data/halueval_sample_dataset.jsonl @@ -0,0 +1,3 @@ +{"knowledge": "Arthur's Magazine (1844–1846) was an American literary periodical published in Philadelphia in the 19th century.First for Women is a woman's magazine published by Bauer Media Group in the USA.", "question": "Which magazine was started first Arthur's Magazine or First for Women?", "right_answer": "Arthur's Magazine", "hallucinated_answer": "First for Women was started first."} +{"knowledge": "The Oberoi family is an Indian family that is famous for its involvement in hotels, namely through The Oberoi Group.The Oberoi Group is a hotel company with its head office in Delhi.", "question": "The Oberoi family is part of a hotel company that has a head office in what city?", "right_answer": "Delhi", "hallucinated_answer": "The Oberoi family's hotel company is based in Mumbai."} +{"knowledge": "Allison Beth \"Allie\" Goertz (born March 2, 1991) is an American musician. Goertz is known for her satirical songs based on various pop culture topics. Her videos are posted on YouTube under the name of Cossbysweater.Milhouse Mussolini van Houten is a fictional character featured in the animated television series \"The Simpsons\", voiced by Pamela Hayden, and created by Matt Groening who named the character after President Richard Nixon's middle name.", "question": "Musician and satirist Allie Goertz wrote a song about the \"The Simpsons\" character Milhouse, who Matt Groening named after who?", "right_answer": "President Richard Nixon", "hallucinated_answer": "Allie Goertz wrote a song about Milhouse, a popular TV character, named after an influential political figure."} diff --git a/tests/pytest/test_hallucination.py b/tests/pytest/test_hallucination.py new file mode 100644 index 00000000..23179a25 --- /dev/null +++ b/tests/pytest/test_hallucination.py @@ -0,0 +1,118 @@ +""" +Hallucination detection test using LLM-as-judge. + +This test demonstrates how to detect factual inaccuracies in model responses +by comparing them against provided knowledge using an LLM judge, similar to +tau's evaluate_nl_assertions approach. +""" + +import json +from typing import Any, Dict, List + +from fireworks import LLM + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult +from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test + +judge_llm = LLM(model="accounts/fireworks/models/kimi-k2-instruct", deployment_type="serverless") + +def hallucination_dataset_adapter(data: List[Dict[str, Any]]) -> List[EvaluationRow]: + """Convert HaluEval dataset to EvaluationRow objects.""" + return [ + EvaluationRow( + messages=[Message(role="user", content=f"Knowledge: {item['knowledge']}\n\nQuestion: {item['question']}")], + ground_truth=item["right_answer"] + ) + for item in data + ] + + +@evaluation_test( + input_dataset=["tests/pytest/data/halueval_sample_dataset.jsonl"], + dataset_adapter=hallucination_dataset_adapter, + model=["accounts/fireworks/models/kimi-k2-instruct"], + rollout_input_params=[{"temperature": 0.0, "max_tokens": 512}], + rollout_processor=default_single_turn_rollout_processor, + threshold_of_success=1.0, + num_runs=1, + mode="pointwise", +) +def test_hallucination_detection(row: EvaluationRow) -> EvaluationRow: + """ + Test for response correctness using LLM-as-judge. + """ + messages = row.messages + assistant_response = messages[-1].content + + if not assistant_response: + return EvaluateResult(score=0.0, reason="❌ No assistant response found") + + correct_answer = row.ground_truth + + system_prompt = """ + TASK + - You will be given an assistant's response and the correct answer. + - Your job is to evaluate whether the assistant's response is factually consistent with the correct answer. + - Grade whether the assistant got it right or wrong. + + FORMAT + - Your response should be a JSON object with the following fields: + - `reasoning`: a short explanation for your classification + - `is_correct`: `true` if the assistant's response matches the correct answer, `false` otherwise + + Example response structure: + { + "reasoning": "", + "is_correct": + } + """ + + user_prompt = f""" + assistant_response: + {assistant_response} + + correct_answer: + {correct_answer} + """ + + try: + response = judge_llm.chat.completions.create( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.1, + max_tokens=500, + ) + + result_data = json.loads(response.choices[0].message.content) + is_correct = result_data.get("is_correct", False) + reasoning = result_data.get("reasoning", "Could not parse reasoning") + + except Exception as e: + # Fallback if parsing fails + is_correct = False + reasoning = f"Evaluation failed: {str(e)}" + + score = 1.0 if is_correct else 0.0 + + if is_correct: + assessment = "✅ Response is correct" + else: + assessment = "❌ Response is incorrect" + + reason = f"{assessment}\nReasoning: {reasoning}" + + row.evaluation_result = EvaluateResult( + score=score, + reason=reason, + metrics={ + "llm_judge": MetricResult( + score=score, + reason=reasoning, + is_score_valid=True + ) + } + ) + + return row \ No newline at end of file diff --git a/uv.lock b/uv.lock index 7738a5c3..99d89084 100644 --- a/uv.lock +++ b/uv.lock @@ -903,7 +903,7 @@ version = "3.22.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "docstring-parser", marker = "python_full_version < '4.0'" }, + { name = "docstring-parser", marker = "python_full_version < '4'" }, { name = "rich" }, { name = "rich-rst" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, @@ -1130,6 +1130,7 @@ dependencies = [ { name = "deepdiff" }, { name = "docstring-parser" }, { name = "fastapi" }, + { name = "fireworks-ai" }, { name = "fsspec" }, { name = "gymnasium" }, { name = "httpx" }, @@ -1239,6 +1240,7 @@ requires-dist = [ { name = "docstring-parser", specifier = ">=0.15" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.68.0" }, + { name = "fireworks-ai", specifier = ">=0.19.12" }, { name = "fireworks-ai", marker = "extra == 'fireworks'", specifier = ">=0.19.10" }, { name = "flake8", marker = "extra == 'dev'", specifier = ">=3.9.2" }, { name = "fsspec" },