""" DEMI v17-compatible MAGI service 2026 03 31 ================================ Purpose ------- This module updates the legacy MAGI agent service to use DEMI v17 core math while preserving the existing agent-facing interface: - class name remains ``MAGIService`` - public methods remain ``calculate(...)`` and ``calculate_multiple(...)`` - return value remains a single probability float per target - MySQL access pattern remains the same Compatibility choices --------------------- 1. The surrounding AI system is *not* required to change. 2. The preferred intercept adjustment is the DEMI v17 correction-factor lookup. 3. If no correction-factor lookup table is configured, the service falls back to computing the same correction factor internally from the current subgraph so the existing agent does not break. 4. EPV-balanced predictor selection from the full DEMI v17 pipeline is exposed as an *optional* compatibility flag and is disabled by default, because the legacy service scores only the patient-provided event set and enabling selection by default would change legacy behavior. Environment variables --------------------- Database (legacy names preserved) MAGI_DB_HOST MAGI_DB_USER MAGI_DB_PASSWORD MAGI_DB_NAME MAGI_DB_PORT Optional correction-factor lookup table DEMI_CORRECTION_FACTOR_TABLE_PATH MAGI_CORRECTION_FACTOR_TABLE_PATH # alias Optional service behavior DEMI_SERVICE_ENABLE_SELECTION # true/false, default false DEMI_EPV # default 5 DEMI_USER_CAP # default 500 """ import logging import os from collections import defaultdict from typing import Dict, List, Optional, Set, Tuple, Union import mysql.connector import numpy as np import pandas as pd # --------------------------------------------------------------------------- # DEMI v17 math helpers # --------------------------------------------------------------------------- def compute_or_v17_scalar(a: float, b: float, c: float, d: float) -> float: """DEMI v17 adjusted OR for a single 2×2 table. Returns exp-scale OR. Returns np.nan only when a == 0 and c == 0. """ if any(np.isnan(x) for x in [a, b, c, d]): return np.nan if a == 0 and c == 0: return np.nan thresh = 30.0 ab = a + b ac = a + c both_large = (ab >= thresh) and (ac >= thresh) if a == 0: if b == 0: return float(1.0 / (c + 1.0)) return float(1.0 / min(b, c)) if both_large: if b == 0 and c == 0: return float(a + 1.0) if b == 0: return float(ac / (b + 1.0)) if c == 0: return float(ab / (c + 1.0)) return float((a * d) / (b * c)) if b == 0 and c == 0: return float(a + 1.0) if b == 0: return float(ac / (b + 1.0)) if c == 0: return float(ab / (c + 1.0)) raw = (a * d) / (b * c) or1 = ab / c or2 = ac / b if raw >= 1.0: return float(max(or1, or2)) return float(min(or1, or2)) def compute_or_vectorized( df: pd.DataFrame, se_thr: float = 1.0, min_cell_thr: int = 5, cap_thr: int = 30, is_diagnosis: Union[bool, np.ndarray] = True, is_outcome_target: Union[bool, np.ndarray] = False, ) -> np.ndarray: """Vectorized DEMI v17 adjusted OR. Legacy parameters are accepted for backward compatibility and ignored. """ a = df.get("n_code_target", pd.Series(np.nan, index=df.index)).to_numpy(float) b = df.get("n_code_no_target", pd.Series(np.nan, index=df.index)).to_numpy(float) c = df.get("n_target_no_code", pd.Series(np.nan, index=df.index)).to_numpy(float) d = df.get("n_no_code_no_target", pd.Series(np.nan, index=df.index)).to_numpy(float) thresh = 30.0 ab = a + b ac = a + c out = np.full(len(a), np.nan, dtype=float) valid = ~(np.isnan(a) | np.isnan(b) | np.isnan(c) | np.isnan(d)) defined = valid & ~((a == 0) & (c == 0)) both_large = defined & (ab >= thresh) & (ac >= thresh) small = defined & ~both_large a0 = defined & (a == 0) a0_b0 = a0 & (b == 0) a0_ok = a0 & (b > 0) out[a0_b0] = 1.0 / (c[a0_b0] + 1.0) min_bc = np.minimum(b, c) safe_a0 = a0_ok & (min_bc > 0) out[safe_a0] = 1.0 / min_bc[safe_a0] bl_a = both_large & (a > 0) m = bl_a & (b == 0) & (c == 0) out[m] = a[m] + 1.0 m = bl_a & (b == 0) & (c > 0) out[m] = ac[m] / (b[m] + 1.0) m = bl_a & (b > 0) & (c == 0) out[m] = ab[m] / (c[m] + 1.0) m = bl_a & (b > 0) & (c > 0) out[m] = (a[m] * d[m]) / (b[m] * c[m]) sm_a = small & (a > 0) m = sm_a & (b == 0) & (c == 0) out[m] = a[m] + 1.0 m = sm_a & (b == 0) & (c > 0) out[m] = ac[m] / (b[m] + 1.0) m = sm_a & (b > 0) & (c == 0) out[m] = ab[m] / (c[m] + 1.0) m = sm_a & (b > 0) & (c > 0) raw = (a[m] * d[m]) / (b[m] * c[m]) or1 = ab[m] / c[m] or2 = ac[m] / b[m] out[m] = np.where(raw >= 1.0, np.maximum(or1, or2), np.minimum(or1, or2)) return out def _compute_pref( outcome_rows: pd.DataFrame, pref_override: Optional[float] = None, ) -> Tuple[float, float, Optional[float], Optional[float]]: """Compute baseline prevalence and prior odds from outcome rows.""" n_y1: Optional[float] = None n_y0: Optional[float] = None if len(outcome_rows) > 0: row = outcome_rows.iloc[0] a = float(row.get("n_code_target", 0) or 0) b = float(row.get("n_code_no_target", 0) or 0) c = float(row.get("n_target_no_code", 0) or 0) d = float(row.get("n_no_code_no_target", 0) or 0) n_y1 = a + c n_y0 = b + d if pref_override is not None: pref = float(pref_override) n_y1_eff = pref n_y0_eff = 1.0 - pref elif n_y1 is not None and n_y0 is not None and (n_y1 + n_y0) > 0: pref = n_y1 / (n_y1 + n_y0) n_y1_eff = n_y1 n_y0_eff = n_y0 else: pref = 0.5 n_y1_eff = 1.0 n_y0_eff = 1.0 prior_odds = n_y1_eff / n_y0_eff if n_y0_eff > 0 else 1.0 return pref, prior_odds, n_y1, n_y0 def _partition_kplus_kminus(nodes: List[int], t_val: pd.Series) -> Tuple[List[int], List[int]]: """Partition nodes by sign of total effect.""" k_plus = [k for k in nodes if np.isfinite(t_val.loc[k]) and t_val.loc[k] > 0.0] k_minus = [k for k in nodes if np.isfinite(t_val.loc[k]) and t_val.loc[k] < 0.0] return k_plus, k_minus def _run_recursion_within_group( group: List[int], t_val: pd.Series, lambda_l: Dict[int, pd.Series], sign: int, prev_j: Optional[pd.Series] = None, ) -> pd.Series: """DEMI v17 backward recursion within K+ or K-.""" if not group: return pd.Series(dtype=float) d_val = pd.Series(np.nan, index=group, dtype=float) d_val.iloc[-1] = t_val.loc[group[-1]] for i in range(len(group) - 2, -1, -1): k = group[i] t_k = float(t_val.loc[k]) lam_k = lambda_l.get(k, pd.Series(dtype=float)) indirect = 0.0 for j in group[i + 1 :]: lam_kj = float(lam_k.get(j, 0.0)) if lam_kj != 0.0 and prev_j is not None: pj = float(prev_j.get(j, 0.0)) lam_kj = lam_kj - pj if lam_kj != 0.0: indirect += lam_kj * float(d_val.loc[j]) raw_d = t_k - indirect d_val.loc[k] = max(0.0, raw_d) if sign == 1 else min(0.0, raw_d) return d_val def select_top_predictors( t_series: pd.Series, n_cases: int, *, epv: int = 5, user_cap: int = 500, min_a: int = 0, a_series: Optional[pd.Series] = None, ) -> pd.Index: """EPV-balanced dual-tail selection from DEMI v17.""" t = t_series.dropna() t = t[t != 0.0] n_max = min(max(n_cases // epv, 1), user_cap) n_each = n_max // 2 harmful = t[t > 0].sort_values(ascending=False) protective = t[t < 0].sort_values(ascending=True) n_harm = min(n_each, len(harmful)) n_prot = min(n_each, len(protective)) spare = n_max - n_harm - n_prot if spare > 0: if n_harm < n_each: n_prot = min(len(protective), n_prot + spare) elif n_prot < n_each: n_harm = min(len(harmful), n_harm + spare) return harmful.index[:n_harm].append(protective.index[:n_prot]) def _stable_log_correction_factor(prev_j: pd.Series, d_val: pd.Series) -> float: """Stable log(correction_factor) using DEMI v17 math.""" log_cf = 0.0 common = [k for k in d_val.index if k in prev_j.index] for k in common: d_k = float(d_val.loc[k]) p_k = float(prev_j.loc[k]) if not np.isfinite(d_k) or p_k < 0.0 or p_k > 1.0: continue if p_k == 0.0: continue if p_k == 1.0: log_cf += d_k continue log_cf += float(np.logaddexp(np.log(p_k) + d_k, np.log1p(-p_k))) return float(log_cf) _CORRECTION_FACTOR_RENAME = { "target": "target_concept_code_int", "target_int": "target_concept_code_int", "n_positive": "n_outcome_yes", "num_positive": "n_outcome_yes", "number_positive": "n_outcome_yes", "n_negative": "n_outcome_no", "num_negative": "n_outcome_no", "number_negative": "n_outcome_no", "n_missing": "n_missing_value", "num_missing": "n_missing_value", "number_missing": "n_missing_value", "missing_value": "n_missing_value", "n_factors_or_gt1": "n_factors_or_gt1", "num_factors_or_gt1": "n_factors_or_gt1", "number_factors_or_gt1": "n_factors_or_gt1", "n_factors_or_lt1": "n_factors_or_lt1", "num_factors_or_lt1": "n_factors_or_lt1", "number_factors_or_lt1": "n_factors_or_lt1", "correction_factor": "correction_factor", "e_phi": "correction_factor", "E_phi": "correction_factor", } def _normalize_correction_factor_table(cf_table: Union[str, pd.DataFrame]) -> pd.DataFrame: """Load and normalize a correction-factor lookup table.""" if isinstance(cf_table, str): df = pd.read_csv(cf_table) else: df = cf_table.copy() rename = {} for col in df.columns: key = str(col).strip() norm = key.lower().replace(" ", "_") if norm in _CORRECTION_FACTOR_RENAME: rename[col] = _CORRECTION_FACTOR_RENAME[norm] if rename: df = df.rename(columns=rename) required = [ "target_concept_code_int", "n_outcome_yes", "n_outcome_no", "n_missing_value", "n_factors_or_gt1", "n_factors_or_lt1", "correction_factor", ] missing = [c for c in required if c not in df.columns] if missing: raise KeyError(f"Correction-factor table missing required columns: {missing}") for col in required[:-1]: df[col] = pd.to_numeric(df[col], errors="raise").astype(int) df["correction_factor"] = pd.to_numeric(df["correction_factor"], errors="raise").astype(float) return df[required].copy() def lookup_correction_factor( correction_factor_table: Union[str, pd.DataFrame], *, target_concept_code_int: int, n_outcome_yes: int, n_outcome_no: int, n_missing_value: int, n_factors_or_gt1: int, n_factors_or_lt1: int, ) -> Tuple[float, pd.Series]: """Exact-key lookup of the DEMI v17 correction factor.""" df = _normalize_correction_factor_table(correction_factor_table) mask = ( (df["target_concept_code_int"] == int(target_concept_code_int)) & (df["n_outcome_yes"] == int(n_outcome_yes)) & (df["n_outcome_no"] == int(n_outcome_no)) & (df["n_missing_value"] == int(n_missing_value)) & (df["n_factors_or_gt1"] == int(n_factors_or_gt1)) & (df["n_factors_or_lt1"] == int(n_factors_or_lt1)) ) hits = df.loc[mask].copy() if hits.empty: raise KeyError( "No correction factor found for key " f"(target={int(target_concept_code_int)}, " f"n_outcome_yes={int(n_outcome_yes)}, n_outcome_no={int(n_outcome_no)}, " f"n_missing_value={int(n_missing_value)}, " f"n_factors_or_gt1={int(n_factors_or_gt1)}, n_factors_or_lt1={int(n_factors_or_lt1)})" ) unique_cf = hits["correction_factor"].drop_duplicates() if len(unique_cf) != 1: raise ValueError("Multiple correction-factor rows matched the same key with different values.") row = hits.iloc[0] cf = float(row["correction_factor"]) if not np.isfinite(cf) or cf <= 0.0: raise ValueError(f"Invalid correction_factor in lookup table: {cf}") return cf, row # --------------------------------------------------------------------------- # Updated service with legacy public API # --------------------------------------------------------------------------- def build_snomed_ancestor_sets( relationship_path: str, concept_ids: Optional[Set[int]] = None, ) -> Dict[int, Set[int]]: """Build transitive SNOMED ancestor sets from a Full or Snapshot RF2 file. This matches the DEMI v17 helper so the service can reuse the same ancestor-exclusion behavior when SNOMED rollup filtering is enabled. """ isa = 116680003 rel = pd.read_csv( relationship_path, sep="\t", usecols=["id", "effectiveTime", "active", "typeId", "sourceId", "destinationId"], dtype=str, ) rel["_t"] = rel["effectiveTime"].astype(int) rel = rel.loc[rel.groupby("id")["_t"].idxmax()].copy() rel = rel[(rel["active"] == "1") & (rel["typeId"] == str(isa))] rel["sourceId"] = rel["sourceId"].astype(int) rel["destinationId"] = rel["destinationId"].astype(int) if concept_ids is not None: seeds: Set[int] = set(concept_ids) frontier: Set[int] = set(seeds) all_relevant: Set[int] = set() while frontier: parents_of_frontier = set( rel.loc[rel["sourceId"].isin(frontier), "destinationId"].unique() ) new_nodes = parents_of_frontier - all_relevant - seeds all_relevant.update(frontier) frontier = new_nodes keep = all_relevant | seeds rel = rel[rel["sourceId"].isin(keep) | rel["destinationId"].isin(keep)] parents_map: Dict[int, Set[int]] = defaultdict(set) for _, row in rel.iterrows(): parents_map[int(row["sourceId"])].add(int(row["destinationId"])) all_concepts: Set[int] = set(parents_map.keys()) | { p for ps in parents_map.values() for p in ps } ancestors: Dict[int, Set[int]] = { c: set(parents_map.get(c, set())) for c in all_concepts } changed = True while changed: changed = False for c in all_concepts: new_anc: Set[int] = set() for p in list(ancestors[c]): new_anc.update(ancestors.get(p, set())) before = len(ancestors[c]) ancestors[c].update(new_anc) if len(ancestors[c]) > before: changed = True return ancestors class MAGIService: """Legacy agent-facing service upgraded to DEMI v17 core logic. Public API is intentionally preserved. """ def __init__( self, correction_factor_table: Optional[Union[str, pd.DataFrame]] = None, enable_selection: Optional[bool] = None, epv: Optional[int] = None, user_cap: Optional[int] = None, snomed_relationship_path: Optional[str] = None, ): self.logger = logging.getLogger(__name__) self.db_config = { "host": os.getenv("MAGI_DB_HOST", "rapidimprovement.ai"), "user": os.getenv("MAGI_DB_USER", "aiintakeuser"), "password": os.getenv("MAGI_DB_PASSWORD", ""), "database": os.getenv("MAGI_DB_NAME", "magidbv2"), "port": int(os.getenv("MAGI_DB_PORT", "3306")), } env_cf = ( os.getenv("DEMI_CORRECTION_FACTOR_TABLE_PATH", "") or os.getenv("MAGI_CORRECTION_FACTOR_TABLE_PATH", "") ) self.correction_factor_table = correction_factor_table or env_cf or None env_snomed = ( os.getenv("DEMI_SNOMED_RELATIONSHIP_PATH", "") or os.getenv("MAGI_SNOMED_RELATIONSHIP_PATH", "") or os.getenv("SNOMED_SNAPSHOT_PATH", "") or os.getenv("SNOMED_REL_PATH", "") ) self.snomed_relationship_path = snomed_relationship_path or env_snomed or None self._snomed_ancestors: Optional[Dict[int, Set[int]]] = None self._snomed_cache_loaded = False if enable_selection is None: raw = os.getenv("DEMI_SERVICE_ENABLE_SELECTION", "false").strip().lower() self.enable_selection = raw in {"1", "true", "yes", "y", "on"} else: self.enable_selection = bool(enable_selection) self.epv = int(epv if epv is not None else os.getenv("DEMI_EPV", "5")) self.user_cap = int(user_cap if user_cap is not None else os.getenv("DEMI_USER_CAP", "500")) self.last_result: Optional[Dict[str, Union[float, int, str, pd.Series, pd.DataFrame, list]]] = None self.logger.info( "MAGI database configured: %s:%s/%s", self.db_config["host"], self.db_config["port"], self.db_config["database"], ) self.logger.info( "Using DEMI v17-compatible service core (legacy public API preserved)." ) if self.correction_factor_table: self.logger.info("Correction-factor lookup configured.") else: self.logger.warning( "No correction-factor lookup table configured. The service will " "fall back to internal correction-factor calculation for compatibility." ) if self.enable_selection: self.logger.info("Optional EPV-balanced predictor selection is enabled.") if self.snomed_relationship_path: self._load_snomed_ancestors_once() if self._snomed_ancestors is not None: self.logger.info( "SNOMED ancestor cache initialized from %s.", self.snomed_relationship_path, ) else: self.logger.warning( "SNOMED relationship path was configured but ancestor cache was not built. " "Scoring will proceed without SNOMED ancestor exclusion." ) # ------------------------------------------------------------------ # Database helpers # ------------------------------------------------------------------ def _get_connection(self): try: return mysql.connector.connect(**self.db_config) except Exception as error: self.logger.error("Error connecting to database: %s", error) return None def _close_connection(self, connection) -> None: if connection: connection.close() def _get_table_columns(self, table_name: str) -> Set[str]: connection = self._get_connection() if not connection: return set() try: cursor = connection.cursor() cursor.execute(f"SHOW COLUMNS FROM {table_name}") return {str(row[0]) for row in cursor.fetchall()} except Exception as error: self.logger.warning("Could not inspect table %s: %s", table_name, error) return set() finally: self._close_connection(connection) def get_concept_id(self, concept_code: str) -> Optional[int]: connection = self._get_connection() if not connection: return None try: cursor = connection.cursor() cursor.execute( "SELECT concept_code_int FROM concept_names WHERE concept_code = %s LIMIT 1", (concept_code,), ) result = cursor.fetchone() if result: return int(result[0]) self.logger.warning("No concept found for code: %s", concept_code) return None except Exception as error: self.logger.error("Error getting concept ID: %s", error) return None finally: self._close_connection(connection) def _get_magi_dataframe(self, concept_ids: List[int]) -> pd.DataFrame: connection = self._get_connection() if not connection: return pd.DataFrame() try: placeholders = ",".join(["%s"] * len(concept_ids)) query = f""" SELECT * FROM magi_counts_published WHERE target_concept_code_int IN ({placeholders}) AND concept_code_int IN ({placeholders}) """ params = concept_ids + concept_ids return pd.read_sql_query(query, connection, params=params) except Exception as error: self.logger.error("Error getting MAGI/DEMI data: %s", error) return pd.DataFrame() finally: self._close_connection(connection) def _augment_concept_metadata(self, df: pd.DataFrame, concept_ids: List[int]) -> pd.DataFrame: """Best-effort merge of concept_code / concept_name metadata. This is optional and should never break scoring if metadata columns are absent. """ if df.empty: return df columns = self._get_table_columns("concept_names") if not columns or "concept_code_int" not in columns or "concept_code" not in columns: return df wanted = ["concept_code_int", "concept_code"] for optional_col in ["concept_name", "standard_concept_name"]: if optional_col in columns: wanted.append(optional_col) connection = self._get_connection() if not connection: return df try: placeholders = ",".join(["%s"] * len(concept_ids)) query = f"SELECT {', '.join(wanted)} FROM concept_names WHERE concept_code_int IN ({placeholders})" meta = pd.read_sql_query(query, connection, params=concept_ids) if meta.empty: return df rename = {"concept_code": "concept_code"} for col in ["concept_name", "standard_concept_name"]: if col in meta.columns: rename[col] = col meta = meta.rename(columns=rename) merged = df.merge(meta, how="left", on="concept_code_int") return merged except Exception as error: self.logger.warning("Could not augment concept metadata: %s", error) return df finally: self._close_connection(connection) def _load_snomed_ancestors_once(self) -> None: """Load the SNOMED ancestor map at most once for the life of the service.""" if self._snomed_cache_loaded: return self._snomed_cache_loaded = True if not self.snomed_relationship_path: return try: self._snomed_ancestors = build_snomed_ancestor_sets(self.snomed_relationship_path) except Exception as error: self._snomed_ancestors = None self.logger.warning( "Could not initialize SNOMED ancestor cache from %s: %s", self.snomed_relationship_path, error, ) def _snomed_ancestor_exclusion_mask_cached( self, concept_codes: pd.Series, concept_code_strs: Optional[pd.Series], ) -> pd.Series: """Cached request-time SNOMED ancestor exclusion mask.""" if concept_codes.empty: return pd.Series(False, index=concept_codes.index) self._load_snomed_ancestors_once() if not isinstance(self._snomed_ancestors, dict): return pd.Series(False, index=concept_codes.index) if concept_code_strs is not None: is_snomed = concept_code_strs.astype(str).str.contains("SNOMED", na=False, regex=False) else: is_snomed = pd.Series(True, index=concept_codes.index) snomed_int_codes: Set[int] = set( concept_codes[is_snomed].dropna().astype(int).unique() ) if not snomed_int_codes: return pd.Series(False, index=concept_codes.index) to_exclude: Set[int] = set() ancestors = self._snomed_ancestors for code_a in snomed_int_codes: for code_b in snomed_int_codes: if code_a != code_b and code_a in ancestors.get(code_b, set()): to_exclude.add(code_a) break return concept_codes.isin(to_exclude) & is_snomed def _apply_rollup_exclusion( self, df: pd.DataFrame, outcome: Optional[int] = None, ) -> Tuple[pd.DataFrame, int, int, int]: """Apply DEMI v17 rollup exclusion rules when metadata is available.""" if df.empty: return df, 0, 0, 0 name_col = None for candidate in ("concept_name", "standard_concept_name"): if candidate in df.columns: name_col = candidate break n_rollup = 0 n_rollup_snomed = 0 n_rollup_pattern = 0 if self.snomed_relationship_path is not None and "concept_code_int" in df.columns: code_str_col = df["concept_code"] if "concept_code" in df.columns else None snomed_exclude = self._snomed_ancestor_exclusion_mask_cached( df["concept_code_int"], code_str_col, ) if outcome is not None: snomed_exclude = snomed_exclude & (df["concept_code_int"] != outcome) n_rollup_snomed = int(snomed_exclude.sum()) df = df[~snomed_exclude].copy() if name_col is not None: name = df[name_col].fillna("") rule1_pattern = ( r'\bfinding\b|\bfindings\b|\bdisorder of\b|^disorder of\b' r'|\bstructure\b|\bobservation\b|\bcondition\b|^disease$' ) rule1_mask = name.str.contains(rule1_pattern, case=False, regex=True, na=False) generic_ending = ( r'\bdisorder\b|\bdisease\b|\bsyndrome\b' r'|\bcondition\b|\bproblem\b|\bsymptom\b' ) has_generic = name.str.contains(generic_ending, case=False, regex=True, na=False) if "n_target_no_code" in df.columns and "n_code_target" in df.columns: ac_col = df["n_code_target"].fillna(0) + df["n_target_no_code"].fillna(0) c_rate = df["n_target_no_code"].fillna(0) / ac_col.replace(0, np.nan) rule2_mask = has_generic & (c_rate < 0.03) else: rule2_mask = pd.Series(False, index=df.index) rollup_mask = rule1_mask | rule2_mask if outcome is not None: rollup_mask = rollup_mask & (df["concept_code_int"] != outcome) if self.snomed_relationship_path and "concept_code" in df.columns: already_handled = df["concept_code"].astype(str).str.contains("SNOMED", na=False, regex=False) rollup_mask = rollup_mask & ~already_handled n_rollup_pattern = int(rollup_mask.sum()) df = df[~rollup_mask].copy() n_rollup = n_rollup_pattern + n_rollup_snomed return df, n_rollup, n_rollup_snomed, n_rollup_pattern # ------------------------------------------------------------------ # Public API (unchanged) # ------------------------------------------------------------------ def calculate( self, events: List[str], target: str, minimum_count_threshold: int = 15, ) -> float: """Calculate probability using the legacy public API.""" try: event_ids = [ eid for eid in (self.get_concept_id(e) for e in events) if eid is not None ] target_id = self.get_concept_id(target) if not event_ids or target_id is None: self.logger.warning("Insufficient valid concept IDs for DEMI calculation") return 0.0 return self._run_magi_algorithm(event_ids, target_id, minimum_count_threshold) except Exception as error: self.logger.error("Error in DEMI calculation: %s", error) return 0.0 def calculate_multiple( self, events: List[str], targets: Dict[str, str], ) -> Dict[str, float]: """Calculate probabilities for multiple targets.""" return { drug_name: self.calculate(events, target_code) for drug_name, target_code in targets.items() } # ------------------------------------------------------------------ # Internal algorithm runner (name preserved) # ------------------------------------------------------------------ def _run_magi_algorithm( self, event_ids: List[int], outcome_id: int, minimum_count_threshold: int = 15, ) -> float: all_concept_ids = list(set(event_ids + [outcome_id])) if len(all_concept_ids) < 2: return 0.0 magi_dataframe = self._get_magi_dataframe(all_concept_ids) if magi_dataframe.empty: self.logger.warning("No MAGI/DEMI data found for given concepts") return 0.0 required_columns = [ "target_concept_code_int", "concept_code_int", "n_code_target", "n_code_no_target", "n_target_no_code", "n_code", "n_code_before_target", "n_target_before_code", ] missing = [c for c in required_columns if c not in magi_dataframe.columns] if missing: self.logger.error("Missing required columns: %s", missing) return 0.0 if "n_no_code_no_target" not in magi_dataframe.columns: if "n_no_target" in magi_dataframe.columns: magi_dataframe["n_no_code_no_target"] = ( pd.to_numeric(magi_dataframe["n_no_target"], errors="coerce").fillna(0.0) - pd.to_numeric(magi_dataframe["n_code_no_target"], errors="coerce").fillna(0.0) ).clip(lower=0.0) else: self.logger.error( "Missing both n_no_code_no_target and n_no_target; cannot derive 2x2 table." ) return 0.0 concept_id_set = set(all_concept_ids) magi_dataframe = magi_dataframe[ magi_dataframe["target_concept_code_int"].isin(concept_id_set) & magi_dataframe["concept_code_int"].isin(concept_id_set) ].copy() if magi_dataframe.empty: return 0.0 magi_dataframe = self._augment_concept_metadata(magi_dataframe, all_concept_ids) try: predict_proba = self._analyze_causal_sequence( magi_dataframe, events=all_concept_ids, force_outcome=outcome_id, lambda_min_count=minimum_count_threshold, fix_suppression=True, ) input_dict = {event_id: 1 for event_id in event_ids} return float(predict_proba(input_dict)) except Exception as error: import traceback self.logger.error("Error running DEMI v17-compatible algorithm: %s", error) self.logger.error(traceback.format_exc()) return 0.0 # ------------------------------------------------------------------ # Core algorithm (internal name preserved) # ------------------------------------------------------------------ def _analyze_causal_sequence( self, df_in: pd.DataFrame, *, events: List[int] = None, force_outcome: int = None, lambda_min_count: int = 15, fix_suppression: bool = True, ): """DEMI v17-compatible internal scoring routine. This preserves the old service method shape by returning only a ``predict_proba`` callable, while internally using DEMI v17 math. ``fix_suppression`` is accepted for backward compatibility only. DEMI v17 suppression is enforced inside the K+/K− recursion and is not disabled by passing ``fix_suppression=False``. """ df = df_in.copy() df["target_concept_code_int"] = pd.to_numeric( df["target_concept_code_int"], errors="coerce" ).astype("Int64") df["concept_code_int"] = pd.to_numeric( df["concept_code_int"], errors="coerce" ).astype("Int64") if events is None: targets = df["target_concept_code_int"].dropna().unique() codes = df["concept_code_int"].dropna().unique() events = sorted(set(int(x) for x in targets) | set(int(x) for x in codes)) else: events = sorted(set(int(e) for e in events)) numeric_cols = [ "n_code_target", "n_code", "n_target", "n_no_target", "n_code_before_target", "n_target_before_code", "n_code_no_target", "n_target_no_code", "n_no_code_no_target", ] for col in numeric_cols: if col in df.columns: df[col] = pd.to_numeric(df[col], errors="coerce") if not fix_suppression: self.logger.warning( "fix_suppression=False was provided, but DEMI v17 suppression remains active in K+/K− recursion." ) n_rollup = 0 n_rollup_snomed = 0 n_rollup_pattern = 0 try: df, n_rollup, n_rollup_snomed, n_rollup_pattern = self._apply_rollup_exclusion( df, outcome=force_outcome, ) except Exception as exc: self.logger.warning("Rollup exclusion failed, proceeding without it: %s", exc) n_rollup = n_rollup_snomed = n_rollup_pattern = 0 group_cols = [ "n_target_before_code", "n_code_before_target", "n_code_target", "n_code", "n_target", "n_no_target", "n_code_no_target", "n_target_no_code", "n_no_code_no_target", ] group_cols = [c for c in group_cols if c in df.columns] edge_index = df.groupby( ["target_concept_code_int", "concept_code_int"], as_index=True, )[group_cols].sum() scores: Dict[int, float] = {} for zi in events: if zi not in edge_index.index.get_level_values(0): scores[zi] = 0.0 continue try: pairs = edge_index.xs(zi, level="target_concept_code_int") scores[zi] = float( (pairs["n_target_before_code"].fillna(0.0) - pairs["n_code_before_target"].fillna(0.0)).sum() ) except KeyError: scores[zi] = 0.0 sorted_scores = pd.Series(scores).sort_values(ascending=False) if force_outcome is not None: if force_outcome in sorted_scores.index: outcome = int(force_outcome) else: raise ValueError( f"force_outcome={force_outcome} not found in events. " f"Available: {sorted(sorted_scores.index.tolist())}" ) else: outcome = int(sorted_scores.index[-1]) temporal_order = [int(ev) for ev in sorted_scores.index if ev != outcome] + [outcome] nodes = temporal_order[:-1] pos_by_event = {ev: i for i, ev in enumerate(temporal_order)} e_rows, e_keys = [], [] for k in nodes: key = (outcome, int(k)) if key in edge_index.index: e_rows.append(edge_index.loc[key]) e_keys.append(k) total_or = pd.Series(1.0, index=nodes, dtype=float) t_val = pd.Series(0.0, index=nodes, dtype=float) if e_rows: edf = pd.DataFrame(e_rows, index=e_keys) or_vals = compute_or_vectorized(edf) for k, ov in zip(e_keys, or_vals): if np.isfinite(ov) and ov > 0: total_or.loc[k] = float(ov) t_val.loc[k] = float(np.log(ov)) outcome_rows = df[df["target_concept_code_int"] == outcome] pref, prior_odds_data, n_y1, n_y0 = _compute_pref(outcome_rows) if self.enable_selection: selected_idx = select_top_predictors( t_val, n_cases=int(n_y1) if n_y1 is not None else 1, epv=self.epv, user_cap=self.user_cap, ) selected_set = set(selected_idx.tolist()) nodes = [n for n in nodes if n in selected_set] n_total = float(n_y1 + n_y0) if (n_y1 is not None and n_y0 is not None) else 0.0 prev_j_dict: Dict[int, float] = {} if n_total > 0: for j in nodes: key_jy = (outcome, int(j)) key_yj = (int(j), outcome) n_j = 0.0 for key in (key_jy, key_yj): if key in edge_index.index: row = edge_index.loc[key] n_j = float(row.get("n_code", 0.0)) if n_j == 0.0: n_j = ( float(row.get("n_code_target", 0.0)) + float(row.get("n_code_no_target", 0.0)) ) if n_j > 0: break prev_j_dict[j] = n_j / n_total if n_j > 0 else 0.0 prev_j_series = pd.Series(prev_j_dict, dtype=float) lambda_l: Dict[int, pd.Series] = {} for k in nodes: pos_k = pos_by_event[k] downstream = [int(ev) for ev in temporal_order[pos_k + 1 : -1]] lam: Dict[int, float] = {} for j in downstream: if j not in nodes: continue key = (int(j), int(k)) if key not in edge_index.index: lam[j] = 0.0 continue row = edge_index.loc[key] n_kj = float(row.get("n_code_target", 0.0)) n_k = float(row.get("n_code", 0.0)) if n_k == 0.0: n_k = ( float(row.get("n_code_target", 0.0)) + float(row.get("n_code_no_target", 0.0)) ) lam[j] = ( 0.0 if n_k < lambda_min_count else float(np.clip(n_kj / n_k if n_k > 0 else 0.0, 0.0, 1.0)) ) lambda_l[k] = pd.Series(lam, dtype=float) k_plus, k_minus = _partition_kplus_kminus(nodes, t_val) d_plus = _run_recursion_within_group(k_plus, t_val, lambda_l, sign=+1, prev_j=prev_j_series) d_minus = _run_recursion_within_group(k_minus, t_val, lambda_l, sign=-1, prev_j=prev_j_series) d_val = pd.Series(0.0, index=nodes, dtype=float) for k, v in d_plus.items(): d_val.loc[k] = v if np.isfinite(v) else float(t_val.loc[k]) for k, v in d_minus.items(): d_val.loc[k] = v if np.isfinite(v) else float(t_val.loc[k]) p_marginal = float(n_y1) / float(n_y1 + n_y0) if (n_y1 and n_y0 and n_y1 + n_y0 > 0) else 0.5 prior_odds = p_marginal / (1.0 - p_marginal) beta_0_prior = float(np.log(prior_odds)) n_factors_or_gt1 = int(len(k_plus)) n_factors_or_lt1 = int(len(k_minus)) correction_factor_source = "internal_fallback" correction_factor_row = None try: if self.correction_factor_table is not None: correction_factor, correction_factor_row = lookup_correction_factor( self.correction_factor_table, target_concept_code_int=outcome, n_outcome_yes=int(n_y1) if n_y1 is not None else 0, n_outcome_no=int(n_y0) if n_y0 is not None else 0, n_missing_value=0, n_factors_or_gt1=n_factors_or_gt1, n_factors_or_lt1=n_factors_or_lt1, ) correction_factor_source = "lookup_table" else: raise KeyError("No correction_factor_table configured") except Exception as error: self.logger.warning( "Falling back to internal correction-factor calculation: %s", error ) correction_factor = float(np.exp(_stable_log_correction_factor(prev_j_series, d_val))) if not np.isfinite(correction_factor) or correction_factor <= 0.0: correction_factor = 1.0 beta_0 = beta_0_prior - float(np.log(correction_factor)) beta_vals = d_val.astype(float) predictors = list(beta_vals.index) beta_vec = beta_vals.values def predict_proba(Z): def sigmoid(x): return 1.0 / (1.0 + np.exp(-np.clip(x, -700, 700))) if isinstance(Z, pd.DataFrame): x = Z.reindex(columns=predictors, fill_value=0.0).astype(float).to_numpy() return sigmoid(beta_0 + x @ beta_vec) if isinstance(Z, (dict, pd.Series)): x = np.array([float(Z.get(p, 0.0)) for p in predictors]) return float(sigmoid(beta_0 + float(x @ beta_vec))) arr = np.asarray(Z, dtype=float) if arr.ndim == 1: return float(sigmoid(beta_0 + float(arr @ beta_vec))) if arr.ndim == 2: return sigmoid(beta_0 + arr @ beta_vec) raise ValueError("Z must be 1-D or 2-D array-like") self.last_result = { "outcome": outcome, "sorted_scores": sorted_scores, "temporal_order": temporal_order, "predictors": predictors, "K_plus": k_plus, "K_minus": k_minus, "total_or": total_or.reindex(nodes), "T_val": t_val.reindex(nodes), "D_val": d_val, "lambda_l": lambda_l, "prev_j": prev_j_series, "n_outcome_yes": n_y1, "n_outcome_no": n_y0, "pref": pref, "prior_odds_data": prior_odds_data, "prior_odds": prior_odds, "beta_0_prior": beta_0_prior, "beta_0": beta_0, "correction_factor": correction_factor, "correction_factor_source": correction_factor_source, "correction_factor_row": correction_factor_row.to_dict() if correction_factor_row is not None else None, "n_rollup_excluded": n_rollup, "n_rollup_snomed": n_rollup_snomed, "n_rollup_pattern": n_rollup_pattern, } return predict_proba class DEMIService(MAGIService): """Optional clearer alias; does not affect legacy imports.""" pass