Source code for poisson_topicmodels.models.CSPF

from typing import Any, Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
import pandas as pd
import scipy.sparse as sparse
from numpyro import param, plate, sample
from numpyro.distributions import constraints
from scipy import stats as sp_stats

from .numpyro_model import NumpyroModel


[docs] class CSPF(NumpyroModel): """ Covariate Seeded Poisson Factorization with grouped design-adaptive shrinkage. This implementation preserves the CSPF interface while replacing the internal covariate-effect specification with the model in ``CSPF_model_new.tex``. """ def __init__( self, counts: sparse.csr_matrix, vocab: np.ndarray, keywords: Dict[Any, List[str]], residual_topics: int, batch_size: int, X_design_matrix: Optional[np.ndarray] = None, ) -> None: super().__init__() if not sparse.issparse(counts): raise TypeError(f"counts must be a scipy sparse matrix, got {type(counts).__name__}") D, V = counts.shape if D == 0 or V == 0: raise ValueError(f"counts matrix is empty: shape ({D}, {V})") if vocab.shape[0] != V: raise ValueError(f"vocab size {vocab.shape[0]} != counts columns {V}") if not isinstance(keywords, dict): raise TypeError(f"keywords must be dict, got {type(keywords).__name__}") if residual_topics < 0: raise ValueError(f"residual_topics must be >= 0, got {residual_topics}") if batch_size <= 0 or batch_size > D: raise ValueError(f"batch_size must satisfy 0 < batch_size <= {D}, got {batch_size}") covariate_names: List[str] x_np: np.ndarray if X_design_matrix is not None: if isinstance(X_design_matrix, pd.DataFrame): covariate_names = [str(col) for col in X_design_matrix.columns] x_np = np.asarray(X_design_matrix.values) else: x_np = np.asarray(X_design_matrix) if x_np.ndim != 2: raise ValueError(f"covariates must be 2D, got shape {x_np.shape}") if not isinstance(X_design_matrix, pd.DataFrame): covariate_names = [f"cov_{i}" for i in range(x_np.shape[1])] if x_np.shape[0] != D: raise ValueError(f"covariates has {x_np.shape[0]} rows, expected {D}") if x_np.shape[1] == 0: raise ValueError("covariates matrix is empty (0 columns)") else: x_np = np.ones((D, 1), dtype=np.float32) covariate_names = ["intercept_cov"] # vocab_set = set(vocab) # for topic_id, words in keywords.items(): # for word in words: # if word not in vocab_set: # raise ValueError(f"Keyword '{word}' (topic {topic_id}) not in vocabulary") self.counts = counts self.D = D self.V = V self.vocab = vocab self.K = residual_topics + len(keywords) self.keywords = keywords self.residual_topics = residual_topics self.batch_size = batch_size vocab_lookup = {word: index for index, word in enumerate(vocab)} kw_indices_topics = [ (idx, vocab_lookup[keyword]) for idx, topic_id in enumerate(keywords.keys()) for keyword in keywords[topic_id] if keyword in vocab_lookup ] self.Tilde_V = len(kw_indices_topics) self.kw_indices = tuple(zip(*kw_indices_topics)) if kw_indices_topics else ((), ()) self.X_design_matrix = jnp.array(x_np) self.C = self.X_design_matrix.shape[1] self.covariates = covariate_names self.group_index = self._build_group_index(self.covariates) self.G = int(self.group_index.max()) + 1 if self.C > 0 else 0 self.group_scaling_diag = self._compute_group_scaling_diag(x_np, self.group_index, self.G) self.b_theta = 0.3 self.softplus_inv_one = float(np.log(np.expm1(1.0))) self.s_lambda0 = 1.0 self.a_tau = 0.5 self.a_rho_tau = 0.5 self.b_rho_tau = 1.0 self.a_delta = 0.5 self.a_rho_delta = 0.5 self.b_rho_delta = 1.0 @staticmethod def _build_group_index(covariate_names: List[str]) -> np.ndarray: """ Infer covariate groups from names using explicit separators. Supported separators: ``::``, ``=``, ``[name]`` notation. If none is present, each covariate is treated as its own group. """ group_keys: List[str] = [] for name in covariate_names: if "::" in name: key = name.split("::", 1)[0] elif "=" in name: key = name.split("=", 1)[0] elif "[" in name and name.endswith("]"): key = name.split("[", 1)[0] else: key = name group_keys.append(key) key_to_id: Dict[str, int] = {} ids: List[int] = [] for key in group_keys: if key not in key_to_id: key_to_id[key] = len(key_to_id) ids.append(key_to_id[key]) return np.asarray(ids, dtype=np.int32) @staticmethod def _compute_group_scaling_diag( x_np: np.ndarray, group_index: np.ndarray, G: int ) -> jnp.ndarray: """ Compute diagonal entries of ``(X_g^T X_g)^{-1}`` per covariate column. For one-hot dummy columns this equals ``1 / n_j`` as in the model spec. """ C = x_np.shape[1] scaling = np.zeros(C, dtype=np.float32) ridge = 1e-8 for g in range(G): cols = np.where(group_index == g)[0] xg = x_np[:, cols] xtx = xg.T @ xg xtx_inv = np.linalg.inv(xtx + ridge * np.eye(xtx.shape[0], dtype=np.float32)) scaling[cols] = np.diag(xtx_inv) return jnp.asarray(scaling) def _model(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray) -> None: with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): beta = sample("beta", dist.Gamma(0.3, 0.3)) with plate("tilde_v", size=self.Tilde_V): beta_tilde = sample("beta_tilde", dist.Gamma(1.0, 0.3)) beta = beta.at[self.kw_indices].add(beta_tilde) with plate("k_intercept", size=self.K): lambda_0 = sample( "lambda_intercept", dist.Normal(self.softplus_inv_one, self.s_lambda0) ) rho_tau = sample( "rho_tau", dist.Gamma(self.a_rho_tau, self.b_rho_tau) ) # tau (equation 8) tau2 = sample("tau2", dist.Gamma(self.a_tau, rho_tau)) # rho | tau (equation 8) with plate("g", size=self.G, dim=-2): with plate("g_k", size=self.K, dim=-1): rho_delta = sample( "rho_delta", dist.Gamma(self.a_rho_delta, self.b_rho_delta) ) # equation 9 delta2 = sample( "delta2", dist.Gamma(self.a_delta, rho_delta) ) # delta | rho (equation 9) group_index = jnp.asarray(self.group_index) delta2_per_cov = delta2[group_index, :] lambda_scale = jnp.sqrt(tau2[None, :] * delta2_per_cov * self.group_scaling_diag[:, None]) with plate("c", size=self.C, dim=-2): with plate("c_k", size=self.K, dim=-1): lambda_ = sample("lambda", dist.Normal(0.0, lambda_scale)) eta_theta = lambda_0[None, :] + jnp.matmul(self.X_design_matrix, lambda_) # equation 2 mu_theta = jax.nn.softplus(eta_theta)[d_batch] # equation 2 theta_rate = self.b_theta / mu_theta # equation 1 with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): theta = sample("theta", dist.Gamma(self.b_theta, theta_rate)) P = jnp.matmul(theta, beta) with plate("d_v", size=self.V, dim=-1): sample("Y_batch", dist.Poisson(P), obs=Y_batch) def _guide(self, Y_batch, d_batch): a_beta = param( "beta_shape", init_value=jnp.ones([self.K, self.V]), constraint=constraints.positive ) b_beta = param( "beta_rate", init_value=jnp.ones([self.K, self.V]) * self.D / 1000 * 2, constraint=constraints.positive, ) a_theta = param( "theta_shape", init_value=jnp.ones([self.D, self.K]), constraint=constraints.positive ) b_theta = param( "theta_rate", init_value=jnp.ones([self.D, self.K]) * self.D / 1000, constraint=constraints.positive, ) a_beta_tilde = param( "beta_tilde_shape", init_value=jnp.ones([self.Tilde_V]) * 2, constraint=constraints.positive, ) b_beta_tilde = param( "beta_tilde_rate", init_value=jnp.ones([self.Tilde_V]), constraint=constraints.positive ) location_lambda0 = param("lambda_intercept_location", init_value=jnp.zeros([self.K])) scale_lambda0 = param( "lambda_intercept_scale", init_value=jnp.ones([self.K]), constraint=constraints.positive, ) location_lambda = param("lambda_location", init_value=jnp.zeros([self.C, self.K])) scale_lambda = param( "lambda_scale", init_value=jnp.ones([self.C, self.K]), constraint=constraints.positive ) a_rho_tau = param( "rho_tau_shape", init_value=jnp.ones([self.K]), constraint=constraints.positive ) b_rho_tau = param( "rho_tau_rate", init_value=jnp.ones([self.K]), constraint=constraints.positive ) a_tau2 = param("tau2_shape", init_value=jnp.ones([self.K]), constraint=constraints.positive) b_tau2 = param("tau2_rate", init_value=jnp.ones([self.K]), constraint=constraints.positive) a_rho_delta = param( "rho_delta_shape", init_value=jnp.ones([self.G, self.K]), constraint=constraints.positive, ) b_rho_delta = param( "rho_delta_rate", init_value=jnp.ones([self.G, self.K]), constraint=constraints.positive ) a_delta2 = param( "delta2_shape", init_value=jnp.ones([self.G, self.K]), constraint=constraints.positive ) b_delta2 = param( "delta2_rate", init_value=jnp.ones([self.G, self.K]), constraint=constraints.positive ) with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): sample("beta", dist.Gamma(a_beta, b_beta)) with plate("tilde_v", size=self.Tilde_V): sample("beta_tilde", dist.Gamma(a_beta_tilde, b_beta_tilde)) with plate("k_intercept", size=self.K): sample("lambda_intercept", dist.Normal(location_lambda0, scale_lambda0)) sample("rho_tau", dist.Gamma(a_rho_tau, b_rho_tau)) sample("tau2", dist.Gamma(a_tau2, b_tau2)) with plate("g", size=self.G, dim=-2): with plate("g_k", size=self.K, dim=-1): sample("rho_delta", dist.Gamma(a_rho_delta, b_rho_delta)) sample("delta2", dist.Gamma(a_delta2, b_delta2)) with plate("c", size=self.C, dim=-2): with plate("c_k", size=self.K, dim=-1): sample("lambda", dist.Normal(location_lambda, scale_lambda)) with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): sample("theta", dist.Gamma(a_theta[d_batch], b_theta[d_batch]))
[docs] def return_topics(self): def recode_cats(argmaxes, keywords): num_keywords = len(keywords.keys()) max_index = num_keywords - 1 keyword_keys = np.array(list(keywords.keys())).astype(str) argmaxes_clipped = np.clip(argmaxes, 0, max_index) topics = np.where( argmaxes <= max_index, keyword_keys[argmaxes_clipped], f"No_keyword_topic_{argmaxes - max_index}", ) return topics E_theta = self.estimated_params["theta_shape"] / self.estimated_params["theta_rate"] categories = np.argmax(E_theta, axis=1) topics = recode_cats(np.array(categories), self.keywords) return topics, E_theta
[docs] def return_beta(self): E_beta = self.estimated_params["beta_shape"] / self.estimated_params["beta_rate"] E_beta_tilde = ( self.estimated_params["beta_tilde_shape"] / self.estimated_params["beta_tilde_rate"] ) E_beta = E_beta.at[self.kw_indices].add(E_beta_tilde) return pd.DataFrame(jnp.transpose(E_beta), index=self.vocab, columns=self._topic_names())
[docs] def return_covariate_effects(self) -> pd.DataFrame: """Return point estimates of covariate effects (lambda). Returns ------- pd.DataFrame DataFrame with covariates as rows and topics as columns. """ index = self.covariates return pd.DataFrame( self.estimated_params["lambda_location"], index=index, columns=self._topic_names() )
[docs] def return_covariate_effects_ci(self, ci: float = 0.95) -> pd.DataFrame: """Return covariate effects with credible intervals. Uses the Normal variational posterior for lambda: ``mean = lambda_location``, ``CI = mean +/- z * lambda_scale``. Parameters ---------- ci : float, optional Credible-interval level (default 0.95). Returns ------- pd.DataFrame DataFrame with columns ``['covariate', 'topic', 'mean', 'lower', 'upper']``. Raises ------ ValueError If model has not been trained yet. """ if not self.estimated_params: raise ValueError("Model must be trained before calling return_covariate_effects_ci()") loc = np.asarray(self.estimated_params["lambda_location"]) # (C, K) scale = np.asarray(self.estimated_params["lambda_scale"]) # (C, K) z = sp_stats.norm.ppf(1.0 - (1.0 - ci) / 2.0) topic_names = self._topic_names() rows = [] for c_idx, cov_name in enumerate(self.covariates): for k_idx, topic_name in enumerate(topic_names): rows.append( { "covariate": cov_name, "topic": topic_name, "mean": float(loc[c_idx, k_idx]), "lower": float(loc[c_idx, k_idx] - z * scale[c_idx, k_idx]), "upper": float(loc[c_idx, k_idx] + z * scale[c_idx, k_idx]), } ) return pd.DataFrame(rows)
def _summary_extra(self) -> str: """CSPF-specific summary information.""" lines = [ f" Keywords: {len(self.keywords)} seeded topics", f" Residual topics: {self.residual_topics}", f" Covariates (C): {self.C}", f" Covariate names: {', '.join(self.covariates)}", ] return "\n".join(lines) # ------------------------------------------------------------------ # Forest-plot visualisation # ------------------------------------------------------------------ def _topic_names(self) -> List[str]: """Return ordered list of all topic names (keyword + residual).""" return list(self.keywords.keys()) + [ f"residual_topic_{i + 1}" for i in range(self.residual_topics) ] def _group_names(self) -> List[str]: """Return ordered list of covariate-group names.""" seen: Dict[str, None] = {} for name in self.covariates: key = name.split("::", 1)[0] if "::" in name else name if key not in seen: seen[key] = None return list(seen.keys()) @staticmethod def _gamma_ci( shape: np.ndarray, rate: np.ndarray, ci: float ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Point estimate and CI for a Gamma variational posterior.""" mean = shape / rate alpha_lo = (1.0 - ci) / 2.0 alpha_hi = 1.0 - alpha_lo lo = sp_stats.gamma.ppf(alpha_lo, a=shape, scale=1.0 / rate) hi = sp_stats.gamma.ppf(alpha_hi, a=shape, scale=1.0 / rate) return mean, lo, hi # ---- public API ---------------------------------------------------
[docs] def plot_cov_effects( self, ci: float = 0.95, include_shrinkage: bool = False, topics: Optional[List[str]] = None, group_colors: Optional[Dict[str, str]] = None, figsize_per_topic: Tuple[float, float] = (5.0, 0.28), save_path: Optional[str] = None, ) -> Dict[str, Tuple[plt.Figure, np.ndarray]]: r"""Plot covariate effects as forest plots. Parameters ---------- ci : float, optional Credible-interval level (default ``0.95`` for 95 % CI). include_shrinkage : bool, optional If ``True``, additionally produce forest plots for :math:`\lambda_0` (intercept), :math:`\tau^2_k` (global shrinkage), and :math:`\delta^2_{gk}` (group shrinkage). topics : list of str, optional Subset of topic names to plot. If ``None`` (default), all topics are plotted. group_colors : dict, optional Mapping ``{group_name: colour}`` used to colour the covariate labels on the y-axis. Groups are inferred from the ``::`` separator in covariate names. If ``None`` a default qualitative palette is used. figsize_per_topic : tuple of float, optional ``(width, height_per_covariate)`` used to auto-size the lambda panels. Default ``(5.0, 0.28)``. save_path : str, optional Directory (or file path) where figures are saved. When a directory is given, individual PNGs are written; when a file path is given, only the lambda figure is saved there. If ``None``, figures are not saved. Returns ------- dict ``{"lambda": (fig, axes), ...}`` and, when *include_shrinkage* is ``True``, additional entries ``"lambda_intercept"``, ``"tau2"``, ``"delta2"``. """ import os if not self.estimated_params: raise RuntimeError("No estimated parameters found. Train the model first.") all_topic_names = self._topic_names() if topics is not None: sel = [i for i, t in enumerate(all_topic_names) if t in topics] if not sel: raise ValueError(f"None of {topics} found in model topics {all_topic_names}") plot_topics = [all_topic_names[i] for i in sel] topic_idx = sel else: plot_topics = all_topic_names topic_idx = list(range(len(all_topic_names))) # -- colours per covariate group ---------------------------------- grp_names = self._group_names() if group_colors is None: _qualitative = [ "#4E79A7", "#F28E2B", "#E15759", "#76B7B2", "#59A14F", "#EDC948", "#B07AA1", "#FF9DA7", "#9C755F", "#BAB0AC", ] group_colors = {g: _qualitative[i % len(_qualitative)] for i, g in enumerate(grp_names)} def _cov_color(name: str) -> str: key = name.split("::", 1)[0] if "::" in name else name return group_colors.get(key, "#333333") results: Dict[str, Tuple[plt.Figure, np.ndarray]] = {} # ================================================================ # Lambda forest plot # ================================================================ loc = np.asarray(self.estimated_params["lambda_location"]) # (C, K) scale = np.asarray(self.estimated_params["lambda_scale"]) # (C, K) z = sp_stats.norm.ppf(1.0 - (1.0 - ci) / 2.0) n_topics = len(plot_topics) n_cov = loc.shape[0] with plt.rc_context(self._setup_academic_style()): fig_w = figsize_per_topic[0] fig_h = max(3.0, n_cov * figsize_per_topic[1]) ncols = min(n_topics, 4) nrows = int(np.ceil(n_topics / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(fig_w * ncols, fig_h * nrows), sharey=True, squeeze=False, ) axes_flat = axes.flatten() # Pre-compute global x-range across all panels for shared scale all_lo = loc[:, topic_idx] - z * scale[:, topic_idx] all_hi = loc[:, topic_idx] + z * scale[:, topic_idx] global_xmin = float(np.min(all_lo)) global_xmax = float(np.max(all_hi)) x_pad = (global_xmax - global_xmin) * 0.08 global_xmin -= x_pad global_xmax += x_pad for panel_i, (ki, tname) in enumerate(zip(topic_idx, plot_topics)): ax = axes_flat[panel_i] means = loc[:, ki] lo = means - z * scale[:, ki] hi = means + z * scale[:, ki] y_pos = np.arange(n_cov)[::-1] colors = [_cov_color(c) for c in self.covariates] # CI lines for j in range(n_cov): ax.plot( [lo[j], hi[j]], [y_pos[j], y_pos[j]], color=colors[j], linewidth=1.2, solid_capstyle="round", ) # point estimates ax.scatter( means, y_pos, s=18, zorder=5, color=[colors[j] for j in range(n_cov)], edgecolors="white", linewidths=0.3, ) # Zero reference line — thick solid, semi-transparent ax.axvline(0, color="#333333", linewidth=1.4, linestyle="-", alpha=0.45, zorder=1) ax.set_xlim(global_xmin, global_xmax) ax.set_yticks(y_pos) ax.set_yticklabels( list(self.covariates), fontsize=7, color="#222222", ) # Colour y-tick labels by group for tick_label, cov_name in zip(ax.get_yticklabels(), self.covariates): tick_label.set_color(_cov_color(cov_name)) ax.set_title(tname, fontweight="bold", pad=6) ax.set_xlabel(r"$\lambda$") ax.margins(y=0.02) # hide unused panels for j in range(n_topics, len(axes_flat)): axes_flat[j].set_visible(False) # Build legend from group colours from matplotlib.lines import Line2D legend_handles = [ Line2D( [0], [0], marker="o", color=group_colors[g], linestyle="None", markersize=5, label=g, ) for g in grp_names if g in group_colors ] fig.legend( handles=legend_handles, title="Covariate group", loc="lower center", ncol=min(len(legend_handles), 6), frameon=False, bbox_to_anchor=(0.5, -0.01), ) fig.suptitle( f"Covariate Effects on Topic Intensity ({int(ci * 100)}% CI)", fontsize=12, fontweight="bold", y=1.02, ) fig.tight_layout() results["lambda"] = (fig, axes) if save_path is not None: _save = ( os.path.join(save_path, "forest_lambda.png") if os.path.isdir(save_path) else save_path ) fig.savefig(_save, dpi=200, bbox_inches="tight") # ================================================================ # Optional shrinkage panels # ================================================================ if include_shrinkage: with plt.rc_context(self._setup_academic_style()): # --- lambda_intercept --- loc0 = np.asarray(self.estimated_params["lambda_intercept_location"]) scale0 = np.asarray(self.estimated_params["lambda_intercept_scale"]) means0 = loc0[topic_idx] lo0 = means0 - z * scale0[topic_idx] hi0 = means0 + z * scale0[topic_idx] fig_int, ax_int = plt.subplots(figsize=(4.5, max(2.5, 0.35 * n_topics))) y_pos = np.arange(n_topics)[::-1] for j in range(n_topics): ax_int.plot( [lo0[j], hi0[j]], [y_pos[j], y_pos[j]], color="#4E79A7", linewidth=1.3, solid_capstyle="round", ) ax_int.scatter( means0, y_pos, s=22, zorder=5, color="#4E79A7", edgecolors="white", linewidths=0.4, ) ax_int.axvline( 0, color="#333333", linewidth=1.4, linestyle="-", alpha=0.45, zorder=0 ) ax_int.set_yticks(y_pos) ax_int.set_yticklabels(plot_topics, fontsize=8) ax_int.set_xlabel(r"$\lambda_0$") ax_int.set_title( f"Intercept $\\lambda_0$ ({int(ci * 100)}% CI)", fontweight="bold", pad=6, ) ax_int.margins(y=0.04) fig_int.tight_layout() results["lambda_intercept"] = (fig_int, np.array([ax_int])) if save_path is not None and os.path.isdir(save_path): fig_int.savefig( os.path.join(save_path, "forest_lambda_intercept.png"), dpi=200, bbox_inches="tight", ) # --- tau2 (global shrinkage per topic) --- tau2_s = np.asarray(self.estimated_params["tau2_shape"]) tau2_r = np.asarray(self.estimated_params["tau2_rate"]) tau_mean, tau_lo, tau_hi = self._gamma_ci(tau2_s[topic_idx], tau2_r[topic_idx], ci) fig_tau, ax_tau = plt.subplots(figsize=(4.5, max(2.5, 0.35 * n_topics))) for j in range(n_topics): ax_tau.plot( [tau_lo[j], tau_hi[j]], [y_pos[j], y_pos[j]], color="#E15759", linewidth=1.3, solid_capstyle="round", ) ax_tau.scatter( tau_mean, y_pos, s=22, zorder=5, color="#E15759", edgecolors="white", linewidths=0.4, ) ax_tau.axvline( 0, color="#333333", linewidth=1.4, linestyle="-", alpha=0.45, zorder=0 ) ax_tau.set_yticks(y_pos) ax_tau.set_yticklabels(plot_topics, fontsize=8) ax_tau.set_xlabel(r"$\tau^2$") ax_tau.set_title( f"Global Shrinkage $\\tau^2_k$ ({int(ci * 100)}% CI)", fontweight="bold", pad=6, ) ax_tau.margins(y=0.04) fig_tau.tight_layout() results["tau2"] = (fig_tau, np.array([ax_tau])) if save_path is not None and os.path.isdir(save_path): fig_tau.savefig( os.path.join(save_path, "forest_tau2.png"), dpi=200, bbox_inches="tight", ) # --- delta2 (group shrinkage, per group × topic) --- d2_s = np.asarray(self.estimated_params["delta2_shape"]) # (G, K) d2_r = np.asarray(self.estimated_params["delta2_rate"]) n_groups = d2_s.shape[0] grp_labels = self._group_names() ncols_d = min(n_topics, 4) nrows_d = int(np.ceil(n_topics / ncols_d)) fig_d, axes_d = plt.subplots( nrows_d, ncols_d, figsize=(4.5 * ncols_d, max(2.5, 0.35 * n_groups) * nrows_d), sharey=True, squeeze=False, ) axes_d_flat = axes_d.flatten() # Pre-compute global x-range for delta2 panels all_d_means = [] all_d_los = [] all_d_his = [] for ki in topic_idx: dm, dl, dh = self._gamma_ci(d2_s[:, ki], d2_r[:, ki], ci) all_d_means.append(dm) all_d_los.append(dl) all_d_his.append(dh) d_global_xmin = float(np.min(np.concatenate(all_d_los))) d_global_xmax = float(np.max(np.concatenate(all_d_his))) d_x_pad = (d_global_xmax - d_global_xmin) * 0.08 d_global_xmin = max(0.0, d_global_xmin - d_x_pad) d_global_xmax += d_x_pad for panel_i, (ki, tname) in enumerate(zip(topic_idx, plot_topics)): ax = axes_d_flat[panel_i] d_mean, d_lo, d_hi = self._gamma_ci(d2_s[:, ki], d2_r[:, ki], ci) yp = np.arange(n_groups)[::-1] for j in range(n_groups): ax.plot( [d_lo[j], d_hi[j]], [yp[j], yp[j]], color="#59A14F", linewidth=1.3, solid_capstyle="round", ) ax.scatter( d_mean, yp, s=22, zorder=5, color="#59A14F", edgecolors="white", linewidths=0.4, ) ax.axvline( 0, color="#333333", linewidth=1.4, linestyle="-", alpha=0.45, zorder=0 ) ax.set_xlim(d_global_xmin, d_global_xmax) ax.set_yticks(yp) ax.set_yticklabels(grp_labels, fontsize=8) ax.set_xlabel(r"$\delta^2$") ax.set_title(tname, fontweight="bold", pad=6) ax.margins(y=0.04) for j in range(n_topics, len(axes_d_flat)): axes_d_flat[j].set_visible(False) fig_d.suptitle( f"Group Shrinkage $\\delta^2_{{gk}}$ ({int(ci * 100)}% CI)", fontsize=12, fontweight="bold", y=1.02, ) fig_d.tight_layout() results["delta2"] = (fig_d, axes_d) if save_path is not None and os.path.isdir(save_path): fig_d.savefig( os.path.join(save_path, "forest_delta2.png"), dpi=200, bbox_inches="tight", ) return results