Inference#

Functions for sampling the inference model, checking for convergence, and processing the results.

get_trace

Sample the prior and posterior distributions for a pymc.model.core.Model returned by build_model() in stratmc.model.

extend_age_model

Extend age models to a different set of proxy observations using posterior sample ages from an existing inference.

interpolate_proxy

Use interpolated sample ages from extend_age_model() to calculate proxy values at a given set of ages (e.g., to plot 68 and 95% confidence intervals over time for a new proxy using interpolated_proxy_inference() in stratmc.plotting).

age_range_to_height

Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range.

map_ages_to_section

Helper function for section_proxy_signal() in stratmc.plotting.

count_samples

Helper function for proxy_data_density() in stratmc.plotting.

find_gaps

Helper function for proxy_data_gaps() in stratmc.plotting.

calculate_lengthscale_stability

Compute the posterior standard deviation of the pymc.gp.cov.ExpQuad covariance kernel lengthscale as additional chains are considered (i.e., for 1 to N chains).

calculate_proxy_signal_stability

Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered.

stratmc.inference.age_range_to_height(full_trace, sample_df, ages_df, lower_age, upper_age, **kwargs)[source]#

Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range. If sections is not provided, returns height estimates for every section that overlaps the target age range. To visualize the stratigraphic intervals, use section_age_range() in stratmc.plotting.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

lower_age: float

Lower bound (youngest age) of the target age interval.

upper_age: float

Upper bound (oldest age) of the target age interval.

sections: list(str) or numpy.array(str), optional

List of sections included in the original inference; only required if not all sections in sample_df were included.

Returns:
height_range_df: pandas.DataFrame

Summary statistics for the base and top height of the target age interval (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each section.

stratmc.inference.calculate_lengthscale_stability(full_trace, **kwargs)[source]#

Compute the posterior standard deviation of the pymc.gp.cov.ExpQuad covariance kernel lengthscale as additional chains are considered (i.e., for 1 to N chains). When the posterior has been sufficiently explored, the standard deviation will stabilize; if it has not stabilized, then additional chains should be run. Helper function for lengthscale_stability() in stratmc.plotting.

To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using combine_traces() in stratmc.data.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

proxy: str, optional

Name of the proxy; only required if multiple proxies were included in the inference model.

Returns:
gp_ls_std: np.array of float

Posterior standard deviation of the covariance kernel lengthscale posterior; entries correspond to number of chains considered (first entry is 1 chain, last entry is all N chains).

stratmc.inference.calculate_proxy_signal_stability(full_trace, **kwargs)[source]#

Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered. When the posterior has been sufficiently explored, the residuals will stabilize and approach zero; if they have not stabilized, then additional chains should be run. Helper function for proxy_signal_stability() in stratmc.plotting.

To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using combine_traces() in stratmc.data.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

proxy: str, optional

Name of the proxy; only required if multiple proxies were included in the inference model.

Returns:
median_residuals: np.array of float

Residuals between the median inferred proxy value (at each time in ages passed to get_trace()) calculated using 1 to N-1 chains versus all N chains. Shape is (chains x ages).

stratmc.inference.count_samples(full_trace, time_grid=None)[source]#

Helper function for proxy_data_density() in stratmc.plotting. Counts the number of observations within discrete time bins (based on the posterior sample ages).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

time_grid: np.array, optional

Time bin edges; if not provided, defaults to the ages array passed to get_trace().

Returns:
sample_counts: np.array

Number of observations in each time bin, summed over all posterior draws such that the average number of observations is sample_counts/n.

grid_centers: np.array

Time bin centers.

grid_widths: np.array

Time bin widths.

n: int

Number of posterior draws in full_trace.

stratmc.inference.extend_age_model(full_trace, sample_df, ages_df, new_proxies, new_proxy_df=None, **kwargs)[source]#

Extend age models to a different set of proxy observations using posterior sample ages from an existing inference. For instance, extend an age model built using C isotope data to new S isotope data collected from different stratigraphic horizons in the same sections. Note that the age of stratigraphic horizons that were included in sample_df (but marked Exclude? = True) is already passively tracked within the model; this function is only required to estimate the age of observations that were not in sample_df when the inference was run. To estimate ages for new measurements of the same proxy, place the new data in a different column (e.g., ‘d13c_new`).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame with proxy data used during the inference step (as input to build_model() in stratmc.model).

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints used during the inference step (as input to build_model() in stratmc.model).

new_proxies: str or list(str)

New proxy(s) to construct age models for.

new_proxy_df: pandas.DataFrame, optional

pandas.DataFrame containing new proxy observations. Optional; if not provided, uses sample_df (assumes that observations for the new proxy are in the same DataFrame as the original proxy observations).

sections: list(str) or numpy.array(str), optional

List of sections included in the inference; only required if not all sections in sample_df were included.

Returns:
interpolated_df: pandas.DataFrame

pandas.DataFrame with interpolated age draws and sample age summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each new proxy observation.

stratmc.inference.find_gaps(full_trace, time_grid=None)[source]#

Helper function for proxy_data_gaps() in stratmc.plotting. Counts the number of draws from the posterior where there are no observations within discrete time bins (based on the posterior sample ages).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

time_grid: np.array, optional

Time bin edges; if not provided, defaults to the ages array passed to get_trace().

Returns:
age_gaps: np.array of int

Number of posterior draws where there are no observations; each entry corresponds to an age bin (corresponding to grid_centers and grid_widths).

grid_centers: np.array

Time bin centers.

grid_widths: np.array

Time bin widths.

n: int

Number of posterior draws in full_trace.

stratmc.inference.get_trace(model, gp, ages, sample_df, ages_df, proxies=['d13c'], approximate=False, name='', chains=8, draws=1000, tune=2000, prior_draws=1000, target_accept=0.9, sampler='numpyro', nuts_kwargs=None, jitter=0.001, seed=None, save=True, postprocessing_backend=None, **kwargs)[source]#

Sample the prior and posterior distributions for a pymc.model.core.Model returned by build_model() in stratmc.model. By default, uses pymc.sampling.jax.sample_numpyro_nuts() to sample the posterior; change sampler to ‘blackjax` to use pymc.sampling.jax.sample_blackjax_nuts().

After the posterior has been sampled, runs check_inference() in stratmc.tests to check that superposition is never violated in the posterior. Any chains with superposition violations are removed from the trace with drop_chains() before it is returned (if save = True, both the original and ‘cleaned’ traces are saved to the traces subfolder), and a warning is issued. See check_inference() for details; superposition issues are rare, and typically are related to minor violations of detrital or intrusive age constraints.

Problems during sampling, including frequent divergences or minor violations of limiting age constraints, might be resolved by increasing the number of tune steps and/or increasing target_accept (which decreases the step size).

Parameters:
model: PyMC model

pymc.model.core.Model object returned by build_model() in stratmc.model.

gp: pymc.gp.Latent

Gaussian process prior (pymc.gp.Latent or pymc.gp.HSGP) returned by build_model() in stratmc.model.

ages: numpy.array(float)

array of ages at which to sample the posterior distribution of the proxy signal.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

sections:: list(str) or numpy.array(str), optional

List of sections to include in the inference model. Defaults to all sections in sample_df.

proxies: str or list(str), optional

List of proxies included in the model. Defaults to ‘d13c’.

approximate: bool, optional

Set to True if the Hilbert space GP approximation (pymc.gp.HSGP) was used in build_model(); defaults to False.

name: str, optional

Prefix for the saved NetCDF file with the inference results (suffix is timestamp for function call).

chains: int, optional

Number of Markov chains to sample in parallel; defaults to 8.

draws: int, optional

Number of samples per chain to draw from the posterior; defaults to 1000.

tune: int, optional

Number of iterations to tune; defaults to 1000.

prior_draws: int, optional

Number of samples to draw from the prior; defaults to 1000.

target_accept: float, optional

Between 0 and 1 (exclusive). During tuning, the sampler adapts the proposals such that the average acceptance probability is equal to target_accept; higher values for target_accept typically lead to smaller step sizes. Defaults to 0.9.

sampler: str, optional

Which NUTS algorithm to use to sample the posterior (‘numpyro` or ‘blackjax`); defaults to ‘numpyro`.

nuts_kwargs: dict, optional

Dictionary of keyword arguments passed to NumPyro NUTS sampler (see pymc.sampling.jax.sample_numpyro_nuts() and numpyro.infer.hmc.NUTS) or blackjax NUTS sampler (see pymc.sampling.jax.sample_blackjax_nuts()).

jitter: float, optional

Value of jitter passed to pymc.gp.Latent.conditional(). Defaults to 0.001. Changing this value may help if a linear algebra error is encountered during posterior predictive sampling.

postprocessing_backend: str, optional

Use the cpu’ or `gpu’ for postprocessing. Defaults to ``None`.

seed: int, optional

Random seed for sampler.

save: bool, optional

Whether to save the trace (to ‘traces’ subfolder in the current directory); defaults to True.

Returns:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples.

stratmc.inference.interpolate_proxy(interp_df, proxy, ages)[source]#

Use interpolated sample ages from extend_age_model() to calculate proxy values at a given set of ages (e.g., to plot 68 and 95% confidence intervals over time for a new proxy using interpolated_proxy_inference() in stratmc.plotting).

Parameters:
interp_df: pandas.DataFrame

pandas.DataFrame with proxy data and interpolated ages from extend_age_model().

proxy: str

Tracer to interpolate.

ages: list(float) or numpy.array(float)

Target ages at which to interpolate proxy values.

Returns:
interpolated_proxy_df: pandas.DataFrame

pandas.DataFrame with interpolated proxy values and summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) at each age in ages.

stratmc.inference.map_ages_to_section(full_trace, sample_df, ages_df, include_radiometric_ages=False, **kwargs)[source]#

Helper function for section_proxy_signal() in stratmc.plotting. Maps the ages array passed to get_trace() to height in each section using the most likely posterior age models.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

include_radiometric_ages: bool, optional

Whether to consider radiometric ages when calculating the most likely posterior age model for each section; defaults to False.

sections: list(str) or numpy.array(str), optional

List of sections included in the inference; only required if not all sections in sample_df were included.

Returns:
age_model_df: pandas.DataFrame

pandas.DataFrame with interpolated heights at each age in the ages vector that was passed to get_trace().