from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as sparse
from jax import jit, random
from numpyro.infer import SVI, TraceMeanField_ELBO
from optax import adam
from tqdm import tqdm
from wordcloud import WordCloud
from .Metrics import Metrics
[docs]
class NumpyroModel(ABC):
"""
Abstract base class for all used probabilistic models.
Each model has to implement at least their own Model and Guide.
Attributes
----------
Metrics : Metrics
Instance metrics tracker (per instance, not shared).
estimated_params : dict
Estimated parameters after training.
D : int
Number of documents.
V : int
Vocabulary size.
batch_size : int
Mini-batch size for stochastic variational inference.
counts : scipy.sparse.csr_matrix
Document-term matrix.
vocab : np.ndarray
Vocabulary array.
K : int
Number of topics.
"""
def __init__(self) -> None:
"""Initialize base model with per-instance metrics."""
self.Metrics = Metrics(loss=[])
self.estimated_params: Dict[str, Any] = {}
# These will be set by child classes, declared here for type checking
self.D: int
self.V: int
self.batch_size: int
self.counts: sparse.csr_matrix
self.vocab: np.ndarray
self.K: int
self._dense_counts_cache: Optional[jax.Array] = None
@abstractmethod
def _model(self, Y_batch: Any, d_batch: Any) -> None:
"""Define the probabilistic model."""
pass
@abstractmethod
def _guide(self, Y_batch: Any, d_batch: Any) -> None:
"""Define the variational guide."""
pass
def _prepare_dense_cache(
self, cache_dense_counts: Optional[bool], dense_cache_max_gb: float
) -> None:
"""Optionally cache counts as a dense JAX array for faster mini-batching."""
if jax.default_backend().lower() == "metal":
# Metal backend currently errors on this device_put path.
self._dense_counts_cache = None
return
if cache_dense_counts is False:
self._dense_counts_cache = None
return
if cache_dense_counts is None:
dense_size_bytes = self.D * self.V * np.dtype(np.float32).itemsize
cache_dense_counts = dense_size_bytes <= dense_cache_max_gb * (1024**3)
if cache_dense_counts:
dense_counts = np.asarray(self.counts.toarray(), dtype=np.float32, order="C")
self._dense_counts_cache = jax.device_put(jnp.asarray(dense_counts))
else:
self._dense_counts_cache = None
def _get_batch(self, rng: jax.Array, Y: sparse.csr_matrix) -> Tuple[jnp.ndarray, ...]:
"""
Helper function to obtain a batch of data, convert from scipy.sparse to jax.numpy.array.
Parameters
----------
rng : jax.random.PRNGKey
Random number generator key.
Y : scipy.sparse.csr_matrix
The word counts array.
Returns
-------
tuple
Y_batch : numpy.ndarray
Word counts for the batch.
D_batch : numpy.ndarray
Indices of documents in the batch.
"""
D_batch = random.randint(rng, shape=(self.batch_size,), minval=0, maxval=self.D)
if self._dense_counts_cache is not None:
Y_batch = self._dense_counts_cache[D_batch]
else:
Y_batch = jnp.asarray(Y[np.asarray(D_batch)].toarray(), dtype=jnp.float32)
# Ensure the shape of Y_batch is (batch_size, V)
assert Y_batch.shape == (
self.batch_size,
self.V,
), f"Shape mismatch: {Y_batch.shape} != ({self.batch_size}, {self.V})"
return Y_batch, D_batch
[docs]
def train_step(
self,
num_steps: int,
lr: float,
random_seed: Optional[int] = None,
jit_compile: bool = True,
cache_dense_counts: Optional[bool] = None,
dense_cache_max_gb: float = 0.75,
) -> Dict[str, Any]:
"""
Train the model using Stochastic Variational Inference (SVI).
Parameters
----------
num_steps : int
Number of training iterations. Must be > 0.
lr : float
Learning rate for the optimizer. Must be > 0.
random_seed : int, optional
Seed for JAX random number generator. If provided, ensures
reproducible results. Default is None (random initialization).
jit_compile : bool, optional
Whether to JIT compile SVI updates. Keep enabled for long runs;
disable to avoid compile overhead in very short runs.
cache_dense_counts : bool | None, optional
If True, cache sparse counts as dense array for faster batching.
If None, auto-enable when estimated dense matrix size fits in
``dense_cache_max_gb``.
dense_cache_max_gb : float, optional
Maximum dense cache size in GB used by auto mode.
Returns
-------
dict
Estimated parameters after training.
Raises
------
ValueError
If num_steps <= 0 or lr <= 0.
"""
if num_steps <= 0:
raise ValueError(f"num_steps must be > 0, got {num_steps}")
if lr <= 0:
raise ValueError(f"lr must be > 0, got {lr}")
if dense_cache_max_gb <= 0:
raise ValueError(f"dense_cache_max_gb must be > 0, got {dense_cache_max_gb}")
svi_batch = SVI(
model=self._model, guide=self._guide, optim=adam(lr), loss=TraceMeanField_ELBO()
)
svi_batch_update = jit(svi_batch.update) if jit_compile else svi_batch.update
self._prepare_dense_cache(
cache_dense_counts=cache_dense_counts, dense_cache_max_gb=dense_cache_max_gb
)
# Initialize RNG
if random_seed is not None:
init_rng = jax.random.PRNGKey(random_seed)
else:
init_rng = jax.random.PRNGKey(0)
Y_batch, D_batch = self._get_batch(init_rng, self.counts)
svi_state = svi_batch.init(jax.random.PRNGKey(1), Y_batch=Y_batch, d_batch=D_batch)
rngs = random.split(jax.random.PRNGKey(2), num_steps)
pbar = tqdm(range(num_steps))
for step in pbar:
Y_batch, D_batch = self._get_batch(rngs[step], self.counts)
svi_state, loss = svi_batch_update(svi_state, Y_batch=Y_batch, d_batch=D_batch)
loss = loss / self.D
self.Metrics.loss.append(float(loss))
if step % 10 == 0:
pbar.set_description(
"Init loss: "
+ "{:10.4f}".format(self.Metrics.loss[0])
+ "; Avg loss (last 10 iter): "
+ "{:10.4f}".format(jnp.array(self.Metrics.loss[-10:]).mean())
)
self.estimated_params = svi_batch.get_params(svi_state)
self._dense_counts_cache = None
return self.estimated_params
[docs]
def return_topics(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Return the topics for each document.
Returns
-------
categories : np.ndarray
Array of topic indices for each document (shape: D,).
E_theta : np.ndarray
Estimated topic proportions for each document (shape: D, K).
Raises
------
ValueError
If model has not been trained yet (no estimated parameters).
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling return_topics()")
E_theta = self.estimated_params["theta_shape"] / self.estimated_params["theta_rate"]
return np.argmax(E_theta, axis=1), E_theta
[docs]
def return_beta(self) -> pd.DataFrame:
"""
Return the beta matrix (word-topic associations) for the model.
Returns
-------
pd.DataFrame
DataFrame with words as index and topics as columns,
containing word-topic probability estimates.
Raises
------
ValueError
If model has not been trained yet (no estimated parameters).
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling return_beta()")
E_beta = self.estimated_params["beta_shape"] / self.estimated_params["beta_rate"]
return pd.DataFrame(jnp.transpose(E_beta), index=self.vocab)
[docs]
def return_top_words_per_topic(self, n=10):
beta = self.return_beta()
return {topic: beta[topic].nlargest(n).index.tolist() for topic in beta}
[docs]
def plot_model_loss(
self, window: int = 10, save_path: Optional[str] = None
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot the training loss over time with full and smoothed curves.
Parameters
----------
window : int, optional
Window size for moving average smoothing. Default is 10.
save_path : str, optional
Path to save the figure.
Returns
-------
tuple of (plt.Figure, plt.Axes)
"""
if not self.Metrics.loss:
raise ValueError("No training loss data available. Train the model first.")
losses = self.Metrics.loss
with plt.rc_context(self._setup_academic_style()):
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
# Full loss curve
axes[0].plot(losses)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("ELBO Loss")
axes[0].set_title("Training Loss Over Time")
# Smoothed loss (moving average)
smoothed = pd.Series(losses).rolling(window=window, center=True).mean()
axes[1].plot(smoothed, linewidth=2)
axes[1].set_xlabel("Step")
axes[1].set_ylabel("ELBO Loss (smoothed)")
axes[1].set_title(f"Training Loss (Moving Average, window={window})")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Loss reduction: {(1 - losses[-1] / losses[0]) * 100:.1f}%")
return fig, axes
[docs]
def plot_topic_wordclouds(
self,
n_words: int = 50,
figsize: Tuple[int, int] = (16, 12),
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, np.ndarray]:
"""
Plot wordclouds for each topic based on beta values.
Parameters
----------
n_words : int, optional
Maximum number of words per wordcloud (default 50).
figsize : tuple, optional
Figure size ``(width, height)`` (default ``(16, 12)``).
save_path : str, optional
Path to save the figure.
Returns
-------
tuple of (plt.Figure, np.ndarray of Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_topic_wordclouds()")
beta_df = self.return_beta()
K = beta_df.shape[1]
n_cols = int(np.ceil(np.sqrt(K)))
n_rows = int(np.ceil(K / n_cols))
with plt.rc_context(self._setup_academic_style()):
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = np.atleast_2d(axes)
axes_flat = axes.flatten()
for topic_idx in range(K):
topic_col = beta_df.iloc[:, topic_idx]
top_words = topic_col.nlargest(n_words)
word_freq = dict(top_words)
if word_freq:
wc = WordCloud(
width=400,
height=300,
background_color="white",
relative_scaling=0.5,
min_font_size=10,
).generate_from_frequencies(word_freq)
axes_flat[topic_idx].imshow(wc, interpolation="bilinear")
col_name = beta_df.columns[topic_idx]
title = str(col_name) if not isinstance(col_name, int) else f"Topic {col_name}"
axes_flat[topic_idx].set_title(title, fontsize=11, fontweight="bold")
axes_flat[topic_idx].axis("off")
for idx in range(K, len(axes_flat)):
axes_flat[idx].axis("off")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig, axes
# ------------------------------------------------------------------
# Academic-style plotting
# ------------------------------------------------------------------
@staticmethod
def _setup_academic_style() -> Dict[str, Any]:
"""Return matplotlib rcParams overrides for a clean academic look.
All plot methods in the library use this via
``plt.rc_context(self._setup_academic_style())``.
Returns
-------
dict
Matplotlib rcParams dictionary.
"""
return {
"font.family": "serif",
"font.size": 9,
"axes.titlesize": 11,
"axes.labelsize": 10,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"legend.fontsize": 8,
"figure.dpi": 150,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.linewidth": 0.6,
"xtick.major.width": 0.6,
"ytick.major.width": 0.6,
"lines.linewidth": 1.0,
"axes.grid": True,
"axes.grid.axis": "x",
"grid.alpha": 0.15,
"grid.linewidth": 0.4,
"grid.color": "#999999",
}
# ------------------------------------------------------------------
# Model summary
# ------------------------------------------------------------------
[docs]
def summary(self, n_top_words: int = 5) -> str:
"""Return a formatted text summary of the fitted model.
Includes model class name, dimensions, loss trajectory, and
top words per topic. Subclasses can extend the output by
overriding :meth:`_summary_extra`.
Parameters
----------
n_top_words : int, optional
Number of top words to show per topic (default 5).
Returns
-------
str
Multi-line summary string.
"""
lines: List[str] = []
sep = "=" * 60
lines.append(sep)
lines.append(f" Model: {self.__class__.__name__}")
lines.append(f" Topics (K): {self.K}")
lines.append(f" Vocabulary (V): {self.V}")
lines.append(f" Documents (D): {self.D}")
lines.append(f" Batch size: {self.batch_size}")
# Subclass-specific info
extra = self._summary_extra()
if extra:
lines.append(extra)
lines.append(sep)
# Loss information
if self.Metrics.loss:
lines.append(f" Initial ELBO loss: {self.Metrics.loss[0]:.4f}")
lines.append(f" Final ELBO loss: {self.Metrics.loss[-1]:.4f}")
reduction = (1 - self.Metrics.loss[-1] / self.Metrics.loss[0]) * 100
lines.append(f" Loss reduction: {reduction:.1f}%")
lines.append(f" Training steps: {len(self.Metrics.loss)}")
else:
lines.append(" Model has not been trained yet.")
lines.append(sep)
# Top words per topic
if self.estimated_params:
try:
top_words = self.return_top_words_per_topic(n=n_top_words)
lines.append(" Top words per topic:")
for topic, words in top_words.items():
label = str(topic)
lines.append(f" {label:>25s}: {', '.join(words)}")
except Exception:
lines.append(" (top words not available for this model)")
lines.append(sep)
result = "\n".join(lines)
print(result)
return result
def _summary_extra(self) -> str:
"""Hook for subclass-specific summary lines.
Override in subclasses to append model-specific information
to :meth:`summary`. Return an empty string to add nothing.
Returns
-------
str
"""
return ""
# ------------------------------------------------------------------
# Topic-quality metrics
# ------------------------------------------------------------------
[docs]
def compute_topic_coherence(
self,
texts: Optional[List[List[str]]] = None,
metric: str = "c_npmi",
top_n: int = 10,
) -> pd.DataFrame:
"""Compute topic coherence scores (NPMI or UMass).
Parameters
----------
texts : list of list of str, optional
Tokenised reference corpus. If ``None``, word co-occurrence
is estimated from ``self.counts`` and ``self.vocab``.
metric : ``{'c_npmi', 'u_mass'}``, optional
Coherence measure (default ``'c_npmi'``).
top_n : int, optional
Number of top words per topic used for the calculation
(default 10).
Returns
-------
pd.DataFrame
DataFrame with columns ``['topic', 'coherence']``.
"""
beta_df = self.return_beta()
K = beta_df.shape[1]
topic_names = [str(c) for c in beta_df.columns]
# Build top-n word lists per topic
top_words_per_topic: List[List[str]] = []
for k in range(K):
col = beta_df.iloc[:, k]
top_words_per_topic.append(col.nlargest(top_n).index.tolist())
# Build co-occurrence from counts matrix
bow = np.asarray(self.counts.toarray(), dtype=np.float32)
binary = (bow > 0).astype(np.float32)
D_ref = binary.shape[0]
vocab_list = list(self.vocab)
word2idx = {w: i for i, w in enumerate(vocab_list)}
if texts is not None:
# Build binary doc-word matrix from tokenised texts
V = len(vocab_list)
rows: List[np.ndarray] = []
for doc_tokens in texts:
vec = np.zeros(V, dtype=np.float32)
for tok in doc_tokens:
idx = word2idx.get(tok)
if idx is not None:
vec[idx] = 1.0
rows.append(vec)
binary = np.stack(rows)
D_ref = binary.shape[0]
eps = 1e-12
scores: List[float] = []
for words in top_words_per_topic:
indices = [word2idx[w] for w in words if w in word2idx]
n = len(indices)
if n < 2:
scores.append(float("nan"))
continue
pairs_total = 0.0
pair_count = 0
for i in range(n):
for j in range(i + 1, n):
wi, wj = indices[i], indices[j]
d_wi = float(binary[:, wi].sum())
d_wj = float(binary[:, wj].sum())
d_wi_wj = float((binary[:, wi] * binary[:, wj]).sum())
if metric == "u_mass":
# UMass: log( (D(w_i, w_j) + eps) / D(w_j) )
pairs_total += np.log((d_wi_wj + eps) / (d_wj + eps))
else:
# NPMI: log2(P(wi,wj) / (P(wi)*P(wj))) / -log2(P(wi,wj))
p_wi = d_wi / D_ref
p_wj = d_wj / D_ref
p_wi_wj = d_wi_wj / D_ref
if p_wi_wj < eps:
pmi = 0.0
else:
pmi = np.log2((p_wi_wj + eps) / (p_wi * p_wj + eps))
denom = -np.log2(p_wi_wj + eps)
npmi = pmi / denom if denom > eps else 0.0
pairs_total += npmi
pair_count += 1
scores.append(pairs_total / max(pair_count, 1))
df = pd.DataFrame({"topic": topic_names, "coherence": scores})
self.Metrics.coherence_scores = df
return df
[docs]
def compute_topic_diversity(self, top_n: int = 25) -> float:
"""Compute topic diversity (Dieng et al., 2020).
Measures the fraction of unique words across all topics' top-n
lists. Values near 1.0 indicate diverse topics; near 0
indicates redundancy.
Parameters
----------
top_n : int, optional
Number of top words per topic (default 25).
Returns
-------
float
Topic diversity score in ``[0, 1]``.
"""
beta_df = self.return_beta()
K = beta_df.shape[1]
all_words: List[str] = []
for k in range(K):
col = beta_df.iloc[:, k]
all_words.extend(col.nlargest(top_n).index.tolist())
diversity = len(set(all_words)) / max(len(all_words), 1)
self.Metrics.diversity = diversity
return diversity
# ------------------------------------------------------------------
# Additional post-fitting plots
# ------------------------------------------------------------------
[docs]
def plot_topic_prevalence(self, save_path: Optional[str] = None) -> Tuple[plt.Figure, plt.Axes]:
"""Horizontal bar chart of mean topic prevalence across documents.
Parameters
----------
save_path : str, optional
Path to save the figure.
Returns
-------
tuple of (plt.Figure, plt.Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained first.")
_, E_theta = self.return_topics()
mean_prev = np.asarray(E_theta).mean(axis=0)
beta_df = self.return_beta()
topic_labels = [str(c) for c in beta_df.columns]
order = np.argsort(mean_prev)
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=(6, max(3, 0.35 * len(topic_labels))))
ax.barh(
np.arange(len(order)),
mean_prev[order],
color="#4E79A7",
edgecolor="white",
linewidth=0.3,
)
ax.set_yticks(np.arange(len(order)))
ax.set_yticklabels([topic_labels[i] for i in order])
ax.set_xlabel("Mean topic weight")
ax.set_title("Topic prevalence")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig, ax
[docs]
def plot_topic_correlation(
self, save_path: Optional[str] = None
) -> Tuple[plt.Figure, plt.Axes]:
"""Heatmap of pairwise cosine similarity between topic-word vectors.
Parameters
----------
save_path : str, optional
Path to save the figure.
Returns
-------
tuple of (plt.Figure, plt.Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained first.")
beta_df = self.return_beta()
beta_mat = beta_df.values.T # (K, V)
norms = np.linalg.norm(beta_mat, axis=1, keepdims=True)
norms = np.where(norms == 0, 1.0, norms)
normed = beta_mat / norms
sim = normed @ normed.T
topic_labels = [str(c) for c in beta_df.columns]
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(
figsize=(max(5, 0.6 * len(topic_labels)), max(4, 0.5 * len(topic_labels)))
)
im = ax.imshow(sim, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")
ax.set_xticks(range(len(topic_labels)))
ax.set_xticklabels(topic_labels, rotation=45, ha="right")
ax.set_yticks(range(len(topic_labels)))
ax.set_yticklabels(topic_labels)
ax.set_title("Topic similarity (cosine)")
fig.colorbar(im, ax=ax, shrink=0.8)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig, ax
[docs]
def plot_document_topic_heatmap(
self,
n_docs: int = 50,
sort_by_topic: bool = False,
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""Heatmap of document-topic proportions for a subset of documents.
Parameters
----------
n_docs : int, optional
Number of documents to display (default 50).
sort_by_topic : bool, optional
If True, sort documents by their dominant topic (default False).
save_path : str, optional
Path to save the figure.
Returns
-------
tuple of (plt.Figure, plt.Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained first.")
cats, E_theta = self.return_topics()
E_theta = np.asarray(E_theta)
n_docs = min(n_docs, E_theta.shape[0])
if sort_by_topic:
order = np.argsort(cats)[:n_docs]
else:
order = np.arange(n_docs)
subset = E_theta[order]
# Row-normalise for visualisation
row_sums = subset.sum(axis=1, keepdims=True)
row_sums = np.where(row_sums == 0, 1.0, row_sums)
subset_norm = subset / row_sums
beta_df = self.return_beta()
topic_labels = [str(c) for c in beta_df.columns]
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=(max(6, 0.5 * len(topic_labels)), max(6, 0.15 * n_docs)))
im = ax.imshow(subset_norm, aspect="auto", cmap="YlOrRd", interpolation="nearest")
ax.set_xlabel("Topic")
ax.set_ylabel("Document")
ax.set_xticks(range(len(topic_labels)))
ax.set_xticklabels(topic_labels, rotation=45, ha="right")
ax.set_title(f"Document-topic proportions (n={n_docs})")
fig.colorbar(im, ax=ax, shrink=0.8)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig, ax