Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/promptquery/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ def run_question(
show_default=True,
help="Maximum tables sent to the LLM after FK expansion.",
)
@click.option(
"--temperature",
default=0.0,
show_default=True,
type=click.FloatRange(0.0, 2.0),
help="Sampling temperature passed to the LLM (0 = deterministic, max 2.0).",
)
@click.option(
"--no-selector",
is_flag=True,
Expand All @@ -209,7 +216,7 @@ def run_question(
def main(dsn: str, model: str | None, selector_model: str | None,
query: str | None, out_format: str | None,
top_k: int, select_n: int, max_tables: int,
no_selector: bool, yes: bool) -> None:
temperature: float, no_selector: bool, yes: bool) -> None:
"""PromptQuery — natural-language SQL for Postgres.

DSN is a libpq connection string, e.g. postgresql://user:pass@host/db.
Expand All @@ -234,7 +241,7 @@ def main(dsn: str, model: str | None, selector_model: str | None,
)

try:
llm = make_client(model)
llm = make_client(model, temperature=temperature)
except LLMError as e:
progress.print(f"[red]Error:[/red] {e}")
sys.exit(1)
Expand All @@ -243,7 +250,7 @@ def main(dsn: str, model: str | None, selector_model: str | None,
selector_llm = None
else:
try:
selector_llm = make_client(selector_model) if selector_model else llm
selector_llm = make_client(selector_model, temperature=temperature) if selector_model else llm
except LLMError as e:
progress.print(f"[red]Selector LLM error:[/red] {e}")
sys.exit(1)
Expand Down
30 changes: 16 additions & 14 deletions src/promptquery/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def generate(self, system: str, user: str) -> str: ...
class AnthropicClient(LLMClient):
name = "anthropic"

def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None):
def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None,
temperature: float = 0):
try:
import anthropic
except ImportError as e:
Expand All @@ -29,13 +30,13 @@ def __init__(self, model: str = "claude-sonnet-4-6", api_key: str | None = None)
raise LLMError("ANTHROPIC_API_KEY is not set")
self._client = anthropic.Anthropic(api_key=key)
self.model = model
self.temperature = temperature

def generate(self, system: str, user: str) -> str:
response = self._client.messages.create(
model=self.model,
max_tokens=2000,
# Determinism is a feature: the same question should yield the same SQL.
temperature=0,
temperature=self.temperature,
system=system,
messages=[{"role": "user", "content": user}],
)
Expand All @@ -50,7 +51,8 @@ def generate(self, system: str, user: str) -> str:
class OpenAIClient(LLMClient):
name = "openai"

def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
def __init__(self, model: str = "gpt-4o", api_key: str | None = None,
temperature: float = 0):
try:
import openai
except ImportError as e:
Expand All @@ -60,6 +62,7 @@ def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
raise LLMError("OPENAI_API_KEY is not set")
self._client = openai.OpenAI(api_key=key)
self.model = model
self.temperature = temperature

# Reasoning-class OpenAI models (GPT-5.x, o1, o3, o4) accept
# `max_completion_tokens` and reject the legacy `max_tokens` parameter.
Expand All @@ -75,12 +78,11 @@ def generate(self, system: str, user: str) -> str:
}
if self.model.startswith(self._REASONING_PREFIXES):
# Reasoning models reject `temperature`/`seed`; they sample internally.
# self.temperature is intentionally not forwarded here.
kwargs["max_completion_tokens"] = 4000
else:
kwargs["max_tokens"] = 2000
# Determinism is a feature: same question -> same SQL. temperature=0
# plus a fixed seed gives best-effort reproducibility on chat models.
kwargs["temperature"] = 0
kwargs["temperature"] = self.temperature
kwargs["seed"] = 0
response = self._client.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
Expand All @@ -96,29 +98,29 @@ def extract_sql(text: str) -> str:
return text.strip().rstrip(";").strip()


def make_client(model_spec: str | None = None) -> LLMClient:
def make_client(model_spec: str | None = None, temperature: float = 0) -> LLMClient:
if model_spec:
provider, _, model = model_spec.partition("/")
if not model:
# No explicit provider — guess from model name
if provider.startswith("claude"):
return AnthropicClient(model=provider)
return AnthropicClient(model=provider, temperature=temperature)
if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"):
return OpenAIClient(model=provider)
return OpenAIClient(model=provider, temperature=temperature)
raise LLMError(
f"Cannot infer provider from model {provider!r}. "
"Use 'anthropic/<model>' or 'openai/<model>'."
)
if provider == "anthropic":
return AnthropicClient(model=model)
return AnthropicClient(model=model, temperature=temperature)
if provider == "openai":
return OpenAIClient(model=model)
return OpenAIClient(model=model, temperature=temperature)
raise LLMError(f"Unknown provider: {provider!r}")

if os.environ.get("ANTHROPIC_API_KEY"):
return AnthropicClient()
return AnthropicClient(temperature=temperature)
if os.environ.get("OPENAI_API_KEY"):
return OpenAIClient()
return OpenAIClient(temperature=temperature)
raise LLMError(
"No LLM API key found. Set ANTHROPIC_API_KEY or OPENAI_API_KEY."
)