Source code for plotnine_extra.stats.stat_pointdensity

from __future__ import annotations

from typing import TYPE_CHECKING, cast

import numpy as np
import pandas as pd
from plotnine.doctools import document
from plotnine.mapping.evaluation import after_stat
from plotnine.stats.density import get_var_type, kde
from plotnine.stats.stat import stat

if TYPE_CHECKING:
    from plotnine.typing import FloatArray


[docs] @document class stat_pointdensity(stat): """ Compute density estimation for each point {usage} Parameters ---------- {common_parameters} package : Literal["statsmodels", "scipy", "sklearn"], \ default="statsmodels" Package whose kernel density estimation to use. kde_params : dict, default=None Keyword arguments to pass on to the kde class. See Also -------- plotnine.geom_point : The default `geom` for this `stat`. statsmodels.nonparametric.kde.KDEMultivariate scipy.stats.gaussian_kde sklearn.neighbors.KernelDensity """ _aesthetics_doc = """ {aesthetics_table} **Options for computed aesthetics** ```python "density" # Computed density at a point ``` """ REQUIRED_AES = {"x", "y"} DEFAULT_AES = {"color": after_stat("density")} DEFAULT_PARAMS = { "geom": "point", "position": "identity", "na_rm": False, "package": "statsmodels", "kde_params": None, } CREATES = {"density"}
[docs] def setup_params(self, data): params = self.params if params["kde_params"] is None: params["kde_params"] = {} kde_params = params["kde_params"] if params["package"] == "statsmodels": params["package"] = "statsmodels-m" if "var_type" not in kde_params: x_type = get_var_type(data["x"]) y_type = get_var_type(data["y"]) kde_params["var_type"] = f"{x_type}{y_type}"
[docs] def compute_group(self, data, scales) -> pd.DataFrame: package = self.params["package"] kde_params = self.params["kde_params"] x = cast("FloatArray", data["x"].to_numpy()) y = cast("FloatArray", data["y"].to_numpy()) var_data = np.array([x, y]).T density = kde(var_data, var_data, package, **kde_params) data = pd.DataFrame( { "x": data["x"], "y": data["y"], "density": density.flatten(), } ) return data