diff --git a/.cursor/environment.json b/.cursor/environment.json new file mode 100644 index 0000000..cead2c2 --- /dev/null +++ b/.cursor/environment.json @@ -0,0 +1,4 @@ +{ + "install": "bash .cursor/scripts/install-python-dev-tools.sh", + "start": "bash .cursor/scripts/startup-path.sh" +} diff --git a/.cursor/scripts/install-python-dev-tools.sh b/.cursor/scripts/install-python-dev-tools.sh new file mode 100755 index 0000000..2a5b4a5 --- /dev/null +++ b/.cursor/scripts/install-python-dev-tools.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Install Python dev dependencies into user site-packages. +python3 -m pip install --user -r requirements-dev.txt + +# Expose user-installed tool entrypoints from a stable HOME/bin directory. +mkdir -p "${HOME}/bin" +for tool in pytest ruff mypy; do + if [ -x "${HOME}/.local/bin/${tool}" ]; then + ln -sf "${HOME}/.local/bin/${tool}" "${HOME}/bin/${tool}" + fi +done diff --git a/.cursor/scripts/startup-path.sh b/.cursor/scripts/startup-path.sh new file mode 100755 index 0000000..9548b3d --- /dev/null +++ b/.cursor/scripts/startup-path.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +path_export='export PATH="$HOME/.local/bin:$HOME/bin:$PATH"' + +for profile in "${HOME}/.bashrc" "${HOME}/.profile"; do + touch "${profile}" + if ! grep -Fq "${path_export}" "${profile}"; then + printf '\n%s\n' "${path_export}" >> "${profile}" + fi +done + +export PATH="${HOME}/.local/bin:${HOME}/bin:${PATH}" diff --git a/src/dgf_common/code_utils.py b/src/dgf_common/code_utils.py index d9bbc95..b4dca4b 100644 --- a/src/dgf_common/code_utils.py +++ b/src/dgf_common/code_utils.py @@ -1,6 +1,7 @@ import re -_FENCED_CODE_PATTERN = re.compile(r"```(?:c|C|cpp|c\+\+)?\s*(.*?)```", re.DOTALL) +_FENCED_C_CPP_PATTERN = re.compile(r"```(?:c|C|cpp|c\+\+)\s*(.*?)```", re.DOTALL) +_FENCED_UNTAGGED_PATTERN = re.compile(r"```[ \t]*\n(.*?)```", re.DOTALL) def extract_c_code_block(raw_text): @@ -11,7 +12,12 @@ def extract_c_code_block(raw_text): if raw_text is None: return "" - match = _FENCED_CODE_PATTERN.search(raw_text) + match = _FENCED_C_CPP_PATTERN.search(raw_text) + if match: + return match.group(1).strip() + + # Backward-compatible fallback for unlabeled fenced blocks. + match = _FENCED_UNTAGGED_PATTERN.search(raw_text) if match: return match.group(1).strip() return raw_text.strip() diff --git a/src/dgf_feedback/branch_coverage_collector.py b/src/dgf_feedback/branch_coverage_collector.py index 1c1fc8c..2c9265e 100644 --- a/src/dgf_feedback/branch_coverage_collector.py +++ b/src/dgf_feedback/branch_coverage_collector.py @@ -31,6 +31,9 @@ def collect_branch_coverage(self, binary_path, work_dir): except subprocess.CalledProcessError as exc: LOGGER.warning("Failed to merge profile data: %s", exc.stderr) return {}, 0.0 + except OSError as exc: + LOGGER.warning("Failed to execute profile merge command %s: %s", self.profdata, exc) + return {}, 0.0 export_cmd = [ self.cov, "export", @@ -50,6 +53,9 @@ def collect_branch_coverage(self, binary_path, work_dir): except subprocess.CalledProcessError as exc: LOGGER.warning("Failed to export coverage json: %s", exc.stderr) return {}, 0.0 + except OSError as exc: + LOGGER.warning("Failed to execute coverage export command %s: %s", self.cov, exc) + return {}, 0.0 except json.JSONDecodeError: LOGGER.warning("Invalid coverage JSON output from llvm-cov") return {}, 0.0 diff --git a/src/dgf_feedback/feedback_controller.py b/src/dgf_feedback/feedback_controller.py index f97c0e8..b483f9e 100644 --- a/src/dgf_feedback/feedback_controller.py +++ b/src/dgf_feedback/feedback_controller.py @@ -57,7 +57,8 @@ def run_iteration(self, num_samples=5, base_num_funcs=5): for i in range(num_samples): # === 动态选取 APIs === candidate_apis = self.api_manager.sample_api_combination(base_num_funcs) - mutated_apis = self.mutator.mutate(candidate_apis) + parent_apis = history_api_combos[-1] if history_api_combos else None + mutated_apis = self.mutator.mutate(candidate_apis, parents=parent_apis) history_api_combos.append(mutated_apis) # 记录 prompt 使用次数 diff --git a/src/dgf_feedback/prompt_mutator.py b/src/dgf_feedback/prompt_mutator.py index 4f566f0..02c8bcc 100644 --- a/src/dgf_feedback/prompt_mutator.py +++ b/src/dgf_feedback/prompt_mutator.py @@ -32,12 +32,15 @@ def crossover(self, parent_apis1, parent_apis2): return merged def mutate(self, current_apis, parents=None): - mode = random.choice(["insert", "replace", "crossover"]) + modes = ["insert", "replace"] + if parents: + modes.append("crossover") + mode = random.choice(modes) if mode == "insert": return self.insert(current_apis) elif mode == "replace": return self.replace(current_apis) - elif mode == "crossover" and parents is not None: + elif mode == "crossover": return self.crossover(current_apis, parents) else: return current_apis # 保底返回 diff --git a/src/dgf_feedback/test_feedback.py b/src/dgf_feedback/test_feedback.py index 91e02e4..fbf98c1 100644 --- a/src/dgf_feedback/test_feedback.py +++ b/src/dgf_feedback/test_feedback.py @@ -1,3 +1,4 @@ +import dgf_feedback.prompt_mutator as prompt_mutator_module from dgf_feedback.api_manager import APIManager from dgf_feedback.prompt_mutator import PromptMutator from dgf_feedback.sample_filter import SampleFilter @@ -39,3 +40,28 @@ def test_prompt_mutator_insert_replace_crossover(): crossed = mutator.crossover(["A", "B"], ["B", "C"]) assert set(crossed) == {"A", "B", "C"} + + +def test_prompt_mutator_without_parents_does_not_offer_crossover(monkeypatch): + manager = APIManager(["A", "B", "C"]) + mutator = PromptMutator(manager) + seen_modes = [] + + def fake_choice(options): + seen_modes.extend(options) + return "insert" + + monkeypatch.setattr(prompt_mutator_module.random, "choice", fake_choice) + mutated = mutator.mutate(["A"], parents=None) + + assert "crossover" not in seen_modes + assert "A" in mutated + + +def test_prompt_mutator_crossover_with_parent(monkeypatch): + manager = APIManager(["A", "B", "C"]) + mutator = PromptMutator(manager) + + monkeypatch.setattr(prompt_mutator_module.random, "choice", lambda options: "crossover") + mutated = mutator.mutate(["A"], parents=["B", "C"]) + assert set(mutated) == {"A", "B", "C"} diff --git a/src/dgf_header_parser/ast_parser.py b/src/dgf_header_parser/ast_parser.py index 4f000a8..2df623d 100644 --- a/src/dgf_header_parser/ast_parser.py +++ b/src/dgf_header_parser/ast_parser.py @@ -18,9 +18,16 @@ def parse(self, header_file): tu = index.parse(header_file, args=args) return tu - def extract(self, tu): + def extract(self, tu, source_file=None): functions, structs, typedefs, enums = [], [], [], [] + target_file = os.path.realpath(source_file or tu.spelling) for node in tu.cursor.get_children(): + location_file = getattr(getattr(node, "location", None), "file", None) + location_name = getattr(location_file, "name", None) + if not location_name: + continue + if os.path.realpath(location_name) != target_file: + continue kind = node.kind if kind == cindex.CursorKind.FUNCTION_DECL: functions.append(self.extract_function(node)) diff --git a/src/dgf_header_parser/constraint_inferencer.py b/src/dgf_header_parser/constraint_inferencer.py index fb26eca..285e827 100644 --- a/src/dgf_header_parser/constraint_inferencer.py +++ b/src/dgf_header_parser/constraint_inferencer.py @@ -10,7 +10,7 @@ def infer_constraints(self): for file_entry in self.api_data: for func in file_entry["result"]["functions"]: func_name = func["name"] - constraints[func_name] = [] + constraints.setdefault(func_name, []) for param in func["parameters"]: pname = param["name"].lower() diff --git a/src/dgf_header_parser/extractor.py b/src/dgf_header_parser/extractor.py index 665c889..d94a928 100644 --- a/src/dgf_header_parser/extractor.py +++ b/src/dgf_header_parser/extractor.py @@ -11,19 +11,28 @@ def extract_all_api(header_dir, include_dirs): headers = collect_header_files(header_dir) + if not headers: + LOGGER.warning("No header files found under %s", header_dir) + return [] parser = ASTParser(include_dirs) all_results = [] + failed = 0 for h in tqdm(headers, desc="Parsing Headers"): try: tu = parser.parse(h) - result = parser.extract(tu) + result = parser.extract(tu, source_file=h) all_results.append({ "file": h, "result": result }) except Exception as e: LOGGER.warning("Error parsing %s: %s", h, e) + failed += 1 + + LOGGER.info("Header extraction finished: %d succeeded, %d failed", len(all_results), failed) + if not all_results: + raise RuntimeError("Failed to parse any header file.") return all_results diff --git a/src/dgf_header_parser/test_constraint_inferencer.py b/src/dgf_header_parser/test_constraint_inferencer.py new file mode 100644 index 0000000..f4ee828 --- /dev/null +++ b/src/dgf_header_parser/test_constraint_inferencer.py @@ -0,0 +1,35 @@ +from dgf_header_parser.constraint_inferencer import ConstraintInferencer + + +def test_infer_constraints_keeps_entries_from_same_function_name(): + api_data = [ + { + "file": "a.h", + "result": { + "functions": [ + { + "name": "dup_func", + "result_type": "void", + "parameters": [{"name": "input_len", "type": "size_t"}], + } + ] + }, + }, + { + "file": "b.h", + "result": { + "functions": [ + { + "name": "dup_func", + "result_type": "void", + "parameters": [{"name": "file_path", "type": "const char *"}], + } + ] + }, + }, + ] + + constraints = ConstraintInferencer(api_data).infer_constraints() + assert "dup_func" in constraints + params = {item["param"] for item in constraints["dup_func"]} + assert {"input_len", "file_path"}.issubset(params) diff --git a/src/dgf_prompt_generator/llm_caller.py b/src/dgf_prompt_generator/llm_caller.py index 1769e7f..c6969b6 100644 --- a/src/dgf_prompt_generator/llm_caller.py +++ b/src/dgf_prompt_generator/llm_caller.py @@ -39,7 +39,11 @@ def __init__(self, api_key=None, base_url=None, model=None, temperature=None): final_temperature = os.getenv("OPENAI_TEMPERATURE", "0.2") if local_config is not None and temperature is None: final_temperature = getattr(local_config, "TEMPERATURE", final_temperature) - final_temperature = float(final_temperature) + try: + final_temperature = float(final_temperature) + except (TypeError, ValueError): + LOGGER.warning("Invalid OPENAI_TEMPERATURE=%r, fallback to 0.2", final_temperature) + final_temperature = 0.2 self.client = openai.OpenAI( api_key=final_api_key, @@ -58,4 +62,11 @@ def generate_code(self, prompt): temperature=self.temperature ) LOGGER.debug("LLM generation finished using model=%s", self.model) - return response.choices[0].message.content + choices = getattr(response, "choices", None) + if not choices: + raise ValueError("LLM response has no choices.") + message = getattr(choices[0], "message", None) + content = getattr(message, "content", None) + if not isinstance(content, str) or not content.strip(): + raise ValueError("LLM response has empty content.") + return content diff --git a/src/dgf_prompt_generator/prompt_template.py b/src/dgf_prompt_generator/prompt_template.py index 31c467a..b17d5ac 100644 --- a/src/dgf_prompt_generator/prompt_template.py +++ b/src/dgf_prompt_generator/prompt_template.py @@ -9,13 +9,16 @@ def __init__(self, api_info_json, system_includes=None, api_prefixes=None): with open(api_info_json, "r") as f: self.api_data = json.load(f) - self.system_includes = system_includes or [ - "stdint.h", - "stddef.h", - "stdio.h", - "stdlib.h", - "string.h", - ] + if system_includes is None: + self.system_includes = [ + "stdint.h", + "stddef.h", + "stdio.h", + "stdlib.h", + "string.h", + ] + else: + self.system_includes = system_includes self.api_prefixes = api_prefixes or [] # 约束推导初始化 @@ -59,16 +62,27 @@ def _generate_prompt_from_funcs(self, selected_funcs): Please implement the LLVMFuzzerTestOneInput function that uses these APIs. -void LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {{ +int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {{ // Your implementation here + return 0; }}""" return prompt def get_api_signatures(self, num_funcs=5): functions = [] for file_entry in self.api_data: - functions.extend(file_entry["result"]["functions"]) - selected_funcs = random.sample(functions, min(num_funcs, len(functions))) + for func in file_entry["result"]["functions"]: + name = func["name"] + if not self.api_prefixes or any(name.startswith(prefix) for prefix in self.api_prefixes): + functions.append(func) + if not functions: + return [] + try: + safe_num_funcs = int(num_funcs) + except (TypeError, ValueError): + safe_num_funcs = 0 + safe_num_funcs = max(0, safe_num_funcs) + selected_funcs = random.sample(functions, min(safe_num_funcs, len(functions))) return selected_funcs def format_func_signature(self, func): diff --git a/src/dgf_prompt_generator/test_prompt_gen.py b/src/dgf_prompt_generator/test_prompt_gen.py index 779ad31..ab33783 100644 --- a/src/dgf_prompt_generator/test_prompt_gen.py +++ b/src/dgf_prompt_generator/test_prompt_gen.py @@ -41,3 +41,64 @@ def test_prompt_template_filters_prefix_and_generates_signature(tmp_path): prompt = template.generate_prompt_from_api_list(["cJSON_AddObjectToObject"]) assert "#include " in prompt assert "cJSON_AddObjectToObject" in prompt + + +def test_prompt_template_generate_prompt_uses_int_signature_and_prefix_filter(tmp_path): + api_json = tmp_path / "api.json" + api_json.write_text( + json.dumps( + [ + { + "file": "x.h", + "result": { + "functions": [ + { + "name": "cJSON_Parse", + "result_type": "int", + "parameters": [], + }, + { + "name": "OtherFunc", + "result_type": "void", + "parameters": [], + }, + ] + }, + } + ] + ) + ) + + template = PromptTemplate(str(api_json), api_prefixes=["cJSON"]) + prompt = template.generate_prompt(num_funcs=5) + assert "int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size)" in prompt + assert "return 0;" in prompt + assert "cJSON_Parse" in prompt + assert "OtherFunc" not in prompt + + +def test_prompt_template_allows_empty_system_includes_and_negative_num_funcs(tmp_path): + api_json = tmp_path / "api.json" + api_json.write_text( + json.dumps( + [ + { + "file": "x.h", + "result": { + "functions": [ + { + "name": "OnlyFunc", + "result_type": "void", + "parameters": [], + } + ] + }, + } + ] + ) + ) + + template = PromptTemplate(str(api_json), system_includes=[]) + prompt = template.generate_prompt(num_funcs=-5) + assert "#include " not in prompt + assert template.get_api_signatures(num_funcs=-5) == [] diff --git a/src/dgf_validator/fuzzer_runner.py b/src/dgf_validator/fuzzer_runner.py index 746c104..742036c 100644 --- a/src/dgf_validator/fuzzer_runner.py +++ b/src/dgf_validator/fuzzer_runner.py @@ -34,3 +34,6 @@ def run_libfuzzer(self, binary_path, work_dir): except subprocess.CalledProcessError: LOGGER.warning("Fuzzing crash detected for %s", binary_path) return False + except OSError as exc: + LOGGER.warning("Failed to execute fuzzer binary %s: %s", binary_path, exc) + return False diff --git a/src/dgf_validator/validator.py b/src/dgf_validator/validator.py index 029e269..b7cac03 100644 --- a/src/dgf_validator/validator.py +++ b/src/dgf_validator/validator.py @@ -69,6 +69,9 @@ def validate_source(self, src_file, include_dirs=None, max_retry=3): continue else: return False, None + except OSError as e: + LOGGER.warning("Failed to execute compiler %s: %s", self.clang, e) + return False, None return False, None diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 52b37a1..34fb31f 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -9,3 +9,13 @@ def test_extract_c_code_block_from_fenced_markdown(): def test_extract_c_code_block_fallback_to_raw_text(): raw = "int y = 2;" assert extract_c_code_block(raw) == "int y = 2;" + + +def test_extract_c_code_block_does_not_treat_python_fence_as_c(): + raw = "```python\nprint('x')\n```" + assert extract_c_code_block(raw) == raw + + +def test_extract_c_code_block_supports_untagged_fence(): + raw = "prefix\n```\nint z = 3;\n```\nsuffix" + assert extract_c_code_block(raw) == "int z = 3;" diff --git a/tests/test_llm_caller.py b/tests/test_llm_caller.py index 8fb50b5..31d9c7f 100644 --- a/tests/test_llm_caller.py +++ b/tests/test_llm_caller.py @@ -35,3 +35,39 @@ def __init__(self, api_key=None, base_url=None): caller = LLMCaller(model="demo-model") text = caller.generate_code("hi") assert "int a=0;" in text + + +def test_llm_caller_invalid_temperature_falls_back(monkeypatch): + class FakeClient: + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = SimpleNamespace(completions=SimpleNamespace(create=lambda **kwargs: None)) + + monkeypatch.setenv("OPENAI_API_KEY", "dummy") + monkeypatch.setenv("OPENAI_TEMPERATURE", "not-a-number") + monkeypatch.setattr(llm_caller.openai, "OpenAI", FakeClient) + + caller = LLMCaller(model="demo-model") + assert caller.temperature == 0.2 + + +def test_llm_caller_generate_code_raises_on_empty_choices(monkeypatch): + class FakeCompletions: + @staticmethod + def create(**kwargs): + _ = kwargs + return SimpleNamespace(choices=[]) + + class FakeClient: + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = SimpleNamespace(completions=FakeCompletions()) + + monkeypatch.setenv("OPENAI_API_KEY", "dummy") + monkeypatch.setattr(llm_caller.openai, "OpenAI", FakeClient) + + caller = LLMCaller(model="demo-model") + with pytest.raises(ValueError, match="no choices"): + caller.generate_code("hi") diff --git a/tests/test_runtime_resilience.py b/tests/test_runtime_resilience.py new file mode 100644 index 0000000..e8ffd78 --- /dev/null +++ b/tests/test_runtime_resilience.py @@ -0,0 +1,101 @@ +import subprocess +from types import SimpleNamespace + +import pytest + +import dgf_header_parser.ast_parser as ast_parser_module +import dgf_header_parser.extractor as extractor_module +from dgf_feedback.branch_coverage_collector import BranchCoverageCollector +from dgf_header_parser.ast_parser import ASTParser +from dgf_validator.fuzzer_runner import FuzzerRunner +from dgf_validator.validator import Validator + + +def test_branch_coverage_collector_handles_missing_profdata_binary(monkeypatch, tmp_path): + work_dir = tmp_path + (work_dir / "default.profraw").write_text("profile") + collector = BranchCoverageCollector(profdata_path="missing-profdata", cov_path="llvm-cov") + + def fake_run(cmd, **kwargs): + _ = kwargs + if cmd[0] == "missing-profdata": + raise FileNotFoundError("not found") + return SimpleNamespace(returncode=0, stdout="{}", stderr="") + + monkeypatch.setattr(subprocess, "run", fake_run) + func_cov, overall = collector.collect_branch_coverage("fake-bin", str(work_dir)) + assert func_cov == {} + assert overall == 0.0 + + +def test_validator_returns_false_when_compiler_not_found(monkeypatch, tmp_path): + src_file = tmp_path / "driver.c" + src_file.write_text("int LLVMFuzzerTestOneInput(const unsigned char*d, unsigned long s){return 0;}") + + def fake_run(*args, **kwargs): + _ = args, kwargs + raise FileNotFoundError("clang not found") + + monkeypatch.setattr(subprocess, "run", fake_run) + validator = Validator(clang_path="missing-clang", work_dir=str(tmp_path / "validated")) + assert validator.validate_source(str(src_file)) == (False, None) + + +def test_fuzzer_runner_handles_missing_binary(monkeypatch, tmp_path): + def fake_run(*args, **kwargs): + _ = args, kwargs + raise FileNotFoundError("missing binary") + + monkeypatch.setattr(subprocess, "run", fake_run) + runner = FuzzerRunner(timeout_sec=1) + assert runner.run_libfuzzer(str(tmp_path / "nope"), str(tmp_path)) is False + + +def test_ast_parser_extract_only_keeps_nodes_from_target_header(tmp_path): + header = tmp_path / "a.h" + other = tmp_path / "b.h" + header.write_text("int a(void);") + other.write_text("int b(void);") + + parser = ASTParser() + def fake_extract_function(node): + return {"name": node.spelling} + + parser.extract_function = fake_extract_function # type: ignore[method-assign] + + function_decl_kind = ast_parser_module.cindex.CursorKind.FUNCTION_DECL + + class FakeNode: + def __init__(self, spelling, file_path): + self.kind = function_decl_kind + self.spelling = spelling + self.location = SimpleNamespace(file=SimpleNamespace(name=str(file_path))) + + tu = SimpleNamespace( + spelling=str(header), + cursor=SimpleNamespace( + get_children=lambda: [ + FakeNode("in_header", header), + FakeNode("from_other_file", other), + ] + ), + ) + + result = parser.extract(tu, source_file=str(header)) + assert [func["name"] for func in result["functions"]] == ["in_header"] + + +def test_extract_all_api_raises_when_every_header_fails(monkeypatch): + class BrokenParser: + def __init__(self, include_dirs=None): + _ = include_dirs + + def parse(self, header_file): + _ = header_file + raise RuntimeError("parse failure") + + monkeypatch.setattr(extractor_module, "collect_header_files", lambda _: ["a.h", "b.h"]) + monkeypatch.setattr(extractor_module, "ASTParser", BrokenParser) + + with pytest.raises(RuntimeError, match="Failed to parse any header file"): + extractor_module.extract_all_api("/tmp/headers", [])