Source code for plotnine_extra.stats.stat_compare_means

from __future__ import annotations

import pandas as pd
from plotnine.doctools import document
from plotnine.mapping.evaluation import after_stat
from plotnine.stats.stat import stat

from ._common import preserve_panel_columns
from ._label_utils import compute_label_position
from ._p_format import format_p_value, p_to_signif
from ._stat_test import run_stat_test


[docs] @document class stat_compare_means(stat): """ Add mean comparison p-values to a plot Performs statistical tests comparing groups and displays the results as text annotations. Supports t-test, Wilcoxon, ANOVA, and Kruskal-Wallis tests. {usage} Parameters ---------- {common_parameters} method : str, default="wilcox.test" Statistical test method. One of ``"t.test"``, ``"wilcox.test"``, ``"anova"``, ``"kruskal.test"``. paired : bool, default=False Whether to perform a paired test. comparisons : list of tuple, default=None List of group pairs to compare, e.g. ``[("A", "B"), ("A", "C")]``. If ``None``, performs a global test across all groups. ref_group : str, default=None Reference group for pairwise comparisons. Each group is compared against this reference. hide_ns : bool, default=False If ``True``, hide non-significant results. label : str, default="p.format" Label format. One of ``"p.format"``, ``"p.signif"``, ``"p.format.signif"``. label_x_npc : float or str, default="center" Normalized x position for global test label. label_y_npc : float or str, default="top" Normalized y position for global test label. p_digits : int, default=3 Number of digits for p-value formatting. step_increase : float, default=0.1 Fraction of y-range to step between comparison brackets. See Also -------- plotnine.geom_text : The default `geom` for this `stat`. """ _aesthetics_doc = """ {aesthetics_table} **Options for computed aesthetics** ```python "label" # Formatted test result label "p" # P-value "p_signif" # Significance symbol "method" # Name of the test ``` """ REQUIRED_AES = {"x", "y"} DEFAULT_AES = {"label": after_stat("label")} DEFAULT_PARAMS = { "geom": "text", "position": "identity", "na_rm": False, "method": "wilcox.test", "paired": False, "comparisons": None, "ref_group": None, "hide_ns": False, "label": "p.format", "label_x_npc": "center", "label_y_npc": "top", "p_digits": 3, "step_increase": 0.1, } CREATES = {"label", "p", "p_signif", "method"} def __init__(self, mapping=None, data=None, **kwargs): super().__init__(mapping, data, **kwargs) # Remove 'label' from _kwargs so it is not forwarded # to the geom as a static aesthetic value. The 'label' # kwarg is a stat parameter controlling format (e.g. # "p.signif"), not a literal label string. self._kwargs.pop("label", None)
[docs] def compute_panel(self, data, scales) -> pd.DataFrame: method = self.params["method"] paired = self.params["paired"] comparisons = self.params["comparisons"] ref_group = self.params["ref_group"] hide_ns = self.params["hide_ns"] label_type = self.params["label"] p_digits = self.params["p_digits"] step_increase = self.params["step_increase"] # Group data by x categories grouped = dict(list(data.groupby("x"))) group_names = sorted(grouped.keys()) if len(group_names) < 2: return pd.DataFrame() # Build mapping from original labels to numeric # x values if using a discrete scale label_to_num = {} if hasattr(scales, "x") and hasattr( scales.x, "range" ): try: limits = scales.x.limits if limits and isinstance(limits[0], str): for lbl in limits: mapped = scales.x.map([lbl]) if len(mapped) > 0: label_to_num[lbl] = mapped[0] except Exception: pass # If label_to_num is still empty, build it from # group_names directly (already numeric) if not label_to_num: for name in group_names: label_to_num[name] = name # Determine what comparisons to make if comparisons is not None: # Map string labels to numeric keys mapped_pairs = [] for g1, g2 in comparisons: k1 = label_to_num.get(g1, g1) k2 = label_to_num.get(g2, g2) mapped_pairs.append((k1, k2)) pairs = mapped_pairs elif ref_group is not None: ref_key = label_to_num.get( ref_group, ref_group ) pairs = [ (ref_key, g) for g in group_names if g != ref_key ] else: # Global test return self._global_test( data, grouped, group_names, method ) # Pairwise comparisons return self._pairwise_test( data, grouped, pairs, method, paired, hide_ns, label_type, p_digits, step_increase, )
def _global_test( self, data, grouped, group_names, method ): """Run a global test across all groups.""" groups = [ grouped[g]["y"].to_numpy(dtype=float) for g in group_names ] # For 2 groups, use pairwise test method # For >2 groups, use ANOVA or Kruskal-Wallis if len(groups) > 2: if method in ("t.test", "wilcox.test"): global_method = ( "kruskal.test" if method == "wilcox.test" else "anova" ) else: global_method = method else: global_method = method result = run_stat_test(groups, method=global_method) p_digits = self.params["p_digits"] p_signif = p_to_signif(result.p_value) label = self._make_label( result.p_value, p_signif, p_digits ) x_pos = compute_label_position( data["x"].min(), data["x"].max(), self.params["label_x_npc"], ) y_pos = compute_label_position( data["y"].min(), data["y"].max(), self.params["label_y_npc"], ) return preserve_panel_columns( pd.DataFrame( { "x": [x_pos], "y": [y_pos], "label": [label], "p": [result.p_value], "p_signif": [p_signif], "method": [result.method], } ), data, ) def _pairwise_test( self, data, grouped, pairs, method, paired, hide_ns, label_type, p_digits, step_increase, ): """Run pairwise tests between specified pairs.""" results = [] y_max = data["y"].max() y_range = data["y"].max() - data["y"].min() for i, (g1, g2) in enumerate(pairs): if g1 not in grouped or g2 not in grouped: continue group1 = grouped[g1]["y"].to_numpy(dtype=float) group2 = grouped[g2]["y"].to_numpy(dtype=float) result = run_stat_test( [group1, group2], method=method, paired=paired, ) p_signif = p_to_signif(result.p_value) if hide_ns and p_signif == "ns": continue label = self._make_label( result.p_value, p_signif, p_digits ) # Position: midpoint between groups # Get x positions of groups x1 = grouped[g1]["x"].iloc[0] x2 = grouped[g2]["x"].iloc[0] x_mid = (x1 + x2) / 2 y_pos = ( y_max + y_range * 0.05 + y_range * step_increase * i ) results.append( { "x": x_mid, "y": y_pos, "label": label, "p": result.p_value, "p_signif": p_signif, "method": result.method, } ) if not results: return pd.DataFrame() return preserve_panel_columns( pd.DataFrame(results), data ) def _make_label(self, p_value, p_signif, p_digits): """Create label based on label type.""" label_type = self.params["label"] if label_type == "p.signif": return p_signif elif label_type == "p.format.signif": p_str = format_p_value( p_value, digits=p_digits ) return f"{p_str} ({p_signif})" else: # p.format return format_p_value( p_value, digits=p_digits )