Source code for cycombinepy.detect

"""Batch-effect detection utilities.

Port of ``detect_batch_effect_express`` / ``detect_batch_effect`` from
``R/detect_batch_effect.R``. The Python versions return a dict of matplotlib
figures (or save them to ``out_dir``) rather than printing the R plots.
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Iterable

import numpy as np
import pandas as pd
from anndata import AnnData

from cycombinepy._utils import check_obs_key, marker_matrix, resolve_markers
from cycombinepy.evaluate import compute_emd, compute_mad


def _ensure_single_cluster_label(adata: AnnData, key: str = "_cycombine_all") -> str:
    """Add a trivial cluster column so compute_emd/mad treat the data as one group."""
    if key not in adata.obs.columns:
        adata.obs[key] = "all"
    return key


[docs] def detect_batch_effect_express( adata: AnnData, markers: Iterable[str] | None = None, batch_key: str = "batch", sample_key: str | None = "sample", downsample: int | None = None, out_dir: str | os.PathLike | None = None, seed: int = 472, ) -> dict: """Quick 3-panel batch-effect summary (EMD heatmap, density, MDS). Returns a ``dict`` of matplotlib figures keyed by ``"emd"``, ``"density"`` and ``"mds"``. If ``out_dir`` is given, the figures are saved as PNGs and the dict is still returned for further inspection. Matches ``detect_batch_effect_express`` in ``R/detect_batch_effect.R``. """ try: import matplotlib.pyplot as plt import seaborn as sns except ImportError as exc: # pragma: no cover raise ImportError( "Detection plots require matplotlib + seaborn; install with " "`pip install matplotlib seaborn`." ) from exc check_obs_key(adata, batch_key) markers = resolve_markers(adata, markers) if downsample is not None and adata.n_obs > downsample: rng = np.random.default_rng(seed) idx = rng.choice(adata.n_obs, downsample, replace=False) adata = adata[idx].copy() cluster_key = _ensure_single_cluster_label(adata) emd_df = compute_emd( adata, cell_key=cluster_key, batch_key=batch_key, markers=markers ) # 1. EMD heatmap (mean per marker across batch pairs) pivot = emd_df.groupby("marker")["emd"].mean().to_frame("mean_emd") fig_emd, ax_emd = plt.subplots(figsize=(4, max(2, 0.3 * len(pivot)))) sns.heatmap(pivot, annot=True, cmap="viridis", ax=ax_emd) ax_emd.set_title("Mean EMD per marker (between batches)") # 2. Density plots per marker, colored by batch X = marker_matrix(adata, markers) long = pd.DataFrame(X, columns=markers) long[batch_key] = adata.obs[batch_key].values long = long.melt(id_vars=[batch_key], var_name="marker", value_name="value") g = sns.FacetGrid(long, col="marker", hue=batch_key, col_wrap=4, sharey=False) g.map(sns.kdeplot, "value", fill=True, alpha=0.3, common_norm=False) g.add_legend() fig_density = g.figure # 3. MDS of per-sample median expression fig_mds, ax_mds = plt.subplots(figsize=(4, 4)) if sample_key is not None and sample_key in adata.obs.columns: df = pd.DataFrame(X, columns=markers) df[sample_key] = adata.obs[sample_key].values df[batch_key] = adata.obs[batch_key].values medians = df.groupby(sample_key)[list(markers)].median() sample_batch = df.groupby(sample_key)[batch_key].first() from sklearn.manifold import MDS mds = MDS(n_components=2, random_state=seed, normalized_stress="auto") coords = mds.fit_transform(medians.values) for b in sample_batch.unique(): mask = (sample_batch == b).values ax_mds.scatter(coords[mask, 0], coords[mask, 1], label=str(b)) ax_mds.legend(title=batch_key) ax_mds.set_xlabel("MDS1") ax_mds.set_ylabel("MDS2") ax_mds.set_title("MDS of per-sample medians") else: ax_mds.text(0.5, 0.5, "No sample_key provided", ha="center", va="center") figs = {"emd": fig_emd, "density": fig_density, "mds": fig_mds} if out_dir is not None: out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) for name, fig in figs.items(): fig.savefig(out_dir / f"detect_{name}.png", dpi=120, bbox_inches="tight") return figs
[docs] def detect_batch_effect( adata: AnnData, markers: Iterable[str] | None = None, batch_key: str = "batch", sample_key: str | None = "sample", downsample: int | None = None, out_dir: str | os.PathLike | None = None, seed: int = 472, ) -> dict: """Comprehensive batch-effect diagnostic: express + UMAP + MAD summary. Matches ``detect_batch_effect`` in ``R/detect_batch_effect.R``. """ try: import matplotlib.pyplot as plt import scanpy as sc except ImportError as exc: # pragma: no cover raise ImportError("scanpy + matplotlib required for detect_batch_effect.") from exc figs = detect_batch_effect_express( adata, markers=markers, batch_key=batch_key, sample_key=sample_key, downsample=downsample, out_dir=None, seed=seed, ) # UMAP colored by batch a = adata.copy() sc.pp.pca(a, n_comps=min(20, a.n_vars - 1)) sc.pp.neighbors(a, n_neighbors=15) sc.tl.umap(a, random_state=seed) fig_umap, ax = plt.subplots(figsize=(5, 5)) for b in a.obs[batch_key].unique(): mask = (a.obs[batch_key] == b).to_numpy() ax.scatter( a.obsm["X_umap"][mask, 0], a.obsm["X_umap"][mask, 1], s=3, label=str(b), alpha=0.6, ) ax.legend(title=batch_key) ax.set_title("UMAP (uncorrected) colored by batch") figs["umap"] = fig_umap # Per-marker MAD summary cluster_key = _ensure_single_cluster_label(adata) mad_df = compute_mad(adata, cell_key=cluster_key, batch_key=batch_key, markers=markers) import seaborn as sns fig_mad, ax_mad = plt.subplots(figsize=(6, max(3, 0.3 * mad_df["marker"].nunique()))) sns.barplot(data=mad_df, x="mad", y="marker", hue="batch", ax=ax_mad) ax_mad.set_title("MAD per marker per batch") figs["mad"] = fig_mad if out_dir is not None: out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) for name, fig in figs.items(): fig.savefig(out_dir / f"detect_{name}.png", dpi=120, bbox_inches="tight") return figs