Source code for poisson_topicmodels.models.CPF

from typing import Dict, List, Optional, Tuple

import jax.nn as jnn
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
import pandas as pd
import scipy.sparse as sparse
from scipy import stats as sp_stats
from numpyro import param, plate, sample
from numpyro.distributions import constraints

# Abstract class - defining the minimum requirements for the probabilistic model
from .numpyro_model import NumpyroModel


# Create numpyro model
[docs] class CPF(NumpyroModel): """ Covariate Poisson Factorization (CPF) topic model. Topic model that incorporates document-level covariates to capture how topics vary with external variables (e.g., author attributes, temporal features). Parameters ---------- counts : scipy.sparse.csr_matrix Document-term matrix of shape (D, V) with word counts. vocab : np.ndarray Vocabulary array of shape (V,) containing word terms. covariates : np.ndarray or pd.DataFrame Document-level covariates of shape (D, C) where C is number of features. num_topics : int Number of topics K. Must be > 0. batch_size : int Mini-batch size for stochastic variational inference. Must satisfy 0 < batch_size <= D. Attributes ---------- D : int Number of documents. V : int Vocabulary size. K : int Number of topics. C : int Number of covariate features. counts : scipy.sparse.csr_matrix Document-term matrix. vocab : np.ndarray Vocabulary array. X_design_matrix : jnp.ndarray Design matrix of covariates. Examples -------- >>> from scipy.sparse import random >>> import numpy as np >>> from topicmodels import CPF >>> counts = random(100, 500, density=0.01, format='csr') >>> vocab = np.array([f'word_{i}' for i in range(500)]) >>> covariates = np.random.randn(100, 3) # 3 covariate features >>> model = CPF(counts, vocab, covariates, num_topics=10, batch_size=32) >>> params = model.train_step(num_steps=100, lr=0.01, random_seed=42) """ def __init__( self, counts: sparse.csr_matrix, vocab: np.ndarray, num_topics: int, batch_size: int, X_design_matrix: Optional[np.ndarray] = None, ) -> None: """ Initialize the CPF model with input validation. Parameters ---------- counts : scipy.sparse.csr_matrix Document-term matrix. vocab : np.ndarray Vocabulary array. num_topics : int Number of topics. batch_size : int Mini-batch size. X_design_matrix : np.ndarray or pd.DataFrame, optional Document-level covariates. Raises ------ TypeError If counts is not sparse or covariates have wrong type. ValueError If dimensions are invalid or inconsistent. """ super().__init__() # Input validation if not sparse.issparse(counts): raise TypeError(f"counts must be a scipy sparse matrix, got {type(counts).__name__}") D, V = counts.shape if D == 0 or V == 0: raise ValueError(f"counts matrix is empty: shape ({D}, {V})") if vocab.shape[0] != V: raise ValueError(f"vocab size {vocab.shape[0]} != counts columns {V}") if num_topics <= 0: raise ValueError(f"num_topics must be > 0, got {num_topics}") if batch_size <= 0 or batch_size > D: raise ValueError(f"batch_size must satisfy 0 < batch_size <= {D}, got {batch_size}") if X_design_matrix is not None: _have_names = False if isinstance(X_design_matrix, pd.DataFrame): covariate_names = [str(col) for col in X_design_matrix.columns] _have_names = True X_design_matrix = X_design_matrix.values X_design_matrix = np.asarray(X_design_matrix) if X_design_matrix.ndim != 2: raise ValueError(f"covariates must be 2D, got shape {X_design_matrix.shape}") if X_design_matrix.shape[0] != D: raise ValueError(f"covariates has {X_design_matrix.shape[0]} rows, expected {D}") if X_design_matrix.shape[1] == 0: raise ValueError("covariates matrix is empty (0 columns)") if not _have_names: covariate_names = [f"cov_{i}" for i in range(X_design_matrix.shape[1])] else: covariate_names = ["intercept"] # Store validated inputs self.counts = counts self.D = D self.V = V self.vocab = vocab self.K = num_topics self.batch_size = batch_size self.covariates: List[str] = covariate_names self.X_design_matrix = ( jnp.array(X_design_matrix) if X_design_matrix is not None else jnp.ones((D, 1)) ) self.C = self.X_design_matrix.shape[1] # -- Model -- def _model(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray) -> None: """ Define the probabilistic generative model using NumPyro. Model structure: - Beta (K x V): topic-word distributions - Lambda (C x K): covariate effects on topics - Theta (D x K): document-topic distributions (covariate-dependent) - Y_batch (batch_size x V): observed word counts Parameters ---------- Y_batch : jnp.ndarray Batch of observed word counts (batch_size, V). d_batch : jnp.ndarray Document indices in batch (batch_size,). """ # Topic distributions with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): beta = sample("beta", dist.Gamma(0.3, 0.3)) # Covariate effects with plate("c", size=self.C, dim=-2): with plate("c_k", size=self.K, dim=-1): lambda_ = sample("phi", dist.Normal(0.0, 1.0)) # Transform covariate effects via softplus a_theta_S = jnn.softplus(jnp.matmul(self.X_design_matrix, lambda_))[d_batch] # Document distribution with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): theta = sample("theta", dist.Gamma(a_theta_S, 0.3)) # Poisson rate P = jnp.matmul(theta, beta) with plate("d_v", size=self.V, dim=-1): sample("Y_batch", dist.Poisson(P), obs=Y_batch) # -- Guide, i.e. variational family -- def _guide(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray) -> None: """ Define the variational guide (approximate posterior). Uses Gamma family for topic distributions and Normal family for covariate effects. Parameters ---------- Y_batch : jnp.ndarray Batch of observed word counts. d_batch : jnp.ndarray Document indices in batch. """ # Define variational parameter a_beta = param( "beta_shape", init_value=jnp.ones([self.K, self.V]), constraint=constraints.positive ) b_beta = param( "beta_rate", init_value=jnp.ones([self.K, self.V]) * self.D / 1000 * 2, constraint=constraints.positive, ) a_theta = param( "theta_shape", init_value=jnp.ones([self.D, self.K]), constraint=constraints.positive ) b_theta = param( "theta_rate", init_value=jnp.ones([self.D, self.K]) * self.D / 1000, constraint=constraints.positive, ) location_lambda = param("lambda_location", init_value=jnp.zeros([self.C, self.K])) scale_lambda = param( "lambda_scale", init_value=jnp.ones([self.C, self.K]), constraint=constraints.positive ) with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): sample("beta", dist.Gamma(a_beta, b_beta)) with plate("c", size=self.C, dim=-2): with plate("c_k", size=self.K, dim=-1): sample("phi", dist.Normal(location_lambda, scale_lambda)) with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): sample("theta", dist.Gamma(a_theta[d_batch], b_theta[d_batch]))
[docs] def return_covariate_effects(self) -> pd.DataFrame: """Return point estimates of covariate effects (lambda). Returns ------- pd.DataFrame DataFrame with covariates as rows and topics as columns. Raises ------ ValueError If model has not been trained yet. """ if not self.estimated_params: raise ValueError("Model must be trained before calling return_covariate_effects()") topic_names = [f"topic_{i + 1}" for i in range(self.K)] return pd.DataFrame( np.asarray(self.estimated_params["lambda_location"]), index=self.covariates, columns=topic_names, )
[docs] def return_covariate_effects_ci(self, ci: float = 0.95) -> pd.DataFrame: """Return covariate effects with credible intervals. Uses the Normal variational posterior for lambda: ``mean = lambda_location``, ``CI = mean +/- z * lambda_scale``. Parameters ---------- ci : float, optional Credible-interval level (default 0.95). Returns ------- pd.DataFrame DataFrame with columns ``['covariate', 'topic', 'mean', 'lower', 'upper']``. Raises ------ ValueError If model has not been trained yet. """ if not self.estimated_params: raise ValueError("Model must be trained before calling return_covariate_effects_ci()") loc = np.asarray(self.estimated_params["lambda_location"]) # (C, K) scale = np.asarray(self.estimated_params["lambda_scale"]) # (C, K) z = sp_stats.norm.ppf(1.0 - (1.0 - ci) / 2.0) topic_names = [f"topic_{i + 1}" for i in range(self.K)] rows = [] for c_idx, cov_name in enumerate(self.covariates): for k_idx, topic_name in enumerate(topic_names): rows.append({ "covariate": cov_name, "topic": topic_name, "mean": float(loc[c_idx, k_idx]), "lower": float(loc[c_idx, k_idx] - z * scale[c_idx, k_idx]), "upper": float(loc[c_idx, k_idx] + z * scale[c_idx, k_idx]), }) return pd.DataFrame(rows)
[docs] def plot_cov_effects( self, ci: float = 0.95, topics: Optional[List[str]] = None, figsize_per_topic: Tuple[float, float] = (5.0, 0.28), save_path: Optional[str] = None, ) -> Tuple[plt.Figure, np.ndarray]: r"""Forest plot of covariate effects (lambda) with credible intervals. Parameters ---------- ci : float, optional Credible-interval level (default 0.95). topics : list of str, optional Subset of topic names to plot. If ``None``, all topics are plotted. figsize_per_topic : tuple of float, optional ``(width, height_per_covariate)`` for panel sizing. save_path : str, optional Path to save the figure. Returns ------- tuple of (plt.Figure, np.ndarray of Axes) """ if not self.estimated_params: raise ValueError("Model must be trained before calling plot_cov_effects()") topic_names = [f"topic_{i + 1}" for i in range(self.K)] if topics is not None: topic_idx = [i for i, t in enumerate(topic_names) if t in topics] if not topic_idx: raise ValueError(f"None of {topics} found in {topic_names}") plot_topics = [topic_names[i] for i in topic_idx] else: plot_topics = topic_names topic_idx = list(range(self.K)) loc = np.asarray(self.estimated_params["lambda_location"]) # (C, K) scale = np.asarray(self.estimated_params["lambda_scale"]) # (C, K) z = sp_stats.norm.ppf(1.0 - (1.0 - ci) / 2.0) n_topics = len(plot_topics) n_cov = loc.shape[0] with plt.rc_context(self._setup_academic_style()): fig_w = figsize_per_topic[0] fig_h = max(3.0, n_cov * figsize_per_topic[1]) ncols = min(n_topics, 4) nrows = int(np.ceil(n_topics / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(fig_w * ncols, fig_h * nrows), sharey=True, squeeze=False, ) axes_flat = axes.flatten() for panel_i, (ki, tname) in enumerate(zip(topic_idx, plot_topics)): ax = axes_flat[panel_i] means = loc[:, ki] lo = means - z * scale[:, ki] hi = means + z * scale[:, ki] y_pos = np.arange(n_cov)[::-1] ax.axvline(0, color="#999999", linewidth=0.5, zorder=0) for j in range(n_cov): ax.plot([lo[j], hi[j]], [y_pos[j], y_pos[j]], color="#4E79A7", linewidth=1.2, zorder=1) ax.scatter(means[j], y_pos[j], color="#4E79A7", s=18, zorder=2, edgecolors="white", linewidths=0.3) ax.set_yticks(y_pos) ax.set_yticklabels(list(self.covariates)) ax.set_title(tname) ax.set_xlabel(r"$\lambda$") for idx in range(n_topics, len(axes_flat)): axes_flat[idx].set_visible(False) fig.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig, axes
def _summary_extra(self) -> str: """CPF-specific summary information.""" lines = [ f" Covariates (C): {self.C}", f" Covariate names: {', '.join(self.covariates)}", ] return "\n".join(lines)