Source code for plotnine_extra.composition._plot_layout

from __future__ import annotations

from dataclasses import dataclass, field
from itertools import cycle
from typing import TYPE_CHECKING, Sequence

from ._types import ComposeAddable

if TYPE_CHECKING:
    from ._compose import Compose


[docs] @dataclass(kw_only=True) class plot_layout(ComposeAddable): """ Customise the layout of plots in a composition """ nrow: int | None = None """ Number of rows """ ncol: int | None = None """ Number of columns """ byrow: bool | None = None """ How to place plots into the grid. If None or True, they are placed row by row, left to right. If False, they are placed column by column, top to bottom. """ widths: Sequence[float] | None = None """ Relative widths of each column """ heights: Sequence[float] | None = None """ Relative heights of each column """ _cmp: Compose = field(init=False, repr=False) """ Composition that this layout is attached to """ def __radd__(self, cmp: Compose) -> Compose: """ Add plot layout to composition """ cmp.layout = self return cmp def _setup(self, cmp: Compose): """ Setup default parameters as they are expected by the layout manager. - Ensure nrow and ncol have values - Ensure widths & heights are set and normalised to mean=1 """ from . import Beside, Stack # setup nrow & ncol if isinstance(cmp, Beside): if self.ncol is None: self.ncol = len(cmp) elif self.ncol < len(cmp): raise ValueError( "Composition has more items than the " "layout columns." ) if self.nrow is None: self.nrow = 1 elif isinstance(cmp, Stack): if self.nrow is None: self.nrow = len(cmp) elif self.nrow < len(cmp): raise ValueError( "Composition has more items than the " "layout rows." ) if self.ncol is None: self.ncol = 1 else: from plotnine.facets.facet_wrap import wrap_dims self.nrow, self.ncol = wrap_dims( len(cmp), self.nrow, self.ncol ) nrow, ncol = self.nrow, self.ncol # byrow if self.byrow is None: self.byrow = True # setup widths & heights ws, hs = self.widths, self.heights if ws is None: ws = (1 / ncol,) * ncol elif len(ws) != ncol: ws = repeat(ws, ncol) if hs is None: hs = (1 / nrow,) * nrow elif len(hs) != nrow: hs = repeat(hs, nrow) self.widths = normalise(ws) self.heights = normalise(hs) def update(self, other: plot_layout): """ Update this layout with the contents of other """ if other.widths: self.widths = other.widths if other.heights: self.heights = other.heights if other.ncol: self.ncol = other.ncol if other.nrow: self.nrow = other.nrow if other.byrow is not None: self.byrow = other.byrow
def repeat(seq: Sequence[float], n: int) -> list[float]: """ Ensure returned sequence has n values, repeat as necessary """ return [val for _, val in zip(range(n), cycle(seq))] def normalise(seq: Sequence[float]) -> list[float]: """ Normalise seq so that the mean is 1 """ mean = sum(seq) / len(seq) if mean == 0: raise ValueError("Cannot rescale: mean is zero") return [x / mean for x in seq]