Source code for plotnine_extra.stats.stat_pointdensity
from __future__ import annotations
from typing import TYPE_CHECKING, cast
import numpy as np
import pandas as pd
from plotnine.doctools import document
from plotnine.mapping.evaluation import after_stat
from plotnine.stats.density import get_var_type, kde
from plotnine.stats.stat import stat
if TYPE_CHECKING:
from plotnine.typing import FloatArray
[docs]
@document
class stat_pointdensity(stat):
"""
Compute density estimation for each point
{usage}
Parameters
----------
{common_parameters}
package : Literal["statsmodels", "scipy", "sklearn"], \
default="statsmodels"
Package whose kernel density estimation to use.
kde_params : dict, default=None
Keyword arguments to pass on to the kde class.
See Also
--------
plotnine.geom_point : The default `geom` for this `stat`.
statsmodels.nonparametric.kde.KDEMultivariate
scipy.stats.gaussian_kde
sklearn.neighbors.KernelDensity
"""
_aesthetics_doc = """
{aesthetics_table}
**Options for computed aesthetics**
```python
"density" # Computed density at a point
```
"""
REQUIRED_AES = {"x", "y"}
DEFAULT_AES = {"color": after_stat("density")}
DEFAULT_PARAMS = {
"geom": "point",
"position": "identity",
"na_rm": False,
"package": "statsmodels",
"kde_params": None,
}
CREATES = {"density"}
[docs]
def setup_params(self, data):
params = self.params
if params["kde_params"] is None:
params["kde_params"] = {}
kde_params = params["kde_params"]
if params["package"] == "statsmodels":
params["package"] = "statsmodels-m"
if "var_type" not in kde_params:
x_type = get_var_type(data["x"])
y_type = get_var_type(data["y"])
kde_params["var_type"] = f"{x_type}{y_type}"
[docs]
def compute_group(self, data, scales) -> pd.DataFrame:
package = self.params["package"]
kde_params = self.params["kde_params"]
x = cast("FloatArray", data["x"].to_numpy())
y = cast("FloatArray", data["y"].to_numpy())
var_data = np.array([x, y]).T
density = kde(var_data, var_data, package, **kde_params)
data = pd.DataFrame(
{
"x": data["x"],
"y": data["y"],
"density": density.flatten(),
}
)
return data