"""Multi-pass pipeline orchestrator with incremental saves and resumability.""" from __future__ import annotations import time from dataclasses import dataclass, field from datetime import datetime, timezone from physcom.db.repository import Repository from physcom.engine.combinator import generate_combinations from physcom.engine.constraint_resolver import ConstraintResolver, ConstraintResult from physcom.engine.scorer import Scorer from physcom.llm.base import LLMProvider, LLMRateLimitError from physcom.models.combination import Combination, ScoredResult from physcom.models.domain import Domain @dataclass class PipelineResult: """Summary of a pipeline run.""" total_generated: int = 0 pass1_valid: int = 0 pass1_failed: int = 0 pass1_conditional: int = 0 pass2_estimated: int = 0 pass2_failed: int = 0 pass3_scored: int = 0 pass3_above_threshold: int = 0 pass3_failed: int = 0 pass4_reviewed: int = 0 pass4_failed: int = 0 pass5_human_reviewed: int = 0 top_results: list[dict] = field(default_factory=list) class CancelledError(Exception): """Raised when a pipeline run is cancelled.""" def _describe_combination(combo: Combination) -> str: """Build a natural-language description of a combination.""" parts = [f"{e.dimension}: {e.name}" for e in combo.entities] descriptions = [e.description for e in combo.entities if e.description] header = " + ".join(parts) detail = "; ".join(descriptions) return f"{header}. {detail}" class Pipeline: """Orchestrates the multi-pass viability pipeline.""" def __init__( self, repo: Repository, resolver: ConstraintResolver, scorer: Scorer, llm: LLMProvider | None = None, ) -> None: self.repo = repo self.resolver = resolver self.scorer = scorer self.llm = llm def _check_cancelled(self, run_id: int | None) -> None: """Raise CancelledError if the run has been cancelled.""" if run_id is None: return run = self.repo.get_pipeline_run(run_id) if run and run["status"] == "cancelled": raise CancelledError("Pipeline run cancelled") def _update_run_counters( self, run_id: int | None, result: PipelineResult, current_pass: int ) -> None: """Update pipeline_run progress counters in the DB.""" if run_id is None: return self.repo.update_pipeline_run( run_id, combos_pass1=result.pass1_valid + result.pass1_conditional + result.pass1_failed, combos_pass2=result.pass2_estimated, combos_pass3=result.pass3_scored, combos_pass4=result.pass4_reviewed, current_pass=current_pass, ) def run( self, domain: Domain, dimensions: list[str], score_threshold: float = 0.1, passes: list[int] | None = None, run_id: int | None = None, ) -> PipelineResult: if passes is None: passes = [1, 2, 3, 4, 5] result = PipelineResult() # Mark run as running (unless already cancelled) if run_id is not None: run_record = self.repo.get_pipeline_run(run_id) if run_record and run_record["status"] == "cancelled": result.top_results = self.repo.get_top_results(domain.name, limit=20) return result self.repo.update_pipeline_run( run_id, status="running", started_at=datetime.now(timezone.utc).isoformat(), ) # Generate all combinations combos = generate_combinations(self.repo, dimensions) result.total_generated = len(combos) # Save all combinations to DB (also loads status for existing combos) for combo in combos: self.repo.save_combination(combo) if run_id is not None: self.repo.update_pipeline_run(run_id, total_combos=len(combos)) # Prepare metric lookup metric_names = [mb.metric_name for mb in domain.metric_bounds] bounds_by_name = {mb.metric_name: mb for mb in domain.metric_bounds} # ── Combo-first loop ───────────────────────────────────── try: for combo in combos: self._check_cancelled(run_id) # Check existing progress for this combo in this domain existing_pass = self.repo.get_combo_pass_reached( combo.id, domain.id ) or 0 # Load existing result to preserve human review data existing_result = self.repo.get_existing_result( combo.id, domain.id ) # ── Pass 1: Constraint Resolution ──────────────── if 1 in passes and existing_pass < 1: cr: ConstraintResult = self.resolver.resolve(combo) if cr.status == "p1_fail": combo.status = "p1_fail" combo.block_reason = "; ".join(cr.violations) self.repo.update_combination_status( combo.id, "p1_fail", combo.block_reason ) # Save a result row so failed combos appear in results self.repo.save_result( combo.id, domain.id, composite_score=0.0, pass_reached=1, ) result.pass1_failed += 1 self._update_run_counters(run_id, result, current_pass=1) continue # p1_fail — skip remaining passes else: combo.status = "valid" self.repo.update_combination_status(combo.id, "valid") if cr.status == "conditional": result.pass1_conditional += 1 else: result.pass1_valid += 1 self._update_run_counters(run_id, result, current_pass=1) elif 1 in passes: # Already pass1'd — check if it failed if combo.status.endswith("_fail"): result.pass1_failed += 1 continue else: result.pass1_valid += 1 else: # Pass 1 not requested; check if failed from a prior run if combo.status.endswith("_fail"): result.pass1_failed += 1 continue # ── Pass 2: Physics Estimation ─────────────────── raw_metrics: dict[str, float] = {} if 2 in passes and existing_pass < 2: description = _describe_combination(combo) if self.llm: raw_metrics = self.llm.estimate_physics( description, metric_names ) else: raw_metrics = self._stub_estimate(combo, metric_names) # Save raw estimates immediately (crash-safe) estimate_dicts = [] for mname, rval in raw_metrics.items(): mb = bounds_by_name.get(mname) if mb and mb.metric_id: estimate_dicts.append({ "metric_id": mb.metric_id, "raw_value": rval, "estimation_method": "llm" if self.llm else "stub", "confidence": 1.0, }) if estimate_dicts: self.repo.save_raw_estimates( combo.id, domain.id, estimate_dicts ) # Check for all-zero estimates → p2_fail if raw_metrics and all(v == 0.0 for v in raw_metrics.values()): combo.status = "p2_fail" combo.block_reason = "All metric estimates are zero" self.repo.update_combination_status( combo.id, "p2_fail", combo.block_reason ) self.repo.save_result( combo.id, domain.id, composite_score=0.0, pass_reached=2, ) result.pass2_failed += 1 self._update_run_counters(run_id, result, current_pass=2) continue result.pass2_estimated += 1 self._update_run_counters(run_id, result, current_pass=2) elif 2 in passes: # Already estimated — reload raw values from DB existing_scores = self.repo.get_combination_scores( combo.id, domain.id ) raw_metrics = { s["metric_name"]: s["raw_value"] for s in existing_scores } result.pass2_estimated += 1 else: # Pass 2 not requested, use empty metrics raw_metrics = {} # ── Pass 3: Scoring & Ranking ──────────────────── if 3 in passes and existing_pass < 3: sr = self.scorer.score_combination(combo, raw_metrics) # Persist per-metric scores with normalized values score_dicts = [] for s in sr.scores: mb = bounds_by_name.get(s.metric_name) if mb and mb.metric_id: score_dicts.append({ "metric_id": mb.metric_id, "raw_value": s.raw_value, "normalized_score": s.normalized_score, "estimation_method": s.estimation_method, "confidence": s.confidence, }) if score_dicts: self.repo.save_scores(combo.id, domain.id, score_dicts) # Preserve existing human data novelty_flag = ( existing_result["novelty_flag"] if existing_result else None ) human_notes = ( existing_result["human_notes"] if existing_result else None ) if sr.composite_score < score_threshold: self.repo.save_result( combo.id, domain.id, sr.composite_score, pass_reached=3, novelty_flag=novelty_flag, human_notes=human_notes, ) combo.status = "p3_fail" combo.block_reason = ( f"Composite score {sr.composite_score:.4f} " f"below threshold {score_threshold}" ) self.repo.update_combination_status( combo.id, "p3_fail", combo.block_reason ) result.pass3_failed += 1 result.pass3_scored += 1 self._update_run_counters(run_id, result, current_pass=3) continue self.repo.save_result( combo.id, domain.id, sr.composite_score, pass_reached=3, novelty_flag=novelty_flag, human_notes=human_notes, ) self.repo.update_combination_status(combo.id, "scored") result.pass3_scored += 1 result.pass3_above_threshold += 1 self._update_run_counters(run_id, result, current_pass=3) elif 3 in passes and existing_pass >= 3: # Already scored — count it result.pass3_scored += 1 if existing_result and existing_result["composite_score"] is not None: if existing_result["composite_score"] >= score_threshold: result.pass3_above_threshold += 1 # ── Pass 4: LLM Review ─────────────────────────── if 4 in passes and self.llm: cur_pass = self.repo.get_combo_pass_reached( combo.id, domain.id ) or 0 if cur_pass < 4: cur_result = self.repo.get_existing_result( combo.id, domain.id ) if ( cur_result and cur_result["composite_score"] is not None and cur_result["composite_score"] >= score_threshold ): description = _describe_combination(combo) db_scores = self.repo.get_combination_scores( combo.id, domain.id ) score_dict = { s["metric_name"]: s["normalized_score"] for s in db_scores if s["normalized_score"] is not None } review_result: tuple[str, bool] | None = None try: review_result = self.llm.review_plausibility( description, score_dict ) except LLMRateLimitError as exc: self._wait_for_rate_limit(run_id, exc.retry_after) try: review_result = self.llm.review_plausibility( description, score_dict ) except LLMRateLimitError: pass # still limited; skip, retry next run if review_result is not None: review_text, plausible = review_result if not plausible: self.repo.save_result( combo.id, domain.id, cur_result["composite_score"], pass_reached=4, novelty_flag=cur_result.get("novelty_flag"), llm_review=review_text, human_notes=cur_result.get("human_notes"), ) combo.status = "p4_fail" combo.block_reason = "LLM deemed implausible" self.repo.update_combination_status( combo.id, "p4_fail", combo.block_reason ) result.pass4_failed += 1 else: self.repo.save_result( combo.id, domain.id, cur_result["composite_score"], pass_reached=4, novelty_flag=cur_result.get("novelty_flag"), llm_review=review_text, human_notes=cur_result.get("human_notes"), ) self.repo.update_combination_status( combo.id, "llm_reviewed" ) result.pass4_reviewed += 1 self._update_run_counters( run_id, result, current_pass=4 ) except CancelledError: if run_id is not None: self.repo.update_pipeline_run( run_id, status="cancelled", completed_at=datetime.now(timezone.utc).isoformat(), ) result.top_results = self.repo.get_top_results(domain.name, limit=20) return result # Mark run as completed if run_id is not None: self.repo.update_pipeline_run( run_id, status="completed", completed_at=datetime.now(timezone.utc).isoformat(), ) result.top_results = self.repo.get_top_results(domain.name, limit=20) return result def _wait_for_rate_limit(self, run_id: int | None, retry_after: int) -> None: """Mark run rate_limited, sleep with cancel checks, then resume.""" if run_id is not None: self.repo.update_pipeline_run(run_id, status="rate_limited") waited = 0 while waited < retry_after: time.sleep(5) waited += 5 self._check_cancelled(run_id) if run_id is not None: self.repo.update_pipeline_run(run_id, status="running") def _stub_estimate( self, combo: Combination, metric_names: list[str] ) -> dict[str, float]: """Simple heuristic estimation from dependency data.""" raw: dict[str, float] = {m: 0.0 for m in metric_names} # Extract intrinsic properties from entities power_density = 0.0 # W/kg energy_density = 0.0 # Wh/kg mass_kg = 100.0 # default for entity in combo.entities: for dep in entity.dependencies: if dep.key == "power_density_w_kg" and dep.constraint_type == "provides": power_density = max(power_density, float(dep.value)) if dep.key == "energy_density_wh_kg" and dep.constraint_type == "provides": energy_density = max(energy_density, float(dep.value)) if dep.key == "mass_kg" and dep.constraint_type == "range_min": mass_kg = max(mass_kg, float(dep.value)) # Rough speed estimate: higher power density → faster if "speed" in raw: raw["speed"] = min(power_density * 0.5, 300000) if "cost_efficiency" in raw: raw["cost_efficiency"] = max(0.01, 2.0 - power_density / 1000) if "safety" in raw: raw["safety"] = 0.5 if "availability" in raw: raw["availability"] = 0.5 if "range_fuel" in raw: raw["range_fuel"] = min(energy_density * 10, 1e10) if "range_degradation" in raw: raw["range_degradation"] = 365 if "cargo_capacity" in raw: raw["cargo_capacity"] = mass_kg * 0.5 if "cargo_capacity_kg" in raw: raw["cargo_capacity_kg"] = mass_kg * 0.3 if "environmental_impact" in raw: raw["environmental_impact"] = max(0.0, power_density * 0.2) if "reliability" in raw: raw["reliability"] = 0.5 return raw