Inference#
Functions for sampling the inference model, checking for convergence, and processing the results.
Sample the prior and posterior distributions for a |
|
Generate valid (transformed) initial values for MCMC chains. |
|
Extend age models to a different set of proxy observations using posterior sample ages from an existing inference. |
|
Use interpolated sample ages from |
|
Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range. |
|
Helper function for |
|
Helper function for |
|
Helper function for |
|
Compute the posterior standard deviation of the |
|
Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered. |
|
Modify dictionary of initial values such that all detrital and intrusive age constraints are respected. |
|
Helper function for generating valid initial values for a single MCMC chain. |
|
Helper function to enforces ordered transform in dictionary of prior draws. |
|
Modify dictionary of initial values such that superposition between depositional age constraints is respected. |
|
Modify dictionary of initial values such that superposition between limiting and depositional age constraints is respected. |
- stratmc.inference.age_range_to_height(full_trace, sample_df, ages_df, lower_age, upper_age, **kwargs)[source]#
Use the posterior age model for each section to find the stratigraphic interval (with uncertainty) corresponding to a given age range. If
sectionsis not provided, returns height estimates for every section that overlaps the target age range. To visualize the stratigraphic intervals, usesection_age_range()instratmc.plotting.- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- sample_df: pandas.DataFrame
pandas.DataFramecontaining proxy data for all sections.- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints for all sections.- lower_age: float
Lower bound (youngest age) of the target age interval.
- upper_age: float
Upper bound (oldest age) of the target age interval.
- sections: list(str) or numpy.array(str), optional
List of sections included in the original inference; only required if not all sections in
sample_dfwere included.
- Returns:
- height_range_df: pandas.DataFrame
Summary statistics for the base and top height of the target age interval (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each section.
- stratmc.inference.calculate_lengthscale_stability(full_trace, **kwargs)[source]#
Compute the posterior standard deviation of the
pymc.gp.cov.ExpQuadcovariance kernel lengthscale as additional chains are considered (i.e., for 1 to N chains). When the posterior has been sufficiently explored, the standard deviation will stabilize; if it has not stabilized, then additional chains should be run. Helper function forlengthscale_stability()instratmc.plotting.To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using
combine_traces()instratmc.data.- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- proxy: str, optional
Name of the proxy; only required if multiple proxies were included in the inference model.
- Returns:
- gp_ls_std: np.array of float
Posterior standard deviation of the covariance kernel lengthscale posterior; entries correspond to number of chains considered (first entry is 1 chain, last entry is all N chains).
- stratmc.inference.calculate_proxy_signal_stability(full_trace, **kwargs)[source]#
Compute the residuals between the median inferred proxy signal when all N chains are considered compared to when 1 to N-1 chains are considered. When the posterior has been sufficiently explored, the residuals will stabilize and approach zero; if they have not stabilized, then additional chains should be run. Helper function for
proxy_signal_stability()instratmc.plotting.To consider chains from multiple traces associated with the same inference model, first combine the traces (saved as NetCDF files) using
combine_traces()instratmc.data.- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- proxy: str, optional
Name of the proxy; only required if multiple proxies were included in the inference model.
- Returns:
- median_residuals: np.array of float
Residuals between the median inferred proxy value (at each time in
agespassed toget_trace()) calculated using 1 to N-1 chains versus all N chains. Shape is (chains x ages).
- stratmc.inference.count_samples(full_trace, time_grid=None)[source]#
Helper function for
proxy_data_density()instratmc.plotting. Counts the number of observations within discrete time bins (based on the posterior sample ages).- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- time_grid: np.array, optional
Time bin edges; if not provided, defaults to the
agesarray passed toget_trace().
- Returns:
- sample_counts: np.array
Number of observations in each time bin, summed over all posterior draws such that the average number of observations is
sample_counts/n.- time_grid: np.array
Time bin edges.
- n: int
Number of posterior draws in
full_trace.
- stratmc.inference.extend_age_model(full_trace, sample_df, ages_df, new_proxies, new_proxy_df=None, **kwargs)[source]#
Extend age models to a different set of proxy observations using posterior sample ages from an existing inference. For instance, extend an age model built using C isotope data to new S isotope data collected from different stratigraphic horizons in the same sections. Note that the age of stratigraphic horizons that were included in
sample_df(but markedExclude? = True) is already passively tracked within the model; this function is only required to estimate the age of observations that were not insample_dfwhen the inference was run. To estimate ages for new measurements of the same proxy, place the new data in a different column (e.g., ‘d13c_new`).- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- sample_df: pandas.DataFrame
pandas.DataFramewith proxy data used during the inference step (as input tobuild_model()instratmc.model).- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints used during the inference step (as input tobuild_model()instratmc.model).- new_proxies: str or list(str)
New proxy(s) to construct age models for.
- new_proxy_df: pandas.DataFrame, optional
pandas.DataFramecontaining new proxy observations. Optional; if not provided, usessample_df(assumes that observations for the new proxy are in the same DataFrame as the original proxy observations).- sections: list(str) or numpy.array(str), optional
List of sections included in the inference; only required if not all sections in
sample_dfwere included.
- Returns:
- interpolated_df: pandas.DataFrame
pandas.DataFramewith interpolated age draws and sample age summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) for each new proxy observation.
- stratmc.inference.find_gaps(full_trace, time_grid=None)[source]#
Helper function for
proxy_data_gaps()instratmc.plotting. Counts the number of draws from the posterior where there are no observations within discrete time bins (based on the posterior sample ages).- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- time_grid: np.array, optional
Time bin edges; if not provided, defaults to the
agesarray passed toget_trace().
- Returns:
- age_gaps: np.array of int
Number of posterior draws where there are no observations; each entry corresponds to an age bin (corresponding to
grid_centersandgrid_widths).- grid_centers: np.array
Time bin centers.
- grid_widths: np.array
Time bin widths.
- n: int
Number of posterior draws in
full_trace.
- stratmc.inference.get_trace(model, gp, ages, sample_df, ages_df, proxies=['d13c'], approximate=False, name='', chains=8, draws=1000, tune=2000, prior_draws=1000, target_accept=0.9, sampler='numpyro', nuts_kwargs=None, jitter=0.001, seed=None, save=True, postprocessing_backend=None, generate_custom_initvals=True, initval_seed=None, save_custom_initvals=True, initvals=None, sample_predictive=True, chain_method='parallel', **kwargs)[source]#
Sample the prior and posterior distributions for a
pymc.model.core.Modelreturned bybuild_model()instratmc.model. By default, usespymc.sampling.jax.sample_numpyro_nuts()to sample the posterior; changesamplerto ‘blackjax’ to usepymc.sampling.jax.sample_blackjax_nuts().After the posterior has been sampled, runs
check_inference()instratmc.teststo check that superposition is never violated in the posterior. Any chains with superposition violations are removed from the trace withdrop_chains()before it is returned (ifsave = True, both the original and ‘cleaned’ traces are saved to thetracessubfolder), and a warning is issued. Seecheck_inference()for details; superposition issues are rare, and typically are related to minor violations of detrital or intrusive age constraints.Problems during sampling, including frequent divergences or minor violations of limiting age constraints, might be resolved by increasing the number of
tunesteps and/or increasingtarget_accept(which decreases the step size).- Parameters:
- model: pymc.Model
pymc.model.core.Modelobject returned bybuild_model()instratmc.model.- gp: pymc.gp.Latent
Gaussian process prior (
pymc.gp.Latentorpymc.gp.HSGP) returned bybuild_model()instratmc.model.- ages: numpy.array(float)
array of ages at which to sample the posterior distribution of the proxy signal.
- sample_df: pandas.DataFrame
pandas.DataFramecontaining proxy data for all sections.- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints for all sections.- sections:: list(str) or numpy.array(str), optional
List of sections to include in the inference model. Defaults to all sections in
sample_df.- proxies: str or list(str), optional
List of proxies included in the model. Defaults to ‘d13c’.
- approximate: bool, optional
Set to
Trueif the Hilbert space GP approximation (pymc.gp.HSGP) was used inbuild_model(); defaults toFalse.- name: str, optional
Prefix for the saved NetCDF file with the inference results (suffix is timestamp for function call).
- chains: int, optional
Number of Markov chains to sample in parallel; defaults to 8.
- draws: int, optional
Number of samples per chain to draw from the posterior; defaults to 1000.
- tune: int, optional
Number of iterations to tune; defaults to 1000.
- prior_draws: int, optional
Number of samples to draw from the prior; defaults to 1000.
- target_accept: float, optional
Between 0 and 1 (exclusive). During tuning, the sampler adapts the proposals such that the average acceptance probability is equal to
target_accept; higher values fortarget_accepttypically lead to smaller step sizes. Defaults to 0.9.- generate_custom_initvals: bool, optional
Whether to generate custom initial values for each chain with
make_initial_values_per_chain()prior to sampling; defaults toTrue. Recommended to improve exploration of the posterior, and required to avoid superposition violations for models with intermediate detrial or intrusive age constraints.- initval_seed: int, optional
Random seed for initial value generator.
- save_custom_initvals: bool, optional
Whether to save the initial value dictionary generated by
make_initial_values_per_chain()(to ‘initial_values’ subfolder in the current directory); defaults toTrue.- initvals: dict, optional
Diciontary of custom initial values generated by
make_initial_values_per_chain(). If not provided, initial values will be generated with prior to sampling ifgenerate_custom_initvalsisTrue; ifgenerated_custom_initvalsisFalse, the default initial point ofmodelwill be used.- sampler: str, optional
Which NUTS algorithm to use to sample the posterior (‘numpyro’ or ‘blackjax’); defaults to ‘numpyro’.
- nuts_kwargs: dict, optional
Dictionary of keyword arguments passed to NumPyro NUTS sampler (see
pymc.sampling.jax.sample_numpyro_nuts()andnumpyro.infer.hmc.NUTS) or blackjax NUTS sampler (seepymc.sampling.jax.sample_blackjax_nuts()).- sample_predictive: bool, optional
Whether to draw prior and posterior predictive samples; defaults to
True.- jitter: float, optional
Value of
jitterpassed topymc.gp.Latent.conditional(). Defaults to 0.001. Changing this value may help if a linear algebra error is encountered during posterior predictive sampling.- postprocessing_backend: str, optional
Use the ‘cpu’ or ‘gpu’ for postprocessing. Defaults to
None.- chain_method: str, optional
Method for drawing samples (‘parallel’ or ‘vectorized’); defaults to ‘parallel’. The ‘vectorized’ method should be used for sampling with a GPU (requires installing JAX with GPU support).
- seed: int, optional
Random seed for sampler.
- save: bool, optional
Whether to save the trace (to ‘traces’ subfolder in the current directory); defaults to
True.
- Returns:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples.
- stratmc.inference.get_valid_initial_sample_ages_from_dict(initval_dict, detrital_age_dist_names, intrusive_age_dist_names, maximum_age_dist_name, minimum_age_dist_name, sample_age_dist_name, sample_heights, detrital_heights, intrusive_heights, interval, sf1_name, sf2_name, shared_radiometric_age_dist)[source]#
Modify dictionary of initial values such that all detrital and intrusive age constraints are respected. Replicates logic used to set model initial point in
get_valid_initial_ages()instratmc.model.- Parameters:
- initval_dict: dict
Dictionary of proposed initial values.
- detrital_age_dist_names: list(str)
List of names for detrital age constraint distributions in
model.- intrusive_age_dist_names: list(str)
List of names for intrusive age constraint distributions in
model.- maximum_age_dist_name: str
Name of distribution for underlying maximum age constraint in
model.- minimum_age_dist_name: str
Name of distribution for overlying minimum age constraint in
model.- sample_age_dist_name: str
Name of sample age distribution (unsorted and unscaled) in the pymc.Model object.
- sample_heights: np.array
Array of heights for samples in the current interval.
- detrital_heights: list(float)
Heights of detrital age constraints.
- intrusive_heights: list(float)
Heights of intrusive age constraints.
- interval: int
Current interval number.
- sf1_name: str
Name of the distribution associated with scaling factor 1 in
model.- sf2_name: str
Name of the distribution associated with scaling factor 2 in
model.- shared_radiometric_age_dist: bool, optional
Whether the radiometric age distributions are part of a single object (versus initiated as separate distributions). Defaults to
True.
- Returns:
- initval_dict: dict
Dictionary of valid initial values; keys are model variables.
- stratmc.inference.get_valid_initvals_per_chain(initval_dict, sample_df, ages_df, sections)[source]#
Helper function for generating valid initial values for a single MCMC chain. Called in
make_initial_values_per_chain().- Parameters:
- initval_dict: dict
Dictionary of proposed initial values.
- sample_df: pandas.DataFrame
pandas.DataFramewith proxy data used during the inference step (as input tobuild_model()instratmc.model).
- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints used during the inference step (as input tobuild_model()instratmc.model).- sections: list(str) or numpy.array(str)
List of sections included in the inference.
- Returns:
- initval_dict: list of dict
Dictionary with initial values modified to respect superposition; keys are variable names.
- stratmc.inference.interpolate_proxy(interp_df, proxy, ages)[source]#
Use interpolated sample ages from
extend_age_model()to calculate proxy values at a given set of ages (e.g., to plot 68 and 95% confidence intervals over time for a new proxy usinginterpolated_proxy_inference()instratmc.plotting).- Parameters:
- interp_df: pandas.DataFrame
pandas.DataFramewith proxy data and interpolated ages fromextend_age_model().- proxy: str
Tracer to interpolate.
- ages: list(float) or numpy.array(float)
Target ages at which to interpolate proxy values.
- Returns:
- interpolated_proxy_df: pandas.DataFrame
pandas.DataFramewith interpolated proxy values and summary statistics (maximum likelihood estimate, median, and 68% and 95% confidence intervals) at each age inages.
- stratmc.inference.make_initial_values_per_chain(sample_df, ages_df, model, n_chains, proxies=['d13c'], seed=None, **kwargs)[source]#
Generate valid (transformed) initial values for MCMC chains. Output can be passed to the
initvalsargument ofget_trace().- Parameters:
- sample_df: pandas.DataFrame
pandas.DataFramewith proxy data used during the inference step (as input tobuild_model()instratmc.model).- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints used during the inference step (as input tobuild_model()instratmc.model).- model: pymc.Model
pymc.model.core.Modelobject returned bybuild_model()instratmc.model.- n_chains: int
Number of MCMC chains to generate initial values for.
- proxies: list(str)
List of proxies included in the inference.
- sections: list(str) or numpy.array(str), optional
List of sections included in the inference; only required if not all sections in
sample_dfwere included.- seed: int, optional
Random seed used to sample the prior.
- Returns:
- initval_dicts: list(dict)
List of dictionaries (one per chain) with valid initial values for each variable in
model; keys are variable names.
- stratmc.inference.map_ages_to_section(full_trace, sample_df, ages_df, include_radiometric_ages=False, **kwargs)[source]#
Helper function for
section_proxy_signal()instratmc.plotting. Maps theagesarray passed toget_trace()to height in each section using the most likely posterior age models.- Parameters:
- full_trace: arviz.InferenceData
An
arviz.InferenceDataobject containing the full set of prior and posterior samples fromget_trace()instratmc.inference.- sample_df: pandas.DataFrame
pandas.DataFramecontaining proxy data for all sections.- ages_df: pandas.DataFrame
pandas.DataFramecontaining age constraints for all sections.- include_radiometric_ages: bool, optional
Whether to consider radiometric ages when calculating the most likely posterior age model for each section; defaults to
False.- sections: list(str) or numpy.array(str), optional
List of sections included in the inference; only required if not all sections in
sample_dfwere included.
- Returns:
- age_model_df: pandas.DataFrame
pandas.DataFramewith interpolated heights at each age in theagesvector that was passed toget_trace().
- stratmc.inference.ordered_transform_in_prior(initval_dict, ordered_dist_name)[source]#
Helper function to enforces ordered transform in dictionary of prior draws.
- Parameters:
- initval_dict: dict
Dictionary of proposed initial values.
- ordered_dist_name:
Names of distribution with ordered transform in model.
- Returns:
- initval_dict: dict
Dictionary with sorted (increasing) initial values for
ordered_dist_name.
- stratmc.inference.superposition_depositional_and_limiting_ages_from_dict(initval_dict, depositional_age_name, detrital_age_dist_names, intrusive_age_dist_names, depositional_age_idx=None)[source]#
Modify dictionary of initial values such that superposition between limiting and depositional age constraints is respected. Replicates logic used to set model initial point in
superposition_depositional_and_limiting_ages()instratmc.model.- Parameters:
- initval_dict: dict
Dictionary of proposed initial values.
- depositional_age_name: str
Name of distribution for target depositional age constraint in
model.- detrital_age_dist_names: list(str)
Names of underlying detrital age constraint distributions in
model.- intrusive_age_dist_names: list(str)
Names of overlying intrusive age constraint distributions in
model.- depositional_age_idx: int, optional
Position of
depositional_agein model variabledepositional_age_name. Only required ifdepositional_ageis one of multiple ages modeled using a single multidimensional distribution.
- Returns:
- initval_dict: dict
Dictionary with initial values modified to respect superposition; keys are variable names.
- stratmc.inference.superposition_from_dict(initval_dict, age_dist_names, section_age_df)[source]#
Modify dictionary of initial values such that superposition between depositional age constraints is respected. Replicates logic used to set model initial point in
superposition()instratmc.model.- Parameters:
- initval_dict: dict
Dictionary of proposed initial values.
- age_dist_names:
Names of radiometric age distributions in
model(must be in stratigraphic order - lowest to highest).- section_age_df: section_age_df: pandas.DataFrame
pandas.DataFramecontaining age constraints for current section.
- Returns:
- initval_dict: dict
Dictionary with initial values modified to respect superposition; keys are variable names.