arviz_stats.loo_subsample

Contents

arviz_stats.loo_subsample#

arviz_stats.loo_subsample(data, observations, pointwise=None, var_name=None, reff=None, log_weights=None, log_p=None, log_q=None, seed=315, method='lpd', thin=None, log_lik_fn=None, param_names=None, log=True)[source]#

Compute PSIS-LOO-CV using sub-sampling.

Estimates the expected log pointwise predictive density (elpd) using Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) with sub-sampling for large datasets. Uses either log predictive density (LPD) or point log predictive density (PLPD) approximation and applies a difference estimator based on a simple random sample without replacement.

The PSIS-LOO-CV method is described in [1], [2]. The sub-sampling method is described in [3].

Parameters:
dataxarray.DataTree or InferenceData

Input data. It should contain the posterior and the log_likelihood groups.

observationsint or ndarray

The sub-sample observations to use:

  • An integer specifying the number of observations to randomly sub-sample without replacement.

  • An array of integer indices specifying the exact observations to use.

pointwise: bool, optional

If True the pointwise predictive accuracy will be returned. Defaults to rcParams["stats.ic_pointwise"].

var_namestr, optional

The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.

reff: float, optional

Relative MCMC efficiency, ess / n i.e. number of effective samples divided by the number of actual samples. Computed from trace by default.

log_weightsxarray.DataArray or ELPDData, optional

Smoothed log weights. Can be either:

  • A DataArray with the same shape as the log likelihood data

  • An ELPDData object from a previous arviz_stats.loo call.

Defaults to None. If not provided, it will be computed using the PSIS-LOO method.

log_pndarray or xarray.DataArray, optional

The (target) log-density evaluated at samples from the target distribution (p). If provided along with log_q, approximate posterior correction will be applied.

log_qndarray or xarray.DataArray, optional

The (proposal) log-density evaluated at samples from the proposal distribution (q). If provided along with log_p, approximate posterior correction will be applied.

seed: int, optional

Seed for random sampling.

method: str, optional

Method used for approximating the pointwise log predictive density:

  • ‘lpd’: Use standard log predictive density approximation (default)

  • ‘plpd’: Use Point Log Predictive Density approximation which requires a log_lik_fn.

thin: int, optional

Thinning factor for posterior draws. If specified, the posterior will be thinned by this factor to reduce computation time. For example, using thin=2 will use every 2nd draw. If None (default), all posterior draws are used. This value is stored in the returned ELPDData object and will be automatically used by update_subsample.

log_lik_fncallable, optional

A function that computes the log-likelihood for a single observation given the mean values of posterior parameters. Required only when method="plpd". The function must accept the observed data value for a single point as its first argument (scalar). Subsequent arguments must correspond to the mean values of the posterior parameters specified by param_names, passed in the same order. It should return a single scalar log-likelihood value.

param_nameslist, optional

Only used when method="plpd". List of parameter names to extract from the posterior. If None, all parameters are used.

log: bool, optional

Only used when method="plpd". Whether the log_lik_fn returns log-likelihood (True) or likelihood (False). Default is True.

Returns:
ELPDData

Object with the following attributes:

  • elpd: approximated expected log pointwise predictive density (elpd)

  • se: standard error of the elpd (includes approximation and sampling uncertainty)

  • p: effective number of parameters

  • n_samples: number of samples in the posterior

  • n_data_points: total number of data points (N)

  • warning: True if the estimated shape parameter k of the Pareto distribution is > good_k for any observation in the subsample.

  • elpd_i: DataArray with pointwise elpd values (filled with NaNs for non-subsampled points), only if pointwise=True.

  • pareto_k: DataArray with Pareto shape values for the subsample (filled with NaNs for non-subsampled points), only if pointwise=True.

  • scale: scale of the elpd results (“log”, “negative_log”, or “deviance”).

  • good_k: Threshold for Pareto k warnings.

  • approx_posterior: True if approximate posterior was used.

  • subsampling_se: Standard error estimate from subsampling uncertainty only.

  • subsample_size: Number of observations in the subsample (m).

  • log_p: Log density of the target posterior.

  • log_q: Log density of the proposal posterior.

  • thin: Thinning factor for posterior draws.

  • log_weights: Smoothed log weights.

See also

loo

Standard PSIS-LOO-CV.

loo_approximate_posterior

PSIS-LOO-CV for approximate posteriors.

compare

Compare models based on ELPD.

update_subsample

Update a previously computed sub-sampled LOO-CV.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

[3]

Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. Bayesian Leave-One-Out Cross-Validation for Large Data. Proceedings of the 36th International Conference on Machine Learning, PMLR 97:4244–4253 (2019) https://proceedings.mlr.press/v97/magnusson19a.html arXiv preprint https://arxiv.org/abs/1904.10679

Examples

Calculate sub-sampled PSIS-LOO-CV using 4 random observations:

In [1]: from arviz_stats import loo_subsample
   ...: from arviz_base import load_arviz_data
   ...: data = load_arviz_data("centered_eight")
   ...: loo_results = loo_subsample(data, observations=4, var_name="obs", pointwise=True)
   ...: loo_results
   ...: 
Out[1]: 
Computed from 2000 by 4 subsampled log-likelihood
values from 8 total observations.

         Estimate   SE subsampling SE
elpd_loo     -30.8  1.5            0.3
p_loo          0.9

------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        4  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%

Return the pointwise values for the sub-sample:

In [2]: loo_results.elpd_i
Out[2]: 
<xarray.DataArray 'elpd_i' (school: 8)> Size: 64B
array([-4.89190585,         nan,         nan, -3.46496198, -3.4780878 ,
               nan,         nan, -3.95934834])
Coordinates:
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'

We can also use the PLPD approximation method with a custom log-likelihood function. We need to define a function that computes the log-likelihood for a single observation given the mean values of posterior parameters. For the Eight Schools model, we define a function that computes the likelihood for each observation using the global mean of the parameters (e.g., the overall mean theta):

In [3]: import numpy as np
   ...: from arviz_stats import loo_subsample
   ...: from arviz_base import load_arviz_data
   ...: from scipy.stats import norm
   ...: data = load_arviz_data("centered_eight")
   ...: 
   ...: def log_lik_fn(y, theta):
   ...:     sigma = 12.5  # Using a fixed sigma for simplicity
   ...:     return norm.logpdf(y, loc=theta, scale=sigma)
   ...: 
   ...: loo_results = loo_subsample(
   ...:     data,
   ...:     observations=4,
   ...:     var_name="obs",
   ...:     method="plpd",
   ...:     log_lik_fn=log_lik_fn,
   ...:     param_names=["theta"],
   ...:     pointwise=True
   ...: )
   ...: loo_results
   ...: 
Out[3]: 
Computed from 2000 by 4 subsampled log-likelihood
values from 8 total observations.

         Estimate   SE subsampling SE
elpd_loo     -30.4  1.1            0.7
p_loo          0.0

------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        4  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%