From 3762cdc8d5f55da63b3c1b4d1694a28aa5580239 Mon Sep 17 00:00:00 2001 From: Priyansh2116 Date: Fri, 5 Jun 2026 01:21:47 +0530 Subject: [PATCH 1/2] feat: add --temperature CLI flag (default 0) Expose the LLM sampling temperature as a `--temperature` Click option (float, default 0.0) so callers can trade determinism for creativity. Threads through `make_client()` into both `AnthropicClient` and `OpenAIClient`; behaviour is unchanged when the flag is omitted. Closes #1 Co-Authored-By: Claude Sonnet 4.6 --- src/promptquery/cli.py | 13 ++++++++++--- src/promptquery/llm.py | 29 +++++++++++++++-------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/promptquery/cli.py b/src/promptquery/cli.py index 2644c15..118b8ab 100644 --- a/src/promptquery/cli.py +++ b/src/promptquery/cli.py @@ -194,6 +194,13 @@ def run_question( show_default=True, help="Maximum tables sent to the LLM after FK expansion.", ) +@click.option( + "--temperature", + default=0.0, + show_default=True, + type=float, + help="Sampling temperature passed to the LLM (0 = deterministic).", +) @click.option( "--no-selector", is_flag=True, @@ -209,7 +216,7 @@ def run_question( def main(dsn: str, model: str | None, selector_model: str | None, query: str | None, out_format: str | None, top_k: int, select_n: int, max_tables: int, - no_selector: bool, yes: bool) -> None: + temperature: float, no_selector: bool, yes: bool) -> None: """PromptQuery — natural-language SQL for Postgres. DSN is a libpq connection string, e.g. postgresql://user:pass@host/db. @@ -234,7 +241,7 @@ def main(dsn: str, model: str | None, selector_model: str | None, ) try: - llm = make_client(model) + llm = make_client(model, temperature=temperature) except LLMError as e: progress.print(f"[red]Error:[/red] {e}") sys.exit(1) @@ -243,7 +250,7 @@ def main(dsn: str, model: str | None, selector_model: str | None, selector_llm = None else: try: - selector_llm = make_client(selector_model) if selector_model else llm + selector_llm = make_client(selector_model, temperature=temperature) if selector_model else llm except LLMError as e: progress.print(f"[red]Selector LLM error:[/red] {e}") sys.exit(1) diff --git a/src/promptquery/llm.py b/src/promptquery/llm.py index de30afa..08bf9d5 100644 --- a/src/promptquery/llm.py +++ b/src/promptquery/llm.py @@ -19,7 +19,8 @@ def generate(self, system: str, user: str) -> str: ... class AnthropicClient(LLMClient): name = "anthropic" - def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None): + def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None, + temperature: float = 0): try: import anthropic except ImportError as e: @@ -29,13 +30,13 @@ def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None) raise LLMError("ANTHROPIC_API_KEY is not set") self._client = anthropic.Anthropic(api_key=key) self.model = model + self.temperature = temperature def generate(self, system: str, user: str) -> str: response = self._client.messages.create( model=self.model, max_tokens=2000, - # Determinism is a feature: the same question should yield the same SQL. - temperature=0, + temperature=self.temperature, system=system, messages=[{"role": "user", "content": user}], ) @@ -50,7 +51,8 @@ def generate(self, system: str, user: str) -> str: class OpenAIClient(LLMClient): name = "openai" - def __init__(self, model: str = "gpt-4o", api_key: str | None = None): + def __init__(self, model: str = "gpt-4o", api_key: str | None = None, + temperature: float = 0): try: import openai except ImportError as e: @@ -60,6 +62,7 @@ def __init__(self, model: str = "gpt-4o", api_key: str | None = None): raise LLMError("OPENAI_API_KEY is not set") self._client = openai.OpenAI(api_key=key) self.model = model + self.temperature = temperature # Reasoning-class OpenAI models (GPT-5.x, o1, o3, o4) accept # `max_completion_tokens` and reject the legacy `max_tokens` parameter. @@ -78,9 +81,7 @@ def generate(self, system: str, user: str) -> str: kwargs["max_completion_tokens"] = 4000 else: kwargs["max_tokens"] = 2000 - # Determinism is a feature: same question -> same SQL. temperature=0 - # plus a fixed seed gives best-effort reproducibility on chat models. - kwargs["temperature"] = 0 + kwargs["temperature"] = self.temperature kwargs["seed"] = 0 response = self._client.chat.completions.create(**kwargs) return response.choices[0].message.content or "" @@ -96,29 +97,29 @@ def extract_sql(text: str) -> str: return text.strip().rstrip(";").strip() -def make_client(model_spec: str | None = None) -> LLMClient: +def make_client(model_spec: str | None = None, temperature: float = 0) -> LLMClient: if model_spec: provider, _, model = model_spec.partition("/") if not model: # No explicit provider — guess from model name if provider.startswith("claude"): - return AnthropicClient(model=provider) + return AnthropicClient(model=provider, temperature=temperature) if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"): - return OpenAIClient(model=provider) + return OpenAIClient(model=provider, temperature=temperature) raise LLMError( f"Cannot infer provider from model {provider!r}. " "Use 'anthropic/' or 'openai/'." ) if provider == "anthropic": - return AnthropicClient(model=model) + return AnthropicClient(model=model, temperature=temperature) if provider == "openai": - return OpenAIClient(model=model) + return OpenAIClient(model=model, temperature=temperature) raise LLMError(f"Unknown provider: {provider!r}") if os.environ.get("ANTHROPIC_API_KEY"): - return AnthropicClient() + return AnthropicClient(temperature=temperature) if os.environ.get("OPENAI_API_KEY"): - return OpenAIClient() + return OpenAIClient(temperature=temperature) raise LLMError( "No LLM API key found. Set ANTHROPIC_API_KEY or OPENAI_API_KEY." ) From aa5c0fa22872d13119106ae62e068ffa3b86995f Mon Sep 17 00:00:00 2001 From: Priyansh2116 Date: Fri, 5 Jun 2026 14:22:06 +0530 Subject: [PATCH 2/2] fix: use FloatRange(0,2) for --temperature and note reasoning-model skip - Switch type=float to type=click.FloatRange(0.0, 2.0) so out-of-range values produce a clean CLI error instead of an opaque provider 400 - Add a comment in the OpenAI reasoning-model branch explaining why self.temperature is not forwarded there (those models reject it) Co-Authored-By: Claude Sonnet 4.6 --- src/promptquery/cli.py | 4 ++-- src/promptquery/llm.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/promptquery/cli.py b/src/promptquery/cli.py index 118b8ab..133f6f9 100644 --- a/src/promptquery/cli.py +++ b/src/promptquery/cli.py @@ -198,8 +198,8 @@ def run_question( "--temperature", default=0.0, show_default=True, - type=float, - help="Sampling temperature passed to the LLM (0 = deterministic).", + type=click.FloatRange(0.0, 2.0), + help="Sampling temperature passed to the LLM (0 = deterministic, max 2.0).", ) @click.option( "--no-selector", diff --git a/src/promptquery/llm.py b/src/promptquery/llm.py index 08bf9d5..e3572c5 100644 --- a/src/promptquery/llm.py +++ b/src/promptquery/llm.py @@ -78,6 +78,7 @@ def generate(self, system: str, user: str) -> str: } if self.model.startswith(self._REASONING_PREFIXES): # Reasoning models reject `temperature`/`seed`; they sample internally. + # self.temperature is intentionally not forwarded here. kwargs["max_completion_tokens"] = 4000 else: kwargs["max_tokens"] = 2000