Source code for cycombinepy.batch_correct

"""High-level ``batch_correct`` orchestrator.

Port of ``batch_correct`` in ``R/02_batch_correct.R:66-210``. Runs the full
cyCombine pipeline: batch-wise normalize → SOM clustering → per-cluster ComBat
correction. Supports iterative correction with multiple SOM grid sizes by
passing ``xdim``/``ydim`` as sequences.
"""

from __future__ import annotations

from typing import Iterable, Sequence

import numpy as np
from anndata import AnnData

from cycombinepy._utils import marker_matrix, resolve_markers, set_marker_matrix
from cycombinepy.cluster import create_som
from cycombinepy.correct import CORRECTED_LAYER, correct_data
from cycombinepy.normalize import NormMethod, TiesMethod, normalize


def _as_list(v) -> list:
    if isinstance(v, (list, tuple, np.ndarray)):
        return list(v)
    return [v]


[docs] def batch_correct( adata: AnnData, markers: Iterable[str] | None = None, batch_key: str = "batch", label_key: str = "cycombine_som", xdim: int | Sequence[int] = 8, ydim: int | Sequence[int] = 8, rlen: int = 10, seed: int = 473, n_clusters: int | None = None, norm_method: NormMethod = "scale", ties_method: TiesMethod = "average", covar: str | None = None, anchor: str | None = None, ref_batch=None, parametric: bool = True, out_layer: str = CORRECTED_LAYER, copy: bool = False, ) -> AnnData | None: """Full cyCombine pipeline: normalize → SOM → per-cluster ComBat. Parameters ---------- adata Input AnnData. ``adata.X`` is assumed to already be on an appropriate scale (e.g. post-asinh for cytometry). markers Var names to normalize/cluster/correct. Defaults to :func:`cycombinepy.get_markers`. batch_key Column in ``adata.obs`` holding batch assignments. label_key Column in ``adata.obs`` to write cluster labels to. xdim, ydim SOM grid dimensions. Sequences trigger iterative correction: for each ``(x, y)`` pair, re-normalize, re-cluster, and re-correct. rlen SOM training passes (forwarded to FlowSOM if supported). seed FlowSOM random seed. n_clusters If set, metacluster the SOM nodes into this many clusters. norm_method Normalization method used for clustering. See :func:`cycombinepy.normalize`. ties_method Tie-breaking rule for ``norm_method="rank"``. covar, anchor, ref_batch, parametric Forwarded to :func:`cycombinepy.correct_data`. out_layer Layer name to store the corrected matrix in. copy If True, return a corrected copy; otherwise mutate in place. """ if copy: adata = adata.copy() markers = resolve_markers(adata, markers) xdims = _as_list(xdim) ydims = _as_list(ydim) if len(xdims) != len(ydims): raise ValueError("xdim and ydim must have the same length") # Working copy of the marker matrix that accumulates corrections between # iterations. Clustering sees a normalized view; correction sees the current # unnormalized working state. working = marker_matrix(adata, markers).copy() scratch = adata.copy() for x, y in zip(xdims, ydims): # Normalize + cluster on a fresh normalized view. set_marker_matrix(scratch, markers, working) normalize( scratch, markers=markers, method=norm_method, batch_key=batch_key, ties_method=ties_method, ) create_som( scratch, markers=markers, xdim=x, ydim=y, n_clusters=n_clusters, seed=seed, rlen=rlen, label_key=label_key, ) adata.obs[label_key] = scratch.obs[label_key].values # Correct the (unnormalized) working state per cluster. set_marker_matrix(scratch, markers, working) correct_data( scratch, label_key=label_key, markers=markers, batch_key=batch_key, covar=covar, anchor=anchor, parametric=parametric, ref_batch=ref_batch, out_layer=out_layer, ) working = marker_matrix(scratch, markers, layer=out_layer) set_marker_matrix(adata, markers, working, layer=out_layer) return adata if copy else None