"""Plotting helpers.
Pragmatic subset of ``R/utils_plotting.R``: density comparison, dimensionality
reduction, and an EMD heatmap summary.
"""
from __future__ import annotations
import os
from typing import Iterable
import numpy as np
import pandas as pd
from anndata import AnnData
from cycombinepy._utils import marker_matrix, resolve_markers
[docs]
def plot_density(
adata: AnnData,
markers: Iterable[str] | None = None,
layer: str | None = "cycombine_corrected",
batch_key: str = "batch",
filename: str | os.PathLike | None = None,
):
"""Per-marker density plot colored by batch.
If ``layer`` is set and present, the corrected distribution is overlaid
alongside the uncorrected ``adata.X`` for easy before/after comparison.
Mirrors ``plot_density`` in ``R/utils_plotting.R``.
"""
import matplotlib.pyplot as plt
import seaborn as sns
markers = resolve_markers(adata, markers)
batches = adata.obs[batch_key].astype(str).values
def _melt(X: np.ndarray, kind: str) -> pd.DataFrame:
df = pd.DataFrame(X, columns=list(markers))
df[batch_key] = batches
df["kind"] = kind
return df.melt(id_vars=[batch_key, "kind"], var_name="marker", value_name="value")
parts = [_melt(marker_matrix(adata, markers), "uncorrected")]
has_corrected = layer is not None and layer in adata.layers
if has_corrected:
parts.append(_melt(marker_matrix(adata, markers, layer=layer), "corrected"))
df = pd.concat(parts, ignore_index=True)
g = sns.FacetGrid(
df,
col="marker",
row="kind" if has_corrected else None,
hue=batch_key,
sharey=False,
col_wrap=None if has_corrected else 4,
)
g.map(sns.kdeplot, "value", fill=True, alpha=0.3, common_norm=False)
g.add_legend()
if filename is not None:
g.figure.savefig(filename, dpi=120, bbox_inches="tight")
return g.figure
[docs]
def plot_dimred(
adata: AnnData,
kind: str = "umap",
color: str | list[str] = "batch",
layer: str | None = None,
seed: int = 0,
):
"""Thin wrapper around ``scanpy.pl.umap`` / ``scanpy.pl.pca``.
If ``layer`` is provided, the PCA is computed on that layer so the plot
reflects the corrected values.
"""
import matplotlib.pyplot as plt
import scanpy as sc
a = adata.copy()
if layer is not None:
a.X = a.layers[layer]
sc.pp.pca(a, n_comps=min(20, a.n_vars - 1))
if kind == "pca":
sc.pl.pca(a, color=color, show=False)
return plt.gcf()
sc.pp.neighbors(a, n_neighbors=15, random_state=seed)
sc.tl.umap(a, random_state=seed)
sc.pl.umap(a, color=color, show=False)
return plt.gcf()
[docs]
def plot_emd_heatmap(emd_df: pd.DataFrame, filename: str | os.PathLike | None = None):
"""Heatmap of mean EMD per (cluster, marker) from :func:`compute_emd` output."""
import matplotlib.pyplot as plt
import seaborn as sns
pivot = emd_df.pivot_table(
index="cluster", columns="marker", values="emd", aggfunc="mean"
)
fig, ax = plt.subplots(figsize=(max(4, 0.4 * pivot.shape[1]), max(3, 0.3 * pivot.shape[0])))
sns.heatmap(pivot, annot=False, cmap="viridis", ax=ax)
ax.set_title("Mean EMD per (cluster, marker)")
if filename is not None:
fig.savefig(filename, dpi=120, bbox_inches="tight")
return fig