465 lines
20 KiB
Python
465 lines
20 KiB
Python
"""Multi-pass pipeline orchestrator with incremental saves and resumability."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
|
|
from physcom.db.repository import Repository
|
|
from physcom.engine.combinator import generate_combinations
|
|
from physcom.engine.constraint_resolver import ConstraintResolver, ConstraintResult
|
|
from physcom.engine.scorer import Scorer
|
|
from physcom.llm.base import LLMProvider, LLMRateLimitError
|
|
from physcom.models.combination import Combination, ScoredResult
|
|
from physcom.models.domain import Domain
|
|
|
|
|
|
@dataclass
|
|
class PipelineResult:
|
|
"""Summary of a pipeline run."""
|
|
|
|
total_generated: int = 0
|
|
pass1_valid: int = 0
|
|
pass1_failed: int = 0
|
|
pass1_conditional: int = 0
|
|
pass2_estimated: int = 0
|
|
pass2_failed: int = 0
|
|
pass3_scored: int = 0
|
|
pass3_above_threshold: int = 0
|
|
pass3_failed: int = 0
|
|
pass4_reviewed: int = 0
|
|
pass4_failed: int = 0
|
|
pass5_human_reviewed: int = 0
|
|
top_results: list[dict] = field(default_factory=list)
|
|
|
|
|
|
class CancelledError(Exception):
|
|
"""Raised when a pipeline run is cancelled."""
|
|
|
|
|
|
def _describe_combination(combo: Combination) -> str:
|
|
"""Build a natural-language description of a combination."""
|
|
parts = [f"{e.dimension}: {e.name}" for e in combo.entities]
|
|
descriptions = [e.description for e in combo.entities if e.description]
|
|
header = " + ".join(parts)
|
|
detail = "; ".join(descriptions)
|
|
return f"{header}. {detail}"
|
|
|
|
|
|
class Pipeline:
|
|
"""Orchestrates the multi-pass viability pipeline."""
|
|
|
|
def __init__(
|
|
self,
|
|
repo: Repository,
|
|
resolver: ConstraintResolver,
|
|
scorer: Scorer,
|
|
llm: LLMProvider | None = None,
|
|
) -> None:
|
|
self.repo = repo
|
|
self.resolver = resolver
|
|
self.scorer = scorer
|
|
self.llm = llm
|
|
|
|
def _check_cancelled(self, run_id: int | None) -> None:
|
|
"""Raise CancelledError if the run has been cancelled."""
|
|
if run_id is None:
|
|
return
|
|
run = self.repo.get_pipeline_run(run_id)
|
|
if run and run["status"] == "cancelled":
|
|
raise CancelledError("Pipeline run cancelled")
|
|
|
|
def _update_run_counters(
|
|
self, run_id: int | None, result: PipelineResult, current_pass: int
|
|
) -> None:
|
|
"""Update pipeline_run progress counters in the DB."""
|
|
if run_id is None:
|
|
return
|
|
self.repo.update_pipeline_run(
|
|
run_id,
|
|
combos_pass1=result.pass1_valid
|
|
+ result.pass1_conditional
|
|
+ result.pass1_failed,
|
|
combos_pass2=result.pass2_estimated,
|
|
combos_pass3=result.pass3_scored,
|
|
combos_pass4=result.pass4_reviewed,
|
|
current_pass=current_pass,
|
|
)
|
|
|
|
def run(
|
|
self,
|
|
domain: Domain,
|
|
dimensions: list[str],
|
|
score_threshold: float = 0.1,
|
|
passes: list[int] | None = None,
|
|
run_id: int | None = None,
|
|
) -> PipelineResult:
|
|
if passes is None:
|
|
passes = [1, 2, 3, 4, 5]
|
|
|
|
result = PipelineResult()
|
|
|
|
# Mark run as running (unless already cancelled)
|
|
if run_id is not None:
|
|
run_record = self.repo.get_pipeline_run(run_id)
|
|
if run_record and run_record["status"] == "cancelled":
|
|
result.top_results = self.repo.get_top_results(domain.name, limit=20)
|
|
return result
|
|
self.repo.update_pipeline_run(
|
|
run_id,
|
|
status="running",
|
|
started_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
# Generate all combinations
|
|
combos = generate_combinations(self.repo, dimensions)
|
|
result.total_generated = len(combos)
|
|
|
|
# Save all combinations to DB (also loads status for existing combos)
|
|
for combo in combos:
|
|
self.repo.save_combination(combo)
|
|
|
|
if run_id is not None:
|
|
self.repo.update_pipeline_run(run_id, total_combos=len(combos))
|
|
|
|
# Prepare metric lookup
|
|
metric_names = [mb.metric_name for mb in domain.metric_bounds]
|
|
bounds_by_name = {mb.metric_name: mb for mb in domain.metric_bounds}
|
|
|
|
# ── Combo-first loop ─────────────────────────────────────
|
|
try:
|
|
for combo in combos:
|
|
self._check_cancelled(run_id)
|
|
|
|
# Check existing progress for this combo in this domain
|
|
existing_pass = self.repo.get_combo_pass_reached(
|
|
combo.id, domain.id
|
|
) or 0
|
|
|
|
# Load existing result to preserve human review data
|
|
existing_result = self.repo.get_existing_result(
|
|
combo.id, domain.id
|
|
)
|
|
|
|
# ── Pass 1: Constraint Resolution ────────────────
|
|
if 1 in passes and existing_pass < 1:
|
|
cr: ConstraintResult = self.resolver.resolve(combo)
|
|
if cr.status == "p1_fail":
|
|
combo.status = "p1_fail"
|
|
combo.block_reason = "; ".join(cr.violations)
|
|
self.repo.update_combination_status(
|
|
combo.id, "p1_fail", combo.block_reason
|
|
)
|
|
# Save a result row so failed combos appear in results
|
|
self.repo.save_result(
|
|
combo.id,
|
|
domain.id,
|
|
composite_score=0.0,
|
|
pass_reached=1,
|
|
)
|
|
result.pass1_failed += 1
|
|
self._update_run_counters(run_id, result, current_pass=1)
|
|
continue # p1_fail — skip remaining passes
|
|
else:
|
|
combo.status = "valid"
|
|
self.repo.update_combination_status(combo.id, "valid")
|
|
if cr.status == "conditional":
|
|
result.pass1_conditional += 1
|
|
else:
|
|
result.pass1_valid += 1
|
|
|
|
self._update_run_counters(run_id, result, current_pass=1)
|
|
elif 1 in passes:
|
|
# Already pass1'd — check if it failed
|
|
if combo.status.endswith("_fail"):
|
|
result.pass1_failed += 1
|
|
continue
|
|
else:
|
|
result.pass1_valid += 1
|
|
else:
|
|
# Pass 1 not requested; check if failed from a prior run
|
|
if combo.status.endswith("_fail"):
|
|
result.pass1_failed += 1
|
|
continue
|
|
|
|
# ── Pass 2: Physics Estimation ───────────────────
|
|
raw_metrics: dict[str, float] = {}
|
|
if 2 in passes and existing_pass < 2:
|
|
description = _describe_combination(combo)
|
|
if self.llm:
|
|
raw_metrics = self.llm.estimate_physics(
|
|
description, metric_names
|
|
)
|
|
else:
|
|
raw_metrics = self._stub_estimate(combo, metric_names)
|
|
|
|
# Save raw estimates immediately (crash-safe)
|
|
estimate_dicts = []
|
|
for mname, rval in raw_metrics.items():
|
|
mb = bounds_by_name.get(mname)
|
|
if mb and mb.metric_id:
|
|
estimate_dicts.append({
|
|
"metric_id": mb.metric_id,
|
|
"raw_value": rval,
|
|
"estimation_method": "llm" if self.llm else "stub",
|
|
"confidence": 1.0,
|
|
})
|
|
if estimate_dicts:
|
|
self.repo.save_raw_estimates(
|
|
combo.id, domain.id, estimate_dicts
|
|
)
|
|
|
|
# Check for all-zero estimates → p2_fail
|
|
if raw_metrics and all(v == 0.0 for v in raw_metrics.values()):
|
|
combo.status = "p2_fail"
|
|
combo.block_reason = "All metric estimates are zero"
|
|
self.repo.update_combination_status(
|
|
combo.id, "p2_fail", combo.block_reason
|
|
)
|
|
self.repo.save_result(
|
|
combo.id, domain.id,
|
|
composite_score=0.0, pass_reached=2,
|
|
)
|
|
result.pass2_failed += 1
|
|
self._update_run_counters(run_id, result, current_pass=2)
|
|
continue
|
|
|
|
result.pass2_estimated += 1
|
|
self._update_run_counters(run_id, result, current_pass=2)
|
|
elif 2 in passes:
|
|
# Already estimated — reload raw values from DB
|
|
existing_scores = self.repo.get_combination_scores(
|
|
combo.id, domain.id
|
|
)
|
|
raw_metrics = {
|
|
s["metric_name"]: s["raw_value"] for s in existing_scores
|
|
}
|
|
result.pass2_estimated += 1
|
|
else:
|
|
# Pass 2 not requested, use empty metrics
|
|
raw_metrics = {}
|
|
|
|
# ── Pass 3: Scoring & Ranking ────────────────────
|
|
if 3 in passes and existing_pass < 3:
|
|
sr = self.scorer.score_combination(combo, raw_metrics)
|
|
|
|
# Persist per-metric scores with normalized values
|
|
score_dicts = []
|
|
for s in sr.scores:
|
|
mb = bounds_by_name.get(s.metric_name)
|
|
if mb and mb.metric_id:
|
|
score_dicts.append({
|
|
"metric_id": mb.metric_id,
|
|
"raw_value": s.raw_value,
|
|
"normalized_score": s.normalized_score,
|
|
"estimation_method": s.estimation_method,
|
|
"confidence": s.confidence,
|
|
})
|
|
if score_dicts:
|
|
self.repo.save_scores(combo.id, domain.id, score_dicts)
|
|
|
|
# Preserve existing human data
|
|
novelty_flag = (
|
|
existing_result["novelty_flag"] if existing_result else None
|
|
)
|
|
human_notes = (
|
|
existing_result["human_notes"] if existing_result else None
|
|
)
|
|
|
|
if sr.composite_score < score_threshold:
|
|
self.repo.save_result(
|
|
combo.id, domain.id,
|
|
sr.composite_score, pass_reached=3,
|
|
novelty_flag=novelty_flag,
|
|
human_notes=human_notes,
|
|
)
|
|
combo.status = "p3_fail"
|
|
combo.block_reason = (
|
|
f"Composite score {sr.composite_score:.4f} "
|
|
f"below threshold {score_threshold}"
|
|
)
|
|
self.repo.update_combination_status(
|
|
combo.id, "p3_fail", combo.block_reason
|
|
)
|
|
result.pass3_failed += 1
|
|
result.pass3_scored += 1
|
|
self._update_run_counters(run_id, result, current_pass=3)
|
|
continue
|
|
|
|
self.repo.save_result(
|
|
combo.id,
|
|
domain.id,
|
|
sr.composite_score,
|
|
pass_reached=3,
|
|
novelty_flag=novelty_flag,
|
|
human_notes=human_notes,
|
|
)
|
|
self.repo.update_combination_status(combo.id, "scored")
|
|
|
|
result.pass3_scored += 1
|
|
result.pass3_above_threshold += 1
|
|
|
|
self._update_run_counters(run_id, result, current_pass=3)
|
|
elif 3 in passes and existing_pass >= 3:
|
|
# Already scored — count it
|
|
result.pass3_scored += 1
|
|
if existing_result and existing_result["composite_score"] is not None:
|
|
if existing_result["composite_score"] >= score_threshold:
|
|
result.pass3_above_threshold += 1
|
|
|
|
# ── Pass 4: LLM Review ───────────────────────────
|
|
if 4 in passes and self.llm:
|
|
cur_pass = self.repo.get_combo_pass_reached(
|
|
combo.id, domain.id
|
|
) or 0
|
|
if cur_pass < 4:
|
|
cur_result = self.repo.get_existing_result(
|
|
combo.id, domain.id
|
|
)
|
|
if (
|
|
cur_result
|
|
and cur_result["composite_score"] is not None
|
|
and cur_result["composite_score"] >= score_threshold
|
|
):
|
|
description = _describe_combination(combo)
|
|
db_scores = self.repo.get_combination_scores(
|
|
combo.id, domain.id
|
|
)
|
|
score_dict = {
|
|
s["metric_name"]: s["normalized_score"]
|
|
for s in db_scores
|
|
if s["normalized_score"] is not None
|
|
}
|
|
review_result: tuple[str, bool] | None = None
|
|
try:
|
|
review_result = self.llm.review_plausibility(
|
|
description, score_dict
|
|
)
|
|
except LLMRateLimitError as exc:
|
|
self._wait_for_rate_limit(run_id, exc.retry_after)
|
|
try:
|
|
review_result = self.llm.review_plausibility(
|
|
description, score_dict
|
|
)
|
|
except LLMRateLimitError:
|
|
pass # still limited; skip, retry next run
|
|
if review_result is not None:
|
|
review_text, plausible = review_result
|
|
if not plausible:
|
|
self.repo.save_result(
|
|
combo.id, domain.id,
|
|
cur_result["composite_score"],
|
|
pass_reached=4,
|
|
novelty_flag=cur_result.get("novelty_flag"),
|
|
llm_review=review_text,
|
|
human_notes=cur_result.get("human_notes"),
|
|
)
|
|
combo.status = "p4_fail"
|
|
combo.block_reason = "LLM deemed implausible"
|
|
self.repo.update_combination_status(
|
|
combo.id, "p4_fail", combo.block_reason
|
|
)
|
|
result.pass4_failed += 1
|
|
else:
|
|
self.repo.save_result(
|
|
combo.id, domain.id,
|
|
cur_result["composite_score"],
|
|
pass_reached=4,
|
|
novelty_flag=cur_result.get("novelty_flag"),
|
|
llm_review=review_text,
|
|
human_notes=cur_result.get("human_notes"),
|
|
)
|
|
self.repo.update_combination_status(
|
|
combo.id, "llm_reviewed"
|
|
)
|
|
result.pass4_reviewed += 1
|
|
self._update_run_counters(
|
|
run_id, result, current_pass=4
|
|
)
|
|
|
|
except CancelledError:
|
|
if run_id is not None:
|
|
self.repo.update_pipeline_run(
|
|
run_id,
|
|
status="cancelled",
|
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
result.top_results = self.repo.get_top_results(domain.name, limit=20)
|
|
return result
|
|
|
|
# Mark run as completed
|
|
if run_id is not None:
|
|
self.repo.update_pipeline_run(
|
|
run_id,
|
|
status="completed",
|
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
result.top_results = self.repo.get_top_results(domain.name, limit=20)
|
|
return result
|
|
|
|
def _wait_for_rate_limit(self, run_id: int | None, retry_after: int) -> None:
|
|
"""Mark run rate_limited, sleep with cancel checks, then resume."""
|
|
if run_id is not None:
|
|
self.repo.update_pipeline_run(run_id, status="rate_limited")
|
|
waited = 0
|
|
while waited < retry_after:
|
|
time.sleep(5)
|
|
waited += 5
|
|
self._check_cancelled(run_id)
|
|
if run_id is not None:
|
|
self.repo.update_pipeline_run(run_id, status="running")
|
|
|
|
def _stub_estimate(
|
|
self, combo: Combination, metric_names: list[str]
|
|
) -> dict[str, float]:
|
|
"""Simple heuristic estimation from dependency data."""
|
|
raw: dict[str, float] = {m: 0.0 for m in metric_names}
|
|
|
|
# Extract intrinsic properties from entities
|
|
power_density = 0.0 # W/kg
|
|
energy_density = 0.0 # Wh/kg
|
|
mass_kg = 100.0 # default
|
|
for entity in combo.entities:
|
|
for dep in entity.dependencies:
|
|
if dep.key == "power_density_w_kg" and dep.constraint_type == "provides":
|
|
power_density = max(power_density, float(dep.value))
|
|
if dep.key == "energy_density_wh_kg" and dep.constraint_type == "provides":
|
|
energy_density = max(energy_density, float(dep.value))
|
|
if dep.key == "mass_kg" and dep.constraint_type == "range_min":
|
|
mass_kg = max(mass_kg, float(dep.value))
|
|
|
|
# Rough speed estimate: higher power density → faster
|
|
if "speed" in raw:
|
|
raw["speed"] = min(power_density * 0.5, 300000)
|
|
|
|
if "cost_efficiency" in raw:
|
|
raw["cost_efficiency"] = max(0.01, 2.0 - power_density / 1000)
|
|
|
|
if "safety" in raw:
|
|
raw["safety"] = 0.5
|
|
|
|
if "availability" in raw:
|
|
raw["availability"] = 0.5
|
|
|
|
if "range_fuel" in raw:
|
|
raw["range_fuel"] = min(energy_density * 10, 1e10)
|
|
|
|
if "range_degradation" in raw:
|
|
raw["range_degradation"] = 365
|
|
|
|
if "cargo_capacity" in raw:
|
|
raw["cargo_capacity"] = mass_kg * 0.5
|
|
|
|
if "cargo_capacity_kg" in raw:
|
|
raw["cargo_capacity_kg"] = mass_kg * 0.3
|
|
|
|
if "environmental_impact" in raw:
|
|
raw["environmental_impact"] = max(0.0, power_density * 0.2)
|
|
|
|
if "reliability" in raw:
|
|
raw["reliability"] = 0.5
|
|
|
|
return raw
|