from typing import Optional, List, Union, Tuple, Dict, Literal
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import re
import inspect
import arviz as az
from arviz import InferenceData
from lmfit.model import ModelResult
try:
import corner
except ImportError:
corner = None
from .oc import Linear, Quadratic, Keplerian, Sinusoidal, Parameter, OC, ModelComponent
[docs]
class Plot:
[docs]
@staticmethod
def plot_data(
data: OC,
*,
ax: Optional[plt.Axes] = None,
x_col: str = "cycle",
y_col: str = "oc",
plot_kwargs: Optional[dict] = None
) -> plt.Axes:
"""
Plot the raw O−C data with optional error bars and labeling.
Parameters
----------
data : OC
The observational O−C dataset. Must have a `data` attribute (pandas DataFrame)
containing at least columns for `x_col` and `y_col`. Optionally, it can include
'minimum_time_error' for y-error bars and 'labels' for grouping data by category.
ax : matplotlib.axes.Axes, optional
Axes object on which to plot. If None, a new figure and axes are created.
x_col : str, default "cycle"
Name of the column to use for the x-axis.
y_col : str, default "oc"
Name of the column to use for the y-axis.
plot_kwargs : dict, optional
Additional keyword arguments passed to `matplotlib.pyplot.errorbar`.
Returns
-------
matplotlib.axes.Axes
The axes containing the plotted data.
Notes
-----
- If 'labels' exist in the DataFrame, points are color-coded per unique label.
- Unlabeled points are plotted in gray.
- If 'minimum_time_error' exists, it is used as y-error bars.
- Axes are automatically labeled and a grid is added.
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10.0, 5.4))
# ax = ax # Not needed anymore
plot_kwargs = dict(fmt="o", markersize=4.5, color="tab:blue", alpha=0.8, capsize=2, label="Data", zorder=1) | (
plot_kwargs or {})
x_values = np.asarray(data.data[x_col].to_numpy(), dtype=float)
y_values = np.asarray(data.data[y_col].to_numpy(), dtype=float)
if "labels" in data.data.columns:
labels_data = data.data["labels"]
unique_labels = sorted(list(set(labels_data.dropna().unique())))
if len(unique_labels) > 0:
colormap = plt.get_cmap("tab10")
for index, label in enumerate(unique_labels):
mask = (labels_data == label).to_numpy(dtype=bool)
if not np.any(mask):
continue
color = colormap(index % 10)
local_kwargs = plot_kwargs.copy()
local_kwargs["color"] = color
local_kwargs["label"] = f"Data ({label})"
y_error = None
if "minimum_time_error" in data.data.columns:
y_error = np.asarray(data.data["minimum_time_error"].to_numpy(), dtype=float)[mask]
ax.errorbar(x_values[mask], y_values[mask], yerr=y_error, **local_kwargs)
# Unlabeled data
mask_unlabeled = labels_data.isna().to_numpy(dtype=bool)
if np.any(mask_unlabeled):
local_kwargs = plot_kwargs.copy()
local_kwargs["color"] = "gray"
local_kwargs["label"] = "Data (unlabeled)"
y_error = None
if "minimum_time_error" in data.data.columns:
y_error = np.asarray(data.data["minimum_time_error"].to_numpy(), dtype=float)[mask_unlabeled]
ax.errorbar(x_values[mask_unlabeled], y_values[mask_unlabeled], yerr=y_error, **local_kwargs)
ax.legend()
else:
ax.errorbar(x_values, y_values, yerr=(np.asarray(data.data["minimum_time_error"].to_numpy(),
dtype=float) if "minimum_time_error" in data.data.columns else None),
**plot_kwargs)
else:
y_error = None
if "minimum_time_error" in data.data.columns:
y_error = np.asarray(data.data["minimum_time_error"].to_numpy(), dtype=float)
ax.errorbar(x_values, y_values, yerr=y_error, **plot_kwargs)
ax.set_ylabel("O−C")
ax.set_xlabel(x_col.capitalize())
ax.grid(True, alpha=0.25)
return ax
[docs]
@classmethod
def plot_model_pymc(
cls,
inference_data: az.InferenceData,
data: "OCPyMC", # noqa: F821
*,
ax: Optional[plt.Axes] = None,
x_col: str = "cycle",
n_points: int = 800,
sum_kwargs: Optional[dict] = None,
comp_kwargs: Optional[dict] = None,
plot_kwargs: Optional[dict] = None,
plot_band: bool = True,
extension_factor: float = 0.1,
model_components: Optional[list] = None
) -> plt.Axes:
"""
Plot a model fit to O−C data using a PyMC inference result, with optional uncertainty bands
and component decomposition.
Parameters
----------
inference_data : arviz.InferenceData
Posterior samples from a PyMC model. Expected to contain parameter variables (2D arrays
with shape [chain, draw]), and optionally 'y_model' or 'y_model_dense' for precomputed model fits.
data : OCPyMC
Observational O−C dataset to plot against. Must have a `data` attribute (pandas DataFrame)
containing at least the `x_col` column.
ax : matplotlib.axes.Axes, optional
Axes object on which to plot. If None, uses current axes or creates a new figure.
x_col : str, default "cycle"
Column in `data` to use as the x-axis.
n_points : int, default 800
Number of points to use when plotting continuous model curves.
sum_kwargs : dict, optional
Additional keyword arguments for the summed model line.
comp_kwargs : dict, optional
Additional keyword arguments for individual model component lines.
plot_kwargs : dict, optional
Additional keyword arguments for the data points plot (markers, color, etc.).
plot_band : bool, default True
Whether to display a 1σ uncertainty band from posterior samples.
extension_factor : float, default 0.1
Fractional extension beyond the data range for plotting the fit curve.
model_components : list, optional
List of model component objects (Linear, Quadratic, Sinusoidal, Keplerian, etc.)
to reconstruct and plot individual contributions.
Returns
-------
matplotlib.axes.Axes
The axes containing the plotted data, model fit, components, and optional uncertainty band.
Notes
-----
- If 'y_model_dense' and 'dense_x' exist in `inference_data`, they are used for a smooth fit.
- Otherwise, the method reconstructs the model from posterior medians of scalar parameters.
- If `model_components` are provided, each component is plotted individually.
- Uncertainty bands are computed from a subset of posterior samples (default 200 draws).
- Automatically handles multiple components, labeling, and plotting the sum of components.
"""
if ax is None:
ax = plt.gca()
def split_name(variable_name: str):
underscore_index = variable_name.rfind("_")
return (variable_name[:underscore_index],
variable_name[underscore_index + 1:]) if underscore_index != -1 else (None, None)
def parse_prefix(prefix_str: str):
match = re.match(r"^([A-Za-z_]+?)(\d+)?$", prefix_str)
if not match:
return (prefix_str, 0)
base = match.group(1)
index = int(match.group(2)) if match.group(2) is not None else 0
return (base, index)
scalars = [variable_name for variable_name, data_array in inference_data.posterior.data_vars.items() if
getattr(data_array, "ndim", 0) == 2 and variable_name not in {"y_model", "y_model_dense", "y_obs"}]
if not scalars:
return ax
medians_dict: dict[str, float] = {}
for variable_name in scalars:
data_array = inference_data.posterior[variable_name]
value = data_array.median(dim=("chain", "draw")).item()
medians_dict[variable_name] = float(value)
groups: dict[str, dict[str, float]] = {}
for variable_name, value in medians_dict.items():
prefix, param_name = split_name(variable_name)
if prefix is None:
continue
groups.setdefault(prefix, {})[param_name] = value
order = sorted(groups.keys(), key=lambda p: parse_prefix(p))
components = []
valid_order = []
for prefix in order:
base, _ = parse_prefix(prefix)
fields = groups[prefix]
component = None
if base in ("linear", "lin"):
component = Linear(
a=Parameter(value=fields.get("a", 0.0), fixed=True),
b=Parameter(value=fields.get("b", 0.0), fixed=True)
)
elif base in ("quadratic", "quad"):
component = Quadratic(
q=Parameter(value=fields.get("q", 0.0), fixed=True)
)
elif base in ("keplerian", "kep", "lite", "LiTE"):
t0_value = fields.get("T0", fields.get("T", 0.0))
component = Keplerian(
amp=Parameter(value=fields.get("amp", 0.0), fixed=True),
e=Parameter(value=fields.get("e", 0.0), fixed=True),
omega=Parameter(value=fields.get("omega", 0.0), fixed=True),
P=Parameter(value=fields.get("P", 1.0), fixed=True),
T0=Parameter(value=t0_value, fixed=True),
name=prefix,
)
elif base in ("sinusoidal", "sin"):
component = Sinusoidal(
amp=Parameter(value=fields.get("amp", 0.0), fixed=True),
P=Parameter(value=fields.get("P", 1.0), fixed=True)
)
if component is not None:
components.append(component)
valid_order.append(prefix)
order = valid_order
x = np.asarray(data.data[x_col].to_numpy(), dtype=float)
xmin, xmax = (float(np.min(x)), float(np.max(x))) if x.size else (0.0, 1.0)
margin = (xmax - xmin) * extension_factor
xline = np.linspace(xmin - margin, xmax + margin, n_points)
band = None
# 1. Best Fallback: Use y_model_dense if it exists in inference_data
if "y_model_dense" in inference_data.posterior and "dense_x" in inference_data.posterior:
y_dense_post = inference_data.posterior["y_model_dense"]
x_dense_vals = inference_data.posterior["dense_x"].values[0, 0] # Constant across chains/draws
y_fit = y_dense_post.median(dim=("chain", "draw")).values
if not components:
if ax is None:
fig, ax = plt.subplots(figsize=(10.0, 5.4))
fit_color = (plot_kwargs or {}).get("color", "red")
ax.plot(x_dense_vals, y_fit, color=fit_color, lw=2.6, label="Fit (Median)", zorder=5)
if plot_band:
low = y_dense_post.quantile(0.16, dim=("chain", "draw")).values
high = y_dense_post.quantile(0.84, dim=("chain", "draw")).values
ax.fill_between(x_dense_vals, low, high, color=fit_color, alpha=0.3, linewidth=0,
label=r"Uncertainty (1$\sigma$)", zorder=4)
return ax
# 2. Secondary Fallback: Interpolate y_model at observation points
if "y_model" in inference_data.posterior and len(x) == inference_data.posterior["y_model"].shape[-1]:
y_model_post = inference_data.posterior["y_model"]
y_total_obs = y_model_post.median(dim=("chain", "draw")).values
# If we reconstructed no components, we use y_model points as the fit line
if not components:
if ax is None:
fig, ax = plt.subplots(figsize=(10.0, 5.4))
# Handle duplicates and sort using pandas for robustness
import pandas as pd
df_temp = pd.DataFrame({'x': x, 'y': y_total_obs})
# Check for uncertainty band data
if plot_band:
df_temp['low'] = y_model_post.quantile(0.16, dim=("chain", "draw")).values
df_temp['high'] = y_model_post.quantile(0.84, dim=("chain", "draw")).values
# Group by x and take mean to handle multiple observations at the same cycle
df_average = df_temp.groupby('x').mean().sort_index()
xs_clean = df_average.index.values
ys_clean = df_average['y'].values
fit_color = (plot_kwargs or {}).get("color", "red")
x_range = xs_clean.max() - xs_clean.min()
ext_margin = x_range * extension_factor
# Check if expensive model components are available for proper extension
_model_comps = model_components or getattr(inference_data, 'attrs', {}).get('_model_components', None)
_model_prefs = getattr(inference_data, 'attrs', {}).get('_model_prefixes', None)
has_expensive_models = (
_model_comps is not None
and any(getattr(c, '_expensive', False) for c in _model_comps)
)
if has_expensive_models and ext_margin > 0:
# Use model_func with posterior medians to evaluate full extended range
try:
x_full = np.linspace(xs_clean.min() - ext_margin, xs_clean.max() + ext_margin, n_points)
median_params = {}
for var_name in inference_data.posterior.data_vars:
vals = inference_data.posterior[var_name].values
if vals.ndim == 2: # (chain, draw) -> scalar param
median_params[var_name] = float(np.median(vals))
# Build prefixes if not stored: infer from posterior var names
if _model_prefs is None:
_model_prefs = []
base_names = [getattr(c, 'name', c.__class__.__name__.lower()) for c in _model_comps]
counts = {n: base_names.count(n) for n in base_names}
seen = {n: 0 for n in base_names}
for n in base_names:
seen[n] += 1
if counts[n] > 1:
_model_prefs.append(f"{n}{seen[n]}_")
else:
_model_prefs.append(f"{n}_")
y_full = np.zeros(len(x_full), dtype=float)
for comp, pref in zip(_model_comps, _model_prefs):
comp_params = {}
for pname in getattr(comp, 'params', {}):
full_name = pref + pname
if full_name in median_params:
comp_params[pname] = median_params[full_name]
else:
comp_params[pname] = float(comp.params[pname].value)
y_full = y_full + np.asarray(comp.model_func(x_full, **comp_params), dtype=float)
ax.plot(x_full, y_full, color=fit_color, lw=2.6, label="Fit (Median)", zorder=5)
except Exception:
# Fall back to spline interpolation with flat extension
from scipy.interpolate import make_interp_spline
x_inner = np.linspace(xs_clean.min(), xs_clean.max(), 1000)
spl = make_interp_spline(xs_clean, ys_clean, k=3)
y_inner = spl(x_inner)
x_left = np.linspace(xs_clean.min() - ext_margin, xs_clean.min(), 50, endpoint=False)
x_right = np.linspace(xs_clean.max(), xs_clean.max() + ext_margin, 50)[1:]
x_full = np.concatenate([x_left, x_inner, x_right])
y_full = np.concatenate(
[np.full_like(x_left, y_inner[0]), y_inner, np.full_like(x_right, y_inner[-1])])
ax.plot(x_full, y_full, color=fit_color, lw=2.6, label="Fit (Median)", zorder=5)
else:
try:
from scipy.interpolate import make_interp_spline
x_inner = np.linspace(xs_clean.min(), xs_clean.max(), 1000)
spl = make_interp_spline(xs_clean, ys_clean, k=3)
y_inner = spl(x_inner)
x_left = np.linspace(xs_clean.min() - ext_margin, xs_clean.min(), 50, endpoint=False)
x_right = np.linspace(xs_clean.max(), xs_clean.max() + ext_margin, 50)[1:]
x_full = np.concatenate([x_left, x_inner, x_right])
y_full = np.concatenate(
[np.full_like(x_left, y_inner[0]), y_inner, np.full_like(x_right, y_inner[-1])])
ax.plot(x_full, y_full, color=fit_color, lw=2.6, label="Fit (Median)", zorder=5)
if plot_band:
spl_low = make_interp_spline(xs_clean, df_average['low'].values, k=3)
spl_high = make_interp_spline(xs_clean, df_average['high'].values, k=3)
low_inner = spl_low(x_inner)
high_inner = spl_high(x_inner)
low_full = np.concatenate(
[np.full_like(x_left, low_inner[0]), low_inner, np.full_like(x_right, low_inner[-1])])
high_full = np.concatenate([np.full_like(x_left, high_inner[0]), high_inner,
np.full_like(x_right, high_inner[-1])])
ax.fill_between(x_full, low_full, high_full, color=fit_color, alpha=0.3, linewidth=0,
label=r"Uncertainty (1$\sigma$)", zorder=4)
except Exception:
ax.plot(xs_clean, ys_clean, color=fit_color, lw=2.6, label="Fit (Median)", zorder=5)
if plot_band:
ax.fill_between(xs_clean, df_average['low'].values, df_average['high'].values,
color=fit_color, alpha=0.3, linewidth=0, label=r"Uncertainty (1$\sigma$)",
zorder=4)
return ax
if plot_band and components:
subset = az.extract(inference_data, num_samples=200)
y_samples = []
n_draws = subset.sample.size
for sample_index in range(n_draws):
y_total = np.zeros_like(xline)
for index, prefix in enumerate(order):
component = components[index]
kwargs = {}
for param_name in groups[prefix].keys():
variable_name = f"{prefix}_{param_name}"
if variable_name in subset:
value = subset[variable_name].values[sample_index]
kwargs[param_name] = float(value)
y_total += component.model_func(xline, **kwargs)
y_samples.append(y_total)
y_samples = np.array(y_samples)
low = np.percentile(y_samples, 16, axis=0)
high = np.percentile(y_samples, 84, axis=0)
band = (xline, low, high)
return cls.plot_model_components(
components,
xline,
ax=ax,
sum_kwargs=sum_kwargs,
comp_kwargs=comp_kwargs,
uncertainty_band=band
)
[docs]
@classmethod
def plot_model_lmfit(
cls,
result,
data: "OCLMFit", # noqa: F821
*,
ax: Optional[plt.Axes] = None,
x_col: str = "cycle",
n_points: int = 500,
plot_kwargs: Optional[dict] = None,
extension_factor: float = 0.1
) -> plt.Axes:
"""
Plot a model fit to O−C data using an lmfit.ModelResult object, with optional uncertainty bands.
Parameters
----------
result : lmfit.model.ModelResult
The fitted model result returned by `lmfit.Model.fit`.
Expected to provide `eval(x=...)` and optionally `eval_uncertainty(x=..., sigma=1)`.
data : OCLMFit
Observational O−C dataset to plot against. Must have a `data` attribute (pandas DataFrame)
containing at least the `x_col` column.
ax : matplotlib.axes.Axes, optional
Axes object on which to plot. If None, creates a new figure and axes.
x_col : str, default "cycle"
Column in `data` to use as the x-axis.
n_points : int, default 500
Number of points to evaluate for a smooth model curve.
plot_kwargs : dict, optional
Additional keyword arguments for the fit line (color, linewidth, label, etc.).
extension_factor : float, default 0.1
Fractional extension beyond the data range for plotting the fit curve.
Returns
-------
matplotlib.axes.Axes
The axes containing the plotted data, model fit, and optional uncertainty band.
Notes
-----
- The method evaluates the fitted model on a dense set of points across the data range,
optionally extended by `extension_factor`.
- If `result.eval_uncertainty` is available, a 1σ uncertainty band is plotted around the fit.
- Data points are plotted separately using the `plot_data` method (if called externally).
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10.0, 5.4))
x = np.asarray(data.data[x_col].to_numpy(), dtype=float)
xmin, xmax = (float(np.min(x)), float(np.max(x))) if x.size else (0.0, 1.0)
margin = (xmax - xmin) * extension_factor
x_dense = np.linspace(xmin - margin, xmax + margin, n_points)
y_fit_dense = result.eval(x=x_dense)
plot_kwargs = dict(color="red", label="Fit", zorder=5) | (plot_kwargs or {})
try:
dely = result.eval_uncertainty(x=x_dense, sigma=1)
ax.fill_between(x_dense, y_fit_dense - dely, y_fit_dense + dely, color="red", alpha=0.3, linewidth=0,
label=r"Uncertainty (1$\sigma$)", zorder=4)
except Exception:
pass
ax.plot(x_dense, y_fit_dense, **plot_kwargs)
return ax
[docs]
@classmethod
def plot_model_components(
cls,
model_components: list,
xline: np.ndarray,
*,
ax: Optional[plt.Axes] = None,
sum_kwargs: Optional[dict] = None,
comp_kwargs: Optional[dict] = None,
uncertainty_band: Optional[tuple] = None
) -> plt.Axes:
"""
Plot individual model components and their sum over a specified x-range, with optional uncertainty bands.
Parameters
----------
model_components : list
List of model component objects. Each component must have a `model_func(x, **params)` method
and a `params` attribute (dict of Parameter objects or numeric values).
xline : np.ndarray
Array of x-values to evaluate the component models.
ax : matplotlib.axes.Axes, optional
Axes object on which to plot. If None, a new figure and axes are created.
sum_kwargs : dict, optional
Keyword arguments for the summed model curve. Default color is red, linewidth 2.6, alpha 0.95.
comp_kwargs : dict, optional
Keyword arguments for individual component curves. Default linewidth 1.5, alpha 0.9, linestyle '--'.
uncertainty_band : tuple, optional
Tuple `(x_band, y_low, y_high)` representing an uncertainty envelope around the sum of components.
If provided, plotted as a filled area behind the curves.
Returns
-------
matplotlib.axes.Axes
The axes containing the component curves, sum curve, and optional uncertainty band.
Notes
-----
- Each component is evaluated using its `model_func` and current parameter values.
- The sum curve is drawn on top of the components, optionally with an uncertainty band.
- Components without parameters (or with missing parameters) will raise a KeyError.
- Useful for visualizing contributions of multiple additive model components
in O−C analysis or similar time-series modeling contexts.
"""
def _comp_name(comp):
return getattr(comp, "name", comp.__class__.__name__.lower())
def _sig_param_names(comp):
sig = inspect.signature(comp.model_func)
params = list(sig.parameters.values())[1:] # skip 'x'
names = [p.name for p in params if p.kind not in (p.VAR_KEYWORD, p.VAR_POSITIONAL)]
# If the function uses **kwargs, we assume it takes everything in comp.params
if any(p.kind == p.VAR_KEYWORD for p in params):
names = list(getattr(comp, "params", {}).keys())
return names
def _param_value(v):
return getattr(v, "value", v)
def _eval_component(comp, xvals):
pnames = _sig_param_names(comp)
params_dict = getattr(comp, "params", {}) or {}
kwargs = {}
for pname in pnames:
if pname not in params_dict:
raise KeyError(f"Component '{_comp_name(comp)}' missing parameter '{pname}'")
kwargs[pname] = float(_param_value(params_dict[pname]))
return comp.model_func(xvals, **kwargs)
if ax is None:
fig, ax = plt.subplots(figsize=(10.0, 5.4))
sum_color = (sum_kwargs or {}).get("color", "red")
sum_kwargs = dict(lw=2.6, alpha=0.95, label="Sum of selected components", color=sum_color, zorder=5) | (
sum_kwargs or {})
comp_kwargs = dict(lw=1.5, alpha=0.9, linestyle="--") | (comp_kwargs or {})
comp_curves = []
for comp in model_components:
y_comp = _eval_component(comp, xline)
comp_curves.append((comp, y_comp))
y_sum = np.sum([yc for _, yc in comp_curves], axis=0) if comp_curves else np.zeros_like(xline)
if uncertainty_band is not None:
bx, blow, bhigh = uncertainty_band
ax.fill_between(bx, blow, bhigh, color=sum_color, alpha=0.3, linewidth=0, label=r"Uncertainty (1$\sigma$)",
zorder=4)
ax.plot(xline, y_sum, **sum_kwargs)
for component, y_comp in comp_curves:
ax.plot(xline, y_comp, label=_comp_name(component), **comp_kwargs)
return ax
[docs]
@classmethod
def plot(
cls,
data: "OC",
model: Union[InferenceData, ModelResult, List[ModelComponent]] = None,
*,
ax: Optional[plt.Axes] = None,
res_ax: Optional[plt.Axes] = None,
res: bool = True,
title: Optional[str] = None,
x_col: str = "cycle",
y_col: str = "oc",
fig_size: tuple = (10, 7),
plot_kwargs: Optional[dict] = None,
extension_factor: float = 0.1,
model_components: Optional[list] = None
) -> Union[plt.Axes, Tuple[plt.Axes, plt.Axes]]:
"""
Plot data with optional model fit and residuals.
This is a high-level plotting function that can:
- Display raw O−C data points (with optional labels and error bars),
- Overlay model fits from PyMC (`InferenceData`), lmfit (`ModelResult`), or a list of model components,
- Display residuals below the main plot if requested.
Parameters
----------
data : OC
The data object containing O−C measurements. Must have `data` attribute (pandas DataFrame)
with at least columns specified by `x_col` and `y_col`.
model : Union[InferenceData, ModelResult, List[ModelComponent]], optional
Model to overlay on the data:
- PyMC model (`InferenceData` from arviz) with posterior samples,
- lmfit result (`ModelResult`) with `.eval()` method,
- List of component objects with `.model_func` and `.params`.
If None, only the raw data is plotted.
ax : matplotlib.axes.Axes, optional
Axes for the main data/fit plot. If None, a new figure is created.
res_ax : matplotlib.axes.Axes, optional
Axes for residuals plot. If None and `res=True`, a new subplot is created.
res : bool, default True
Whether to plot residuals beneath the main plot.
title : str, optional
Title for the main plot.
x_col : str, default "cycle"
Column in `data.data` used for x-axis.
y_col : str, default "oc"
Column in `data.data` used for y-axis.
fig_size : tuple, default (10, 7)
Figure size in inches.
plot_kwargs : dict, optional
Keyword arguments passed to the main data plot (color, markers, alpha, etc.).
extension_factor : float, default 0.1
Fractional extension beyond the data range for plotting model fits.
model_components : list, optional
If provided, used for plotting PyMC components in `plot_model_pymc`.
Returns
-------
matplotlib.axes.Axes or tuple(matplotlib.axes.Axes, matplotlib.axes.Axes)
If `res=True`, returns a tuple `(ax, res_ax)` for main and residual plots.
Otherwise, returns only `ax`.
Notes
-----
- Automatically handles labeled and unlabeled data points.
- Residuals are computed as `y - y_model` if model is provided.
- Supports both PyMC posterior models (median and uncertainty bands) and lmfit fits.
- Useful for O−C analysis in astronomy or any time-series with additive models.
"""
x = np.asarray(data.data[x_col].to_numpy(), dtype=float)
y = np.asarray(data.data[y_col].to_numpy(), dtype=float)
mask = np.isfinite(x) & np.isfinite(y)
x_clean = x[mask]
y_clean = y[mask]
yerr = None
if "minimum_time_error" in data.data.columns:
yerr = np.asarray(data.data["minimum_time_error"].to_numpy(), dtype=float)
yerr_clean = yerr[mask] if yerr is not None else None
labels = data.data.get("labels", None)
labels_clean = labels[mask] if labels is not None else None
# ax_main = ax
# res_ax_internal = res_ax
if ax is None:
if model is not None and res:
fig, (ax, res_ax) = plt.subplots(2, 1, figsize=fig_size, sharex=True,
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.04})
else:
fig, ax = plt.subplots(figsize=(fig_size[0], fig_size[1] * 0.75))
res_ax = None
else:
if res and res_ax is None:
res = False
cls.plot_data(data, ax=ax, x_col=x_col, y_col=y_col, plot_kwargs=plot_kwargs)
def _plot_resid(ax_r, x_r, resid_r, yerr_r, labels_r):
# scatter_kwargs = dict(fmt='o', markersize=3, alpha=0.8, elinewidth=0.8, capsize=1) # This was unused
resid_kwargs = dict(fmt='o', markersize=3, alpha=0.8, elinewidth=0.8, capsize=1)
if labels_r is not None:
unique_labels = sorted(list(set(labels_r.dropna().unique())))
if len(unique_labels) > 0:
cmap = plt.get_cmap("tab10")
for i, lbl in enumerate(unique_labels):
m = (labels_r == lbl).to_numpy(dtype=bool)
if not np.any(m):
continue
c = cmap(i % 10)
ax_r.errorbar(x_r[m], resid_r[m], yerr=(yerr_r[m] if yerr_r is not None else None), color=c,
**resid_kwargs)
# Check for unlabeled data (NaN in labels)
m_nan = labels_r.isna().to_numpy(dtype=bool)
if np.any(m_nan):
ax_r.errorbar(x_r[m_nan], resid_r[m_nan], yerr=(yerr_r[m_nan] if yerr_r is not None else None),
color="gray", **resid_kwargs)
return
ax_r.errorbar(x_r, resid_r, yerr=yerr_r, color="tab:blue", **resid_kwargs)
if model is not None:
is_pymc = hasattr(model, "posterior")
is_lmfit = hasattr(model, "eval")
is_list = isinstance(model, (list, tuple))
is_component = hasattr(model, "model_func") and hasattr(model, "params")
if is_component:
model = [model]
is_list = True
if is_pymc:
cls.plot_model_pymc(inference_data=model, data=data, ax=ax, x_col=x_col, plot_kwargs=plot_kwargs,
extension_factor=extension_factor, model_components=model_components)
if res and res_ax is not None:
y_model_post = model.posterior["y_model"]
yfit = y_model_post.median(dim=("chain", "draw")).values
if yfit.shape == y.shape:
resid = y - yfit
_plot_resid(res_ax, x, resid, yerr, labels)
res_ax.axhline(0, color="gray", lw=1.5, ls="--", alpha=0.6)
elif is_lmfit:
cls.plot_model_lmfit(result=model, data=data, ax=ax, x_col=x_col, plot_kwargs=plot_kwargs,
extension_factor=extension_factor)
if res and res_ax is not None:
y_fit_at_x = model.eval(x=x_clean)
resid = y_clean - y_fit_at_x
_plot_resid(res_ax, x_clean, resid, yerr_clean, labels_clean)
res_ax.axhline(0, color="gray", lw=1.5, ls="--", alpha=0.6)
elif is_list:
xmin, xmax = (float(np.min(x)), float(np.max(x))) if x.size else (0.0, 1.0)
margin = (xmax - xmin) * extension_factor
xline = np.linspace(xmin - margin, xmax + margin, 800)
cls.plot_model_components(model, xline=xline, ax=ax)
if res and res_ax is not None:
y_model_at_obs = np.zeros_like(x)
# ... internal logic omitted for brevity, but I need to make sure I don't break it
def _sig_param_names(comp):
sig = inspect.signature(comp.model_func)
params = list(sig.parameters.values())[1:]
names = [p.name for p in params if p.kind not in (p.VAR_KEYWORD, p.VAR_POSITIONAL)]
if any(p.kind == p.VAR_KEYWORD for p in params):
names = list(getattr(comp, "params", {}).keys())
return names
def _param_value(v):
return getattr(v, "value", v)
for comp in model:
pnames = _sig_param_names(comp)
params_dict = getattr(comp, "params", {}) or {}
kwargs = {}
for pname in pnames:
if pname in params_dict:
kwargs[pname] = float(_param_value(params_dict[pname]))
y_model_at_obs += comp.model_func(x, **kwargs)
resid = y - y_model_at_obs
_plot_resid(res_ax, x, resid, yerr, labels)
res_ax.axhline(0, color="gray", lw=1.5, ls="--", alpha=0.6)
if res_ax:
res_ax.set_ylabel("Resid")
res_ax.set_xlabel(x_col.capitalize())
res_ax.grid(True, alpha=0.25)
ax.set_xlabel("")
if title:
ax.set_title(title)
ax.legend(loc="best")
if ax is None: # If we created the figure internally
if res_ax is None: # If no residuals subplot was created
try:
fig.tight_layout()
except Exception:
pass
return ax
@staticmethod
def _format_label(name: str, unit: Optional[str] = None) -> str:
r"""
Convert a parameter name into a nicely formatted LaTeX string for plotting.
Parameters
----------
name : str
The parameter name, e.g., 'P', 'omega', 'q', 'amp', or with suffixes like 'amp_1'.
unit : str, optional
Optional unit string to append, e.g., 'days', 'deg'.
Returns
-------
str
LaTeX-formatted string suitable for matplotlib labels.
Examples:
- "omega" -> r"$\omega$"
- "amp_1" -> r"$A_{1}$"
- "P" with unit "days" -> r"$P$ [days]"
Notes
-----
- Recognizes common O−C or orbital parameter names: omega, e, P, T0, T, m, a, b, q, sigma, gamma, tau, amp.
- If the name includes an index (like 'amp_2'), it is converted to a subscript in LaTeX.
- If the name is not in the predefined mapping, the raw name is returned (optionally with unit).
"""
mapping = {
"omega": r"$\omega$",
"e": r"$e$",
"P": r"$P$",
"T0": r"$T_0$",
"T": r"$T$",
"m": r"$m$",
"a": r"$a$",
"b": r"$b$",
"q": r"$q$",
"sigma": r"$\sigma$",
"gamma": r"$\gamma$",
"tau": r"$\tau$",
"amp": r"$A$"
}
parts = name.rsplit("_", 1)
if len(parts) == 2 and parts[1] in mapping:
sym = mapping[parts[1]]
formatted = sym
pre = parts[0]
m = re.match(r".*?(\d+)$", pre)
if m:
formatted = fr"{sym}_{{{m.group(1)}}}"
else:
pass
elif name in mapping:
formatted = mapping[name]
else:
formatted = name
if unit:
return fr"{formatted} [{unit}]"
return formatted
[docs]
@staticmethod
def plot_corner(
inference_data: az.InferenceData,
var_names: Optional[List[str]] = None,
cornerstyle: Literal["corner", "arviz"] = "corner",
units: Optional[Dict[str, str]] = None,
**kwargs
) -> Union[plt.Figure, az.plot_pair]:
"""
Generate a corner plot (pairwise parameter plot) from PyMC/ArviZ inference data.
Parameters
----------
inference_data : arviz.InferenceData
The posterior samples from a Bayesian model.
var_names : list of str, optional
Names of parameters to include. If None, automatically selects all
posterior parameters with 2 dimensions (chain, draw) excluding model output.
cornerstyle : {'corner', 'arviz'}, default='corner'
Which library/style to use for the corner plot:
- 'corner' : uses the `corner` package.
- 'arviz' : uses ArviZ's `plot_pair`.
units : dict, optional
Mapping from parameter name to unit string, e.g., {'P': 'days'}.
Units are appended to axis labels.
**kwargs
Additional keyword arguments passed to `corner.corner` or `arviz.plot_pair`.
Returns
-------
matplotlib.figure.Figure
The figure object for the 'corner' style.
arviz.plot_pair axes or Figure
The ArviZ axes object when `cornerstyle='arviz'`.
Notes
-----
- Automatically ignores fixed parameters with negligible variation.
- Adds median values as "truths" in the plot if using 'corner' and no `truths` provided.
- If units are provided, labels are formatted with LaTeX, e.g., r"$P$ [days]".
- Supports up to 200 subplots in ArviZ style using `plot.max_subplots` context.
Raises
------
ImportError
If `corner` library is required but not installed.
ValueError
If no suitable parameters are found for plotting.
"""
if var_names is None:
variable_candidates = [var_name for var_name in inference_data.posterior.data_vars
if getattr(inference_data.posterior[var_name], "ndim", 0) == 2
and var_name not in {"y_model", "y_model_dense", "y_obs", "dense_x"}]
else:
variable_candidates = var_names
selected_variables = []
for var_name in variable_candidates:
values_array = inference_data.posterior[var_name].values
# Check if there's actual variation. ptp is peak-to-peak (max - min)
if np.ptp(values_array) > 1e-12:
selected_variables.append(var_name)
if not selected_variables:
# If everything is fixed, we can't really do a corner plot,
# but let's at least not crash or warn cryptically.
if variable_candidates:
selected_variables = [variable_candidates[0]]
else:
raise ValueError("No suitable parameters found for corner plot.")
if cornerstyle == "corner":
if corner is None:
raise ImportError("Corner plot requires 'corner' library. Please install it with `pip install corner`.")
extracted_samples = az.extract(inference_data, var_names=selected_variables)
samples = np.vstack([extracted_samples[var_name].values for var_name in selected_variables]).T
# Map the 'range' list if its length doesn't match selected_variables
if "range" in kwargs and isinstance(kwargs["range"], (list, np.ndarray)):
range_list = list(kwargs["range"])
if len(range_list) != len(selected_variables):
all_variables = list(inference_data.posterior.data_vars)
if len(range_list) == len(all_variables):
indices = [all_variables.index(v) for v in selected_variables]
kwargs["range"] = [range_list[idx] for idx in indices]
elif len(range_list) == len(variable_candidates):
indices = [variable_candidates.index(v) for v in selected_variables]
kwargs["range"] = [range_list[idx] for idx in indices]
elif len(set(range_list)) == 1:
# If all values are the same, corner accepts a single float
kwargs["range"] = range_list[0]
# Calculate medians for truth lines
medians = [float(inference_data.posterior[var_name].median(dim=("chain", "draw"))) for var_name in
selected_variables]
plot_labels = [Plot._format_label(var_name, (units or {}).get(var_name)) for var_name in selected_variables]
if "quantiles" not in kwargs:
kwargs["quantiles"] = [0.16, 0.5, 0.84]
if "show_titles" not in kwargs:
kwargs["show_titles"] = True
if "title_fmt" not in kwargs:
kwargs["title_fmt"] = ".4f"
# Add truths if not already provided
if "truths" not in kwargs:
kwargs["truths"] = medians
kwargs.setdefault("truth_color", "red")
figure = corner.corner(samples, labels=plot_labels, **kwargs)
return figure
elif cornerstyle == "arviz":
if "marginals" not in kwargs:
kwargs["marginals"] = True
if "kind" not in kwargs:
kwargs["kind"] = "kde"
with az.rc_context({"plot.max_subplots": 200}):
return az.plot_pair(inference_data, var_names=selected_variables, **kwargs)
else:
raise ValueError(f"Unknown cornerstyle: {cornerstyle}. Use 'corner' or 'arviz'.")
[docs]
@staticmethod
def plot_trace(inference_data, var_names=None, **kwargs) -> matplotlib.axes.Axes:
"""
Generate trace plots for posterior samples from a PyMC InferenceData object.
Trace plots show the sampled parameter values across chains and draws,
allowing evaluation of convergence, mixing, and sampling behavior.
Parameters
----------
inference_data : arviz.InferenceData
The posterior sampling results, typically returned by a PyMC fit.
var_names : list of str, optional
List of parameter names to include in the trace plot.
If None, all variable parameters with variation are included.
**kwargs
Additional keyword arguments passed to `arviz.plot_trace`.
Returns
-------
matplotlib.axes.Axes
Array of matplotlib axes objects containing the trace plots.
Notes
-----
- Automatically excludes fixed parameters (near-zero variance) unless all are fixed.
- Figures are automatically tightened with `tight_layout`.
- Designed to complement `plot_corner` for full posterior visualization.
"""
if var_names is None:
# Only take variables with 2 dimensions (chain, draw) - these are usually parameters
variable_candidates = [var_name for var_name in inference_data.posterior.data_vars
if var_name not in {"y_model", "y_model_dense", "y_obs", "dense_x"}
and getattr(inference_data.posterior[var_name], "ndim", 0) == 2]
else:
variable_candidates = var_names
selected_variables = []
for var_name in variable_candidates:
values_array = inference_data.posterior[var_name].values
# Exclude variables with near-zero variance (e.g. fixed parameters)
if values_array.std() > 1e-11:
selected_variables.append(var_name)
if not selected_variables:
selected_variables = variable_candidates
axes = az.plot_trace(inference_data, var_names=selected_variables, **kwargs)
try:
fig = axes.flatten()[0].figure
fig.tight_layout()
except Exception:
pass
return axes