"""Multi-pass pipeline orchestrator with incremental saves and resumability.""" from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone from physcom.db.repository import Repository from physcom.engine.combinator import generate_combinations from physcom.engine.constraint_resolver import ConstraintResolver, ConstraintResult from physcom.engine.scorer import Scorer from physcom.llm.base import LLMProvider, LLMRateLimitError from physcom.models.combination import Combination, ScoredResult from physcom.models.domain import Domain @dataclass class PipelineResult: """Summary of a pipeline run.""" total_generated: int = 0 pass1_valid: int = 0 pass1_blocked: int = 0 pass1_conditional: int = 0 pass2_estimated: int = 0 pass3_scored: int = 0 pass3_above_threshold: int = 0 pass4_reviewed: int = 0 pass5_human_reviewed: int = 0 top_results: list[dict] = field(default_factory=list) class CancelledError(Exception): """Raised when a pipeline run is cancelled.""" def _describe_combination(combo: Combination) -> str: """Build a natural-language description of a combination.""" parts = [f"{e.dimension}: {e.name}" for e in combo.entities] descriptions = [e.description for e in combo.entities if e.description] header = " + ".join(parts) detail = "; ".join(descriptions) return f"{header}. {detail}" class Pipeline: """Orchestrates the multi-pass viability pipeline.""" def __init__( self, repo: Repository, resolver: ConstraintResolver, scorer: Scorer, llm: LLMProvider | None = None, ) -> None: self.repo = repo self.resolver = resolver self.scorer = scorer self.llm = llm def _check_cancelled(self, run_id: int | None) -> None: """Raise CancelledError if the run has been cancelled.""" if run_id is None: return run = self.repo.get_pipeline_run(run_id) if run and run["status"] == "cancelled": raise CancelledError("Pipeline run cancelled") def _update_run_counters( self, run_id: int | None, result: PipelineResult, current_pass: int ) -> None: """Update pipeline_run progress counters in the DB.""" if run_id is None: return self.repo.update_pipeline_run( run_id, combos_pass1=result.pass1_valid + result.pass1_conditional + result.pass1_blocked, combos_pass2=result.pass2_estimated, combos_pass3=result.pass3_scored, combos_pass4=result.pass4_reviewed, current_pass=current_pass, ) def run( self, domain: Domain, dimensions: list[str], score_threshold: float = 0.1, passes: list[int] | None = None, run_id: int | None = None, ) -> PipelineResult: if passes is None: passes = [1, 2, 3, 4, 5] result = PipelineResult() # Mark run as running (unless already cancelled) if run_id is not None: run_record = self.repo.get_pipeline_run(run_id) if run_record and run_record["status"] == "cancelled": result.top_results = self.repo.get_top_results(domain.name, limit=20) return result self.repo.update_pipeline_run( run_id, status="running", started_at=datetime.now(timezone.utc).isoformat(), ) # Generate all combinations combos = generate_combinations(self.repo, dimensions) result.total_generated = len(combos) # Save all combinations to DB (also loads status for existing combos) for combo in combos: self.repo.save_combination(combo) if run_id is not None: self.repo.update_pipeline_run(run_id, total_combos=len(combos)) # Prepare metric lookup metric_names = [mb.metric_name for mb in domain.metric_bounds] bounds_by_name = {mb.metric_name: mb for mb in domain.metric_bounds} # ── Combo-first loop ───────────────────────────────────── try: for combo in combos: self._check_cancelled(run_id) # Check existing progress for this combo in this domain existing_pass = self.repo.get_combo_pass_reached( combo.id, domain.id ) or 0 # Load existing result to preserve human review data existing_result = self.repo.get_existing_result( combo.id, domain.id ) # ── Pass 1: Constraint Resolution ──────────────── if 1 in passes and existing_pass < 1: cr: ConstraintResult = self.resolver.resolve(combo) if cr.status == "blocked": combo.status = "blocked" combo.block_reason = "; ".join(cr.violations) self.repo.update_combination_status( combo.id, "blocked", combo.block_reason ) # Save a result row so blocked combos appear in results self.repo.save_result( combo.id, domain.id, composite_score=0.0, pass_reached=1, ) result.pass1_blocked += 1 self._update_run_counters(run_id, result, current_pass=1) continue # blocked — skip remaining passes else: combo.status = "valid" self.repo.update_combination_status(combo.id, "valid") if cr.status == "conditional": result.pass1_conditional += 1 else: result.pass1_valid += 1 self._update_run_counters(run_id, result, current_pass=1) elif 1 in passes: # Already pass1'd — check if it was blocked if combo.status == "blocked": result.pass1_blocked += 1 continue else: result.pass1_valid += 1 else: # Pass 1 not requested; check if blocked from a prior run if combo.status == "blocked": result.pass1_blocked += 1 continue # ── Pass 2: Physics Estimation ─────────────────── raw_metrics: dict[str, float] = {} if 2 in passes and existing_pass < 2: description = _describe_combination(combo) if self.llm: raw_metrics = self.llm.estimate_physics( description, metric_names ) else: raw_metrics = self._stub_estimate(combo, metric_names) # Save raw estimates immediately (crash-safe) estimate_dicts = [] for mname, rval in raw_metrics.items(): mb = bounds_by_name.get(mname) if mb and mb.metric_id: estimate_dicts.append({ "metric_id": mb.metric_id, "raw_value": rval, "estimation_method": "llm" if self.llm else "stub", "confidence": 1.0, }) if estimate_dicts: self.repo.save_raw_estimates( combo.id, domain.id, estimate_dicts ) result.pass2_estimated += 1 self._update_run_counters(run_id, result, current_pass=2) elif 2 in passes: # Already estimated — reload raw values from DB existing_scores = self.repo.get_combination_scores( combo.id, domain.id ) raw_metrics = { s["metric_name"]: s["raw_value"] for s in existing_scores } result.pass2_estimated += 1 else: # Pass 2 not requested, use empty metrics raw_metrics = {} # ── Pass 3: Scoring & Ranking ──────────────────── if 3 in passes and existing_pass < 3: sr = self.scorer.score_combination(combo, raw_metrics) # Persist per-metric scores with normalized values score_dicts = [] for s in sr.scores: mb = bounds_by_name.get(s.metric_name) if mb and mb.metric_id: score_dicts.append({ "metric_id": mb.metric_id, "raw_value": s.raw_value, "normalized_score": s.normalized_score, "estimation_method": s.estimation_method, "confidence": s.confidence, }) if score_dicts: self.repo.save_scores(combo.id, domain.id, score_dicts) # Preserve existing human data novelty_flag = ( existing_result["novelty_flag"] if existing_result else None ) human_notes = ( existing_result["human_notes"] if existing_result else None ) self.repo.save_result( combo.id, domain.id, sr.composite_score, pass_reached=3, novelty_flag=novelty_flag, human_notes=human_notes, ) self.repo.update_combination_status(combo.id, "scored") result.pass3_scored += 1 if sr.composite_score >= score_threshold: result.pass3_above_threshold += 1 self._update_run_counters(run_id, result, current_pass=3) elif 3 in passes and existing_pass >= 3: # Already scored — count it result.pass3_scored += 1 if existing_result and existing_result["composite_score"] is not None: if existing_result["composite_score"] >= score_threshold: result.pass3_above_threshold += 1 # ── Pass 4: LLM Review ─────────────────────────── if 4 in passes and self.llm: cur_pass = self.repo.get_combo_pass_reached( combo.id, domain.id ) or 0 if cur_pass < 4: cur_result = self.repo.get_existing_result( combo.id, domain.id ) if ( cur_result and cur_result["composite_score"] is not None and cur_result["composite_score"] >= score_threshold ): description = _describe_combination(combo) db_scores = self.repo.get_combination_scores( combo.id, domain.id ) score_dict = { s["metric_name"]: s["normalized_score"] for s in db_scores if s["normalized_score"] is not None } review = self.llm.review_plausibility( description, score_dict ) self.repo.save_result( combo.id, domain.id, cur_result["composite_score"], pass_reached=4, novelty_flag=cur_result.get("novelty_flag"), llm_review=review, human_notes=cur_result.get("human_notes"), ) result.pass4_reviewed += 1 self._update_run_counters( run_id, result, current_pass=4 ) except CancelledError: if run_id is not None: self.repo.update_pipeline_run( run_id, status="cancelled", completed_at=datetime.now(timezone.utc).isoformat(), ) result.top_results = self.repo.get_top_results(domain.name, limit=20) return result except LLMRateLimitError: # Rate limit hit — save progress and let the user re-run to continue. # Already-reviewed combos are persisted; resumability skips them next time. if run_id is not None: self.repo.update_pipeline_run( run_id, status="completed", completed_at=datetime.now(timezone.utc).isoformat(), ) result.top_results = self.repo.get_top_results(domain.name, limit=20) return result # Mark run as completed if run_id is not None: self.repo.update_pipeline_run( run_id, status="completed", completed_at=datetime.now(timezone.utc).isoformat(), ) result.top_results = self.repo.get_top_results(domain.name, limit=20) return result def _stub_estimate( self, combo: Combination, metric_names: list[str] ) -> dict[str, float]: """Simple heuristic estimation from dependency data.""" raw: dict[str, float] = {m: 0.0 for m in metric_names} # Extract force output from power source force_watts = 0.0 mass_kg = 100.0 # default for entity in combo.entities: for dep in entity.dependencies: if dep.key == "force_output_watts" and dep.constraint_type == "provides": force_watts = max(force_watts, float(dep.value)) if dep.key == "min_mass_kg" and dep.constraint_type == "range_min": mass_kg = max(mass_kg, float(dep.value)) # Rough speed estimate: F=ma -> v proportional to power/mass if "speed" in raw and mass_kg > 0: raw["speed"] = min(force_watts / mass_kg * 0.5, 300000) if "cost_efficiency" in raw: raw["cost_efficiency"] = max(0.01, 2.0 - force_watts / 100000) if "safety" in raw: raw["safety"] = 0.5 if "availability" in raw: raw["availability"] = 0.5 if "range_fuel" in raw: raw["range_fuel"] = min(force_watts * 0.01, 1e10) if "range_degradation" in raw: raw["range_degradation"] = 365 return raw