domain-level constraints
This commit is contained in:
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from typing import Sequence
|
||||
|
||||
from physcom.models.entity import Dependency, Entity
|
||||
from physcom.models.domain import Domain, MetricBound
|
||||
from physcom.models.domain import Domain, DomainConstraint, MetricBound
|
||||
from physcom.models.combination import Combination
|
||||
|
||||
|
||||
@@ -249,9 +249,25 @@ class Repository:
|
||||
(domain.id, metric_id, mb.weight, mb.norm_min, mb.norm_max,
|
||||
int(mb.lower_is_better)),
|
||||
)
|
||||
for dc in domain.constraints:
|
||||
for val in dc.allowed_values:
|
||||
self.conn.execute(
|
||||
"INSERT OR IGNORE INTO domain_constraints (domain_id, key, value) VALUES (?, ?, ?)",
|
||||
(domain.id, dc.key, val),
|
||||
)
|
||||
self.conn.commit()
|
||||
return domain
|
||||
|
||||
def _load_domain_constraints(self, domain_id: int) -> list[DomainConstraint]:
|
||||
rows = self.conn.execute(
|
||||
"SELECT key, value FROM domain_constraints WHERE domain_id = ? ORDER BY key, value",
|
||||
(domain_id,),
|
||||
).fetchall()
|
||||
by_key: dict[str, list[str]] = {}
|
||||
for r in rows:
|
||||
by_key.setdefault(r["key"], []).append(r["value"])
|
||||
return [DomainConstraint(key=k, allowed_values=v) for k, v in by_key.items()]
|
||||
|
||||
def get_domain(self, name: str) -> Domain | None:
|
||||
row = self.conn.execute("SELECT * FROM domains WHERE name = ?", (name,)).fetchone()
|
||||
if not row:
|
||||
@@ -277,6 +293,7 @@ class Repository:
|
||||
)
|
||||
for w in weights
|
||||
],
|
||||
constraints=self._load_domain_constraints(row["id"]),
|
||||
)
|
||||
|
||||
def list_domains(self) -> list[Domain]:
|
||||
@@ -308,6 +325,7 @@ class Repository:
|
||||
)
|
||||
for w in weights
|
||||
],
|
||||
constraints=self._load_domain_constraints(row["id"]),
|
||||
)
|
||||
|
||||
def update_domain(self, domain_id: int, name: str, description: str) -> None:
|
||||
@@ -359,9 +377,28 @@ class Repository:
|
||||
self.conn.execute("DELETE FROM combination_results WHERE domain_id = ?", (domain_id,))
|
||||
self.conn.execute("DELETE FROM combination_scores WHERE domain_id = ?", (domain_id,))
|
||||
self.conn.execute("DELETE FROM domain_metric_weights WHERE domain_id = ?", (domain_id,))
|
||||
self.conn.execute("DELETE FROM domain_constraints WHERE domain_id = ?", (domain_id,))
|
||||
self.conn.execute("DELETE FROM domains WHERE id = ?", (domain_id,))
|
||||
self.conn.commit()
|
||||
|
||||
def replace_domain_constraints(self, domain: Domain) -> None:
|
||||
"""Delete and re-insert domain constraints. Used by seed backfill."""
|
||||
if not domain.id:
|
||||
existing = self.conn.execute(
|
||||
"SELECT id FROM domains WHERE name = ?", (domain.name,)
|
||||
).fetchone()
|
||||
if not existing:
|
||||
return
|
||||
domain.id = existing["id"]
|
||||
self.conn.execute("DELETE FROM domain_constraints WHERE domain_id = ?", (domain.id,))
|
||||
for dc in domain.constraints:
|
||||
for val in dc.allowed_values:
|
||||
self.conn.execute(
|
||||
"INSERT OR IGNORE INTO domain_constraints (domain_id, key, value) VALUES (?, ?, ?)",
|
||||
(domain.id, dc.key, val),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def reset_domain_results(self, domain_name: str) -> int:
|
||||
"""Delete all pipeline results for a domain so it can be re-run from scratch.
|
||||
|
||||
@@ -560,14 +597,15 @@ class Repository:
|
||||
novelty_flag: str | None = None,
|
||||
llm_review: str | None = None,
|
||||
human_notes: str | None = None,
|
||||
domain_block_reason: str | None = None,
|
||||
) -> None:
|
||||
self.conn.execute(
|
||||
"""INSERT OR REPLACE INTO combination_results
|
||||
(combination_id, domain_id, composite_score, novelty_flag,
|
||||
llm_review, human_notes, pass_reached)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
llm_review, human_notes, pass_reached, domain_block_reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(combo_id, domain_id, composite_score, novelty_flag,
|
||||
llm_review, human_notes, pass_reached),
|
||||
llm_review, human_notes, pass_reached, domain_block_reason),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
@@ -667,6 +705,7 @@ class Repository:
|
||||
"human_notes": r["human_notes"],
|
||||
"pass_reached": r["pass_reached"],
|
||||
"domain_id": r["domain_id"],
|
||||
"domain_block_reason": r["domain_block_reason"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
@@ -814,6 +853,7 @@ class Repository:
|
||||
self.conn.execute("DELETE FROM dependencies")
|
||||
self.conn.execute("DELETE FROM entities")
|
||||
self.conn.execute("DELETE FROM domain_metric_weights")
|
||||
self.conn.execute("DELETE FROM domain_constraints")
|
||||
self.conn.execute("DELETE FROM domains")
|
||||
self.conn.execute("DELETE FROM metrics")
|
||||
self.conn.execute("DELETE FROM dimensions")
|
||||
|
||||
@@ -81,14 +81,15 @@ CREATE TABLE IF NOT EXISTS combination_scores (
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS combination_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
combination_id INTEGER NOT NULL REFERENCES combinations(id),
|
||||
domain_id INTEGER NOT NULL REFERENCES domains(id),
|
||||
composite_score REAL,
|
||||
novelty_flag TEXT,
|
||||
llm_review TEXT,
|
||||
human_notes TEXT,
|
||||
pass_reached INTEGER,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
combination_id INTEGER NOT NULL REFERENCES combinations(id),
|
||||
domain_id INTEGER NOT NULL REFERENCES domains(id),
|
||||
composite_score REAL,
|
||||
novelty_flag TEXT,
|
||||
llm_review TEXT,
|
||||
human_notes TEXT,
|
||||
pass_reached INTEGER,
|
||||
domain_block_reason TEXT,
|
||||
UNIQUE(combination_id, domain_id)
|
||||
);
|
||||
|
||||
@@ -109,6 +110,14 @@ CREATE TABLE IF NOT EXISTS pipeline_runs (
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS domain_constraints (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain_id INTEGER NOT NULL REFERENCES domains(id),
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(domain_id, key, value)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_deps_entity ON dependencies(entity_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_deps_category_key ON dependencies(category, key);
|
||||
CREATE INDEX IF NOT EXISTS idx_combo_status ON combinations(status);
|
||||
@@ -126,6 +135,26 @@ def _migrate(conn: sqlite3.Connection) -> None:
|
||||
"ALTER TABLE domain_metric_weights ADD COLUMN lower_is_better INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
|
||||
# Create domain_constraints table if missing (added after initial schema)
|
||||
tables = {r[0] for r in conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()}
|
||||
if "domain_constraints" not in tables:
|
||||
conn.execute("""CREATE TABLE IF NOT EXISTS domain_constraints (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain_id INTEGER NOT NULL REFERENCES domains(id),
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(domain_id, key, value)
|
||||
)""")
|
||||
|
||||
# Add domain_block_reason to combination_results if missing
|
||||
result_cols = {r[1] for r in conn.execute("PRAGMA table_info(combination_results)").fetchall()}
|
||||
if "domain_block_reason" not in result_cols:
|
||||
conn.execute(
|
||||
"ALTER TABLE combination_results ADD COLUMN domain_block_reason TEXT"
|
||||
)
|
||||
|
||||
# Backfill: cost_efficiency is lower-is-better in all domains
|
||||
conn.execute(
|
||||
"""UPDATE domain_metric_weights SET lower_is_better = 1
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from physcom.models.combination import Combination
|
||||
from physcom.models.domain import DomainConstraint
|
||||
from physcom.models.entity import Dependency
|
||||
|
||||
|
||||
@@ -158,6 +159,25 @@ class ConstraintResolver:
|
||||
f"(under-density)"
|
||||
)
|
||||
|
||||
def check_domain_constraints(
|
||||
self, combination: Combination, constraints: list[DomainConstraint]
|
||||
) -> ConstraintResult:
|
||||
"""Check if a combo's entity requirements fall within domain-allowed values."""
|
||||
result = ConstraintResult()
|
||||
for dc in constraints:
|
||||
allowed = set(dc.allowed_values)
|
||||
for entity in combination.entities:
|
||||
for dep in entity.dependencies:
|
||||
if dep.key == dc.key and dep.constraint_type == "requires":
|
||||
if dep.value not in allowed:
|
||||
result.violations.append(
|
||||
f"{entity.name} requires {dc.key}={dep.value} "
|
||||
f"but domain only allows {dc.allowed_values}"
|
||||
)
|
||||
if result.violations:
|
||||
result.status = "p1_fail"
|
||||
return result
|
||||
|
||||
def _check_unmet_requirements(
|
||||
self, all_deps: list[tuple[str, Dependency]], result: ConstraintResult
|
||||
) -> None:
|
||||
|
||||
@@ -164,6 +164,26 @@ class Pipeline:
|
||||
else:
|
||||
combo.status = "valid"
|
||||
self.repo.update_combination_status(combo.id, "valid")
|
||||
|
||||
# Domain constraint check (per-domain block only)
|
||||
if domain.constraints:
|
||||
dc_result = self.resolver.check_domain_constraints(
|
||||
combo, domain.constraints
|
||||
)
|
||||
if dc_result.status == "p1_fail":
|
||||
self.repo.save_result(
|
||||
combo.id, domain.id,
|
||||
composite_score=0.0, pass_reached=1,
|
||||
domain_block_reason="; ".join(
|
||||
dc_result.violations
|
||||
),
|
||||
)
|
||||
result.pass1_failed += 1
|
||||
self._update_run_counters(
|
||||
run_id, result, current_pass=1
|
||||
)
|
||||
continue
|
||||
|
||||
if cr.status == "conditional":
|
||||
result.pass1_conditional += 1
|
||||
else:
|
||||
@@ -175,8 +195,11 @@ class Pipeline:
|
||||
if combo.status.endswith("_fail"):
|
||||
result.pass1_failed += 1
|
||||
continue
|
||||
else:
|
||||
result.pass1_valid += 1
|
||||
# Check if domain-blocked from a prior run
|
||||
if existing_result and existing_result["pass_reached"] == 1:
|
||||
result.pass1_failed += 1
|
||||
continue
|
||||
result.pass1_valid += 1
|
||||
else:
|
||||
# Pass 1 not requested; check if failed from a prior run
|
||||
if combo.status.endswith("_fail"):
|
||||
|
||||
@@ -18,6 +18,14 @@ class MetricBound:
|
||||
metric_id: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainConstraint:
|
||||
"""Whitelist constraint: only these values are allowed for a dependency key."""
|
||||
|
||||
key: str # dependency key, e.g. "medium"
|
||||
allowed_values: list[str] = field(default_factory=list) # e.g. ["ground", "air"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Domain:
|
||||
"""A context frame that defines what 'good' means (e.g., urban_commuting)."""
|
||||
@@ -25,4 +33,5 @@ class Domain:
|
||||
name: str
|
||||
description: str = ""
|
||||
metric_bounds: list[MetricBound] = field(default_factory=list)
|
||||
constraints: list[DomainConstraint] = field(default_factory=list)
|
||||
id: int | None = None
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from physcom.models.entity import Entity, Dependency
|
||||
from physcom.models.domain import Domain, MetricBound
|
||||
from physcom.models.domain import Domain, DomainConstraint, MetricBound
|
||||
|
||||
|
||||
# ── Platforms — Ground ──────────────────────────────────────────
|
||||
@@ -726,6 +726,7 @@ URBAN_COMMUTING = Domain(
|
||||
MetricBound("availability", weight=0.15, norm_min=0.0, norm_max=1.0, unit="0-1"),
|
||||
MetricBound("range_fuel", weight=0.10, norm_min=5000, norm_max=500000, unit="m"),
|
||||
],
|
||||
constraints=[DomainConstraint("medium", ["ground", "air"])],
|
||||
)
|
||||
|
||||
INTERPLANETARY = Domain(
|
||||
@@ -738,6 +739,7 @@ INTERPLANETARY = Domain(
|
||||
MetricBound("cost_efficiency", weight=0.10, norm_min=1.0, norm_max=1e6, unit="$/m", lower_is_better=True),
|
||||
MetricBound("range_degradation", weight=0.10, norm_min=8640000, norm_max=3.1536e9, unit="s"),
|
||||
],
|
||||
constraints=[DomainConstraint("medium", ["space"])],
|
||||
)
|
||||
|
||||
MARITIME_SHIPPING = Domain(
|
||||
@@ -750,6 +752,7 @@ MARITIME_SHIPPING = Domain(
|
||||
MetricBound("safety", weight=0.20, norm_min=0.0, norm_max=1.0, unit="0-1"),
|
||||
MetricBound("range_fuel", weight=0.15, norm_min=100000, norm_max=40000000, unit="m"),
|
||||
],
|
||||
constraints=[DomainConstraint("medium", ["water"])],
|
||||
)
|
||||
|
||||
LAST_MILE_DELIVERY = Domain(
|
||||
@@ -762,6 +765,7 @@ LAST_MILE_DELIVERY = Domain(
|
||||
MetricBound("safety", weight=0.15, norm_min=0.0, norm_max=1.0, unit="0-1"),
|
||||
MetricBound("environmental_impact", weight=0.10, norm_min=0, norm_max=5e-4, unit="kg/m", lower_is_better=True),
|
||||
],
|
||||
constraints=[DomainConstraint("medium", ["ground", "air"])],
|
||||
)
|
||||
|
||||
CROSS_COUNTRY_FREIGHT = Domain(
|
||||
@@ -774,6 +778,7 @@ CROSS_COUNTRY_FREIGHT = Domain(
|
||||
MetricBound("range_fuel", weight=0.20, norm_min=100000, norm_max=5000000, unit="m"),
|
||||
MetricBound("reliability", weight=0.10, norm_min=0.0, norm_max=1.0, unit="0-1"),
|
||||
],
|
||||
constraints=[DomainConstraint("medium", ["ground"])],
|
||||
)
|
||||
|
||||
ALL_DOMAINS = [
|
||||
@@ -831,5 +836,7 @@ def load_transport_seed(repo) -> dict:
|
||||
repo.ensure_metric(mb.metric_name, unit=mb.unit)
|
||||
if mb.lower_is_better:
|
||||
repo.backfill_lower_is_better(domain.name, mb.metric_name)
|
||||
# Backfill domain constraints
|
||||
repo.replace_domain_constraints(domain)
|
||||
|
||||
return counts
|
||||
|
||||
Reference in New Issue
Block a user