Add pluggable LLM support with Gemini provider
- Add LLMProvider registry (llm/registry.py) that builds a provider from env vars (LLM_PROVIDER, GEMINI_API_KEY, GEMINI_MODEL) - Add GeminiLLMProvider using the google-genai SDK - Wire build_llm_provider() into CLI and web pipeline route (replacing llm=None) - Wrap pass 2 and pass 4 LLM calls in per-combo try/except so API errors skip individual combos rather than aborting the whole run - Add gemini optional dep to pyproject.toml; Dockerfile installs [web,gemini] - Document env vars in .env.example and README - Lower requires-python to >=3.10 to match installed system Python Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -135,7 +135,8 @@ def run(ctx, domain_name, passes, threshold, dimensions):
|
||||
pass_list = [int(p.strip()) for p in passes.split(",")]
|
||||
dim_list = [d.strip() for d in dimensions.split(",")]
|
||||
|
||||
pipeline = Pipeline(repo, resolver, scorer, llm=None)
|
||||
from physcom.llm.registry import build_llm_provider
|
||||
pipeline = Pipeline(repo, resolver, scorer, llm=build_llm_provider())
|
||||
click.echo(f"Running pipeline for '{domain_name}' (passes={pass_list}, threshold={threshold})")
|
||||
click.echo(f"Dimensions: {dim_list}")
|
||||
|
||||
@@ -209,16 +210,12 @@ def review(ctx, combination_id):
|
||||
notes = click.prompt("Human notes (or empty)", default="")
|
||||
|
||||
if novelty != "skip" or notes:
|
||||
# Get all domains this combo has results for
|
||||
rows = repo.conn.execute(
|
||||
"SELECT domain_id, composite_score FROM combination_results WHERE combination_id = ?",
|
||||
(combo.id,),
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
for row in repo.get_results_for_combination(combo.id):
|
||||
repo.save_result(
|
||||
combo.id, row["domain_id"], row["composite_score"],
|
||||
pass_reached=5,
|
||||
novelty_flag=novelty if novelty != "skip" else None,
|
||||
llm_review=row.get("llm_review"),
|
||||
human_notes=notes or None,
|
||||
)
|
||||
repo.update_combination_status(combo.id, "reviewed")
|
||||
|
||||
@@ -184,9 +184,12 @@ class Pipeline:
|
||||
if 2 in passes and existing_pass < 2:
|
||||
description = _describe_combination(combo)
|
||||
if self.llm:
|
||||
raw_metrics = self.llm.estimate_physics(
|
||||
description, metric_names
|
||||
)
|
||||
try:
|
||||
raw_metrics = self.llm.estimate_physics(
|
||||
description, metric_names
|
||||
)
|
||||
except Exception:
|
||||
raw_metrics = self._stub_estimate(combo, metric_names)
|
||||
else:
|
||||
raw_metrics = self._stub_estimate(combo, metric_names)
|
||||
|
||||
@@ -284,32 +287,34 @@ class Pipeline:
|
||||
and cur_result["composite_score"] is not None
|
||||
and cur_result["composite_score"] >= score_threshold
|
||||
):
|
||||
description = _describe_combination(combo)
|
||||
db_scores = self.repo.get_combination_scores(
|
||||
combo.id, domain.id
|
||||
)
|
||||
score_dict = {
|
||||
s["metric_name"]: s["normalized_score"]
|
||||
for s in db_scores
|
||||
if s["normalized_score"] is not None
|
||||
}
|
||||
review = self.llm.review_plausibility(
|
||||
description, score_dict
|
||||
)
|
||||
|
||||
self.repo.save_result(
|
||||
combo.id,
|
||||
domain.id,
|
||||
cur_result["composite_score"],
|
||||
pass_reached=4,
|
||||
novelty_flag=cur_result.get("novelty_flag"),
|
||||
llm_review=review,
|
||||
human_notes=cur_result.get("human_notes"),
|
||||
)
|
||||
result.pass4_reviewed += 1
|
||||
self._update_run_counters(
|
||||
run_id, result, current_pass=4
|
||||
)
|
||||
try:
|
||||
description = _describe_combination(combo)
|
||||
db_scores = self.repo.get_combination_scores(
|
||||
combo.id, domain.id
|
||||
)
|
||||
score_dict = {
|
||||
s["metric_name"]: s["normalized_score"]
|
||||
for s in db_scores
|
||||
if s["normalized_score"] is not None
|
||||
}
|
||||
review = self.llm.review_plausibility(
|
||||
description, score_dict
|
||||
)
|
||||
self.repo.save_result(
|
||||
combo.id,
|
||||
domain.id,
|
||||
cur_result["composite_score"],
|
||||
pass_reached=4,
|
||||
novelty_flag=cur_result.get("novelty_flag"),
|
||||
llm_review=review,
|
||||
human_notes=cur_result.get("human_notes"),
|
||||
)
|
||||
result.pass4_reviewed += 1
|
||||
self._update_run_counters(
|
||||
run_id, result, current_pass=4
|
||||
)
|
||||
except Exception:
|
||||
pass # skip this combo; don't abort the run
|
||||
|
||||
except CancelledError:
|
||||
if run_id is not None:
|
||||
|
||||
57
src/physcom/llm/providers/gemini.py
Normal file
57
src/physcom/llm/providers/gemini.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Gemini LLM provider via google-genai SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from physcom.llm.base import LLMProvider
|
||||
from physcom.llm.prompts import PHYSICS_ESTIMATION_PROMPT, PLAUSIBILITY_REVIEW_PROMPT
|
||||
|
||||
|
||||
class GeminiLLMProvider(LLMProvider):
|
||||
"""LLM provider backed by Google Gemini."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gemini-2.0-flash") -> None:
|
||||
try:
|
||||
from google import genai
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-genai is required: pip install 'physcom[gemini]'"
|
||||
)
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
self._model = model
|
||||
|
||||
def estimate_physics(
|
||||
self, combination_description: str, metrics: list[str]
|
||||
) -> dict[str, float]:
|
||||
prompt = PHYSICS_ESTIMATION_PROMPT.format(
|
||||
description=combination_description,
|
||||
metrics=", ".join(metrics),
|
||||
)
|
||||
response = self._client.models.generate_content(
|
||||
model=self._model, contents=prompt
|
||||
)
|
||||
return self._parse_json(response.text, metrics)
|
||||
|
||||
def review_plausibility(
|
||||
self, combination_description: str, scores: dict[str, float]
|
||||
) -> str:
|
||||
scores_str = "\n".join(f"- {k}: {v:.3f}" for k, v in scores.items())
|
||||
prompt = PLAUSIBILITY_REVIEW_PROMPT.format(
|
||||
description=combination_description,
|
||||
scores=scores_str,
|
||||
)
|
||||
response = self._client.models.generate_content(
|
||||
model=self._model, contents=prompt
|
||||
)
|
||||
return response.text.strip()
|
||||
|
||||
def _parse_json(self, text: str, metrics: list[str]) -> dict[str, float]:
|
||||
"""Strip markdown fences and parse JSON; fall back to 0.5 per metric on error."""
|
||||
text = re.sub(r"```(?:json)?\s*", "", text).strip().rstrip("`").strip()
|
||||
try:
|
||||
data = json.loads(text)
|
||||
return {k: float(v) for k, v in data.items() if k in metrics}
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
return {m: 0.5 for m in metrics}
|
||||
30
src/physcom/llm/registry.py
Normal file
30
src/physcom/llm/registry.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Build an LLMProvider from environment variables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from physcom.llm.base import LLMProvider
|
||||
|
||||
|
||||
def build_llm_provider() -> LLMProvider | None:
|
||||
"""Return an LLMProvider based on env vars, or None if not configured.
|
||||
|
||||
LLM_PROVIDER — provider name ('gemini'; more can be added)
|
||||
GEMINI_API_KEY — required when LLM_PROVIDER=gemini
|
||||
GEMINI_MODEL — optional Gemini model name (default: gemini-2.0-flash)
|
||||
"""
|
||||
provider = os.environ.get("LLM_PROVIDER", "").lower().strip()
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
if provider == "gemini":
|
||||
api_key = os.environ.get("GEMINI_API_KEY", "")
|
||||
if not api_key:
|
||||
raise ValueError("LLM_PROVIDER=gemini requires GEMINI_API_KEY to be set")
|
||||
model = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
||||
from physcom.llm.providers.gemini import GeminiLLMProvider
|
||||
return GeminiLLMProvider(api_key=api_key, model=model)
|
||||
|
||||
raise ValueError(f"Unknown LLM_PROVIDER: {provider!r}. Supported: gemini")
|
||||
Reference in New Issue
Block a user