861 lines
37 KiB
Python
861 lines
37 KiB
Python
"""CRUD operations for all entities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import sqlite3
|
|
from datetime import datetime, timezone
|
|
from typing import Sequence
|
|
|
|
from physcom.models.entity import Dependency, Entity
|
|
from physcom.models.domain import Domain, DomainConstraint, MetricBound
|
|
from physcom.models.combination import Combination
|
|
|
|
|
|
class Repository:
|
|
"""Thin data-access layer over the SQLite database."""
|
|
|
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
|
self.conn = conn
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
|
# ── Dimensions ──────────────────────────────────────────────
|
|
|
|
def ensure_dimension(self, name: str, description: str = "") -> int:
|
|
"""Insert dimension if it doesn't exist, return its id."""
|
|
cur = self.conn.execute(
|
|
"INSERT OR IGNORE INTO dimensions (name, description) VALUES (?, ?)",
|
|
(name, description),
|
|
)
|
|
if cur.lastrowid and cur.rowcount:
|
|
self.conn.commit()
|
|
return cur.lastrowid
|
|
row = self.conn.execute(
|
|
"SELECT id FROM dimensions WHERE name = ?", (name,)
|
|
).fetchone()
|
|
return row["id"]
|
|
|
|
def list_dimensions(self) -> list[dict]:
|
|
rows = self.conn.execute("SELECT * FROM dimensions ORDER BY name").fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
# ── Entities ────────────────────────────────────────────────
|
|
|
|
def add_entity(self, entity: Entity) -> Entity:
|
|
"""Persist an Entity (and its dependencies). Returns it with id set."""
|
|
dim_id = self.ensure_dimension(entity.dimension)
|
|
cur = self.conn.execute(
|
|
"INSERT INTO entities (dimension_id, name, description) VALUES (?, ?, ?)",
|
|
(dim_id, entity.name, entity.description),
|
|
)
|
|
entity.id = cur.lastrowid
|
|
entity.dimension_id = dim_id
|
|
for dep in entity.dependencies:
|
|
dep_cur = self.conn.execute(
|
|
"""INSERT INTO dependencies
|
|
(entity_id, category, key, value, unit, constraint_type)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
(entity.id, dep.category, dep.key, dep.value, dep.unit, dep.constraint_type),
|
|
)
|
|
dep.id = dep_cur.lastrowid
|
|
self.conn.commit()
|
|
return entity
|
|
|
|
def get_entity(self, entity_id: int) -> Entity | None:
|
|
row = self.conn.execute(
|
|
"""SELECT e.id, e.name, e.description, d.name as dimension, e.dimension_id
|
|
FROM entities e JOIN dimensions d ON e.dimension_id = d.id
|
|
WHERE e.id = ?""",
|
|
(entity_id,),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
deps = self._load_dependencies(row["id"])
|
|
return Entity(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
dimension=row["dimension"],
|
|
dimension_id=row["dimension_id"],
|
|
dependencies=deps,
|
|
)
|
|
|
|
def list_entities(self, dimension: str | None = None) -> list[Entity]:
|
|
if dimension:
|
|
rows = self.conn.execute(
|
|
"""SELECT e.id, e.name, e.description, d.name as dimension, e.dimension_id
|
|
FROM entities e JOIN dimensions d ON e.dimension_id = d.id
|
|
WHERE d.name = ? ORDER BY e.name""",
|
|
(dimension,),
|
|
).fetchall()
|
|
else:
|
|
rows = self.conn.execute(
|
|
"""SELECT e.id, e.name, e.description, d.name as dimension, e.dimension_id
|
|
FROM entities e JOIN dimensions d ON e.dimension_id = d.id
|
|
ORDER BY d.name, e.name"""
|
|
).fetchall()
|
|
entities = []
|
|
for r in rows:
|
|
deps = self._load_dependencies(r["id"])
|
|
entities.append(Entity(
|
|
id=r["id"], name=r["name"], description=r["description"] or "",
|
|
dimension=r["dimension"], dimension_id=r["dimension_id"],
|
|
dependencies=deps,
|
|
))
|
|
return entities
|
|
|
|
def _load_dependencies(self, entity_id: int) -> list[Dependency]:
|
|
rows = self.conn.execute(
|
|
"SELECT * FROM dependencies WHERE entity_id = ?", (entity_id,)
|
|
).fetchall()
|
|
return [
|
|
Dependency(
|
|
id=r["id"], category=r["category"], key=r["key"],
|
|
value=r["value"], unit=r["unit"], constraint_type=r["constraint_type"],
|
|
)
|
|
for r in rows
|
|
]
|
|
|
|
def update_entity(self, entity_id: int, name: str, description: str) -> None:
|
|
self.conn.execute(
|
|
"UPDATE entities SET name = ?, description = ? WHERE id = ?",
|
|
(name, description, entity_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def delete_entity(self, entity_id: int) -> None:
|
|
combo_ids = [
|
|
r["combination_id"]
|
|
for r in self.conn.execute(
|
|
"SELECT combination_id FROM combination_entities WHERE entity_id = ?",
|
|
(entity_id,),
|
|
).fetchall()
|
|
]
|
|
if combo_ids:
|
|
ph = ",".join("?" * len(combo_ids))
|
|
self.conn.execute(f"DELETE FROM combination_results WHERE combination_id IN ({ph})", combo_ids)
|
|
self.conn.execute(f"DELETE FROM combination_scores WHERE combination_id IN ({ph})", combo_ids)
|
|
self.conn.execute(f"DELETE FROM combination_entities WHERE combination_id IN ({ph})", combo_ids)
|
|
self.conn.execute(f"DELETE FROM combinations WHERE id IN ({ph})", combo_ids)
|
|
self.conn.execute("DELETE FROM dependencies WHERE entity_id = ?", (entity_id,))
|
|
self.conn.execute("DELETE FROM entities WHERE id = ?", (entity_id,))
|
|
self.conn.commit()
|
|
|
|
def add_dependency(self, entity_id: int, dep: Dependency) -> Dependency:
|
|
cur = self.conn.execute(
|
|
"""INSERT INTO dependencies
|
|
(entity_id, category, key, value, unit, constraint_type)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
(entity_id, dep.category, dep.key, dep.value, dep.unit, dep.constraint_type),
|
|
)
|
|
dep.id = cur.lastrowid
|
|
self.conn.commit()
|
|
return dep
|
|
|
|
def replace_entity_dependencies(self, entity_id: int, deps: list[Dependency]) -> None:
|
|
"""Delete all existing dependencies for an entity and insert new ones."""
|
|
self.conn.execute("DELETE FROM dependencies WHERE entity_id = ?", (entity_id,))
|
|
for dep in deps:
|
|
cur = self.conn.execute(
|
|
"""INSERT INTO dependencies
|
|
(entity_id, category, key, value, unit, constraint_type)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
(entity_id, dep.category, dep.key, dep.value, dep.unit, dep.constraint_type),
|
|
)
|
|
dep.id = cur.lastrowid
|
|
self.conn.commit()
|
|
|
|
def get_entity_by_name(self, dimension: str, name: str) -> Entity | None:
|
|
row = self.conn.execute(
|
|
"""SELECT e.id, e.name, e.description, d.name as dimension, e.dimension_id
|
|
FROM entities e JOIN dimensions d ON e.dimension_id = d.id
|
|
WHERE d.name = ? AND e.name = ?""",
|
|
(dimension, name),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
deps = self._load_dependencies(row["id"])
|
|
return Entity(
|
|
id=row["id"], name=row["name"], description=row["description"] or "",
|
|
dimension=row["dimension"], dimension_id=row["dimension_id"],
|
|
dependencies=deps,
|
|
)
|
|
|
|
def update_dependency(self, dep_id: int, dep: Dependency) -> None:
|
|
self.conn.execute(
|
|
"""UPDATE dependencies
|
|
SET category = ?, key = ?, value = ?, unit = ?, constraint_type = ?
|
|
WHERE id = ?""",
|
|
(dep.category, dep.key, dep.value, dep.unit, dep.constraint_type, dep_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def delete_dependency(self, dep_id: int) -> None:
|
|
self.conn.execute("DELETE FROM dependencies WHERE id = ?", (dep_id,))
|
|
self.conn.commit()
|
|
|
|
def get_dependency(self, dep_id: int) -> Dependency | None:
|
|
row = self.conn.execute(
|
|
"SELECT * FROM dependencies WHERE id = ?", (dep_id,)
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
return Dependency(
|
|
id=row["id"], category=row["category"], key=row["key"],
|
|
value=row["value"], unit=row["unit"], constraint_type=row["constraint_type"],
|
|
)
|
|
|
|
# ── Domains & Metrics ───────────────────────────────────────
|
|
|
|
def ensure_metric(self, name: str, unit: str = "", description: str = "") -> int:
|
|
self.conn.execute(
|
|
"INSERT OR IGNORE INTO metrics (name, unit, description) VALUES (?, ?, ?)",
|
|
(name, unit, description),
|
|
)
|
|
if unit:
|
|
self.conn.execute(
|
|
"UPDATE metrics SET unit = ? WHERE name = ? AND (unit IS NULL OR unit = '')",
|
|
(unit, name),
|
|
)
|
|
row = self.conn.execute("SELECT id FROM metrics WHERE name = ?", (name,)).fetchone()
|
|
self.conn.commit()
|
|
return row["id"]
|
|
|
|
def backfill_lower_is_better(self, domain_name: str, metric_name: str) -> None:
|
|
"""Set lower_is_better=1 for an existing domain-metric row that still has the default 0."""
|
|
self.conn.execute(
|
|
"""UPDATE domain_metric_weights SET lower_is_better = 1
|
|
WHERE lower_is_better = 0
|
|
AND domain_id = (SELECT id FROM domains WHERE name = ?)
|
|
AND metric_id = (SELECT id FROM metrics WHERE name = ?)""",
|
|
(domain_name, metric_name),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def add_domain(self, domain: Domain) -> Domain:
|
|
cur = self.conn.execute(
|
|
"INSERT INTO domains (name, description) VALUES (?, ?)",
|
|
(domain.name, domain.description),
|
|
)
|
|
domain.id = cur.lastrowid
|
|
for mb in domain.metric_bounds:
|
|
metric_id = self.ensure_metric(mb.metric_name, unit=mb.unit)
|
|
mb.metric_id = metric_id
|
|
self.conn.execute(
|
|
"""INSERT INTO domain_metric_weights
|
|
(domain_id, metric_id, weight, norm_min, norm_max, lower_is_better)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
(domain.id, metric_id, mb.weight, mb.norm_min, mb.norm_max,
|
|
int(mb.lower_is_better)),
|
|
)
|
|
for dc in domain.constraints:
|
|
for val in dc.allowed_values:
|
|
self.conn.execute(
|
|
"INSERT OR IGNORE INTO domain_constraints (domain_id, key, value) VALUES (?, ?, ?)",
|
|
(domain.id, dc.key, val),
|
|
)
|
|
self.conn.commit()
|
|
return domain
|
|
|
|
def _load_domain_constraints(self, domain_id: int) -> list[DomainConstraint]:
|
|
rows = self.conn.execute(
|
|
"SELECT key, value FROM domain_constraints WHERE domain_id = ? ORDER BY key, value",
|
|
(domain_id,),
|
|
).fetchall()
|
|
by_key: dict[str, list[str]] = {}
|
|
for r in rows:
|
|
by_key.setdefault(r["key"], []).append(r["value"])
|
|
return [DomainConstraint(key=k, allowed_values=v) for k, v in by_key.items()]
|
|
|
|
def get_domain(self, name: str) -> Domain | None:
|
|
row = self.conn.execute("SELECT * FROM domains WHERE name = ?", (name,)).fetchone()
|
|
if not row:
|
|
return None
|
|
weights = self.conn.execute(
|
|
"""SELECT m.name, m.unit, dmw.weight, dmw.norm_min, dmw.norm_max,
|
|
dmw.metric_id, dmw.lower_is_better
|
|
FROM domain_metric_weights dmw
|
|
JOIN metrics m ON dmw.metric_id = m.id
|
|
WHERE dmw.domain_id = ?""",
|
|
(row["id"],),
|
|
).fetchall()
|
|
return Domain(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
metric_bounds=[
|
|
MetricBound(
|
|
metric_name=w["name"], weight=w["weight"],
|
|
norm_min=w["norm_min"], norm_max=w["norm_max"],
|
|
metric_id=w["metric_id"], unit=w["unit"] or "",
|
|
lower_is_better=bool(w["lower_is_better"]),
|
|
)
|
|
for w in weights
|
|
],
|
|
constraints=self._load_domain_constraints(row["id"]),
|
|
)
|
|
|
|
def list_domains(self) -> list[Domain]:
|
|
rows = self.conn.execute("SELECT name FROM domains ORDER BY name").fetchall()
|
|
return [self.get_domain(r["name"]) for r in rows]
|
|
|
|
def get_domain_by_id(self, domain_id: int) -> Domain | None:
|
|
row = self.conn.execute("SELECT * FROM domains WHERE id = ?", (domain_id,)).fetchone()
|
|
if not row:
|
|
return None
|
|
weights = self.conn.execute(
|
|
"""SELECT m.name, m.unit, dmw.weight, dmw.norm_min, dmw.norm_max,
|
|
dmw.metric_id, dmw.lower_is_better
|
|
FROM domain_metric_weights dmw
|
|
JOIN metrics m ON dmw.metric_id = m.id
|
|
WHERE dmw.domain_id = ?""",
|
|
(row["id"],),
|
|
).fetchall()
|
|
return Domain(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
metric_bounds=[
|
|
MetricBound(
|
|
metric_name=w["name"], weight=w["weight"],
|
|
norm_min=w["norm_min"], norm_max=w["norm_max"],
|
|
metric_id=w["metric_id"], unit=w["unit"] or "",
|
|
lower_is_better=bool(w["lower_is_better"]),
|
|
)
|
|
for w in weights
|
|
],
|
|
constraints=self._load_domain_constraints(row["id"]),
|
|
)
|
|
|
|
def update_domain(self, domain_id: int, name: str, description: str) -> None:
|
|
self.conn.execute(
|
|
"UPDATE domains SET name = ?, description = ? WHERE id = ?",
|
|
(name, description, domain_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def add_metric_bound(self, domain_id: int, mb: MetricBound) -> MetricBound:
|
|
metric_id = self.ensure_metric(mb.metric_name, mb.unit)
|
|
mb.metric_id = metric_id
|
|
self.conn.execute(
|
|
"""INSERT OR REPLACE INTO domain_metric_weights
|
|
(domain_id, metric_id, weight, norm_min, norm_max, lower_is_better)
|
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
|
(domain_id, metric_id, mb.weight, mb.norm_min, mb.norm_max,
|
|
int(mb.lower_is_better)),
|
|
)
|
|
self.conn.commit()
|
|
return mb
|
|
|
|
def update_metric_bound(
|
|
self, domain_id: int, metric_id: int, weight: float, norm_min: float, norm_max: float,
|
|
unit: str, lower_is_better: bool = False,
|
|
) -> None:
|
|
self.conn.execute(
|
|
"""UPDATE domain_metric_weights
|
|
SET weight = ?, norm_min = ?, norm_max = ?, lower_is_better = ?
|
|
WHERE domain_id = ? AND metric_id = ?""",
|
|
(weight, norm_min, norm_max, int(lower_is_better), domain_id, metric_id),
|
|
)
|
|
if unit:
|
|
self.conn.execute(
|
|
"UPDATE metrics SET unit = ? WHERE id = ?",
|
|
(unit, metric_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def delete_metric_bound(self, domain_id: int, metric_id: int) -> None:
|
|
self.conn.execute(
|
|
"DELETE FROM domain_metric_weights WHERE domain_id = ? AND metric_id = ?",
|
|
(domain_id, metric_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def delete_domain(self, domain_id: int) -> None:
|
|
self.conn.execute("DELETE FROM pipeline_runs WHERE domain_id = ?", (domain_id,))
|
|
self.conn.execute("DELETE FROM combination_results WHERE domain_id = ?", (domain_id,))
|
|
self.conn.execute("DELETE FROM combination_scores WHERE domain_id = ?", (domain_id,))
|
|
self.conn.execute("DELETE FROM domain_metric_weights WHERE domain_id = ?", (domain_id,))
|
|
self.conn.execute("DELETE FROM domain_constraints WHERE domain_id = ?", (domain_id,))
|
|
self.conn.execute("DELETE FROM domains WHERE id = ?", (domain_id,))
|
|
self.conn.commit()
|
|
|
|
def replace_domain_constraints(self, domain: Domain) -> None:
|
|
"""Delete and re-insert domain constraints. Used by seed backfill."""
|
|
if not domain.id:
|
|
existing = self.conn.execute(
|
|
"SELECT id FROM domains WHERE name = ?", (domain.name,)
|
|
).fetchone()
|
|
if not existing:
|
|
return
|
|
domain.id = existing["id"]
|
|
self.conn.execute("DELETE FROM domain_constraints WHERE domain_id = ?", (domain.id,))
|
|
for dc in domain.constraints:
|
|
for val in dc.allowed_values:
|
|
self.conn.execute(
|
|
"INSERT OR IGNORE INTO domain_constraints (domain_id, key, value) VALUES (?, ?, ?)",
|
|
(domain.id, dc.key, val),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def reset_domain_results(self, domain_name: str) -> int:
|
|
"""Delete all pipeline results for a domain so it can be re-run from scratch.
|
|
|
|
Returns the number of result rows deleted.
|
|
"""
|
|
domain = self.get_domain(domain_name)
|
|
if not domain:
|
|
return 0
|
|
count = self.conn.execute(
|
|
"SELECT COUNT(*) FROM combination_results WHERE domain_id = ?",
|
|
(domain.id,),
|
|
).fetchone()[0]
|
|
self.conn.execute("DELETE FROM combination_scores WHERE domain_id = ?", (domain.id,))
|
|
self.conn.execute("DELETE FROM combination_results WHERE domain_id = ?", (domain.id,))
|
|
self.conn.execute("DELETE FROM pipeline_runs WHERE domain_id = ?", (domain.id,))
|
|
# Delete orphaned combos (no results left in any domain) and all their
|
|
# related rows — scores, entity links — so FK constraints don't block.
|
|
orphan_sql = """SELECT c.id FROM combinations c
|
|
WHERE c.id NOT IN (
|
|
SELECT DISTINCT combination_id FROM combination_results
|
|
)"""
|
|
self.conn.execute(
|
|
f"DELETE FROM combination_scores WHERE combination_id IN ({orphan_sql})"
|
|
)
|
|
self.conn.execute(
|
|
f"DELETE FROM combination_entities WHERE combination_id IN ({orphan_sql})"
|
|
)
|
|
self.conn.execute(
|
|
f"DELETE FROM combinations WHERE id IN ({orphan_sql})"
|
|
)
|
|
self.conn.commit()
|
|
return count
|
|
|
|
# ── Combinations ────────────────────────────────────────────
|
|
|
|
@staticmethod
|
|
def compute_hash(entity_ids: Sequence[int]) -> str:
|
|
key = ",".join(str(eid) for eid in sorted(entity_ids))
|
|
return hashlib.sha256(key.encode()).hexdigest()[:16]
|
|
|
|
def save_combination(self, combination: Combination) -> Combination:
|
|
entity_ids = [e.id for e in combination.entities]
|
|
combination.hash = self.compute_hash(entity_ids)
|
|
|
|
existing = self.conn.execute(
|
|
"SELECT id, status, block_reason FROM combinations WHERE hash = ?",
|
|
(combination.hash,),
|
|
).fetchone()
|
|
if existing:
|
|
combination.id = existing["id"]
|
|
combination.status = existing["status"]
|
|
combination.block_reason = existing["block_reason"]
|
|
return combination
|
|
|
|
cur = self.conn.execute(
|
|
"INSERT INTO combinations (hash, status, block_reason) VALUES (?, ?, ?)",
|
|
(combination.hash, combination.status, combination.block_reason),
|
|
)
|
|
combination.id = cur.lastrowid
|
|
for eid in entity_ids:
|
|
self.conn.execute(
|
|
"INSERT INTO combination_entities (combination_id, entity_id) VALUES (?, ?)",
|
|
(combination.id, eid),
|
|
)
|
|
self.conn.commit()
|
|
return combination
|
|
|
|
def update_combination_status(
|
|
self, combo_id: int, status: str, block_reason: str | None = None
|
|
) -> None:
|
|
# Don't downgrade from higher pass states — preserves human/LLM review data
|
|
if status in ("scored", "llm_reviewed") or status.endswith("_fail"):
|
|
row = self.conn.execute(
|
|
"SELECT status FROM combinations WHERE id = ?", (combo_id,)
|
|
).fetchone()
|
|
if row:
|
|
cur = row["status"]
|
|
# Fail statuses should not overwrite llm_reviewed or reviewed
|
|
if status.endswith("_fail") and cur in ("llm_reviewed", "reviewed"):
|
|
return
|
|
if status == "scored" and cur in ("llm_reviewed", "reviewed"):
|
|
return
|
|
if status == "llm_reviewed" and cur == "reviewed":
|
|
return
|
|
self.conn.execute(
|
|
"UPDATE combinations SET status = ?, block_reason = ? WHERE id = ?",
|
|
(status, block_reason, combo_id),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_combination(self, combo_id: int) -> Combination | None:
|
|
row = self.conn.execute("SELECT * FROM combinations WHERE id = ?", (combo_id,)).fetchone()
|
|
if not row:
|
|
return None
|
|
entity_rows = self.conn.execute(
|
|
"SELECT entity_id FROM combination_entities WHERE combination_id = ?",
|
|
(combo_id,),
|
|
).fetchall()
|
|
entities = [self.get_entity(er["entity_id"]) for er in entity_rows]
|
|
return Combination(
|
|
id=row["id"], hash=row["hash"], status=row["status"],
|
|
block_reason=row["block_reason"], entities=entities,
|
|
)
|
|
|
|
def _bulk_load_combinations(self, combo_ids: list[int]) -> dict[int, Combination]:
|
|
"""Load multiple Combinations in O(4) queries instead of O(N*M)."""
|
|
if not combo_ids:
|
|
return {}
|
|
ph = ",".join("?" * len(combo_ids))
|
|
combo_rows = self.conn.execute(
|
|
f"SELECT * FROM combinations WHERE id IN ({ph})", combo_ids
|
|
).fetchall()
|
|
combos: dict[int, Combination] = {
|
|
r["id"]: Combination(
|
|
id=r["id"], hash=r["hash"], status=r["status"],
|
|
block_reason=r["block_reason"], entities=[],
|
|
)
|
|
for r in combo_rows
|
|
}
|
|
ce_rows = self.conn.execute(
|
|
f"SELECT combination_id, entity_id FROM combination_entities WHERE combination_id IN ({ph})",
|
|
combo_ids,
|
|
).fetchall()
|
|
combo_to_eids: dict[int, list[int]] = {}
|
|
for r in ce_rows:
|
|
combo_to_eids.setdefault(r["combination_id"], []).append(r["entity_id"])
|
|
|
|
entity_ids = list({r["entity_id"] for r in ce_rows})
|
|
if entity_ids:
|
|
eph = ",".join("?" * len(entity_ids))
|
|
entity_rows = self.conn.execute(
|
|
f"""SELECT e.id, e.name, e.description, d.name as dimension, e.dimension_id
|
|
FROM entities e JOIN dimensions d ON e.dimension_id = d.id
|
|
WHERE e.id IN ({eph})""",
|
|
entity_ids,
|
|
).fetchall()
|
|
dep_rows = self.conn.execute(
|
|
f"SELECT * FROM dependencies WHERE entity_id IN ({eph})", entity_ids
|
|
).fetchall()
|
|
deps_by_entity: dict[int, list[Dependency]] = {}
|
|
for r in dep_rows:
|
|
deps_by_entity.setdefault(r["entity_id"], []).append(Dependency(
|
|
id=r["id"], category=r["category"], key=r["key"],
|
|
value=r["value"], unit=r["unit"], constraint_type=r["constraint_type"],
|
|
))
|
|
entities_by_id: dict[int, Entity] = {
|
|
r["id"]: Entity(
|
|
id=r["id"], name=r["name"], description=r["description"] or "",
|
|
dimension=r["dimension"], dimension_id=r["dimension_id"],
|
|
dependencies=deps_by_entity.get(r["id"], []),
|
|
)
|
|
for r in entity_rows
|
|
}
|
|
for cid, eids in combo_to_eids.items():
|
|
if cid in combos:
|
|
combos[cid].entities = [entities_by_id[eid] for eid in eids if eid in entities_by_id]
|
|
return combos
|
|
|
|
def list_combinations(self, status: str | None = None) -> list[Combination]:
|
|
if status:
|
|
rows = self.conn.execute(
|
|
"SELECT id FROM combinations WHERE status = ? ORDER BY id", (status,)
|
|
).fetchall()
|
|
else:
|
|
rows = self.conn.execute("SELECT id FROM combinations ORDER BY id").fetchall()
|
|
ids = [r["id"] for r in rows]
|
|
combos = self._bulk_load_combinations(ids)
|
|
return [combos[i] for i in ids if i in combos]
|
|
|
|
# ── Scores & Results ────────────────────────────────────────
|
|
|
|
def save_scores(
|
|
self,
|
|
combo_id: int,
|
|
domain_id: int,
|
|
scores: list[dict],
|
|
) -> None:
|
|
"""Save per-metric scores. Each dict: metric_id, raw_value, normalized_score, estimation_method, confidence."""
|
|
for s in scores:
|
|
self.conn.execute(
|
|
"""INSERT OR REPLACE INTO combination_scores
|
|
(combination_id, domain_id, metric_id, raw_value, normalized_score,
|
|
estimation_method, confidence)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(combo_id, domain_id, s["metric_id"], s["raw_value"],
|
|
s["normalized_score"], s["estimation_method"], s["confidence"]),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def save_result(
|
|
self,
|
|
combo_id: int,
|
|
domain_id: int,
|
|
composite_score: float,
|
|
pass_reached: int,
|
|
novelty_flag: str | None = None,
|
|
llm_review: str | None = None,
|
|
human_notes: str | None = None,
|
|
domain_block_reason: str | None = None,
|
|
) -> None:
|
|
self.conn.execute(
|
|
"""INSERT OR REPLACE INTO combination_results
|
|
(combination_id, domain_id, composite_score, novelty_flag,
|
|
llm_review, human_notes, pass_reached, domain_block_reason)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(combo_id, domain_id, composite_score, novelty_flag,
|
|
llm_review, human_notes, pass_reached, domain_block_reason),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_combination_scores(self, combo_id: int, domain_id: int) -> list[dict]:
|
|
"""Return per-metric scores for a combination in a domain."""
|
|
rows = self.conn.execute(
|
|
"""SELECT cs.*, m.name as metric_name, m.unit as metric_unit
|
|
FROM combination_scores cs
|
|
JOIN metrics m ON cs.metric_id = m.id
|
|
WHERE cs.combination_id = ? AND cs.domain_id = ?""",
|
|
(combo_id, domain_id),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def count_combinations_by_status(self, domain_name: str | None = None) -> dict[str, int]:
|
|
"""Count combos by status. If domain_name given, only combos with results in that domain."""
|
|
if domain_name:
|
|
rows = self.conn.execute(
|
|
"""SELECT c.status, COUNT(*) as cnt
|
|
FROM combination_results cr
|
|
JOIN combinations c ON cr.combination_id = c.id
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE d.name = ?
|
|
GROUP BY c.status""",
|
|
(domain_name,),
|
|
).fetchall()
|
|
else:
|
|
rows = self.conn.execute(
|
|
"SELECT status, COUNT(*) as cnt FROM combinations GROUP BY status"
|
|
).fetchall()
|
|
return {r["status"]: r["cnt"] for r in rows}
|
|
|
|
def get_pipeline_summary(self, domain_name: str) -> dict | None:
|
|
"""Return a summary of results for a domain, or None if no results."""
|
|
row = self.conn.execute(
|
|
"""SELECT COUNT(*) as total,
|
|
AVG(cr.composite_score) as avg_score,
|
|
MAX(cr.composite_score) as max_score,
|
|
MIN(cr.composite_score) as min_score,
|
|
MAX(cr.pass_reached) as last_pass
|
|
FROM combination_results cr
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE d.name = ?""",
|
|
(domain_name,),
|
|
).fetchone()
|
|
if not row or row["total"] == 0:
|
|
return None
|
|
failed = self.conn.execute(
|
|
"""SELECT COUNT(*) as cnt
|
|
FROM combinations c
|
|
JOIN combination_results cr ON cr.combination_id = c.id
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE c.status LIKE '%\\_fail' ESCAPE '\\' AND d.name = ?""",
|
|
(domain_name,),
|
|
).fetchone()
|
|
return {
|
|
"total_results": row["total"],
|
|
"avg_score": row["avg_score"],
|
|
"max_score": row["max_score"],
|
|
"min_score": row["min_score"],
|
|
"last_pass": row["last_pass"],
|
|
"failed": failed["cnt"] if failed else 0,
|
|
}
|
|
|
|
def get_result(self, combo_id: int, domain_id: int) -> dict | None:
|
|
"""Return a single combination_result row."""
|
|
row = self.conn.execute(
|
|
"""SELECT cr.*, d.name as domain_name
|
|
FROM combination_results cr
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE cr.combination_id = ? AND cr.domain_id = ?""",
|
|
(combo_id, domain_id),
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def get_all_results(self, domain_name: str, status: str | None = None) -> list[dict]:
|
|
"""Return all results for a domain, optionally filtered by combo status."""
|
|
query = """SELECT cr.*, c.hash, c.status as combo_status, d.name as domain_name
|
|
FROM combination_results cr
|
|
JOIN combinations c ON cr.combination_id = c.id
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE d.name = ?"""
|
|
params: list = [domain_name]
|
|
if status:
|
|
query += " AND c.status = ?"
|
|
params.append(status)
|
|
query += " ORDER BY cr.composite_score DESC"
|
|
rows = self.conn.execute(query, params).fetchall()
|
|
combo_ids = [r["combination_id"] for r in rows]
|
|
combos = self._bulk_load_combinations(combo_ids)
|
|
return [
|
|
{
|
|
"combination": combos.get(r["combination_id"]),
|
|
"composite_score": r["composite_score"],
|
|
"novelty_flag": r["novelty_flag"],
|
|
"llm_review": r["llm_review"],
|
|
"human_notes": r["human_notes"],
|
|
"pass_reached": r["pass_reached"],
|
|
"domain_id": r["domain_id"],
|
|
"domain_block_reason": r["domain_block_reason"],
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
def get_top_results(self, domain_name: str, limit: int = 10) -> list[dict]:
|
|
"""Return top-N results for a domain, ordered by composite_score DESC."""
|
|
rows = self.conn.execute(
|
|
"""SELECT cr.*, c.hash, c.status, d.name as domain_name
|
|
FROM combination_results cr
|
|
JOIN combinations c ON cr.combination_id = c.id
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE d.name = ?
|
|
ORDER BY cr.composite_score DESC
|
|
LIMIT ?""",
|
|
(domain_name, limit),
|
|
).fetchall()
|
|
combo_ids = [r["combination_id"] for r in rows]
|
|
combos = self._bulk_load_combinations(combo_ids)
|
|
return [
|
|
{
|
|
"combination": combos.get(r["combination_id"]),
|
|
"composite_score": r["composite_score"],
|
|
"novelty_flag": r["novelty_flag"],
|
|
"llm_review": r["llm_review"],
|
|
"human_notes": r["human_notes"],
|
|
"pass_reached": r["pass_reached"],
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
def get_results_for_combination(self, combo_id: int) -> list[dict]:
|
|
"""Return all domain results for a combination."""
|
|
rows = self.conn.execute(
|
|
"""SELECT cr.*, d.name as domain_name
|
|
FROM combination_results cr
|
|
JOIN domains d ON cr.domain_id = d.id
|
|
WHERE cr.combination_id = ?""",
|
|
(combo_id,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
# ── Pipeline Runs ────────────────────────────────────────
|
|
|
|
def create_pipeline_run(self, domain_id: int, config: dict) -> int:
|
|
"""Create a new pipeline_run record. Returns the run id."""
|
|
cur = self.conn.execute(
|
|
"""INSERT INTO pipeline_runs (domain_id, status, config, created_at)
|
|
VALUES (?, 'pending', ?, ?)""",
|
|
(domain_id, json.dumps(config), datetime.now(timezone.utc).isoformat()),
|
|
)
|
|
self.conn.commit()
|
|
return cur.lastrowid
|
|
|
|
_PIPELINE_RUN_UPDATABLE = frozenset({
|
|
"status", "total_combos", "combos_pass1", "combos_pass2",
|
|
"combos_pass3", "combos_pass4", "current_pass",
|
|
"error_message", "started_at", "completed_at",
|
|
})
|
|
|
|
def update_pipeline_run(self, run_id: int, **fields) -> None:
|
|
"""Update fields on a pipeline_run. Only allowlisted column names are accepted."""
|
|
if not fields:
|
|
return
|
|
invalid = set(fields) - self._PIPELINE_RUN_UPDATABLE
|
|
if invalid:
|
|
raise ValueError(f"Invalid pipeline_run fields: {invalid}")
|
|
set_clause = ", ".join(f"{k} = ?" for k in fields)
|
|
values = list(fields.values())
|
|
values.append(run_id)
|
|
self.conn.execute(
|
|
f"UPDATE pipeline_runs SET {set_clause} WHERE id = ?", values
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_pipeline_run(self, run_id: int) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT * FROM pipeline_runs WHERE id = ?", (run_id,)
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def list_pipeline_runs(self, domain_id: int | None = None) -> list[dict]:
|
|
if domain_id is not None:
|
|
rows = self.conn.execute(
|
|
"""SELECT pr.*, d.name as domain_name
|
|
FROM pipeline_runs pr
|
|
JOIN domains d ON pr.domain_id = d.id
|
|
WHERE pr.domain_id = ?
|
|
ORDER BY pr.created_at DESC""",
|
|
(domain_id,),
|
|
).fetchall()
|
|
else:
|
|
rows = self.conn.execute(
|
|
"""SELECT pr.*, d.name as domain_name
|
|
FROM pipeline_runs pr
|
|
JOIN domains d ON pr.domain_id = d.id
|
|
ORDER BY pr.created_at DESC"""
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def get_combo_pass_reached(self, combo_id: int, domain_id: int) -> int | None:
|
|
"""Return the pass_reached for a combo in a domain, or None if no result."""
|
|
row = self.conn.execute(
|
|
"""SELECT pass_reached FROM combination_results
|
|
WHERE combination_id = ? AND domain_id = ?""",
|
|
(combo_id, domain_id),
|
|
).fetchone()
|
|
return row["pass_reached"] if row else None
|
|
|
|
def save_raw_estimates(
|
|
self, combo_id: int, domain_id: int, estimates: list[dict]
|
|
) -> None:
|
|
"""Save raw metric estimates (pass 2) with normalized_score=NULL.
|
|
|
|
Each dict: metric_id, raw_value, estimation_method, confidence.
|
|
"""
|
|
for e in estimates:
|
|
self.conn.execute(
|
|
"""INSERT OR REPLACE INTO combination_scores
|
|
(combination_id, domain_id, metric_id, raw_value, normalized_score,
|
|
estimation_method, confidence)
|
|
VALUES (?, ?, ?, ?, NULL, ?, ?)""",
|
|
(combo_id, domain_id, e["metric_id"], e["raw_value"],
|
|
e["estimation_method"], e["confidence"]),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_existing_result(self, combo_id: int, domain_id: int) -> dict | None:
|
|
"""Return the full combination_results row for resume logic."""
|
|
row = self.conn.execute(
|
|
"""SELECT * FROM combination_results
|
|
WHERE combination_id = ? AND domain_id = ?""",
|
|
(combo_id, domain_id),
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
# ── Admin ────────────────────────────────────────────────────
|
|
|
|
def clear_all(self) -> None:
|
|
"""Delete all data from every table in FK-safe order."""
|
|
self.conn.execute("DELETE FROM pipeline_runs")
|
|
self.conn.execute("DELETE FROM combination_results")
|
|
self.conn.execute("DELETE FROM combination_scores")
|
|
self.conn.execute("DELETE FROM combination_entities")
|
|
self.conn.execute("DELETE FROM combinations")
|
|
self.conn.execute("DELETE FROM dependencies")
|
|
self.conn.execute("DELETE FROM entities")
|
|
self.conn.execute("DELETE FROM domain_metric_weights")
|
|
self.conn.execute("DELETE FROM domain_constraints")
|
|
self.conn.execute("DELETE FROM domains")
|
|
self.conn.execute("DELETE FROM metrics")
|
|
self.conn.execute("DELETE FROM dimensions")
|
|
self.conn.commit()
|