"""Per-cluster ComBat correction.
Port of ``correct_data`` in ``R/02_batch_correct.R:356-544``. The AnnData is split
by its SOM cluster label, each sub-group is corrected with
:func:`cycombinepy.combat.run_combat`, and results are stitched back in the original
row order. Values are capped to the per-cluster min/max of the input (matching R
lines 524-531).
"""
from __future__ import annotations
from typing import Iterable
import numpy as np
import pandas as pd
from anndata import AnnData
from cycombinepy._utils import (
check_confound,
check_obs_key,
marker_matrix,
resolve_markers,
set_marker_matrix,
)
from cycombinepy.combat import run_combat
CORRECTED_LAYER = "cycombine_corrected"
def _build_model_matrix(
df_sub: pd.DataFrame,
covar: str | None,
anchor: str | None,
) -> np.ndarray | None:
"""Build a design matrix (sans intercept) from covar and/or anchor columns.
Uses :mod:`formulaic` to match R's ``stats::model.matrix`` (treatment
contrasts, drop first level).
"""
from formulaic import model_matrix
terms = [t for t in (covar, anchor) if t is not None]
if not terms:
return None
sub = df_sub[terms].astype("category")
mm = np.asarray(model_matrix(" + ".join(terms), sub), dtype=float)
# Drop the intercept column so we hand inmoose a pure covariate block.
if mm.shape[1] and np.all(mm[:, 0] == 1):
mm = mm[:, 1:]
return mm if mm.size else None
def _resolve_num_factors(
series: pd.Series,
batch: pd.Series,
design: np.ndarray | None,
) -> int:
"""Return effective number of factor levels, mirroring R lines 455-506.
- 1 if the covariate is confounded with batch
- 1 if the cluster is heavily skewed to a single level
- else the number of distinct levels.
"""
if check_confound(batch, design):
return 1
counts = series.value_counts()
total = counts.sum()
n = counts.size
if total < counts.max() + n * 5:
return 1
return n
[docs]
def correct_data(
adata: AnnData,
label_key: str = "cycombine_som",
markers: Iterable[str] | None = None,
batch_key: str = "batch",
covar: str | None = None,
anchor: str | None = None,
parametric: bool = True,
ref_batch=None,
layer: str | None = None,
out_layer: str = CORRECTED_LAYER,
copy: bool = False,
) -> AnnData | None:
"""Per-cluster ComBat batch correction.
Parameters
----------
adata
AnnData with a cluster label in ``adata.obs[label_key]`` and a batch in
``adata.obs[batch_key]``.
label_key
Column in ``adata.obs`` with the SOM cluster id (from :func:`create_som`).
markers
Var names to correct. If ``None``, uses :func:`cycombinepy.get_markers`.
batch_key
Column in ``adata.obs`` giving the batch assignment.
covar, anchor
Optional ``adata.obs`` columns used as ComBat covariates. Skew- and
confound-detection follow the R logic at lines 455-506.
parametric
Parametric vs. non-parametric ComBat prior.
ref_batch
Optional reference batch that is kept unchanged.
layer
If given, read the uncorrected matrix from this layer rather than ``X``.
out_layer
Name of the layer to store the corrected matrix in.
copy
If True, return a corrected copy; otherwise mutate in place.
"""
check_obs_key(adata, batch_key)
check_obs_key(adata, label_key)
if covar is not None:
check_obs_key(adata, covar)
if anchor is not None:
check_obs_key(adata, anchor)
markers = resolve_markers(adata, markers)
if copy:
adata = adata.copy()
X = marker_matrix(adata, markers, layer=layer) # (n_cells, n_markers)
n_cells = X.shape[0]
labels = adata.obs[label_key].astype(str).to_numpy()
batches = adata.obs[batch_key].astype(str).to_numpy()
corrected = X.copy()
for lab in pd.unique(labels):
idx = np.where(labels == lab)[0]
if idx.size == 0:
continue
sub_X = X[idx] # (n_sub, n_markers)
sub_batch = pd.Series(batches[idx])
uniq_batches = sub_batch.unique()
if uniq_batches.size <= 1:
# Only one batch in this cluster — nothing to correct. (R lines 448-452)
continue
sub_df = adata.obs.iloc[idx]
# Covar / anchor handling: determine effective level count
num_covar = 1
if covar is not None:
cov_design = _build_model_matrix(sub_df, covar, None)
num_covar = _resolve_num_factors(sub_df[covar], sub_batch, cov_design)
num_anchor = 1
if anchor is not None:
anc_design = _build_model_matrix(sub_df, None, anchor)
num_anchor = _resolve_num_factors(sub_df[anchor], sub_batch, anc_design)
# If both are non-trivial, check that their combination is not confounded
# with batch; if it is, drop anchor (R prioritises covar, lines 489-495).
if num_covar > 1 and num_anchor > 1:
joint = _build_model_matrix(sub_df, covar, anchor)
if check_confound(sub_batch, joint):
num_anchor = 1
eff_covar = covar if num_covar > 1 else None
eff_anchor = anchor if num_anchor > 1 else None
mod = _build_model_matrix(sub_df, eff_covar, eff_anchor)
# inmoose expects (n_features, n_samples)
x_t = sub_X.T
try:
corrected_sub = run_combat(
x_t,
batch=sub_batch.values,
mod=mod,
parametric=parametric,
ref_batch=ref_batch,
).T
except Exception as exc: # pragma: no cover
# If ComBat fails inside a cluster (e.g. singular cov), leave untouched.
# This matches the spirit of R's skip-on-confound handling.
import warnings
warnings.warn(
f"ComBat failed for cluster {lab!r} ({exc}); leaving uncorrected.",
RuntimeWarning,
)
continue
# Cap to per-marker min/max within this cluster (R lines 524-531).
lo = sub_X.min(axis=0)
hi = sub_X.max(axis=0)
corrected_sub = np.clip(corrected_sub, lo, hi)
corrected[idx] = corrected_sub
set_marker_matrix(adata, markers, corrected, layer=out_layer)
return adata if copy else None