"""Compute PSIS-LOO-CV using sub-sampling."""
import numpy as np
import xarray as xr
from arviz_base import rcParams
from xarray_einstats.stats import logsumexp
from arviz_stats.loo.helper_loo import (
_compute_loo_results,
_diff_srs_estimator,
_get_r_eff,
_prepare_full_arrays,
_prepare_loo_inputs,
_prepare_subsample,
_prepare_update_subsample,
_select_obs_by_coords,
_select_obs_by_indices,
_srs_estimator,
_warn_pareto_k,
)
from arviz_stats.loo.loo_approximate_posterior import loo_approximate_posterior
from arviz_stats.utils import ELPDData
[docs]
def loo_subsample(
data,
observations,
pointwise=None,
var_name=None,
reff=None,
log_weights=None,
log_p=None,
log_q=None,
seed=315,
method="lpd",
thin=None,
log_lik_fn=None,
param_names=None,
log=True,
):
"""Compute PSIS-LOO-CV using sub-sampling.
Estimates the expected log pointwise predictive density (elpd) using Pareto smoothed
importance sampling leave-one-out cross-validation (PSIS-LOO-CV) with sub-sampling
for large datasets. Uses either log predictive density (LPD) or point log predictive
density (PLPD) approximation and applies a difference estimator based on a simple random
sample without replacement.
The PSIS-LOO-CV method is described in [1]_, [2]_. The sub-sampling
method is described in [3]_.
Parameters
----------
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood groups.
observations : int or ndarray
The sub-sample observations to use:
- An integer specifying the number of observations to randomly sub-sample without
replacement.
- An array of integer indices specifying the exact observations to use.
pointwise: bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to
``rcParams["stats.ic_pointwise"]``.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff: float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
log_weights : DataArray or ELPDData, optional
Smoothed log weights. Can be either:
- A DataArray with the same shape as the log likelihood data
- An ELPDData object from a previous :func:`arviz_stats.loo` call.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
log_p : ndarray or DataArray, optional
The (target) log-density evaluated at samples from the target distribution (p).
If provided along with ``log_q``, approximate posterior correction will be applied.
log_q : ndarray or DataArray, optional
The (proposal) log-density evaluated at samples from the proposal distribution (q).
If provided along with ``log_p``, approximate posterior correction will be applied.
seed: int, optional
Seed for random sampling.
method: str, optional
Method used for approximating the pointwise log predictive density:
- 'lpd': Use standard log predictive density approximation (default)
- 'plpd': Use Point Log Predictive Density approximation which requires a ``log_lik_fn``.
thin: int, optional
Thinning factor for posterior draws. If specified, the posterior will be thinned
by this factor to reduce computation time. For example, using thin=2 will use
every 2nd draw. If None (default), all posterior draws are used. This value is stored
in the returned ELPDData object and will be automatically used by ``update_subsample``.
log_lik_fn : callable, optional
A function that computes the log-likelihood for a single observation given the
mean values of posterior parameters. Required only when ``method="plpd"``.
The function must accept the observed data value for a single point as its
first argument (scalar). Subsequent arguments must correspond to the mean
values of the posterior parameters specified by ``param_names``, passed in the
same order. It should return a single scalar log-likelihood value.
param_names : list, optional
Only used when ``method="plpd"``. List of parameter names to extract from
the posterior. If None, all parameters are used.
log: bool, optional
Only used when ``method="plpd"``. Whether the ``log_lik_fn`` returns
log-likelihood (True) or likelihood (False). Default is True.
Returns
-------
ELPDData
Object with the following attributes:
- **elpd**: approximated expected log pointwise predictive density (elpd)
- **se**: standard error of the elpd (includes approximation and sampling uncertainty)
- **p**: effective number of parameters
- **n_samples**: number of samples in the posterior
- **n_data_points**: total number of data points (N)
- **warning**: True if the estimated shape parameter k of the Pareto distribution
is > ``good_k`` for any observation in the subsample.
- **elpd_i**: :class:`~xarray.DataArray` with pointwise elpd values (filled with NaNs
for non-subsampled points), only if ``pointwise=True``.
- **pareto_k**: :class:`~xarray.DataArray` with Pareto shape values for the subsample
(filled with NaNs for non-subsampled points), only if ``pointwise=True``.
- **scale**: scale of the elpd results ("log", "negative_log", or "deviance").
- **good_k**: Threshold for Pareto k warnings.
- **approx_posterior**: True if approximate posterior was used.
- **subsampling_se**: Standard error estimate from subsampling uncertainty only.
- **subsample_size**: Number of observations in the subsample (m).
- **log_p**: Log density of the target posterior.
- **log_q**: Log density of the proposal posterior.
- **thin**: Thinning factor for posterior draws.
- **log_weights**: Smoothed log weights.
Examples
--------
Calculate sub-sampled PSIS-LOO-CV using 4 random observations:
.. ipython::
In [1]: from arviz_stats import loo_subsample
...: from arviz_base import load_arviz_data
...: data = load_arviz_data("centered_eight")
...: loo_results = loo_subsample(data, observations=4, var_name="obs", pointwise=True)
...: loo_results
Return the pointwise values for the sub-sample:
.. ipython::
In [2]: loo_results.elpd_i
We can also use the PLPD approximation method with a custom log-likelihood function.
We need to define a function that computes the log-likelihood for a single observation
given the mean values of posterior parameters. For the Eight Schools model, we define a
function that computes the likelihood for each observation using the *global mean* of the
parameters (e.g., the overall mean `theta`):
.. ipython::
In [1]: import numpy as np
...: from arviz_stats import loo_subsample
...: from arviz_base import load_arviz_data
...: from scipy.stats import norm
...: data = load_arviz_data("centered_eight")
...:
...: def log_lik_fn(y, theta):
...: sigma = 12.5 # Using a fixed sigma for simplicity
...: return norm.logpdf(y, loc=theta, scale=sigma)
...:
...: loo_results = loo_subsample(
...: data,
...: observations=4,
...: var_name="obs",
...: method="plpd",
...: log_lik_fn=log_lik_fn,
...: param_names=["theta"],
...: pointwise=True
...: )
...: loo_results
See Also
--------
loo : Standard PSIS-LOO-CV.
loo_approximate_posterior : PSIS-LOO-CV for approximate posteriors.
compare : Compare models based on ELPD.
update_subsample : Update a previously computed sub-sampled LOO-CV.
References
----------
.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
arXiv preprint https://arxiv.org/abs/1507.04544.
.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646
.. [3] Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. *Bayesian Leave-One-Out
Cross-Validation for Large Data.* Proceedings of the 36th International Conference on
Machine Learning, PMLR 97:4244–4253 (2019)
https://proceedings.mlr.press/v97/magnusson19a.html
arXiv preprint https://arxiv.org/abs/1904.10679
"""
loo_inputs = _prepare_loo_inputs(data, var_name, thin)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
if method not in ["lpd", "plpd"]:
raise ValueError("Method must be either 'lpd' or 'plpd'")
if method == "plpd" and log_lik_fn is None:
raise ValueError("log_lik_fn must be provided when method='plpd'")
log_likelihood = loo_inputs.log_likelihood
if reff is None:
reff = _get_r_eff(data, loo_inputs.n_samples)
subsample_data = _prepare_subsample(
data,
log_likelihood,
loo_inputs.var_name,
observations,
seed,
method,
log_lik_fn,
param_names,
log,
loo_inputs.obs_dims,
loo_inputs.sample_dims,
loo_inputs.n_data_points,
loo_inputs.n_samples,
)
sample_ds = xr.Dataset({loo_inputs.var_name: subsample_data.log_likelihood_sample})
if log_p is not None and log_q is not None:
sample_data = xr.DataTree()
sample_data["log_likelihood"] = sample_ds
loo_approx = loo_approximate_posterior(
sample_data,
log_p,
log_q,
True,
loo_inputs.var_name,
)
elpd_loo_i = loo_approx.elpd_i
pareto_k_sample_da = loo_approx.pareto_k
approx_posterior = True
else:
if log_weights is not None:
if isinstance(log_weights, ELPDData):
if log_weights.log_weights is None:
raise ValueError("ELPDData object does not contain log_weights")
log_weights = log_weights.log_weights
if loo_inputs.var_name in log_weights:
log_weights = log_weights[loo_inputs.var_name]
if len(loo_inputs.obs_dims) > 1:
stacked_obs_dim = "__obs__"
log_weights_stacked = log_weights.stack({stacked_obs_dim: loo_inputs.obs_dims})
log_weights_sample = _select_obs_by_indices(
log_weights_stacked, subsample_data.indices, [stacked_obs_dim], stacked_obs_dim
)
log_weights_sample = log_weights_sample.unstack(stacked_obs_dim)
else:
obs_dim = loo_inputs.obs_dims[0]
log_weights_sample = _select_obs_by_indices(
log_weights, subsample_data.indices, loo_inputs.obs_dims, obs_dim
)
log_weights_sample_ds = xr.Dataset({loo_inputs.var_name: log_weights_sample})
_, pareto_k_ds = sample_ds.azstats.psislw(r_eff=reff, dim=loo_inputs.sample_dims)
log_weights_ds = log_weights_sample_ds + sample_ds
else:
log_weights_ds, pareto_k_ds = sample_ds.azstats.psislw(
r_eff=reff, dim=loo_inputs.sample_dims
)
log_weights_sample = log_weights_ds[loo_inputs.var_name]
log_weights_ds += sample_ds
elpd_loo_i = logsumexp(log_weights_ds, dims=loo_inputs.sample_dims)[loo_inputs.var_name]
pareto_k_sample_da = pareto_k_ds[loo_inputs.var_name]
approx_posterior = False
warn_mg, good_k = _warn_pareto_k(pareto_k_sample_da, loo_inputs.n_samples)
elpd_loo_hat, subsampling_se, se = _diff_srs_estimator(
elpd_loo_i,
subsample_data.lpd_approx_sample,
subsample_data.lpd_approx_all,
loo_inputs.n_data_points,
subsample_data.subsample_size,
)
# Calculate p_loo using SRS estimation directly on the p_loo values
# from the subsample
p_loo_sample = subsample_data.lpd_approx_sample - elpd_loo_i
p_loo, _, _ = _srs_estimator(
p_loo_sample,
loo_inputs.n_data_points,
subsample_data.subsample_size,
)
if not pointwise:
stored_log_weights = log_weights_sample if "log_weights_sample" in locals() else None
return ELPDData(
"loo",
elpd_loo_hat,
se,
p_loo,
loo_inputs.n_samples,
loo_inputs.n_data_points,
"log",
warn_mg,
good_k,
None,
None,
approx_posterior,
subsampling_se,
subsample_data.subsample_size,
log_p,
log_q,
thin,
stored_log_weights,
)
elpd_i_full, pareto_k_full = _prepare_full_arrays(
elpd_loo_i,
pareto_k_sample_da,
subsample_data.lpd_approx_all,
subsample_data.indices,
loo_inputs.obs_dims,
elpd_loo_hat,
)
if "log_weights_sample" in locals() and log_weights_sample is not None:
log_weights_full = xr.Dataset({loo_inputs.var_name: log_weights_sample})
else:
log_weights_full = None
return ELPDData(
"loo",
elpd_loo_hat,
se,
p_loo,
loo_inputs.n_samples,
loo_inputs.n_data_points,
"log",
warn_mg,
good_k,
elpd_i_full,
pareto_k_full,
approx_posterior,
subsampling_se,
subsample_data.subsample_size,
log_p,
log_q,
thin,
log_weights_full,
)
[docs]
def update_subsample(
loo_orig,
data,
observations=None,
var_name=None,
reff=None,
log_weights=None,
seed=315,
method="lpd",
log_lik_fn=None,
param_names=None,
log=True,
):
"""Update a sub-sampled PSIS-LOO-CV object with new observations.
Extends a sub-sampled PSIS-LOO-CV result by adding new observations to the sub-sample
without recomputing values for previously sampled observations. This allows for
incrementally improving the sub-sampled PSIS-LOO-CV estimate with additional observations.
The sub-sampling method is described in [1]_.
Parameters
----------
loo_orig : ELPDData
Original PSIS-LOO-CV result created with ``loo_subsample`` with ``pointwise=True``.
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood groups.
observations : int or ndarray, optional
The additional observations to use:
- An integer specifying the number of new observations to randomly sub-sample
without replacement.
- An array of integer indices specifying the exact new observations to use.
- If None or 0, returns the original PSIS-LOO-CV result unchanged.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff : float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
log_weights : DataArray or ELPDData, optional
Smoothed log weights. Can be either:
- A DataArray with the same shape as the log likelihood data
- An ELPDData object from a previous :func:`arviz_stats.loo` call.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
seed : int, optional
Seed for random sampling.
method: str, optional
Method used for approximating the pointwise log predictive density:
- 'lpd': Use standard log predictive density approximation (default)
- 'plpd': Use Point Log Predictive Density approximation which requires a ``log_lik_fn``.
log_lik_fn : callable, optional
A function that computes the log-likelihood for a single observation given the
mean values of posterior parameters. Required only when ``method="plpd"``.
The function must accept the observed data value for a single point as its
first argument (scalar). Subsequent arguments must correspond to the mean
values of the posterior parameters specified by ``param_names``, passed in the
same order. It should return a single scalar log-likelihood value.
param_names: list, optional
Only used when ``method="plpd"``. List of parameter names to extract from
the posterior. If None, all parameters are used.
log: bool, optional
Only used when ``method="plpd"``. Whether the ``log_lik_fn`` returns
log-likelihood (True) or likelihood (False). Default is True.
Returns
-------
ELPDData
Object with the following attributes:
- **elpd**: updated approximated expected log pointwise predictive density (elpd)
- **se**: standard error of the elpd (includes approximation and sampling uncertainty)
- **p**: effective number of parameters
- **n_samples**: number of samples in the posterior
- **n_data_points**: total number of data points (N)
- **warning**: True if the estimated shape parameter k of the Pareto distribution
is > ``good_k`` for any observation in the subsample.
- **elpd_i**: :class:`~xarray.DataArray` with pointwise elpd values (filled with NaNs
for non-subsampled points), only if ``pointwise=True``.
- **pareto_k**: :class:`~xarray.DataArray` with Pareto shape values for the subsample
(filled with NaNs for non-subsampled points), only if ``pointwise=True``.
- **scale**: scale of the elpd results ("log", "negative_log", or "deviance").
- **good_k**: Threshold for Pareto k warnings.
- **approx_posterior**: True if approximate posterior was used.
- **subsampling_se**: Standard error estimate from subsampling uncertainty only.
- **subsample_size**: Number of observations in the subsample (original + new).
- **log_p**: Log density of the target posterior.
- **log_q**: Log density of the proposal posterior.
- **thin**: Thinning factor for posterior draws.
- **log_weights**: Smoothed log weights.
Examples
--------
Calculate initial sub-sampled PSIS-LOO-CV using 4 observations, then update with 4 more:
.. ipython::
:okwarning:
In [1]: from arviz_stats import loo_subsample, update_subsample
...: from arviz_base import load_arviz_data
...: data = load_arviz_data("non_centered_eight")
...: initial_loo = loo_subsample(data, observations=4, var_name="obs", pointwise=True)
...: updated_loo = update_subsample(initial_loo, data, observations=2)
...: updated_loo
See Also
--------
loo : Exact PSIS-LOO cross-validation.
loo_subsample : PSIS-LOO-CV with subsampling.
compare : Compare models based on ELPD.
References
----------
.. [1] Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. *Bayesian Leave-One-Out
Cross-Validation for Large Data.* Proceedings of the 36th International Conference on
Machine Learning, PMLR 97:4244–4253 (2019)
https://proceedings.mlr.press/v97/magnusson19a.html
arXiv preprint https://arxiv.org/abs/1904.10679
"""
if observations is None or (isinstance(observations, int) and observations == 0):
return loo_orig
if loo_orig.elpd_i is None:
raise ValueError("Original loo_subsample result must have pointwise=True")
if method not in ["lpd", "plpd"]:
raise ValueError("Method must be either 'lpd' or 'plpd'")
if method == "plpd" and log_lik_fn is None:
raise ValueError("log_lik_fn must be provided when method='plpd'")
thin = getattr(loo_orig, "thin_factor", None)
loo_inputs = _prepare_loo_inputs(data, var_name, thin)
update_data = _prepare_update_subsample(
loo_orig, data, observations, var_name, seed, method, log_lik_fn, param_names, log
)
if reff is None:
reff = _get_r_eff(data, loo_inputs.n_samples)
# Get log densities from original ELPD data if they exist
log_p = getattr(loo_orig, "log_p", None)
log_q = getattr(loo_orig, "log_q", None)
log_weights_new = None
if log_weights is None:
log_weights = getattr(loo_orig, "log_weights", None)
if log_weights is not None:
if isinstance(log_weights, ELPDData):
if log_weights.log_weights is None:
raise ValueError("ELPDData object does not contain log_weights")
log_weights = log_weights.log_weights
if loo_inputs.var_name in log_weights:
log_weights = log_weights[loo_inputs.var_name]
if len(loo_inputs.obs_dims) > 1:
stacked_obs_dim = "__obs__"
log_weights_stacked = log_weights.stack({stacked_obs_dim: loo_inputs.obs_dims})
log_weights_new = _select_obs_by_indices(
log_weights_stacked, update_data.new_indices, [stacked_obs_dim], stacked_obs_dim
)
log_weights_new = log_weights_new.unstack(stacked_obs_dim)
else:
obs_dim = loo_inputs.obs_dims[0]
log_weights_new = _select_obs_by_indices(
log_weights, update_data.new_indices, loo_inputs.obs_dims, obs_dim
)
if log_weights_new is None:
log_weights_new_ds, _ = update_data.log_likelihood_new.azstats.psislw(
r_eff=reff, dim=loo_inputs.sample_dims
)
log_weights_new = log_weights_new_ds[loo_inputs.var_name]
elpd_loo_i_new_da, pareto_k_new_da, approx_posterior = _compute_loo_results(
log_likelihood=update_data.log_likelihood_new,
var_name=loo_inputs.var_name,
sample_dims=loo_inputs.sample_dims,
n_samples=loo_inputs.n_samples,
n_data_points=len(update_data.new_indices),
log_weights=log_weights_new,
reff=reff,
log_p=log_p,
log_q=log_q,
return_pointwise=True,
)
combined_elpd_i_da = xr.concat(
[update_data.old_elpd_i, elpd_loo_i_new_da], dim=update_data.concat_dim
)
combined_pareto_k_da = xr.concat(
[update_data.old_pareto_k, pareto_k_new_da], dim=update_data.concat_dim
)
good_k = loo_orig.good_k
warn_mg, _ = _warn_pareto_k(combined_pareto_k_da, loo_inputs.n_samples)
lpd_approx_sample_da = _select_obs_by_coords(
update_data.lpd_approx_all, combined_elpd_i_da, loo_inputs.obs_dims, "__obs__"
)
elpd_loo_hat, subsampling_se, se = _diff_srs_estimator(
combined_elpd_i_da,
lpd_approx_sample_da,
update_data.lpd_approx_all,
loo_inputs.n_data_points,
update_data.combined_size,
)
# Calculate p_loo using SRS estimation directly on the p_loo values
# from the subsample
p_loo_sample = lpd_approx_sample_da - combined_elpd_i_da
p_loo, _, _ = _srs_estimator(
p_loo_sample,
loo_inputs.n_data_points,
update_data.combined_size,
)
combined_indices = np.concatenate((update_data.old_indices, update_data.new_indices))
elpd_i_full, pareto_k_full = _prepare_full_arrays(
combined_elpd_i_da,
combined_pareto_k_da,
update_data.lpd_approx_all,
combined_indices,
loo_inputs.obs_dims,
elpd_loo_hat,
)
if loo_orig.log_weights is not None and log_weights_new is not None:
old_log_weights = loo_orig.log_weights
if isinstance(old_log_weights, xr.Dataset):
old_log_weights = old_log_weights[loo_inputs.var_name]
if isinstance(log_weights_new, xr.Dataset):
log_weights_new = log_weights_new[loo_inputs.var_name]
combined_log_weights = xr.concat(
[old_log_weights, log_weights_new], dim=update_data.concat_dim
)
log_weights_full = xr.Dataset({loo_inputs.var_name: combined_log_weights})
else:
log_weights_full = None
return ELPDData(
"loo",
elpd_loo_hat,
se,
p_loo,
loo_inputs.n_samples,
loo_inputs.n_data_points,
"log",
warn_mg,
good_k,
elpd_i_full,
pareto_k_full,
approx_posterior,
subsampling_se,
update_data.combined_size,
log_p,
log_q,
thin,
log_weights_full,
)