Source code for poisson_topicmodels.models.SPF

from typing import Any, Dict, List, Optional, Tuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
import pandas as pd
import scipy.sparse as sparse
from numpyro import param, plate, sample
from numpyro.distributions import constraints

# Abstract class - defining the minimum requirements for the probabilistic model
from .numpyro_model import NumpyroModel


[docs] class SPF(NumpyroModel): """ Seeded Poisson Factorization (SPF) topic model. Guided topic modeling with keyword priors. SPF allows researchers to incorporate domain knowledge by specifying seed words for each topic, which increases the topical prevalence of those words in the model. Parameters ---------- counts : scipy.sparse.csr_matrix Document-term matrix of shape (D, V) with word counts. vocab : np.ndarray Vocabulary array of shape (V,) containing word terms. keywords : Dict[Any, List[str]] Dictionary mapping topic identifiers to lists of seed words. Keys can be strings or integers. Example: {0: ['climate', 'environment'], 1: ['economy', 'trade']} or {'climate': ['climate', 'environment'], 'economy': ['economy', 'trade']} residual_topics : int Number of residual (unsupervised) topics. Must be >= 0. batch_size : int Mini-batch size for stochastic variational inference. Must satisfy 0 < batch_size <= D. Attributes ---------- D : int Number of documents. V : int Vocabulary size. K : int Total number of topics (seeded + residual). counts : scipy.sparse.csr_matrix Document-term matrix. vocab : np.ndarray Vocabulary array. keywords : Dict[int, List[str]] Seed words for guided topics. residual_topics : int Number of unsupervised topics. Examples -------- >>> from scipy.sparse import random >>> import numpy as np >>> from topicmodels import SPF >>> counts = random(100, 500, density=0.01, format='csr') >>> vocab = np.array([f'word_{i}' for i in range(500)]) >>> keywords = { ... 0: ['word_1', 'word_2', 'word_3'], ... 1: ['word_10', 'word_11', 'word_12'], ... } >>> model = SPF(counts, vocab, keywords, residual_topics=5, batch_size=32) >>> params = model.train_step(num_steps=100, lr=0.01, random_seed=42) """ def __init__( self, counts: sparse.csr_matrix, vocab: np.ndarray, keywords: Dict[Any, List[str]], residual_topics: int, batch_size: int, ) -> None: """ Initialize the SPF model with input validation. Parameters ---------- counts : scipy.sparse.csr_matrix Document-term matrix. vocab : np.ndarray Vocabulary array. keywords : Dict[Any, List[str]] Seed words for each seeded topic. residual_topics : int Number of unsupervised topics. batch_size : int Mini-batch size. Raises ------ TypeError If counts is not sparse or keywords is not dict. ValueError If dimensions are invalid or keywords contain unknown terms. """ super().__init__() # Input validation if not sparse.issparse(counts): raise TypeError(f"counts must be a scipy sparse matrix, got {type(counts).__name__}") D, V = counts.shape if D == 0 or V == 0: raise ValueError(f"counts matrix is empty: shape ({D}, {V})") if vocab.shape[0] != V: raise ValueError(f"vocab size {vocab.shape[0]} != counts columns {V}") if not isinstance(keywords, dict): raise TypeError(f"keywords must be dict, got {type(keywords).__name__}") if residual_topics < 0: raise ValueError(f"residual_topics must be >= 0, got {residual_topics}") if batch_size <= 0 or batch_size > D: raise ValueError(f"batch_size must satisfy 0 < batch_size <= {D}, got {batch_size}") # validate that keywords is not empty and that each topic has at least one keyword as string if len(keywords) == 0: raise ValueError( "keywords dictionary is empty; must contain at least one topic with keywords" ) for topic_id, words in keywords.items(): if not isinstance(words, list) or len(words) == 0: raise ValueError( f"keywords for topic {topic_id} must be a non-empty list of strings" ) # Validate keywords are in vocabulary vocab_set = set(vocab) for topic_id, words in keywords.items(): for word in words: if word not in vocab_set: raise ValueError(f"Keyword '{word}' (topic {topic_id}) not in vocabulary") # Store validated inputs self.counts = counts self.V = V self.D = D self.vocab = vocab self.residual_topics = residual_topics self.K = residual_topics + len(keywords) self.keywords = keywords # Compute keyword indices vocab_lookup = {word: index for index, word in enumerate(vocab)} kw_indices_topics = [ (idx, vocab_lookup[keyword]) for idx, topic_id in enumerate(keywords.keys()) for keyword in keywords[topic_id] ] self.Tilde_V = len(kw_indices_topics) self.kw_indices = tuple(zip(*kw_indices_topics)) if kw_indices_topics else ((), ()) self.batch_size = batch_size def _model(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray) -> None: """ Define the probabilistic generative model using NumPyro. Model structure: - Beta (K x V): topic-word distributions with keyword boosts - Beta_tilde: additional weights for seeded words - Theta (D x K): document-topic distributions - Y_batch (batch_size x V): observed word counts Parameters ---------- Y_batch : jnp.ndarray Batch of observed word counts (batch_size, V). d_batch : jnp.ndarray Document indices in batch (batch_size,). """ # Topic-word distributions: Beta ~ Gamma(0.3, 0.3) with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): beta = sample("beta", dist.Gamma(0.3, 0.3)) # Boost for seed words: Beta_tilde ~ Gamma(1, 0.3) with plate("tilde_v", size=self.Tilde_V): beta_tilde = sample("beta_tilde", dist.Gamma(1.0, 0.3)) # Add seed word boosts to beta beta = beta.at[self.kw_indices].add(beta_tilde) # Document-topic distributions: Theta ~ Gamma(0.3, 0.3) with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): theta = sample("theta", dist.Gamma(0.3, 0.3)) # Poisson rate parameter P = jnp.matmul(theta, beta) # Word counts likelihood with plate("v", size=self.V, dim=-1): sample("Y_batch", dist.Poisson(P), obs=Y_batch) def _guide(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray) -> None: """ Define the variational guide for the model. Parameters ---------- Y_batch : numpy.ndarray The observed word counts for the current batch. d_batch : numpy.ndarray Indices of documents in the current batch. """ # Define variational parameter a_beta = param( "beta_shape", init_value=jnp.ones([self.K, self.V]), constraint=constraints.positive ) b_beta = param( "beta_rate", init_value=jnp.ones([self.K, self.V]) * self.D / 1000 * 2, constraint=constraints.positive, ) a_theta = param( "theta_shape", init_value=jnp.ones([self.D, self.K]), constraint=constraints.positive ) b_theta = param( "theta_rate", init_value=jnp.ones([self.D, self.K]) * self.D / 1000, constraint=constraints.positive, ) a_beta_tilde = param( "beta_tilde_shape", init_value=jnp.ones([self.Tilde_V]) * 2, constraint=constraints.positive, ) b_beta_tilde = param( "beta_tilde_rate", init_value=jnp.ones([self.Tilde_V]), constraint=constraints.positive ) with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): sample("beta", dist.Gamma(a_beta, b_beta)) with plate("tilde_v", size=self.Tilde_V): sample("beta_tilde", dist.Gamma(a_beta_tilde, b_beta_tilde)) with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): sample("theta", dist.Gamma(a_theta[d_batch], b_theta[d_batch]))
[docs] def return_topics(self): """ Return the topics for each document. Reimplemented from the base class due to the guided topic modeling approach, where topics are not fully unsupervised. Returns ------- tuple topics : numpy.ndarray Array of recoded topics. E_theta : numpy.ndarray Estimated topic proportions for each document. """ def recode_cats(argmaxes, keywords): """ Recodes the argmax index into topic strings. :param argmaxes: np.array() or jnp.array() because of vectorized parallel computing :param keywords: Dictionary containing keyword topics :return: Array of recoded topics """ num_keywords = len(keywords.keys()) max_index = num_keywords - 1 keyword_keys = np.array(list(keywords.keys())).astype(str) # clip argmaxes to be within the valid range of keyword topics argmaxes_clipped = np.clip(argmaxes, 0, max_index) topics = np.where( argmaxes <= max_index, keyword_keys[argmaxes_clipped], f"No_keyword_topic_{argmaxes - max_index}", ) return topics E_theta = self.estimated_params["theta_shape"] / self.estimated_params["theta_rate"] categories = np.argmax(E_theta, axis=1) topics = recode_cats(np.array(categories), self.keywords) return topics, E_theta
[docs] def return_beta(self): """ Return the beta matrix for the model, i.e. topic-word intensities. Reimplemented from the base class due to the higher rates approach for seed words. Returns ------- pandas.DataFrame DataFrame containing the beta matrix with words as rows and topics as columns. """ E_beta = self.estimated_params["beta_shape"] / self.estimated_params["beta_rate"] E_beta_tilde = ( self.estimated_params["beta_tilde_shape"] / self.estimated_params["beta_tilde_rate"] ) E_beta = E_beta.at[self.kw_indices].add(E_beta_tilde) rs_names = [f"residual_topic_{i+1}" for i in range(self.residual_topics)] return pd.DataFrame( jnp.transpose(E_beta), index=self.vocab, columns=list(self.keywords.keys()) + rs_names )
# def return_top_words_per_topic(self, n=10): # beta = self.return_beta() # return {topic: beta[topic].nlargest(n).index.tolist() for topic in beta} def _recode_topics(self, indices: np.ndarray) -> np.ndarray: keyword_keys = np.array(list(self.keywords.keys())) num_keywords = len(keyword_keys) indices_clipped = np.clip(indices, 0, num_keywords - 1) return np.where( indices < num_keywords, keyword_keys[indices_clipped], [f"No_keyword_topic_{i - num_keywords}" for i in indices], )
[docs] def plot_seed_effectiveness( self, save_path: Optional[str] = None, ) -> Tuple[plt.Figure, np.ndarray]: """Grouped bar chart comparing seed vs. non-seed word weights per topic. For every seeded topic, shows the mean beta weight of seed words alongside the mean beta weight of all other words. Helps assess whether seed words actually dominate their intended topics. Parameters ---------- save_path : str, optional Path to save the figure. Returns ------- tuple of (plt.Figure, np.ndarray of Axes) """ if not self.estimated_params: raise ValueError("Model must be trained first.") beta_df = self.return_beta() keyword_keys = list(self.keywords.keys()) vocab_list = list(self.vocab) topic_labels: List[str] = [] seed_means: List[float] = [] other_means: List[float] = [] for idx, topic_id in enumerate(keyword_keys): col = beta_df[topic_id] seed_words = set(self.keywords[topic_id]) is_seed = np.array([w in seed_words for w in vocab_list]) seed_means.append(float(col.values[is_seed].mean())) other_means.append(float(col.values[~is_seed].mean())) topic_labels.append(str(topic_id)) x = np.arange(len(topic_labels)) width = 0.35 with plt.rc_context(self._setup_academic_style()): fig, ax = plt.subplots(figsize=(max(6, len(topic_labels) * 1.2), 4)) ax.bar(x - width / 2, seed_means, width, label="Seed words", color="#4E79A7") ax.bar(x + width / 2, other_means, width, label="Other words", color="#E15759") ax.set_xticks(x) ax.set_xticklabels(topic_labels, rotation=30, ha="right") ax.set_ylabel(r"Mean $\beta$ weight") ax.set_title("Seed word effectiveness") ax.legend() fig.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig, np.array([ax])
def _summary_extra(self) -> str: """SPF-specific summary information.""" kw_info = ", ".join( f"{k} ({len(v)} words)" for k, v in self.keywords.items() ) lines = [ f" Seeded topics: {len(self.keywords)}", f" Residual topics: {self.residual_topics}", f" Keyword groups: {kw_info}", ] return "\n".join(lines)