From 14449003077a500dae0cb82e94311f538c030506 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 6 Mar 2026 15:03:35 +0000 Subject: [PATCH] =?UTF-8?q?=E5=85=A8=E9=9D=A2=E5=B7=A5=E7=A8=8B=E5=8C=96?= =?UTF-8?q?=E6=94=B9=E9=80=A0=EF=BC=9A=E9=85=8D=E7=BD=AE=E8=A7=A3=E8=80=A6?= =?UTF-8?q?=E3=80=81=E6=B5=8B=E8=AF=95=E4=B8=8ECI=E3=80=81=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E4=B8=8E=E8=B4=A8=E9=87=8F=E5=9F=BA=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: BlueOrbit --- .github/workflows/ci.yml | 32 ++++ .gitignore | 4 +- README.md | 172 ++++++++---------- pyproject.toml | 23 +++ requirements-dev.txt | 5 + requirements.txt | 4 + src/config/experiment.yaml | 27 ++- src/dgf_common/__init__.py | 1 + src/dgf_common/code_utils.py | 17 ++ src/dgf_common/logging_utils.py | 12 ++ src/dgf_feedback/api_manager.py | 12 +- src/dgf_feedback/branch_coverage_collector.py | 38 +++- src/dgf_feedback/feedback_controller.py | 62 +++++-- src/dgf_feedback/prompt_mutator.py | 1 + src/dgf_feedback/test_feedback.py | 46 ++++- src/dgf_header_parser/ast_parser.py | 14 +- .../constraint_inferencer.py | 2 - src/dgf_header_parser/extractor.py | 16 +- src/dgf_header_parser/header_scanner.py | 1 + src/dgf_pipeline/run_pipeline.py | 17 +- src/dgf_prompt_generator/class_chain.py | 11 +- src/dgf_prompt_generator/config.example.py | 13 ++ src/dgf_prompt_generator/generator.py | 44 +++-- src/dgf_prompt_generator/llm_caller.py | 51 +++++- src/dgf_prompt_generator/prompt_template.py | 29 +-- src/dgf_prompt_generator/test_prompt_gen.py | 59 +++--- src/dgf_validator/fuzzer_runner.py | 13 +- src/dgf_validator/test_validator.py | 49 +++-- src/dgf_validator/validator.py | 61 +++++-- src/gen_fuzz_driver.py | 40 ++-- src/main.py | 53 ++++-- src/test.py | 1 - src/test_cjson.sh | 11 +- tests/test_code_utils.py | 11 ++ tests/test_llm_caller.py | 37 ++++ 35 files changed, 701 insertions(+), 288 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 pyproject.toml create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100644 src/dgf_common/__init__.py create mode 100644 src/dgf_common/code_utils.py create mode 100644 src/dgf_common/logging_utils.py create mode 100644 src/dgf_prompt_generator/config.example.py delete mode 100644 src/test.py create mode 100644 tests/test_code_utils.py create mode 100644 tests/test_llm_caller.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..45187d6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + branches: ["**"] + pull_request: + +jobs: + quality: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + + - name: Ruff check + run: ruff check src + + - name: Mypy check + run: mypy src + + - name: Pytest + run: pytest diff --git a/.gitignore b/.gitignore index fccd289..3d69015 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,9 @@ build/ # Test results *.cover coverage.xml -*.pytest_cache/ +.pytest_cache/ +.ruff_cache/ +.coverage # Jupyter Notebook .ipynb_checkpoints diff --git a/README.md b/README.md index 417d5be..6a33c36 100644 --- a/README.md +++ b/README.md @@ -1,126 +1,112 @@ -# DGF: 基于Prompt的Fuzz Driver自生成系统 +# DGF: 基于 Prompt 的 Fuzz Driver 自动生成系统 -> 本项目是 PromptFuzz 论文《Prompt Fuzzing for Fuzz Driver Generation》 (CCS 2024) 的全面复现实现,并扩展了调用链分析技术,用于增强 LLM 生成合理 API 调用系列的能力。 +本项目实现了从头文件 API 抽取、Prompt 构造、LLM 代码生成、编译验证、fuzz 执行到覆盖率反馈迭代的完整流程。 ---- +## 1. 模块流程 -## 一、项目概述 - -DGF 自动生成高质量的 fuzz driver,进行应用程序库的默黑模糊测试,根据覆盖率反馈和程序验证进行迭代优化,复现 PromptFuzz 论文中的核心技术思路。 - ---- - -## 二、系统模块构成 - -``` -头文件 --> 头文解析 (Header Parser) --> API 签名 - | - v - 调用链分析 (Call Chain Analysis) - | - v -Prompt 生成 (Prompt Generator) --> LLM 生成代码 - | - v -程序验证 (Validator) - | - v -覆盖率收集 (Coverage Collector) - | - v -Prompt 变异 (Prompt Mutation) <--- 反馈控制 (Feedback Controller) +```text +Header Parser -> Prompt Generator -> LLM Code -> Validator -> Fuzzer -> Coverage -> Feedback ``` ---- - -## 三、目录结构 - -``` -DGF-main/ -| -├— src/ -| ├— main.py # 总控制入口 -| ├— config/ # 配置文件 -| ├— dgf_header_parser/ # 头文解析和API提取 -| ├— dgf_prompt_generator/ # Prompt生成和LLM调用 -| ├— dgf_validator/ # 程序验证模块 -| ├— dgf_feedback/ # 反馈控制与覆盖率收集 -| └— dgf_pipeline/ # 完整流水线执行控制 -| -└— README.md +核心入口:`src/main.py` + +## 2. 目录结构 + +```text +. +├── src/ +│ ├── main.py +│ ├── config/experiment.yaml +│ ├── dgf_header_parser/ +│ ├── dgf_prompt_generator/ +│ ├── dgf_validator/ +│ ├── dgf_feedback/ +│ ├── dgf_pipeline/ +│ └── dgf_common/ +├── tests/ +├── requirements.txt +├── requirements-dev.txt +└── pyproject.toml ``` ---- +## 3. 环境要求 -## 四、快速使用 +- Python 3.9+ +- clang/llvm(建议 14+) +- 支持 libFuzzer 的编译环境 -### 1.环境供与 +可选环境变量: -- Python 3.8+ -- clang, llvm, lcov, cmake -- 支持libFuzzer的编译环境 -- 安装Python依赖: +- `OPENAI_API_KEY`(必需,除非使用本地 `src/dgf_prompt_generator/config.py`) +- `OPENAI_BASE_URL`(可选) +- `OPENAI_MODEL`(默认 `gpt-4.1-mini`) +- `OPENAI_TEMPERATURE`(默认 `0.2`) +- `LIBCLANG_PATH`(可选,如 `/usr/lib/llvm-14/lib/libclang.so.1`) + +## 4. 安装依赖 ```bash -python -m venv venv -source venv/bin/activate -pip install -r src/dgf_prompt_generator/requirements.txt -pip install -r src/dgf_validator/requirements.txt -pip install -r src/dgf_feedback/requirements.txt +python -m venv .venv +source .venv/bin/activate +pip install -r requirements-dev.txt ``` -### 2.目标库准备 +## 5. 目标库准备(以 cJSON 为例) -将测试库的源码和头文件放入指定路径,如: +将目标库放到: -``` -testdata/cJSON/ +```text +testdata/cJSON ``` -### 3.运行全流程 +并确保其可被 clang include/link(`src/config/experiment.yaml` 已给出默认路径模板)。 -```bash -cd src/ -python main.py --config config/experiment.yaml -``` - -### 4.运行结果 +## 6. 运行方式 -- 生成种子seed程序 -- 生成fuzz driver并执行libFuzzer测试 -- 生成覆盖率和bug报告 +### 6.1 运行完整流程 ---- +```bash +PYTHONPATH=src python src/main.py --config src/config/experiment.yaml +``` -## 五、配置文件 +### 6.2 单独运行 feedback pipeline -根本配置文件位于 `config/experiment.yaml`,具体包括: +```bash +PYTHONPATH=src python src/dgf_pipeline/run_pipeline.py \ + --api_json data/extracted_api.json \ + --output_dir data/feedback_results \ + --samples 5 \ + --clang_path clang \ + --include_dirs testdata/cJSON \ + --lib_dir testdata/cJSON/build \ + --libs cjson cjson_utils +``` -- `library_path`:库源码路径 -- `header_path`:头文件路径 -- `clang_bin`:clang编译器路径 -- `llm_provider`:设置LLM接口和API密钥 -- `mutation_params`:Prompt变异策略参数 +## 7. 配置说明 ---- +主配置:`src/config/experiment.yaml` -## 六、项目特性 +- `api_extraction`:头文件扫描路径、include 路径、抽取 JSON 输出位置 +- `prompt_generation`:seed driver 数量、每个 driver 的 API 数、include 模板、API 前缀过滤 +- `feedback_iteration`:每轮样本数与 fuzz 超时 +- `validator`:clang 路径、include 路径、库目录与库名 -- 完全复现 PromptFuzz 核心设计 -- 基于覆盖率的 Prompt 变异和能量调度 -- 多阶验证(编译+sanitizer+fuzzing) -- 集成 AFLFast 风格的 API energy scheduling -- 增强 **调用链分析** (扩展部分) -- 支持可复现性实验 +## 8. 开发与质量检查 ---- +```bash +ruff check src +mypy src +pytest +``` -## 七、参考文献 +仓库已包含 GitHub Actions 工作流(`.github/workflows/ci.yml`)用于自动执行上述检查。 -- PromptFuzz: Prompt Fuzzing for Fuzz Driver Generation -- CCS 2024, Yunlong Lyu et al. -- 本实现在此基础上扩展了静态程序分析分支,增强了生成合理性 +## 9. 本地 LLM 配置(可选) ---- +如不想依赖环境变量,可复制: +```text +src/dgf_prompt_generator/config.example.py -> src/dgf_prompt_generator/config.py +``` +并填写 API 配置。`config.py` 默认不提交到仓库。 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5742932 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[tool.pytest.ini_options] +pythonpath = ["src"] +testpaths = ["src", "tests"] +addopts = "-q" + +[tool.ruff] +line-length = 100 +target-version = "py39" +src = ["src"] + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.9" +mypy_path = "src" +namespace_packages = true +explicit_package_bases = true +ignore_missing_imports = true +check_untyped_defs = false +warn_unused_ignores = true +no_implicit_optional = false diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..8b13e69 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +-r requirements.txt +pytest +ruff +mypy +types-PyYAML diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fc7d749 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +openai +tqdm +PyYAML +clang diff --git a/src/config/experiment.yaml b/src/config/experiment.yaml index 7a40b40..c0054a9 100644 --- a/src/config/experiment.yaml +++ b/src/config/experiment.yaml @@ -1,23 +1,34 @@ api_extraction: - header_dir: /home/lanjiachen/DGF/testdata/cJSON + header_dir: testdata/cJSON include_dirs: - - /home/lanjiachen/DGF/testdata/cJSON - extracted_api_json: /home/lanjiachen/DGF/src/data/extracted_api.json + - testdata/cJSON + extracted_api_json: data/extracted_api.json prompt_generation: - output_dir: /home/lanjiachen/DGF/src/data/seed_prompts + output_dir: data/seed_prompts samples: 2 num_funcs: 5 + system_includes: + - stdint.h + - stddef.h + - stdio.h + - stdlib.h + - string.h + - cJSON.h + - cJSON_Utils.h + api_prefixes: + - cJSON feedback_iteration: - output_dir: /home/lanjiachen/DGF/src/data/feedback_results + output_dir: data/feedback_results samples_per_round: 2 + fuzz_timeout_sec: 20 validator: - clang_path: clang-14 + clang_path: clang include_dirs: - - /home/lanjiachen/DGF/testdata/cJSON - lib_dir: /home/lanjiachen/DGF/testdata/cJSON/build + - testdata/cJSON + lib_dir: testdata/cJSON/build libs: - cjson - cjson_utils diff --git a/src/dgf_common/__init__.py b/src/dgf_common/__init__.py new file mode 100644 index 0000000..741e0f8 --- /dev/null +++ b/src/dgf_common/__init__.py @@ -0,0 +1 @@ +# Common shared helpers for DGF modules. diff --git a/src/dgf_common/code_utils.py b/src/dgf_common/code_utils.py new file mode 100644 index 0000000..d9bbc95 --- /dev/null +++ b/src/dgf_common/code_utils.py @@ -0,0 +1,17 @@ +import re + +_FENCED_CODE_PATTERN = re.compile(r"```(?:c|C|cpp|c\+\+)?\s*(.*?)```", re.DOTALL) + + +def extract_c_code_block(raw_text): + """ + Extract C/C++ code from markdown fenced block. + If no fenced block is present, return stripped raw text. + """ + if raw_text is None: + return "" + + match = _FENCED_CODE_PATTERN.search(raw_text) + if match: + return match.group(1).strip() + return raw_text.strip() diff --git a/src/dgf_common/logging_utils.py b/src/dgf_common/logging_utils.py new file mode 100644 index 0000000..b735bdc --- /dev/null +++ b/src/dgf_common/logging_utils.py @@ -0,0 +1,12 @@ +import logging +import os + + +def configure_logging(default_level="INFO"): + level_name = os.getenv("DGF_LOG_LEVEL", default_level).upper() + level = getattr(logging, level_name, logging.INFO) + + logging.basicConfig( + level=level, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + ) diff --git a/src/dgf_feedback/api_manager.py b/src/dgf_feedback/api_manager.py index 7550e5e..a01d183 100644 --- a/src/dgf_feedback/api_manager.py +++ b/src/dgf_feedback/api_manager.py @@ -1,8 +1,11 @@ # src/dgf_feedback/api_manager.py +import logging import random from collections import defaultdict +LOGGER = logging.getLogger(__name__) + class APIManager: def __init__(self, api_list, exponent=1.0): self.api_list = api_list # list of api function names @@ -40,4 +43,11 @@ def sample_api_combination(self, num_funcs): def print_state(self): for api in self.api_list: - print(f"{api}: cov={self.coverage[api]:.2f}, seed={self.seed_count[api]}, prompt={self.prompt_count[api]}, energy={self.get_energy(api):.4f}") + LOGGER.info( + "%s: cov=%.2f, seed=%d, prompt=%d, energy=%.4f", + api, + self.coverage[api], + self.seed_count[api], + self.prompt_count[api], + self.get_energy(api), + ) diff --git a/src/dgf_feedback/branch_coverage_collector.py b/src/dgf_feedback/branch_coverage_collector.py index 0ec9903..1c1fc8c 100644 --- a/src/dgf_feedback/branch_coverage_collector.py +++ b/src/dgf_feedback/branch_coverage_collector.py @@ -1,8 +1,11 @@ # src/dgf_feedback/branch_coverage_collector.py -import subprocess -import os import json +import logging +import os +import subprocess + +LOGGER = logging.getLogger(__name__) class BranchCoverageCollector: def __init__(self, profdata_path="llvm-profdata", cov_path="llvm-cov"): @@ -14,10 +17,20 @@ def collect_branch_coverage(self, binary_path, work_dir): profdata_out = os.path.join(work_dir, "default.profdata") if not os.path.exists(profraw): - print(f"Warning: Coverage profile file not found. Likely crash occurred.") + LOGGER.warning("Coverage profile file not found at %s", profraw) return {}, 0.0 - subprocess.run([self.profdata, "merge", "-sparse", profraw, "-o", profdata_out], check=True) + try: + subprocess.run( + [self.profdata, "merge", "-sparse", profraw, "-o", profdata_out], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except subprocess.CalledProcessError as exc: + LOGGER.warning("Failed to merge profile data: %s", exc.stderr) + return {}, 0.0 export_cmd = [ self.cov, "export", @@ -25,8 +38,21 @@ def collect_branch_coverage(self, binary_path, work_dir): binary_path, "--format=json" ] - result = subprocess.run(export_cmd, stdout=subprocess.PIPE, check=True) - output = json.loads(result.stdout.decode()) + try: + result = subprocess.run( + export_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + text=True, + ) + output = json.loads(result.stdout) + except subprocess.CalledProcessError as exc: + LOGGER.warning("Failed to export coverage json: %s", exc.stderr) + return {}, 0.0 + except json.JSONDecodeError: + LOGGER.warning("Invalid coverage JSON output from llvm-cov") + return {}, 0.0 func_coverage = {} diff --git a/src/dgf_feedback/feedback_controller.py b/src/dgf_feedback/feedback_controller.py index bdea4ef..f97c0e8 100644 --- a/src/dgf_feedback/feedback_controller.py +++ b/src/dgf_feedback/feedback_controller.py @@ -1,19 +1,45 @@ +import logging import os -from dgf_prompt_generator.prompt_template import PromptTemplate -from dgf_prompt_generator.llm_caller import LLMCaller -from dgf_validator.validator import Validator -from dgf_validator.fuzzer_runner import FuzzerRunner + +from dgf_common.code_utils import extract_c_code_block from dgf_feedback.api_manager import APIManager -from dgf_feedback.prompt_mutator import PromptMutator from dgf_feedback.branch_coverage_collector import BranchCoverageCollector +from dgf_feedback.prompt_mutator import PromptMutator from dgf_feedback.sample_filter import SampleFilter +from dgf_prompt_generator.llm_caller import LLMCaller +from dgf_prompt_generator.prompt_template import PromptTemplate +from dgf_validator.fuzzer_runner import FuzzerRunner +from dgf_validator.validator import Validator + +LOGGER = logging.getLogger(__name__) class FeedbackController: - def __init__(self, api_json, output_dir, clang_path="clang-14", include_dirs=[], lib_dir=None, libs=[]): - self.prompt_template = PromptTemplate(api_json) + def __init__( + self, + api_json, + output_dir, + clang_path="clang", + include_dirs=None, + lib_dir=None, + libs=None, + system_includes=None, + api_prefixes=None, + fuzz_timeout_sec=20, + ): + self.include_dirs = include_dirs or [] + self.prompt_template = PromptTemplate( + api_json, + system_includes=system_includes or [], + api_prefixes=api_prefixes or [], + ) self.llm = LLMCaller() - self.validator = Validator(clang_path="clang-14") - self.fuzzer = FuzzerRunner(timeout_sec=20) + self.validator = Validator( + clang_path=clang_path, + work_dir=os.path.join(output_dir, "validated"), + lib_dir=lib_dir, + libs=libs or [], + ) + self.fuzzer = FuzzerRunner(timeout_sec=fuzz_timeout_sec) self.cov_collector = BranchCoverageCollector() self.sample_filter = SampleFilter(min_api_coverage=0.2, min_success_ratio=0.5) self.output_dir = output_dir @@ -40,42 +66,38 @@ def run_iteration(self, num_samples=5, base_num_funcs=5): prompt = self.prompt_template.generate_prompt_from_api_list(mutated_apis) code = self.llm.generate_code(prompt) - - # 提取代码块 - start = code.find("```c") - end = code.find("```", start + 4) - if start != -1 and end != -1: - code = code[start + 4:end].strip() - else: - code = code.strip() + code = extract_c_code_block(code) src_path = os.path.join(self.output_dir, f"fuzz_driver_{i}.c") with open(src_path, "w") as f: f.write(code) - success, binary = self.validator.validate_source(src_path, include_dirs=["../testdata/cJSON"]) + success, binary = self.validator.validate_source(src_path, include_dirs=self.include_dirs) if not success: + LOGGER.warning("Validation failed for %s", src_path) continue work_dir = os.path.dirname(binary) os.environ["LLVM_PROFILE_FILE"] = os.path.join(work_dir, "default.profraw") if not self.fuzzer.run_libfuzzer(binary, work_dir): + LOGGER.warning("Fuzzer run failed for %s", binary) continue # 分支覆盖收集 func_cov_result, overall_coverage = self.cov_collector.collect_branch_coverage(binary, work_dir) - print(f"Overall branch coverage for this driver: {overall_coverage:.2%}") + LOGGER.info("Driver %s overall branch coverage: %.2f%%", binary, overall_coverage * 100) # 写入到文件 cov_file_path = os.path.join(self.output_dir, f"coverage_{i}.txt") with open(cov_file_path, "w") as cov_file: for api, coverage in func_cov_result.items(): cov_file.write(f"{api}: {coverage:.2%}\n") cov_file.write(f"Overall coverage: {overall_coverage:.2%}\n") - + # 使用 API级别 SampleFilter if not self.sample_filter.filter_sample(mutated_apis, func_cov_result): + LOGGER.info("Sample %s filtered by coverage rule", src_path) continue # 更新 API 反馈 diff --git a/src/dgf_feedback/prompt_mutator.py b/src/dgf_feedback/prompt_mutator.py index 2306e18..4f566f0 100644 --- a/src/dgf_feedback/prompt_mutator.py +++ b/src/dgf_feedback/prompt_mutator.py @@ -2,6 +2,7 @@ import random + class PromptMutator: def __init__(self, api_manager): self.api_manager = api_manager diff --git a/src/dgf_feedback/test_feedback.py b/src/dgf_feedback/test_feedback.py index 8403d02..91e02e4 100644 --- a/src/dgf_feedback/test_feedback.py +++ b/src/dgf_feedback/test_feedback.py @@ -1,7 +1,41 @@ -import sys -import os -from dgf_feedback.feedback_controller import FeedbackController +from dgf_feedback.api_manager import APIManager +from dgf_feedback.prompt_mutator import PromptMutator +from dgf_feedback.sample_filter import SampleFilter -fc = FeedbackController(api_json="../data/cjson_extracted.json", output_dir="../data/fuzz_output_round1") -successful = fc.run_iteration(num_samples=10) -print(f"Successful samples: {len(successful)}") + +def test_sample_filter_thresholds(): + sample_filter = SampleFilter(min_api_coverage=0.5, min_success_ratio=0.5) + mutated = ["a", "b", "c", "d"] + coverage = {"a": 0.6, "b": 0.1, "c": 0.8, "d": 0.2} + assert sample_filter.filter_sample(mutated, coverage) is True + + +def test_api_manager_energy_and_sampling(): + manager = APIManager(["A", "B", "C"], exponent=1.0) + manager.update_seed("A") + manager.update_prompt("A") + manager.update_coverage("A", 0.8) + + energy_a = manager.get_energy("A") + energy_b = manager.get_energy("B") + assert energy_b > energy_a + + selected = manager.sample_api_combination(2) + assert 1 <= len(selected) <= 2 + assert set(selected).issubset({"A", "B", "C"}) + + +def test_prompt_mutator_insert_replace_crossover(): + manager = APIManager(["A", "B", "C", "D"]) + mutator = PromptMutator(manager) + + inserted = mutator.insert(["A"], num_insert=2) + assert "A" in inserted + assert len(set(inserted)) >= 1 + + replaced = mutator.replace(["A", "B"], num_replace=1) + assert len(replaced) >= 1 + assert set(replaced).issubset({"A", "B", "C", "D"}) + + crossed = mutator.crossover(["A", "B"], ["B", "C"]) + assert set(crossed) == {"A", "B", "C"} diff --git a/src/dgf_header_parser/ast_parser.py b/src/dgf_header_parser/ast_parser.py index d1006bd..4f000a8 100644 --- a/src/dgf_header_parser/ast_parser.py +++ b/src/dgf_header_parser/ast_parser.py @@ -1,18 +1,16 @@ -# import libclang -# from libclang import cindex -import json import os -import clang.cindex as cindex +import clang.cindex as cindex -# 配置 libclang 路径 -cindex.Config.set_library_file("/usr/lib/llvm-14/lib/libclang.so.1") +_LIBCLANG_PATH = os.getenv("LIBCLANG_PATH") +if _LIBCLANG_PATH and not cindex.Config.loaded: + cindex.Config.set_library_file(_LIBCLANG_PATH) class ASTParser: - def __init__(self, include_dirs=[]): - self.include_dirs = include_dirs + def __init__(self, include_dirs=None): + self.include_dirs = include_dirs or [] def parse(self, header_file): index = cindex.Index.create() diff --git a/src/dgf_header_parser/constraint_inferencer.py b/src/dgf_header_parser/constraint_inferencer.py index 1fe405e..fb26eca 100644 --- a/src/dgf_header_parser/constraint_inferencer.py +++ b/src/dgf_header_parser/constraint_inferencer.py @@ -1,7 +1,5 @@ # src/dgf_header_parser/constraint_inferencer.py -import re - class ConstraintInferencer: def __init__(self, extracted_api_json): self.api_data = extracted_api_json diff --git a/src/dgf_header_parser/extractor.py b/src/dgf_header_parser/extractor.py index f5b6263..665c889 100644 --- a/src/dgf_header_parser/extractor.py +++ b/src/dgf_header_parser/extractor.py @@ -1,13 +1,18 @@ import argparse import json -from dgf_header_parser.header_scanner import collect_header_files -from dgf_header_parser.ast_parser import ASTParser +import logging + from tqdm import tqdm +from dgf_header_parser.ast_parser import ASTParser +from dgf_header_parser.header_scanner import collect_header_files + +LOGGER = logging.getLogger(__name__) + def extract_all_api(header_dir, include_dirs): headers = collect_header_files(header_dir) parser = ASTParser(include_dirs) - + all_results = [] for h in tqdm(headers, desc="Parsing Headers"): try: @@ -18,11 +23,12 @@ def extract_all_api(header_dir, include_dirs): "result": result }) except Exception as e: - print(f"Error parsing {h}: {e}") + LOGGER.warning("Error parsing %s: %s", h, e) return all_results if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") parser = argparse.ArgumentParser() parser.add_argument("--header_dir", required=True, help="Path to library header files") parser.add_argument("--include_dirs", nargs='*', default=[], help="Additional include directories") @@ -34,4 +40,4 @@ def extract_all_api(header_dir, include_dirs): with open(args.output, "w") as f: json.dump(results, f, indent=2) - print(f"Extraction completed, output saved to {args.output}") + LOGGER.info("Extraction completed, output saved to %s", args.output) diff --git a/src/dgf_header_parser/header_scanner.py b/src/dgf_header_parser/header_scanner.py index 14fac57..a9aaea8 100644 --- a/src/dgf_header_parser/header_scanner.py +++ b/src/dgf_header_parser/header_scanner.py @@ -1,5 +1,6 @@ import os + def collect_header_files(root_dir): header_files = [] for dirpath, _, filenames in os.walk(root_dir): diff --git a/src/dgf_pipeline/run_pipeline.py b/src/dgf_pipeline/run_pipeline.py index 5b984ef..482053b 100644 --- a/src/dgf_pipeline/run_pipeline.py +++ b/src/dgf_pipeline/run_pipeline.py @@ -1,15 +1,30 @@ import argparse + +from dgf_common.logging_utils import configure_logging + # from dgf_header_parser.extractor import Extractor # from dgf_prompt_generator.generator import PromptGenerator from dgf_feedback.feedback_controller import FeedbackController if __name__ == "__main__": + configure_logging() parser = argparse.ArgumentParser() parser.add_argument('--api_json', type=str, required=True) parser.add_argument('--output_dir', type=str, required=True) parser.add_argument('--samples', type=int, default=10) + parser.add_argument("--clang_path", type=str, default="clang") + parser.add_argument("--include_dirs", nargs="*", default=[]) + parser.add_argument("--lib_dir", type=str, default=None) + parser.add_argument("--libs", nargs="*", default=[]) args = parser.parse_args() # 直接调用 - fc = FeedbackController(api_json=args.api_json, output_dir=args.output_dir) + fc = FeedbackController( + api_json=args.api_json, + output_dir=args.output_dir, + clang_path=args.clang_path, + include_dirs=args.include_dirs, + lib_dir=args.lib_dir, + libs=args.libs, + ) fc.run_iteration(num_samples=args.samples) diff --git a/src/dgf_prompt_generator/class_chain.py b/src/dgf_prompt_generator/class_chain.py index df1fc78..89d6536 100644 --- a/src/dgf_prompt_generator/class_chain.py +++ b/src/dgf_prompt_generator/class_chain.py @@ -1,12 +1,19 @@ +import os import re + class CallChainAnalyzer: - def __init__(self): - self.input_file = '/home/lanjiachen/DGF/src/data/reverse_callgraph.txt' + def __init__(self, input_file=None): + self.input_file = input_file or os.getenv( + "DGF_REVERSE_CALLGRAPH_PATH", + "data/reverse_callgraph.txt", + ) self.lines = self._read_and_clean_lines() def _read_and_clean_lines(self): lines = [] + if not os.path.exists(self.input_file): + return lines with open(self.input_file, 'r') as f: for line in f: i = 0 diff --git a/src/dgf_prompt_generator/config.example.py b/src/dgf_prompt_generator/config.example.py new file mode 100644 index 0000000..8399171 --- /dev/null +++ b/src/dgf_prompt_generator/config.example.py @@ -0,0 +1,13 @@ +""" +Optional local configuration for LLM caller. +Prefer environment variables in production: + - OPENAI_API_KEY + - OPENAI_BASE_URL (optional) + - OPENAI_MODEL (optional) + - OPENAI_TEMPERATURE (optional) +""" + +API_KEY = "replace-with-your-api-key" +API_BASE_URL = "https://api.openai.com/v1" +MODEL_NAME = "gpt-4.1-mini" +TEMPERATURE = 0.2 diff --git a/src/dgf_prompt_generator/generator.py b/src/dgf_prompt_generator/generator.py index b791644..d988c92 100644 --- a/src/dgf_prompt_generator/generator.py +++ b/src/dgf_prompt_generator/generator.py @@ -1,11 +1,22 @@ import argparse -from dgf_prompt_generator.prompt_template import PromptTemplate -from dgf_prompt_generator.llm_caller import LLMCaller -from tqdm import tqdm +import logging import os +from tqdm import tqdm + +from dgf_common.code_utils import extract_c_code_block +from dgf_common.logging_utils import configure_logging +from dgf_prompt_generator.llm_caller import LLMCaller +from dgf_prompt_generator.prompt_template import PromptTemplate + +LOGGER = logging.getLogger(__name__) + def main(args): - prompt_gen = PromptTemplate(args.api_json) + prompt_gen = PromptTemplate( + args.api_json, + system_includes=args.system_includes, + api_prefixes=args.api_prefixes, + ) llm = LLMCaller() os.makedirs(args.output_dir, exist_ok=True) @@ -13,27 +24,32 @@ def main(args): for i in tqdm(range(args.samples)): prompt = prompt_gen.generate_prompt(num_funcs=args.num_funcs) code = llm.generate_code(prompt) - - # 找到code中的代码块 - # 假设代码块以 ```c 开始,并以 ``` 结束 - start = code.find("```c") - end = code.find("```", start + 4) - if start != -1 and end != -1: - code = code[start + 4:end].strip() - else: - print(f"Warning: No valid code block found in generated code for sample {i}. Using full code.") - code = code.strip() + code = extract_c_code_block(code) output_path = f"{args.output_dir}/fuzz_driver_{i}.c" with open(output_path, "w") as f: f.write(code) + LOGGER.info("Generated %s", output_path) if __name__ == "__main__": + configure_logging() parser = argparse.ArgumentParser() parser.add_argument("--api_json", required=True, help="Extracted API JSON") parser.add_argument("--output_dir", required=True, help="Directory to save generated fuzz drivers") parser.add_argument("--samples", type=int, default=5, help="Number of fuzz drivers to generate") parser.add_argument("--num_funcs", type=int, default=5, help="Number of APIs to include per driver") + parser.add_argument( + "--system_includes", + nargs="*", + default=[], + help="Header includes to place into prompt (e.g. cJSON.h)", + ) + parser.add_argument( + "--api_prefixes", + nargs="*", + default=[], + help="Only keep APIs with selected prefixes", + ) args = parser.parse_args() main(args) diff --git a/src/dgf_prompt_generator/llm_caller.py b/src/dgf_prompt_generator/llm_caller.py index 9b6b564..1769e7f 100644 --- a/src/dgf_prompt_generator/llm_caller.py +++ b/src/dgf_prompt_generator/llm_caller.py @@ -1,14 +1,52 @@ +import logging +import os +from importlib import import_module +from typing import Any + import openai -from dgf_prompt_generator import config + +LOGGER = logging.getLogger(__name__) + +local_config: Any = None +try: + # Backward-compatible local private config. + local_config = import_module("dgf_prompt_generator.config") +except Exception: # pragma: no cover - optional local module + pass class LLMCaller: - def __init__(self): + def __init__(self, api_key=None, base_url=None, model=None, temperature=None): + final_api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("DGF_API_KEY") + if not final_api_key and local_config is not None: + final_api_key = getattr(local_config, "API_KEY", None) + + if not final_api_key: + raise ValueError( + "Missing OpenAI API key. Set OPENAI_API_KEY (or DGF_API_KEY), " + "or provide dgf_prompt_generator/config.py." + ) + + final_base_url = base_url or os.getenv("OPENAI_BASE_URL") or os.getenv("DGF_API_BASE_URL") + if not final_base_url and local_config is not None: + final_base_url = getattr(local_config, "API_BASE_URL", None) + + final_model = model or os.getenv("OPENAI_MODEL", "gpt-4.1-mini") + if (not model) and local_config is not None: + final_model = getattr(local_config, "MODEL_NAME", final_model) + + final_temperature = temperature + if final_temperature is None: + final_temperature = os.getenv("OPENAI_TEMPERATURE", "0.2") + if local_config is not None and temperature is None: + final_temperature = getattr(local_config, "TEMPERATURE", final_temperature) + final_temperature = float(final_temperature) + self.client = openai.OpenAI( - api_key=config.API_KEY, - base_url=config.API_BASE_URL + api_key=final_api_key, + base_url=final_base_url ) - self.model = config.MODEL_NAME - self.temperature = config.TEMPERATURE + self.model = final_model + self.temperature = final_temperature def generate_code(self, prompt): response = self.client.chat.completions.create( @@ -19,4 +57,5 @@ def generate_code(self, prompt): ], temperature=self.temperature ) + LOGGER.debug("LLM generation finished using model=%s", self.model) return response.choices[0].message.content diff --git a/src/dgf_prompt_generator/prompt_template.py b/src/dgf_prompt_generator/prompt_template.py index 2d1a676..31c467a 100644 --- a/src/dgf_prompt_generator/prompt_template.py +++ b/src/dgf_prompt_generator/prompt_template.py @@ -1,15 +1,22 @@ import json import random + from dgf_header_parser.constraint_inferencer import ConstraintInferencer + class PromptTemplate: - def __init__(self, api_info_json): - # 加载API信息 - self.api_data = json.load(open(api_info_json)) - self.system_includes = [ - "stdint.h", "stddef.h", "stdio.h", - "stdlib.h", "string.h", "cJSON.h", "cJSON_Utils.h" + def __init__(self, api_info_json, system_includes=None, api_prefixes=None): + with open(api_info_json, "r") as f: + self.api_data = json.load(f) + + self.system_includes = system_includes or [ + "stdint.h", + "stddef.h", + "stdio.h", + "stdlib.h", + "string.h", ] + self.api_prefixes = api_prefixes or [] # 约束推导初始化 inferencer = ConstraintInferencer(self.api_data) @@ -18,11 +25,11 @@ def __init__(self, api_info_json): def get_all_api_names(self): functions = [] for file_entry in self.api_data: - for f in file_entry["result"]["functions"]: - # 只保留库函数(如函数名前缀筛选) - if f["name"].startswith("cJSON"): - functions.append(f["name"]) - return functions + for func in file_entry["result"]["functions"]: + name = func["name"] + if not self.api_prefixes or any(name.startswith(prefix) for prefix in self.api_prefixes): + functions.append(name) + return sorted(set(functions)) def generate_prompt(self, num_funcs=5): diff --git a/src/dgf_prompt_generator/test_prompt_gen.py b/src/dgf_prompt_generator/test_prompt_gen.py index 979ccc3..779ad31 100644 --- a/src/dgf_prompt_generator/test_prompt_gen.py +++ b/src/dgf_prompt_generator/test_prompt_gen.py @@ -1,26 +1,43 @@ -# from dgf_prompt_generator import config +import json + from dgf_prompt_generator.prompt_template import PromptTemplate -from dgf_prompt_generator.llm_caller import LLMCaller -import time -import os -prompt_gen = PromptTemplate("/home/lanjiachen/DGF/src/data/cjson_extracted.json") -llm = LLMCaller() -apiname= ["cJSON_AddObjectToObject"] -prompt = prompt_gen.generate_prompt_from_api_list(apiname) -print("======== Prompt ========") -print(prompt) -code = llm.generate_code(prompt) -print("======== Generated Code ========") -print(code) +def test_prompt_template_filters_prefix_and_generates_signature(tmp_path): + api_json = tmp_path / "api.json" + api_json.write_text( + json.dumps( + [ + { + "file": "x.h", + "result": { + "functions": [ + { + "name": "cJSON_AddObjectToObject", + "result_type": "void", + "parameters": [{"name": "obj", "type": "void*"}], + }, + { + "name": "OtherFunc", + "result_type": "int", + "parameters": [], + }, + ] + }, + } + ] + ) + ) + + template = PromptTemplate( + str(api_json), + system_includes=["stdint.h", "cJSON.h"], + api_prefixes=["cJSON"], + ) -# Save the generated prompt and code to files -timestamp = time.strftime("%Y%m%d_%H%M%S") -prompt_filename = f"./data/generated_prompt_{timestamp}.txt" -code_filename = f"./data/generated_code_{timestamp}.c" + all_names = template.get_all_api_names() + assert all_names == ["cJSON_AddObjectToObject"] -with open(prompt_filename, "w") as f: - f.write(prompt) -with open(code_filename, "w") as f: - f.write(code) + prompt = template.generate_prompt_from_api_list(["cJSON_AddObjectToObject"]) + assert "#include " in prompt + assert "cJSON_AddObjectToObject" in prompt diff --git a/src/dgf_validator/fuzzer_runner.py b/src/dgf_validator/fuzzer_runner.py index 3c6c2e5..746c104 100644 --- a/src/dgf_validator/fuzzer_runner.py +++ b/src/dgf_validator/fuzzer_runner.py @@ -1,7 +1,10 @@ # src/dgf_validator/fuzzer_runner.py -import subprocess +import logging import os +import subprocess + +LOGGER = logging.getLogger(__name__) class FuzzerRunner: def __init__(self, timeout_sec=10, max_input_size=4096): @@ -20,14 +23,14 @@ def run_libfuzzer(self, binary_path, work_dir): "-close_fd_mask=3" ] - print("Launching libFuzzer run:", " ".join(cmd)) + LOGGER.info("Launching libFuzzer run: %s", " ".join(cmd)) env = os.environ.copy() try: subprocess.run(cmd, timeout=self.timeout_sec + 5, check=True, env=env) return True except subprocess.TimeoutExpired: - print(f"Fuzzing timeout for {binary_path}") + LOGGER.warning("Fuzzing timeout for %s", binary_path) return False - except subprocess.CalledProcessError as e: - print(f"Fuzzing crash detected for {binary_path}") + except subprocess.CalledProcessError: + LOGGER.warning("Fuzzing crash detected for %s", binary_path) return False diff --git a/src/dgf_validator/test_validator.py b/src/dgf_validator/test_validator.py index e2db513..529cf19 100644 --- a/src/dgf_validator/test_validator.py +++ b/src/dgf_validator/test_validator.py @@ -1,18 +1,35 @@ import os +import subprocess +from types import SimpleNamespace + from dgf_validator.validator import Validator -from dgf_validator.runner import Runner - -# 初始化模块 -validator = Validator(clang_path="clang-14") -runner = Runner() - -# LLM 生成输出目录 -fuzz_output_dir = "../data/fuzz_output_test" -include_dirs = ["../testdata/cJSON", "/usr/include", "/usr/local/include"] - -for filename in os.listdir(fuzz_output_dir): - if filename.endswith(".c"): - src_file = os.path.join(fuzz_output_dir, filename) - success, binary = validator.validate_source(src_file, include_dirs=include_dirs) - if success: - runner.run_binary(binary) + + +def test_validate_source_uses_include_and_lib_flags(monkeypatch, tmp_path): + src_file = tmp_path / "driver.c" + src_file.write_text("int LLVMFuzzerTestOneInput(const unsigned char*d, unsigned long s){return 0;}") + + captured = {} + + def fake_run(cmd, check, stdout, stderr, text): + captured["cmd"] = cmd + return SimpleNamespace(returncode=0, stdout="", stderr="") + + monkeypatch.setattr(subprocess, "run", fake_run) + + validator = Validator( + clang_path="clang", + work_dir=str(tmp_path / "validated"), + lib_dir="/tmp/libs", + libs=["cjson", "-lfoo"], + ) + success, binary = validator.validate_source(str(src_file), include_dirs=["/tmp/include"]) + + assert success is True + assert os.path.basename(binary) == "driver" + assert "-I" in captured["cmd"] + assert "/tmp/include" in captured["cmd"] + assert "-L/tmp/libs" in captured["cmd"] + assert "-Wl,-rpath,/tmp/libs" in captured["cmd"] + assert "-lcjson" in captured["cmd"] + assert "-lfoo" in captured["cmd"] diff --git a/src/dgf_validator/validator.py b/src/dgf_validator/validator.py index dffaef7..029e269 100644 --- a/src/dgf_validator/validator.py +++ b/src/dgf_validator/validator.py @@ -1,19 +1,30 @@ -import subprocess +import logging import os import re +import subprocess + +LOGGER = logging.getLogger(__name__) class Validator: - def __init__(self, clang_path="clang-14", work_dir="./validated"): + def __init__( + self, + clang_path="clang", + work_dir="./validated", + lib_dir=None, + libs=None, + extra_link_flags=None, + ): self.clang = clang_path self.work_dir = work_dir os.makedirs(work_dir, exist_ok=True) - self.extra_link_flags = [ - "-lm", "-lpthread", "-ldl", - "-L/home/lanjiachen/DGF/testdata/cJSON/build", - "-Wl,-rpath,/home/lanjiachen/DGF/testdata/cJSON/build", - "-lcjson", "-lcjson_utils" - ] + libs = libs or [] + self.extra_link_flags = ["-lm", "-lpthread", "-ldl"] + if lib_dir: + self.extra_link_flags.extend([f"-L{lib_dir}", f"-Wl,-rpath,{lib_dir}"]) + self.extra_link_flags.extend(_normalize_lib_flags(libs)) + if extra_link_flags: + self.extra_link_flags.extend(extra_link_flags) self.fuzzer_flags = [ "-fsanitize=fuzzer,address,undefined", @@ -21,7 +32,8 @@ def __init__(self, clang_path="clang-14", work_dir="./validated"): "-O0", "-g" ] - def validate_source(self, src_file, include_dirs=[], max_retry=3): + def validate_source(self, src_file, include_dirs=None, max_retry=3): + include_dirs = include_dirs or [] output_binary = os.path.join(self.work_dir, os.path.basename(src_file).replace(".c", "")) for attempt in range(max_retry): @@ -31,24 +43,41 @@ def validate_source(self, src_file, include_dirs=[], max_retry=3): compile_cmd.extend(self.extra_link_flags) compile_cmd.extend(["-o", output_binary]) - print(f"Compiling (attempt {attempt+1}):", " ".join(compile_cmd)) + LOGGER.info("Compiling (attempt %d): %s", attempt + 1, " ".join(compile_cmd)) try: - subprocess.run(compile_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + subprocess.run( + compile_cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) return True, output_binary except subprocess.CalledProcessError as e: - stderr_output = e.stderr.decode() - print(f"Compilation failed (attempt {attempt+1}):\n{stderr_output}") + stderr_output = e.stderr + LOGGER.warning("Compilation failed (attempt %d):\n%s", attempt + 1, stderr_output) # 检测 undefined reference to `__xxx` undefined_refs = re.findall(r"undefined reference to `(__[a-zA-Z0-9_]+)`", stderr_output) if undefined_refs and attempt + 1 < max_retry: # 尝试自动修复,追加 -lm - print("Detected undefined internal reference(s):", undefined_refs) - self.extra_link_flags.append("-lm") - print("AutoFixer: Retrying compilation with additional -lm") + LOGGER.info("Detected undefined internal reference(s): %s", undefined_refs) + if "-lm" not in self.extra_link_flags: + self.extra_link_flags.append("-lm") + LOGGER.info("AutoFixer: retrying compilation with additional -lm") continue else: return False, None return False, None + + +def _normalize_lib_flags(libs): + normalized = [] + for lib in libs: + if lib.startswith("-l"): + normalized.append(lib) + else: + normalized.append(f"-l{lib}") + return normalized diff --git a/src/gen_fuzz_driver.py b/src/gen_fuzz_driver.py index 5d67f2d..140dab8 100644 --- a/src/gen_fuzz_driver.py +++ b/src/gen_fuzz_driver.py @@ -1,12 +1,15 @@ import argparse -import yaml -import json +import logging import os -from dgf_header_parser.extractor import extract_all_api -from dgf_prompt_generator.prompt_template import PromptTemplate +import yaml + +from dgf_common.code_utils import extract_c_code_block +from dgf_common.logging_utils import configure_logging from dgf_prompt_generator.llm_caller import LLMCaller -from dgf_feedback.feedback_controller import FeedbackController +from dgf_prompt_generator.prompt_template import PromptTemplate + +LOGGER = logging.getLogger(__name__) def generate_seed_prompt(config): """ @@ -16,37 +19,34 @@ def generate_seed_prompt(config): output_dir = config['prompt_generation']['output_dir'] samples = config['prompt_generation'].get('samples', 5) num_funcs = config['prompt_generation'].get('num_funcs', 5) - - output_dir = "/home/lanjiachen/DGF/src/data/fuzz_output" - samples = config['prompt_generation'].get('samples', 5) - + system_includes = config['prompt_generation'].get('system_includes', []) + api_prefixes = config['prompt_generation'].get('api_prefixes', []) os.makedirs(output_dir, exist_ok=True) - prompt_template = PromptTemplate(api_json) + prompt_template = PromptTemplate( + api_json, + system_includes=system_includes, + api_prefixes=api_prefixes, + ) llm = LLMCaller() - print(f"[*] 开始生成 {samples} 个 fuzz driver种子") + LOGGER.info("开始生成 %d 个 fuzz driver 种子", samples) for i in range(samples): prompt = prompt_template.generate_prompt(num_funcs=num_funcs) code = llm.generate_code(prompt) - - start = code.find("```c") - end = code.find("```", start + 4) - if start != -1 and end != -1: - code = code[start + 4:end].strip() - else: - code = code.strip() + code = extract_c_code_block(code) with open(os.path.join(output_dir, f"fuzz_driver_{i}.c"), "w") as f: f.write(code) - print(f"[*] 生成fuzz driver: fuzz_driver_{i}.c") + LOGGER.info("生成 fuzz driver: fuzz_driver_%d.c", i) - print("[*] fuzz driver生成完成") + LOGGER.info("fuzz driver 生成完成") def main(): + configure_logging() parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) args = parser.parse_args() diff --git a/src/main.py b/src/main.py index a1d1ea4..56e069c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,12 +1,18 @@ import argparse -import yaml import json +import logging import os +import yaml + +from dgf_common.code_utils import extract_c_code_block +from dgf_common.logging_utils import configure_logging +from dgf_feedback.feedback_controller import FeedbackController from dgf_header_parser.extractor import extract_all_api -from dgf_prompt_generator.prompt_template import PromptTemplate from dgf_prompt_generator.llm_caller import LLMCaller -from dgf_feedback.feedback_controller import FeedbackController +from dgf_prompt_generator.prompt_template import PromptTemplate + +LOGGER = logging.getLogger(__name__) def extract_api(config): """ @@ -15,13 +21,15 @@ def extract_api(config): header_dir = config['api_extraction']['header_dir'] include_dirs = config['api_extraction'].get('include_dirs', []) - print(f"[*] 开始抽取API信息 from {header_dir}") + LOGGER.info("开始抽取API信息: %s", header_dir) results = extract_all_api(header_dir, include_dirs) output_path = config['api_extraction']['extracted_api_json'] + output_parent = os.path.dirname(output_path) or "." + os.makedirs(output_parent, exist_ok=True) with open(output_path, "w") as f: json.dump(results, f, indent=2) - print(f"[*] API信息保存至 {output_path}") + LOGGER.info("API信息保存至 %s", output_path) def generate_seed_prompt(config): """ @@ -31,29 +39,29 @@ def generate_seed_prompt(config): output_dir = config['prompt_generation']['output_dir'] samples = config['prompt_generation'].get('samples', 5) num_funcs = config['prompt_generation'].get('num_funcs', 5) + system_includes = config['prompt_generation'].get('system_includes', []) + api_prefixes = config['prompt_generation'].get('api_prefixes', []) os.makedirs(output_dir, exist_ok=True) - prompt_template = PromptTemplate(api_json) + prompt_template = PromptTemplate( + api_json, + system_includes=system_includes, + api_prefixes=api_prefixes, + ) llm = LLMCaller() - print(f"[*] 开始生成 {samples} 个种子 Prompt") + LOGGER.info("开始生成 %d 个种子 Prompt", samples) for i in range(samples): prompt = prompt_template.generate_prompt(num_funcs=num_funcs) code = llm.generate_code(prompt) - - start = code.find("```c") - end = code.find("```", start + 4) - if start != -1 and end != -1: - code = code[start + 4:end].strip() - else: - code = code.strip() + code = extract_c_code_block(code) with open(os.path.join(output_dir, f"fuzz_driver_{i}.c"), "w") as f: f.write(code) - print("[*] 种子Prompt生成完成") + LOGGER.info("种子 Prompt 生成完成") def run_feedback_loop(config): """ @@ -62,13 +70,16 @@ def run_feedback_loop(config): api_json = config['api_extraction']['extracted_api_json'] output_dir = config['feedback_iteration']['output_dir'] samples_per_round = config['feedback_iteration']['samples_per_round'] + system_includes = config['prompt_generation'].get('system_includes', []) + api_prefixes = config['prompt_generation'].get('api_prefixes', []) + fuzz_timeout_sec = config['feedback_iteration'].get('fuzz_timeout_sec', 20) clang_path = config['validator']['clang_path'] include_dirs = config['validator']['include_dirs'] lib_dir = config['validator']['lib_dir'] libs = config['validator']['libs'] - print(f"[*] 开始反馈循环,输出目录: {output_dir}") + LOGGER.info("开始反馈循环,输出目录: %s", output_dir) fc = FeedbackController( api_json=api_json, @@ -76,13 +87,17 @@ def run_feedback_loop(config): clang_path=clang_path, include_dirs=include_dirs, lib_dir=lib_dir, - libs=libs + libs=libs, + system_includes=system_includes, + api_prefixes=api_prefixes, + fuzz_timeout_sec=fuzz_timeout_sec, ) - print("[*] 初始化 FeedbackController 完成") + LOGGER.info("FeedbackController 初始化完成") fc.run_iteration(num_samples=samples_per_round) - print("[*] 反馈循环完成") + LOGGER.info("反馈循环完成") def main(): + configure_logging() parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) args = parser.parse_args() diff --git a/src/test.py b/src/test.py deleted file mode 100644 index e28dd61..0000000 --- a/src/test.py +++ /dev/null @@ -1 +0,0 @@ -import clang.cindex diff --git a/src/test_cjson.sh b/src/test_cjson.sh index 7b8f359..f1b7ac4 100755 --- a/src/test_cjson.sh +++ b/src/test_cjson.sh @@ -2,9 +2,8 @@ set -e -# 激活 miniconda 环境 -source ~/miniconda3/etc/profile.d/conda.sh -conda activate dgf310 +ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)" +cd "${ROOT_DIR}" # 下载 cJSON 库作为测试目标 mkdir -p testdata @@ -16,11 +15,11 @@ fi cd .. -# 设置 LIBCLANG_PATH(注意按你机器实际路径填写) -export LIBCLANG_PATH="/usr/lib/llvm-14/lib/libclang.so.1" +# 可选设置 LIBCLANG_PATH(按环境覆盖) +export LIBCLANG_PATH="${LIBCLANG_PATH:-/usr/lib/llvm-14/lib/libclang.so.1}" # 执行 header_parser 模块 -python3 header_parser/extractor.py \ +python3 src/dgf_header_parser/extractor.py \ --header_dir testdata/cJSON \ --include_dirs testdata/cJSON \ --output data/cjson_extracted.json diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py new file mode 100644 index 0000000..52b37a1 --- /dev/null +++ b/tests/test_code_utils.py @@ -0,0 +1,11 @@ +from dgf_common.code_utils import extract_c_code_block + + +def test_extract_c_code_block_from_fenced_markdown(): + raw = "hello\n```c\nint x = 1;\n```\nbye" + assert extract_c_code_block(raw) == "int x = 1;" + + +def test_extract_c_code_block_fallback_to_raw_text(): + raw = "int y = 2;" + assert extract_c_code_block(raw) == "int y = 2;" diff --git a/tests/test_llm_caller.py b/tests/test_llm_caller.py new file mode 100644 index 0000000..8fb50b5 --- /dev/null +++ b/tests/test_llm_caller.py @@ -0,0 +1,37 @@ +from types import SimpleNamespace + +import pytest + +from dgf_prompt_generator import llm_caller +from dgf_prompt_generator.llm_caller import LLMCaller + + +def test_llm_caller_requires_api_key(monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("DGF_API_KEY", raising=False) + monkeypatch.setattr(llm_caller, "local_config", None) + with pytest.raises(ValueError): + LLMCaller() + + +def test_llm_caller_generate_code(monkeypatch): + class FakeCompletions: + @staticmethod + def create(**kwargs): + _ = kwargs + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="```c\nint a=0;\n```"))] + ) + + class FakeClient: + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = SimpleNamespace(completions=FakeCompletions()) + + monkeypatch.setenv("OPENAI_API_KEY", "dummy") + monkeypatch.setattr(llm_caller.openai, "OpenAI", FakeClient) + + caller = LLMCaller(model="demo-model") + text = caller.generate_code("hi") + assert "int a=0;" in text