arviz_stats.loo_kfold

Contents

arviz_stats.loo_kfold#

arviz_stats.loo_kfold(data, wrapper, pointwise=None, var_name=None, k=10, folds=None, stratify_by=None, group_by=None, save_fits=False)[source]#

Perform exact K-fold cross-validation.

K-fold cross-validation evaluates model predictive accuracy by partitioning the data into K complementary subsets (folds), then iteratively refitting the model K times, each time holding out one fold as a test set and training on the remaining K-1 folds.

This method provides an unbiased estimate of model performance by ensuring each observation is used exactly once for testing. Unlike PSIS-LOO-CV (Pareto-smoothed importance sampling leave-one-out cross-validation), which approximates cross-validation efficiently, K-fold requires actual model refitting but yields exact results.

Parameters:
dataxarray.DataTree or InferenceData

Input data containing the posterior and log_likelihood groups from the full model fit.

wrapperSamplingWrapper

An instance of SamplingWrapper class handling model refitting. The wrapper must implement the following methods: sel_observations, sample, get_inference_data, and log_likelihood__i.

pointwisebool, optional

If True, return pointwise estimates. Defaults to rcParams["stats.ic_pointwise"].

var_namestr, optional

The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for computation.

kint, default=10

The number of folds for cross-validation. The data will be partitioned into k subsets of equal (or approximately equal) size.

foldsarray or xarray.DataArray, optional

An optional integer array or DataArray with one element per observation in the data. Each element should be an integer from 1 to k indicating which fold the observation belongs to. For example, with k=4 and 8 observations, one possible assignment is [1,1,2,2,3,3,4,4] to put the first two observations in fold 1, next two in fold 2, etc. If not provided, data will be randomly partitioned into k folds of approximately equal size. DataArray inputs will be automatically flattened to 1D.

stratify_byarray or xarray.DataArray, optional

A categorical variable to use for stratified K-fold splitting. For example, with 8 observations where [0,0,1,1,0,0,1,1] indicates two categories (0 and 1), the algorithm ensures each fold contains approximately the same 50/50 split of 0s and 1s as the full dataset. Cannot be used together with folds or group_by. DataArray inputs will be automatically flattened to 1D.

group_byarray or xarray.DataArray, optional

A grouping variable to use for grouped K-fold splitting. For example, with [1,1,2,2,3,3,4,4] representing 4 subjects with 2 observations each, all observations from subject 1 will be placed in the same fold, all from subject 2 in the same fold, etc. This ensures related observations stay together. Cannot be used together with folds or stratify_by. DataArray inputs will be automatically flattened to 1D.

save_fitsbool, default=False

If True, store the fitted models and fold indices in the returned object.

Returns:
ELPDData

Object with the following attributes:

  • elpd: expected log pointwise predictive density

  • se: standard error of the elpd

  • p: effective number of parameters

  • n_samples: number of samples per fold

  • n_data_points: number of data points

  • warning: True if any issues occurred during fitting

  • elpd_i: pointwise predictive accuracy (if pointwise=True)

  • p_kfold_i: pointwise effective number of parameters (if pointwise=True)

  • pareto_k: None (not applicable for k-fold)

  • scale: “log”

Additional attributes when save_fits=True:

  • fold_fits: Dictionary containing fitted models for each fold

  • fold_indices: Dictionary containing test indices for each fold

See also

loo

Pareto-smoothed importance sampling LOO-CV

SamplingWrapper

Base class for implementing sampling wrappers

Notes

When K equals the number of observations, this becomes exact leave-one-out cross-validation. Note that arviz_stats.loo provides a much more efficient approximation for that case and is recommended.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Unlike PSIS-LOO (which approximates LOO-CV), k-fold cross-validation refits the model k times. So we need to tell loo_kfold how to refit the model.

This is done by creating an instance of the SamplingWrapper class that implements four key methods: sel_observations, sample, get_inference_data, and log_likelihood__i.

In [1]: import numpy as np
   ...: import xarray as xr
   ...: from scipy import stats
   ...: from arviz_base import load_arviz_data, from_dict
   ...: from arviz_stats import loo_kfold
   ...: from arviz_stats.loo import SamplingWrapper
   ...: 
   ...: class CenteredEightWrapper(SamplingWrapper):
   ...:     def __init__(self, idata):
   ...:         super().__init__(model=None, idata_orig=idata)
   ...:         self.y_obs = idata.observed_data["obs"].values
   ...:         self.sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
   ...: 
   ...:     def sel_observations(self, idx):
   ...:         all_idx = np.arange(len(self.y_obs))
   ...:         train_idx = np.setdiff1d(all_idx, idx)
   ...: 
   ...:         train_data = {
   ...:             "y": self.y_obs[train_idx],
   ...:             "sigma": self.sigma[train_idx],
   ...:             "indices": train_idx
   ...:         }
   ...:         test_data = {
   ...:             "y": self.y_obs[idx],
   ...:             "sigma": self.sigma[idx],
   ...:             "indices": idx
   ...:         }
   ...:         return train_data, test_data
   ...: 
   ...:     def sample(self, modified_observed_data):
   ...:         # (Simplified version where we normally would use the actual sampler)
   ...:         train_y = modified_observed_data["y"]
   ...:         n = 1000
   ...:         mu = np.random.normal(train_y.mean(), 5, n)
   ...:         tau = np.abs(np.random.normal(10, 2, n))
   ...:         return {"mu": mu, "tau": tau}
   ...: 
   ...:     def get_inference_data(self, fitted_model):
   ...:         posterior = {
   ...:             "mu": fitted_model["mu"].reshape(1, -1),
   ...:             "tau": fitted_model["tau"].reshape(1, -1)
   ...:         }
   ...:         return from_dict({"posterior": posterior})
   ...: 
   ...:     def log_likelihood__i(self, excluded_obs, idata__i):
   ...:         test_y = excluded_obs["y"]
   ...:         test_sigma = excluded_obs["sigma"]
   ...:         mu = idata__i.posterior["mu"].values.flatten()
   ...:         tau = idata__i.posterior["tau"].values.flatten()
   ...: 
   ...:         var_total = tau[:, np.newaxis] ** 2 + test_sigma**2
   ...:         log_lik = stats.norm.logpdf(
   ...:             test_y, loc=mu[:, np.newaxis], scale=np.sqrt(var_total)
   ...:         )
   ...: 
   ...:         dims = ["chain", "school", "draw"]
   ...:         coords = {"school": excluded_obs["indices"]}
   ...:         return xr.DataArray(
   ...:             log_lik.T[np.newaxis, :, :], dims=dims, coords=coords
   ...:         )
   ...: 

Now let’s run k-fold cross-validation. With k=4, we’ll refit the model 4 times, each time leaving out 2 schools for testing:

In [2]: data = load_arviz_data("centered_eight")
   ...: wrapper = CenteredEightWrapper(data)
   ...: kfold_results = loo_kfold(data, wrapper, k=4, pointwise=True)
   ...: kfold_results
   ...: 
Out[2]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

               Estimate       SE
elpd_loo_kfold   -31.58     0.81
p_loo_kfold        1.75        -

Sometimes we want more control over how the data is split. For instance, if you have imbalanced groups, stratified k-fold ensures each fold has a similar distribution:

In [3]: strata = (data.observed_data["obs"] > 5).astype(int)
   ...: kfold_strat = loo_kfold(data, wrapper, k=4, stratify_by=strata)
   ...: kfold_strat
   ...: 
Out[3]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

               Estimate       SE
elpd_loo_kfold   -31.44     0.74
p_loo_kfold        1.61        -

Moreover, sometimes we want to group observations together. For instance, if we have repeated measurements from the same subject, we can group by subject:

In [4]: groups = xr.DataArray([1, 1, 2, 2, 3, 3, 4, 4], dims="school")
   ...: kfold_group = loo_kfold(data, wrapper, k=4, group_by=groups)
   ...: kfold_group
   ...: 
Out[4]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

               Estimate       SE
elpd_loo_kfold   -31.83     0.75
p_loo_kfold        1.99        -