arviz_stats.loo_subsample#
- arviz_stats.loo_subsample(data, observations, pointwise=None, var_name=None, reff=None, log_weights=None, log_p=None, log_q=None, seed=315, method='lpd', thin=None, log_lik_fn=None, param_names=None, log=True)[source]#
Compute PSIS-LOO-CV using sub-sampling.
Estimates the expected log pointwise predictive density (elpd) using Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) with sub-sampling for large datasets. Uses either log predictive density (LPD) or point log predictive density (PLPD) approximation and applies a difference estimator based on a simple random sample without replacement.
The PSIS-LOO-CV method is described in [1], [2]. The sub-sampling method is described in [3].
- Parameters:
- data
xarray.DataTree
orInferenceData
Input data. It should contain the posterior and the log_likelihood groups.
- observations
int
orndarray
The sub-sample observations to use:
An integer specifying the number of observations to randomly sub-sample without replacement.
An array of integer indices specifying the exact observations to use.
- pointwise: bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to
rcParams["stats.ic_pointwise"]
.- var_name
str
, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.
- reff: float, optional
Relative MCMC efficiency,
ess / n
i.e. number of effective samples divided by the number of actual samples. Computed from trace by default.- log_weights
xarray.DataArray
orELPDData
, optional Smoothed log weights. Can be either:
A DataArray with the same shape as the log likelihood data
An ELPDData object from a previous
arviz_stats.loo
call.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
- log_p
ndarray
orxarray.DataArray
, optional The (target) log-density evaluated at samples from the target distribution (p). If provided along with
log_q
, approximate posterior correction will be applied.- log_q
ndarray
orxarray.DataArray
, optional The (proposal) log-density evaluated at samples from the proposal distribution (q). If provided along with
log_p
, approximate posterior correction will be applied.- seed: int, optional
Seed for random sampling.
- method: str, optional
Method used for approximating the pointwise log predictive density:
‘lpd’: Use standard log predictive density approximation (default)
‘plpd’: Use Point Log Predictive Density approximation which requires a
log_lik_fn
.
- thin: int, optional
Thinning factor for posterior draws. If specified, the posterior will be thinned by this factor to reduce computation time. For example, using thin=2 will use every 2nd draw. If None (default), all posterior draws are used. This value is stored in the returned ELPDData object and will be automatically used by
update_subsample
.- log_lik_fn
callable
, optional A function that computes the log-likelihood for a single observation given the mean values of posterior parameters. Required only when
method="plpd"
. The function must accept the observed data value for a single point as its first argument (scalar). Subsequent arguments must correspond to the mean values of the posterior parameters specified byparam_names
, passed in the same order. It should return a single scalar log-likelihood value.- param_names
list
, optional Only used when
method="plpd"
. List of parameter names to extract from the posterior. If None, all parameters are used.- log: bool, optional
Only used when
method="plpd"
. Whether thelog_lik_fn
returns log-likelihood (True) or likelihood (False). Default is True.
- data
- Returns:
ELPDData
Object with the following attributes:
elpd: approximated expected log pointwise predictive density (elpd)
se: standard error of the elpd (includes approximation and sampling uncertainty)
p: effective number of parameters
n_samples: number of samples in the posterior
n_data_points: total number of data points (N)
warning: True if the estimated shape parameter k of the Pareto distribution is >
good_k
for any observation in the subsample.elpd_i:
DataArray
with pointwise elpd values (filled with NaNs for non-subsampled points), only ifpointwise=True
.pareto_k:
DataArray
with Pareto shape values for the subsample (filled with NaNs for non-subsampled points), only ifpointwise=True
.scale: scale of the elpd results (“log”, “negative_log”, or “deviance”).
good_k: Threshold for Pareto k warnings.
approx_posterior: True if approximate posterior was used.
subsampling_se: Standard error estimate from subsampling uncertainty only.
subsample_size: Number of observations in the subsample (m).
log_p: Log density of the target posterior.
log_q: Log density of the proposal posterior.
thin: Thinning factor for posterior draws.
log_weights: Smoothed log weights.
See also
loo
Standard PSIS-LOO-CV.
loo_approximate_posterior
PSIS-LOO-CV for approximate posteriors.
compare
Compare models based on ELPD.
update_subsample
Update a previously computed sub-sampled LOO-CV.
References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
[3]Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. Bayesian Leave-One-Out Cross-Validation for Large Data. Proceedings of the 36th International Conference on Machine Learning, PMLR 97:4244–4253 (2019) https://proceedings.mlr.press/v97/magnusson19a.html arXiv preprint https://arxiv.org/abs/1904.10679
Examples
Calculate sub-sampled PSIS-LOO-CV using 4 random observations:
In [1]: from arviz_stats import loo_subsample ...: from arviz_base import load_arviz_data ...: data = load_arviz_data("centered_eight") ...: loo_results = loo_subsample(data, observations=4, var_name="obs", pointwise=True) ...: loo_results ...: Out[1]: Computed from 2000 by 4 subsampled log-likelihood values from 8 total observations. Estimate SE subsampling SE elpd_loo -30.8 1.5 0.3 p_loo 0.9 ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.70] (good) 4 100.0% (0.70, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0%
Return the pointwise values for the sub-sample:
In [2]: loo_results.elpd_i Out[2]: <xarray.DataArray 'elpd_i' (school: 8)> Size: 64B array([-4.89190585, nan, nan, -3.46496198, -3.4780878 , nan, nan, -3.95934834]) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
We can also use the PLPD approximation method with a custom log-likelihood function. We need to define a function that computes the log-likelihood for a single observation given the mean values of posterior parameters. For the Eight Schools model, we define a function that computes the likelihood for each observation using the global mean of the parameters (e.g., the overall mean theta):
In [3]: import numpy as np ...: from arviz_stats import loo_subsample ...: from arviz_base import load_arviz_data ...: from scipy.stats import norm ...: data = load_arviz_data("centered_eight") ...: ...: def log_lik_fn(y, theta): ...: sigma = 12.5 # Using a fixed sigma for simplicity ...: return norm.logpdf(y, loc=theta, scale=sigma) ...: ...: loo_results = loo_subsample( ...: data, ...: observations=4, ...: var_name="obs", ...: method="plpd", ...: log_lik_fn=log_lik_fn, ...: param_names=["theta"], ...: pointwise=True ...: ) ...: loo_results ...: Out[3]: Computed from 2000 by 4 subsampled log-likelihood values from 8 total observations. Estimate SE subsampling SE elpd_loo -30.4 1.1 0.7 p_loo 0.0 ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.70] (good) 4 100.0% (0.70, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0%