Inference#

Functions for sampling the inference model, checking for convergence, and processing the results.

get_trace

Sample the prior and posterior distributions for a pymc.model.core.Model returned by build_model() in stratmc.model.

make_initial_values_per_chain

Generate valid (transformed) initial values for MCMC chains.

extend_age_model

Extend age models to a different set of proxy observations using posterior sample ages from an existing inference.

interpolate_proxy

Use interpolated sample ages from extend_age_model() to calculate proxy values at a given set of ages (e.g., to plot 68 and 95% confidence intervals over time for a new proxy using interpolated_proxy_inference() in stratmc.plotting).

age_range_to_height

Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range.

map_ages_to_section

Helper function for section_proxy_signal() in stratmc.plotting.

count_samples

Helper function for proxy_data_density() in stratmc.plotting.

find_gaps

Helper function for proxy_data_gaps() in stratmc.plotting.

calculate_lengthscale_stability

Compute the posterior standard deviation of the pymc.gp.cov.ExpQuad covariance kernel lengthscale as additional chains are considered (i.e., for 1 to N chains).

calculate_proxy_signal_stability

Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered.

get_valid_initial_sample_ages_from_dict

Modify dictionary of initial values such that all detrital and intrusive age constraints are respected.

get_valid_initvals_per_chain

Helper function for generating valid initial values for a single MCMC chain.

ordered_transform_in_prior

Helper function to enforces ordered transform in dictionary of prior draws.

superposition_from_dict

Modify dictionary of initial values such that superposition between depositional age constraints is respected.

superposition_depositional_and_limiting_ages_from_dict

Modify dictionary of initial values such that superposition between limiting and depositional age constraints is respected.

stratmc.inference.age_range_to_height(full_trace, sample_df, ages_df, lower_age, upper_age, **kwargs)[source]#

Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range. If sections is not provided, returns height estimates for every section that overlaps the target age range. To visualize the stratigraphic intervals, use section_age_range() in stratmc.plotting.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

lower_age: float

Lower bound (youngest age) of the target age interval.

upper_age: float

Upper bound (oldest age) of the target age interval.

sections: list(str) or numpy.array(str), optional

List of sections included in the original inference; only required if not all sections in sample_df were included.

Returns:
height_range_df: pandas.DataFrame

Summary statistics for the base and top height of the target age interval (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each section.

stratmc.inference.calculate_lengthscale_stability(full_trace, **kwargs)[source]#

Compute the posterior standard deviation of the pymc.gp.cov.ExpQuad covariance kernel lengthscale as additional chains are considered (i.e., for 1 to N chains). When the posterior has been sufficiently explored, the standard deviation will stabilize; if it has not stabilized, then additional chains should be run. Helper function for lengthscale_stability() in stratmc.plotting.

To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using combine_traces() in stratmc.data.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

proxy: str, optional

Name of the proxy; only required if multiple proxies were included in the inference model.

Returns:
gp_ls_std: np.array of float

Posterior standard deviation of the covariance kernel lengthscale posterior; entries correspond to number of chains considered (first entry is 1 chain, last entry is all N chains).

stratmc.inference.calculate_proxy_signal_stability(full_trace, **kwargs)[source]#

Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered. When the posterior has been sufficiently explored, the residuals will stabilize and approach zero; if they have not stabilized, then additional chains should be run. Helper function for proxy_signal_stability() in stratmc.plotting.

To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using combine_traces() in stratmc.data.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

proxy: str, optional

Name of the proxy; only required if multiple proxies were included in the inference model.

Returns:
median_residuals: np.array of float

Residuals between the median inferred proxy value (at each time in ages passed to get_trace()) calculated using 1 to N-1 chains versus all N chains. Shape is (chains x ages).

stratmc.inference.count_samples(full_trace, time_grid=None)[source]#

Helper function for proxy_data_density() in stratmc.plotting. Counts the number of observations within discrete time bins (based on the posterior sample ages).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

time_grid: np.array, optional

Time bin edges; if not provided, defaults to the ages array passed to get_trace().

Returns:
sample_counts: np.array

Number of observations in each time bin, summed over all posterior draws such that the average number of observations is sample_counts/n.

time_grid: np.array

Time bin edges.

n: int

Number of posterior draws in full_trace.

stratmc.inference.extend_age_model(full_trace, sample_df, ages_df, new_proxies, new_proxy_df=None, **kwargs)[source]#

Extend age models to a different set of proxy observations using posterior sample ages from an existing inference. For instance, extend an age model built using C isotope data to new S isotope data collected from different stratigraphic horizons in the same sections. Note that the age of stratigraphic horizons that were included in sample_df (but marked Exclude? = True) is already passively tracked within the model; this function is only required to estimate the age of observations that were not in sample_df when the inference was run. To estimate ages for new measurements of the same proxy, place the new data in a different column (e.g., ‘d13c_new`).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame with proxy data used during the inference step (as input to build_model() in stratmc.model).

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints used during the inference step (as input to build_model() in stratmc.model).

new_proxies: str or list(str)

New proxy(s) to construct age models for.

new_proxy_df: pandas.DataFrame, optional

pandas.DataFrame containing new proxy observations. Optional; if not provided, uses sample_df (assumes that observations for the new proxy are in the same DataFrame as the original proxy observations).

sections: list(str) or numpy.array(str), optional

List of sections included in the inference; only required if not all sections in sample_df were included.

Returns:
interpolated_df: pandas.DataFrame

pandas.DataFrame with interpolated age draws and sample age summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each new proxy observation.

stratmc.inference.find_gaps(full_trace, time_grid=None)[source]#

Helper function for proxy_data_gaps() in stratmc.plotting. Counts the number of draws from the posterior where there are no observations within discrete time bins (based on the posterior sample ages).

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

time_grid: np.array, optional

Time bin edges; if not provided, defaults to the ages array passed to get_trace().

Returns:
age_gaps: np.array of int

Number of posterior draws where there are no observations; each entry corresponds to an age bin (corresponding to grid_centers and grid_widths).

grid_centers: np.array

Time bin centers.

grid_widths: np.array

Time bin widths.

n: int

Number of posterior draws in full_trace.

stratmc.inference.get_trace(model, gp, ages, sample_df, ages_df, proxies=['d13c'], approximate=False, name='', chains=8, draws=1000, tune=2000, prior_draws=1000, target_accept=0.9, sampler='numpyro', nuts_kwargs=None, jitter=0.001, seed=None, save=True, postprocessing_backend=None, generate_custom_initvals=True, initval_seed=None, save_custom_initvals=True, initvals=None, sample_predictive=True, chain_method='parallel', **kwargs)[source]#

Sample the prior and posterior distributions for a pymc.model.core.Model returned by build_model() in stratmc.model. By default, uses pymc.sampling.jax.sample_numpyro_nuts() to sample the posterior; change sampler to ‘blackjax’ to use pymc.sampling.jax.sample_blackjax_nuts().

After the posterior has been sampled, runs check_inference() in stratmc.tests to check that superposition is never violated in the posterior. Any chains with superposition violations are removed from the trace with drop_chains() before it is returned (if save = True, both the original and ‘cleaned’ traces are saved to the traces subfolder), and a warning is issued. See check_inference() for details; superposition issues are rare, and typically are related to minor violations of detrital or intrusive age constraints.

Problems during sampling, including frequent divergences or minor violations of limiting age constraints, might be resolved by increasing the number of tune steps and/or increasing target_accept (which decreases the step size).

Parameters:
model: pymc.Model

pymc.model.core.Model object returned by build_model() in stratmc.model.

gp: pymc.gp.Latent

Gaussian process prior (pymc.gp.Latent or pymc.gp.HSGP) returned by build_model() in stratmc.model.

ages: numpy.array(float)

array of ages at which to sample the posterior distribution of the proxy signal.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

sections:: list(str) or numpy.array(str), optional

List of sections to include in the inference model. Defaults to all sections in sample_df.

proxies: str or list(str), optional

List of proxies included in the model. Defaults to ‘d13c’.

approximate: bool, optional

Set to True if the Hilbert space GP approximation (pymc.gp.HSGP) was used in build_model(); defaults to False.

name: str, optional

Prefix for the saved NetCDF file with the inference results (suffix is timestamp for function call).

chains: int, optional

Number of Markov chains to sample in parallel; defaults to 8.

draws: int, optional

Number of samples per chain to draw from the posterior; defaults to 1000.

tune: int, optional

Number of iterations to tune; defaults to 1000.

prior_draws: int, optional

Number of samples to draw from the prior; defaults to 1000.

target_accept: float, optional

Between 0 and 1 (exclusive). During tuning, the sampler adapts the proposals such that the average acceptance probability is equal to target_accept; higher values for target_accept typically lead to smaller step sizes. Defaults to 0.9.

generate_custom_initvals: bool, optional

Whether to generate custom initial values for each chain with make_initial_values_per_chain() prior to sampling; defaults to True. Recommended to improve exploration of the posterior, and required to avoid superposition violations for models with intermediate detrial or intrusive age constraints.

initval_seed: int, optional

Random seed for initial value generator.

save_custom_initvals: bool, optional

Whether to save the initial value dictionary generated by make_initial_values_per_chain() (to ‘initial_values’ subfolder in the current directory); defaults to True.

initvals: dict, optional

Diciontary of custom initial values generated by make_initial_values_per_chain(). If not provided, initial values will be generated with prior to sampling if generate_custom_initvals is True; if generated_custom_initvals is False, the default initial point of model will be used.

sampler: str, optional

Which NUTS algorithm to use to sample the posterior (‘numpyro’ or ‘blackjax’); defaults to ‘numpyro’.

nuts_kwargs: dict, optional

Dictionary of keyword arguments passed to NumPyro NUTS sampler (see pymc.sampling.jax.sample_numpyro_nuts() and numpyro.infer.hmc.NUTS) or blackjax NUTS sampler (see pymc.sampling.jax.sample_blackjax_nuts()).

sample_predictive: bool, optional

Whether to draw prior and posterior predictive samples; defaults to True.

jitter: float, optional

Value of jitter passed to pymc.gp.Latent.conditional(). Defaults to 0.001. Changing this value may help if a linear algebra error is encountered during posterior predictive sampling.

postprocessing_backend: str, optional

Use the ‘cpu’ or ‘gpu’ for postprocessing. Defaults to None.

chain_method: str, optional

Method for drawing samples (‘parallel’ or ‘vectorized’); defaults to ‘parallel’. The ‘vectorized’ method should be used for sampling with a GPU (requires installing JAX with GPU support).

seed: int, optional

Random seed for sampler.

save: bool, optional

Whether to save the trace (to ‘traces’ subfolder in the current directory); defaults to True.

Returns:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples.

stratmc.inference.get_valid_initial_sample_ages_from_dict(initval_dict, detrital_age_dist_names, intrusive_age_dist_names, maximum_age_dist_name, minimum_age_dist_name, sample_age_dist_name, sample_heights, detrital_heights, intrusive_heights, interval, sf1_name, sf2_name, shared_radiometric_age_dist)[source]#

Modify dictionary of initial values such that all detrital and intrusive age constraints are respected. Replicates logic used to set model initial point in get_valid_initial_ages() in stratmc.model.

Parameters:
initval_dict: dict

Dictionary of proposed initial values.

detrital_age_dist_names: list(str)

List of names for detrital age constraint distributions in model.

intrusive_age_dist_names: list(str)

List of names for intrusive age constraint distributions in model.

maximum_age_dist_name: str

Name of distribution for underlying maximum age constraint in model.

minimum_age_dist_name: str

Name of distribution for overlying minimum age constraint in model.

sample_age_dist_name: str

Name of sample age distribution (unsorted and unscaled) in the pymc.Model object.

sample_heights: np.array

Array of heights for samples in the current interval.

detrital_heights: list(float)

Heights of detrital age constraints.

intrusive_heights: list(float)

Heights of intrusive age constraints.

interval: int

Current interval number.

sf1_name: str

Name of the distribution associated with scaling factor 1 in model.

sf2_name: str

Name of the distribution associated with scaling factor 2 in model.

shared_radiometric_age_dist: bool, optional

Whether the radiometric age distributions are part of a single object (versus initiated as separate distributions). Defaults to True.

Returns:
initval_dict: dict

Dictionary of valid initial values; keys are model variables.

stratmc.inference.get_valid_initvals_per_chain(initval_dict, sample_df, ages_df, sections)[source]#

Helper function for generating valid initial values for a single MCMC chain. Called in make_initial_values_per_chain().

Parameters:
initval_dict: dict

Dictionary of proposed initial values.

sample_df: pandas.DataFrame

pandas.DataFrame with proxy data used during the inference step (as input to build_model() in stratmc.model).

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints used during the inference step (as input to build_model() in stratmc.model).

sections: list(str) or numpy.array(str)

List of sections included in the inference.

Returns:
initval_dict: list of dict

Dictionary with initial values modified to respect superposition; keys are variable names.

stratmc.inference.interpolate_proxy(interp_df, proxy, ages)[source]#

Use interpolated sample ages from extend_age_model() to calculate proxy values at a given set of ages (e.g., to plot 68 and 95% confidence intervals over time for a new proxy using interpolated_proxy_inference() in stratmc.plotting).

Parameters:
interp_df: pandas.DataFrame

pandas.DataFrame with proxy data and interpolated ages from extend_age_model().

proxy: str

Tracer to interpolate.

ages: list(float) or numpy.array(float)

Target ages at which to interpolate proxy values.

Returns:
interpolated_proxy_df: pandas.DataFrame

pandas.DataFrame with interpolated proxy values and summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) at each age in ages.

stratmc.inference.make_initial_values_per_chain(sample_df, ages_df, model, n_chains, proxies=['d13c'], seed=None, **kwargs)[source]#

Generate valid (transformed) initial values for MCMC chains. Output can be passed to the initvals argument of get_trace().

Parameters:
sample_df: pandas.DataFrame

pandas.DataFrame with proxy data used during the inference step (as input to build_model() in stratmc.model).

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints used during the inference step (as input to build_model() in stratmc.model).

model: pymc.Model

pymc.model.core.Model object returned by build_model() in stratmc.model.

n_chains: int

Number of MCMC chains to generate initial values for.

proxies: list(str)

List of proxies included in the inference.

sections: list(str) or numpy.array(str), optional

List of sections included in the inference; only required if not all sections in sample_df were included.

seed: int, optional

Random seed used to sample the prior.

Returns:
initval_dicts: list(dict)

List of dictionaries (one per chain) with valid initial values for each variable in model; keys are variable names.

stratmc.inference.map_ages_to_section(full_trace, sample_df, ages_df, include_radiometric_ages=False, **kwargs)[source]#

Helper function for section_proxy_signal() in stratmc.plotting. Maps the ages array passed to get_trace() to height in each section using the most likely posterior age models.

Parameters:
full_trace: arviz.InferenceData

An arviz.InferenceData object containing the full set of prior and posterior samples from get_trace() in stratmc.inference.

sample_df: pandas.DataFrame

pandas.DataFrame containing proxy data for all sections.

ages_df: pandas.DataFrame

pandas.DataFrame containing age constraints for all sections.

include_radiometric_ages: bool, optional

Whether to consider radiometric ages when calculating the most likely posterior age model for each section; defaults to False.

sections: list(str) or numpy.array(str), optional

List of sections included in the inference; only required if not all sections in sample_df were included.

Returns:
age_model_df: pandas.DataFrame

pandas.DataFrame with interpolated heights at each age in the ages vector that was passed to get_trace().

stratmc.inference.ordered_transform_in_prior(initval_dict, ordered_dist_name)[source]#

Helper function to enforces ordered transform in dictionary of prior draws.

Parameters:
initval_dict: dict

Dictionary of proposed initial values.

ordered_dist_name:

Names of distribution with ordered transform in model.

Returns:
initval_dict: dict

Dictionary with sorted (increasing) initial values for ordered_dist_name.

stratmc.inference.superposition_depositional_and_limiting_ages_from_dict(initval_dict, depositional_age_name, detrital_age_dist_names, intrusive_age_dist_names, depositional_age_idx=None)[source]#

Modify dictionary of initial values such that superposition between limiting and depositional age constraints is respected. Replicates logic used to set model initial point in superposition_depositional_and_limiting_ages() in stratmc.model.

Parameters:
initval_dict: dict

Dictionary of proposed initial values.

depositional_age_name: str

Name of distribution for target depositional age constraint in model.

detrital_age_dist_names: list(str)

Names of underlying detrital age constraint distributions in model.

intrusive_age_dist_names: list(str)

Names of overlying intrusive age constraint distributions in model.

depositional_age_idx: int, optional

Position of depositional_age in model variable depositional_age_name. Only required if depositional_age is one of multiple ages modeled using a single multidimensional distribution.

Returns:
initval_dict: dict

Dictionary with initial values modified to respect superposition; keys are variable names.

stratmc.inference.superposition_from_dict(initval_dict, age_dist_names, section_age_df)[source]#

Modify dictionary of initial values such that superposition between depositional age constraints is respected. Replicates logic used to set model initial point in superposition() in stratmc.model.

Parameters:
initval_dict: dict

Dictionary of proposed initial values.

age_dist_names:

Names of radiometric age distributions in model (must be in stratigraphic order - lowest to highest).

section_age_df: section_age_df: pandas.DataFrame

pandas.DataFrame containing age constraints for current section.

Returns:
initval_dict: dict

Dictionary with initial values modified to respect superposition; keys are variable names.