- Load m.unit in get_domain() so MetricBound carries units from DB - Add Unit column to domains list template - Make load_transport_seed() idempotent with IntegrityError handling and metric unit backfill for existing DBs - Remove unused imports (json, sqlite3, Entity) - Simplify combinator loop to list comprehension - Merge duplicate conditional/valid branches in pipeline - Consolidate duplicated SQL in get_all_results() - Expand CLAUDE.md with fuller architecture docs and conventions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
303 lines
10 KiB
Python
303 lines
10 KiB
Python
"""Tests for async pipeline: resume, cancellation, status guard, run lifecycle."""
|
|
|
|
from physcom.engine.constraint_resolver import ConstraintResolver
|
|
from physcom.engine.scorer import Scorer
|
|
from physcom.engine.pipeline import Pipeline, CancelledError
|
|
|
|
|
|
def test_pipeline_run_lifecycle(seeded_repo):
|
|
"""Pipeline run should transition: pending -> running -> completed."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
config = {"passes": [1, 2, 3], "threshold": 0.1, "dimensions": ["platform", "power_source"]}
|
|
run_id = repo.create_pipeline_run(domain.id, config)
|
|
|
|
run = repo.get_pipeline_run(run_id)
|
|
assert run["status"] == "pending"
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
pipeline.run(domain, ["platform", "power_source"], passes=[1, 2, 3], run_id=run_id)
|
|
|
|
run = repo.get_pipeline_run(run_id)
|
|
assert run["status"] == "completed"
|
|
assert run["total_combos"] == 81
|
|
assert run["started_at"] is not None
|
|
assert run["completed_at"] is not None
|
|
|
|
|
|
def test_pipeline_run_failed(seeded_repo):
|
|
"""Pipeline run should be marked failed on error."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
config = {"passes": [1], "threshold": 0.1, "dimensions": ["platform", "power_source"]}
|
|
run_id = repo.create_pipeline_run(domain.id, config)
|
|
|
|
# Manually mark as failed (simulating what the web route does on exception)
|
|
repo.update_pipeline_run(run_id, status="failed", error_message="Test error")
|
|
|
|
run = repo.get_pipeline_run(run_id)
|
|
assert run["status"] == "failed"
|
|
assert run["error_message"] == "Test error"
|
|
|
|
|
|
def test_resume_skips_completed_combos(seeded_repo):
|
|
"""Re-running the same passes on the same domain should skip already-completed combos."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
# First run: passes 1-3
|
|
run_id_1 = repo.create_pipeline_run(domain.id, {"passes": [1, 2, 3]})
|
|
result1 = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3], run_id=run_id_1,
|
|
)
|
|
assert result1.pass2_estimated > 0
|
|
first_estimated = result1.pass2_estimated
|
|
|
|
# Second run: same passes — should skip all combos (already pass_reached >= 3)
|
|
run_id_2 = repo.create_pipeline_run(domain.id, {"passes": [1, 2, 3]})
|
|
result2 = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3], run_id=run_id_2,
|
|
)
|
|
# pass2_estimated still counted (reloaded from DB) but no new estimation work
|
|
# The key thing: the run completes successfully
|
|
assert result2.total_generated == result1.total_generated
|
|
run2 = repo.get_pipeline_run(run_id_2)
|
|
assert run2["status"] == "completed"
|
|
|
|
|
|
def test_cancellation_stops_processing(seeded_repo):
|
|
"""Cancelling a run mid-flight should stop the pipeline gracefully."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
run_id = repo.create_pipeline_run(domain.id, {"passes": [1, 2, 3]})
|
|
|
|
# Pre-cancel the run before it starts processing
|
|
repo.update_pipeline_run(run_id, status="running")
|
|
repo.update_pipeline_run(run_id, status="cancelled")
|
|
|
|
result = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3], run_id=run_id,
|
|
)
|
|
|
|
# Should have stopped without processing all combos
|
|
run = repo.get_pipeline_run(run_id)
|
|
assert run["status"] == "cancelled"
|
|
# The pipeline was cancelled before any combo processing could happen
|
|
assert result.pass2_estimated == 0
|
|
|
|
|
|
def test_status_guard_no_downgrade_reviewed(seeded_repo):
|
|
"""update_combination_status should not downgrade 'reviewed' to 'scored'."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
# Run pipeline to get scored combos
|
|
result = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3],
|
|
)
|
|
|
|
# Find a scored combo and manually mark it as reviewed
|
|
scored_combos = repo.list_combinations(status="scored")
|
|
assert len(scored_combos) > 0
|
|
|
|
combo = scored_combos[0]
|
|
repo.conn.execute(
|
|
"UPDATE combinations SET status = 'reviewed' WHERE id = ?", (combo.id,)
|
|
)
|
|
repo.conn.commit()
|
|
|
|
# Attempt to downgrade to 'scored'
|
|
repo.update_combination_status(combo.id, "scored")
|
|
|
|
# Should still be 'reviewed'
|
|
reloaded = repo.get_combination(combo.id)
|
|
assert reloaded.status == "reviewed"
|
|
|
|
|
|
def test_human_notes_preserved_on_rerun(seeded_repo):
|
|
"""Human notes should not be overwritten when re-running the pipeline."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
# First run
|
|
pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3],
|
|
)
|
|
|
|
# Add human notes to a result
|
|
results = repo.get_all_results(domain.name)
|
|
assert len(results) > 0
|
|
target = results[0]
|
|
combo_id = target["combination"].id
|
|
domain_id = target["domain_id"]
|
|
|
|
repo.save_result(
|
|
combo_id, domain_id,
|
|
target["composite_score"],
|
|
pass_reached=target["pass_reached"],
|
|
novelty_flag=target["novelty_flag"],
|
|
human_notes="Important human insight",
|
|
)
|
|
|
|
# Clear pass_reached so re-run processes this combo again
|
|
repo.conn.execute(
|
|
"""UPDATE combination_results SET pass_reached = 0
|
|
WHERE combination_id = ? AND domain_id = ?""",
|
|
(combo_id, domain_id),
|
|
)
|
|
repo.conn.commit()
|
|
|
|
# Re-run pipeline
|
|
pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3],
|
|
)
|
|
|
|
# Check that human_notes survived
|
|
result = repo.get_existing_result(combo_id, domain_id)
|
|
assert result["human_notes"] == "Important human insight"
|
|
|
|
|
|
def test_list_pipeline_runs(seeded_repo):
|
|
"""list_pipeline_runs should return runs for a domain or all domains."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
run_id_1 = repo.create_pipeline_run(domain.id, {"passes": [1]})
|
|
run_id_2 = repo.create_pipeline_run(domain.id, {"passes": [1, 2, 3]})
|
|
|
|
all_runs = repo.list_pipeline_runs()
|
|
assert len(all_runs) >= 2
|
|
|
|
domain_runs = repo.list_pipeline_runs(domain_id=domain.id)
|
|
assert len(domain_runs) >= 2
|
|
assert all(r["domain_id"] == domain.id for r in domain_runs)
|
|
|
|
|
|
def test_get_combo_pass_reached(seeded_repo):
|
|
"""get_combo_pass_reached returns the correct pass level."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3],
|
|
)
|
|
|
|
# Get a scored combo
|
|
scored_combos = repo.list_combinations(status="scored")
|
|
assert len(scored_combos) > 0
|
|
combo = scored_combos[0]
|
|
|
|
pass_reached = repo.get_combo_pass_reached(combo.id, domain.id)
|
|
assert pass_reached == 3
|
|
|
|
# Non-existent combo
|
|
assert repo.get_combo_pass_reached(99999, domain.id) is None
|
|
|
|
|
|
def test_blocked_combos_have_results(seeded_repo):
|
|
"""Blocked combinations should still appear in combination_results."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
result = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3],
|
|
)
|
|
|
|
assert result.pass1_blocked > 0
|
|
|
|
# All combos (blocked + scored) should have result rows
|
|
all_results = repo.get_all_results(domain.name)
|
|
total_with_results = len(all_results)
|
|
# blocked combos get pass_reached=1 results, non-blocked get pass_reached=3
|
|
assert total_with_results == result.pass1_blocked + result.pass3_scored
|
|
|
|
# Blocked combos should have pass_reached=1 and composite_score=0.0
|
|
blocked_results = [r for r in all_results if r["combination"].status == "blocked"]
|
|
assert len(blocked_results) == result.pass1_blocked
|
|
for br in blocked_results:
|
|
assert br["pass_reached"] == 1
|
|
assert br["composite_score"] == 0.0
|
|
|
|
|
|
def test_all_passes_run_and_tracked(seeded_repo):
|
|
"""With passes [1,2,3], all three should show nonzero counts in run record."""
|
|
repo = seeded_repo
|
|
domain = repo.get_domain("urban_commuting")
|
|
|
|
resolver = ConstraintResolver()
|
|
scorer = Scorer(domain)
|
|
pipeline = Pipeline(repo, resolver, scorer)
|
|
|
|
run_id = repo.create_pipeline_run(domain.id, {"passes": [1, 2, 3]})
|
|
result = pipeline.run(
|
|
domain, ["platform", "power_source"],
|
|
score_threshold=0.01, passes=[1, 2, 3], run_id=run_id,
|
|
)
|
|
|
|
run = repo.get_pipeline_run(run_id)
|
|
assert run["combos_pass1"] > 0, "Pass 1 counter should be nonzero"
|
|
assert run["combos_pass2"] > 0, "Pass 2 counter should be nonzero"
|
|
assert run["combos_pass3"] > 0, "Pass 3 counter should be nonzero"
|
|
|
|
# Pass 2 should equal valid + conditional (blocked don't get estimated)
|
|
assert run["combos_pass2"] == result.pass2_estimated
|
|
# Pass 3 should equal pass3_scored (all scored combos, not just above threshold)
|
|
assert run["combos_pass3"] == result.pass3_scored
|
|
|
|
|
|
def test_save_combination_loads_existing_status(seeded_repo):
|
|
"""save_combination should load the status of an existing combo from DB."""
|
|
repo = seeded_repo
|
|
from physcom.models.combination import Combination
|
|
|
|
entities = repo.list_entities(dimension="platform")[:1] + repo.list_entities(dimension="power_source")[:1]
|
|
combo = Combination(entities=entities)
|
|
saved = repo.save_combination(combo)
|
|
assert saved.status == "pending"
|
|
|
|
# Mark it blocked in DB
|
|
repo.update_combination_status(saved.id, "blocked", "test reason")
|
|
|
|
# Re-saving should pick up the blocked status
|
|
combo2 = Combination(entities=entities)
|
|
reloaded = repo.save_combination(combo2)
|
|
assert reloaded.id == saved.id
|
|
assert reloaded.status == "blocked"
|
|
assert reloaded.block_reason == "test reason"
|