import warnings
from typing import List, Optional, Tuple
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
import pandas as pd
import scipy.sparse as sparse
import seaborn as sns
from jax import jit, random
from numpyro import param, plate, sample
from numpyro.distributions import constraints
from numpyro.infer import SVI, TraceMeanField_ELBO
from optax import adam
from scipy.special import digamma
from tqdm import tqdm
from wordcloud import WordCloud
# Abstract class - defining the minimum requirements for the probabilistic model
from .numpyro_model import NumpyroModel
[docs]
class STBS(NumpyroModel):
"""
STBS Model
This class models structural text-based scaling (STBS), including
topic-specific ideal points and author-specific covariates for
documents authored by different individuals. The model aims to
capture how ideology can vary by topic and with external variables.
"""
def __init__(
self,
counts: sparse.csr_matrix,
vocab: np.ndarray,
num_topics: int,
authors: np.ndarray,
batch_size: int,
X_design_matrix: Optional[np.ndarray] = None,
beta_shape_init: np.ndarray = None,
beta_rate_init: np.ndarray = None,
theta_shape_init: np.ndarray = None,
theta_rate_init: np.ndarray = None,
i_mu_init: np.ndarray = None,
) -> None:
"""
Initialize the STBS model.
Parameters
----------
counts : scipy.sparse.csr_matrix
A 2D sparse array of shape (D, V) representing the word counts in each document,
where D is the number of documents and V is the vocabulary size.
vocab : np.ndarray
A vocabulary array of shape (V,) containing word terms.
num_topics : int
The number of topics (K). Must be > 0.
authors : np.ndarray or list
An array of authors for each document.
X_design_matrix : np.ndarray or pd.DataFrame
Author-level covariates of shape (N, L). Row i must correspond to the
i-th element of the sorted unique authors from `authors` (i.e., np.unique(authors)).
batch_size : int
The number of documents to be processed in each batch.
Must satisfy 0 < batch_size <= D.
beta_shape_init : np.ndarray, optional
Initial shape parameters for the topic-word distributions (default is None).
Must have shape (K, V) if provided.
beta_rate_init : np.ndarray, optional
Initial rate parameters for the topic-word distributions (default is None).
Must have shape (K, V) if provided.
theta_shape_init : np.ndarray, optional
Initial shape parameters for the document-topic distributions (default is None).
Must have shape (D, K) if provided.
theta_rate_init : np.ndarray, optional
Initial rate parameters for the document-topic distributions (default is None).
Must have shape (D, K) if provided.
i_mu_init : np.ndarray, optional
Initial mean parameters for the ideology-topic distributions (default is None).
Must have shape (N, ) if provided.
Raises
------
TypeError
If counts is not a sparse matrix.
ValueError
If dimensions are invalid or time_varying parameters have wrong shape.
"""
super().__init__()
# Input validation
if not sparse.issparse(counts):
raise TypeError(f"counts must be a scipy sparse matrix, got {type(counts).__name__}")
D, V = counts.shape
if D == 0 or V == 0:
raise ValueError(f"counts matrix is empty: shape ({D}, {V})")
if num_topics <= 0:
raise ValueError(f"num_topics must be > 0, got {num_topics}")
if batch_size <= 0 or batch_size > D:
raise ValueError(f"batch_size must satisfy 0 < batch_size <= {D}, got {batch_size}")
if vocab.shape[0] != V:
raise ValueError(f"vocab size {vocab.shape[0]} != counts columns {V}")
# Convert authors to array-like
authors = np.asarray(authors)
if len(authors) != D:
raise ValueError(f"authors length {len(authors)} != counts rows {D}")
self.authors_unique = np.unique(authors)
self.author_map = {speaker: idx for idx, speaker in enumerate(self.authors_unique)}
self.author_indices = np.array([self.author_map[a] for a in authors])
self.N = len(self.authors_unique) # number of people
self.counts = counts
self.D = D
self.V = V
self.K = num_topics
self.batch_size = batch_size # number of documents in a batch
self.vocab = vocab
if X_design_matrix is not None:
if isinstance(X_design_matrix, pd.DataFrame):
self.covariates = list(X_design_matrix.columns)
X_design_matrix = X_design_matrix.values
else:
self.covariates = [f"cov_{i}" for i in range(X_design_matrix.shape[1])]
X_design_matrix = np.asarray(X_design_matrix)
if X_design_matrix.ndim != 2:
raise ValueError(f"covariates must be 2D, got shape {X_design_matrix.shape}")
if X_design_matrix.shape[0] != self.N:
raise ValueError(
f"covariates has {X_design_matrix.shape[0]} rows, expected {self.N}"
)
if X_design_matrix.shape[1] == 0:
raise ValueError("covariates matrix is empty (0 columns)")
self.X_design_matrix = (
jnp.array(X_design_matrix) if X_design_matrix is not None else jnp.ones((self.N, 1))
)
self.L = self.X_design_matrix.shape[1]
# check if initialization parameters have the correct shape and are jnp.arrays
for inits in [beta_shape_init, beta_rate_init]:
if inits is None:
warnings.warn(
"No initial values for beta parameters were provided. "
"The model will initialize them uniformly."
)
if inits is not None:
if not isinstance(inits, (np.ndarray, jnp.ndarray)):
raise ValueError(
"beta_shape_init and beta_rate_init must be numpy or jnp.ndarray objects "
"with matching dimensions [num_topics times num_words]."
)
if inits.shape != (self.K, self.V):
raise ValueError(
f"beta_shape_init and beta_rate_init must have shape ({self.K}, {self.V}), "
f"got {inits.shape}"
)
self.beta_rate_init = beta_rate_init
self.beta_shape_init = beta_shape_init
for inits in [theta_shape_init, theta_rate_init]:
if inits is None:
warnings.warn(
"No initial values for theta parameters were provided. "
"The model will initialize them uniformly."
)
if inits is not None:
if not isinstance(inits, (np.ndarray, jnp.ndarray)):
raise ValueError(
"theta_shape_init and theta_rate_init must be numpy or jnp.ndarray objects "
"with matching dimensions [num_documents times num_topics]."
)
if inits.shape != (self.D, self.K):
raise ValueError(
f"theta_shape_init and theta_rate_init must have shape ({self.D}, {self.K}), "
f"got {inits.shape}"
)
self.theta_rate_init = theta_rate_init
self.theta_shape_init = theta_shape_init
if i_mu_init is None:
warnings.warn(
"No initial values for the ideology parameters were provided. "
"The model will initialize them uniformly."
)
if i_mu_init is not None:
if not isinstance(i_mu_init, (np.ndarray, jnp.ndarray)):
raise ValueError(
"i_mu_init must be a numpy or jnp.ndarray object " "with shape [num_authors, ]."
)
if i_mu_init.shape != (self.N,):
raise ValueError(
f"i_mu_init must have shape ({self.N},), " f"got {i_mu_init.shape}"
)
self.i_mu_init = i_mu_init
def _model(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray, i_batch: jnp.ndarray) -> None: # type: ignore[override]
"""Define the probabilistic model using NumPyro.
Model structure:
- beta (K x V): topic-word distributions
- eta (K x V): ideal point loadings for words
- iota (K x L): topic regression coefficients for covariates
- i (N x K): author-topic ideal points
- theta (D x K): document-topic intensities
- Y_batch: observed word counts with Poisson likelihood
Parameters
----------
Y_batch : jnp.ndarray
The observed word counts for the current batch (batch_size, V).
d_batch : jnp.ndarray
Indices of documents in the current batch (batch_size,).
i_batch : jnp.ndarray
Indices of authors for the documents in the batch (batch_size,).
"""
with plate("v", size=self.V, dim=-1):
b_beta = sample("b_beta", dist.Gamma(0.3, 1.0))
with plate("k", size=self.K, dim=-1):
b_rho = sample("b_rho", dist.Gamma(0.3, 1.0))
rho = sample("rho", dist.Gamma(0.3, b_rho))
with plate("k", size=self.K, dim=-2):
with plate("k_v", size=self.V, dim=-1):
beta = sample("beta", dist.Gamma(0.3, b_beta))
eta = sample("eta", dist.Normal(0, jnp.tile(1 / jnp.sqrt(rho), (self.V, 1)).T))
with plate("l", size=self.L, dim=-1):
b_omega = sample("b_omega", dist.Gamma(0.3, 1.0))
omega = sample("omega", dist.Gamma(0.3, b_omega))
iota_dot = sample("iota_dot", dist.Normal(0, 1))
with plate("l", size=self.L, dim=-2):
with plate("l_k", size=self.K, dim=-1):
iota = sample(
"iota",
dist.Normal(
jnp.tile(iota_dot, (self.K, 1)).T,
jnp.tile(1 / jnp.sqrt(omega), (self.K, 1)).T,
),
)
i_mu = jnp.matmul(self.X_design_matrix, iota)
with plate("n", size=self.N, dim=-1):
I = sample("I", dist.Gamma(0.3, 0.3))
with plate("n", size=self.N, dim=-2):
with plate("k", size=self.K, dim=-1):
# Sample the per-unit latent variables (ideal points)
i = sample("i", dist.Normal(i_mu, jnp.tile(1 / jnp.sqrt(I), (self.K, 1)).T))
with plate("n", size=self.N, dim=-1):
b_author = sample("b_author", dist.Gamma(0.3, 1.0))
with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2):
b_author_d = b_author[i_batch]
b_author_dk = jnp.tile(b_author_d.reshape(-1, 1), (1, self.K))
with plate("d_k", size=self.K, dim=-1):
# Sample document-level latent variables (topic intensities)
theta = sample("theta", dist.Gamma(0.3, b_author_dk))
# Compute Poisson rates for each word
P = jnp.sum(
jnp.expand_dims(theta, axis=-1)
* jnp.expand_dims(beta, axis=0)
* jnp.exp(jnp.expand_dims(eta, axis=0) * jnp.expand_dims(i[i_batch], axis=-1)),
1,
)
with plate("v", size=self.V, dim=-1):
# Sample observed words
sample("Y_batch", dist.Poisson(P), obs=Y_batch)
def _guide(self, Y_batch: jnp.ndarray, d_batch: jnp.ndarray, i_batch: jnp.ndarray) -> None: # type: ignore[override]
"""Define the variational guide for the model.
Uses Gamma and Normal variational families for approximate posterior inference.
Parameters
----------
Y_batch : jnp.ndarray
The observed word counts for the current batch (batch_size, V).
d_batch : jnp.ndarray
Indices of documents in the current batch (batch_size,).
i_batch : jnp.ndarray
Indices of authors for the documents in the batch (batch_size,).
"""
b_beta_shape = param(
"b_beta_shape",
init_value=jnp.ones(self.V),
constraint=constraints.positive,
)
b_beta_rate = param(
"b_beta_rate",
init_value=jnp.ones(self.V),
constraint=constraints.positive,
)
# Add initial values for beta parameters if provided for the stbs model
if self.beta_rate_init is not None:
beta_rate = param(
"beta_rate",
init_value=self.beta_rate_init,
constraint=constraints.positive,
)
else:
beta_rate = param(
"beta_rate",
init_value=jnp.ones([self.K, self.V]),
constraint=constraints.positive,
)
if self.beta_shape_init is not None:
beta_shape = param(
"beta_shape",
init_value=self.beta_shape_init,
constraint=constraints.positive,
)
else:
beta_shape = param(
"beta_shape",
init_value=jnp.ones([self.K, self.V]),
constraint=constraints.positive,
)
b_rho_shape = param(
"b_rho_shape",
init_value=jnp.ones(self.K),
constraint=constraints.positive,
)
b_rho_rate = param(
"b_rho_rate",
init_value=jnp.ones(self.K),
constraint=constraints.positive,
)
rho_shape = param(
"rho_shape",
init_value=jnp.ones(self.K),
constraint=constraints.positive,
)
rho_rate = param(
"rho_rate",
init_value=jnp.ones(self.K),
constraint=constraints.positive,
)
mu_eta = param(
"mu_eta",
init_value=random.normal(random.PRNGKey(2), (self.K, self.V)),
)
sigma_eta = param(
"sigma_eta",
init_value=jnp.ones([self.K, self.V]),
constraint=constraints.positive,
)
b_omega_shape = param(
"b_omega_shape",
init_value=jnp.ones(self.L),
constraint=constraints.positive,
)
b_omega_rate = param(
"b_omega_rate",
init_value=jnp.ones(self.L),
constraint=constraints.positive,
)
omega_shape = param(
"omega_shape",
init_value=jnp.ones(self.L),
constraint=constraints.positive,
)
omega_rate = param(
"omega_rate",
init_value=jnp.ones(self.L),
constraint=constraints.positive,
)
mu_iota_dot = param(
"mu_iota_dot",
init_value=jnp.zeros(self.L),
)
sigma_iota_dot = param(
"sigma_iota_dot",
init_value=jnp.ones(self.L),
constraint=constraints.positive,
)
mu_iota = param(
"mu_iota",
init_value=jnp.zeros([self.L, self.K]),
)
sigma_iota = param(
"sigma_iota",
init_value=jnp.ones([self.L, self.K]),
constraint=constraints.positive,
)
I_shape = param(
"I_shape",
init_value=jnp.ones(self.N),
constraint=constraints.positive,
)
I_rate = param(
"I_rate",
init_value=jnp.ones(self.N),
constraint=constraints.positive,
)
# Add initial values for ideology parameters if provided for the stbs model
if self.i_mu_init is not None:
mu_i = param(
"mu_i",
init_value=jnp.tile(self.i_mu_init, (self.K, 1)).T,
)
else:
mu_i = param(
"mu_i",
init_value=jnp.zeros((self.N, self.K)),
)
sigma_i = param(
"sigma_i",
init_value=jnp.ones((self.N, self.K)),
constraint=constraints.positive,
)
b_author_shape = param(
"b_author_shape",
init_value=jnp.ones(self.N),
constraint=constraints.positive,
)
b_author_rate = param(
"b_author_rate",
init_value=jnp.ones(self.N),
constraint=constraints.positive,
)
# Add initial values for theta parameters if provided for the stbs model
if self.theta_rate_init is not None:
theta_rate = param(
"theta_rate",
init_value=self.theta_rate_init,
constraint=constraints.positive,
)
else:
theta_rate = param(
"theta_rate",
init_value=jnp.ones([self.D, self.K]),
constraint=constraints.positive,
)
if self.theta_shape_init is not None:
theta_shape = param(
"theta_shape",
init_value=self.theta_shape_init,
constraint=constraints.positive,
)
else:
theta_shape = param(
"theta_shape",
init_value=jnp.ones([self.D, self.K]),
constraint=constraints.positive,
)
with plate("v", size=self.V, dim=-1):
sample("b_beta", dist.Gamma(b_beta_shape, b_beta_rate))
with plate("k", size=self.K, dim=-1):
sample("b_rho", dist.Gamma(b_rho_shape, b_rho_rate))
sample("rho", dist.Gamma(rho_shape, rho_rate))
with plate("k", size=self.K, dim=-2):
with plate("k_v", size=self.V, dim=-1):
sample("beta", dist.Gamma(beta_shape, beta_rate))
sample("eta", dist.Normal(mu_eta, sigma_eta))
with plate("l", size=self.L, dim=-1):
sample("b_omega", dist.Gamma(b_omega_shape, b_omega_rate))
sample("omega", dist.Gamma(omega_shape, omega_rate))
sample("iota_dot", dist.Normal(mu_iota_dot, sigma_iota_dot))
with plate("l", size=self.L, dim=-2):
with plate("l_k", size=self.K, dim=-1):
sample("iota", dist.Normal(mu_iota, sigma_iota))
with plate("n", size=self.N, dim=-1):
sample("I", dist.Gamma(I_shape, I_rate))
with plate("n", size=self.N, dim=-2):
with plate("k", size=self.K, dim=-1):
sample("i", dist.Normal(mu_i, sigma_i))
with plate("n", self.N, dim=-1):
sample("b_author", dist.Gamma(b_author_shape, b_author_rate))
with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2):
with plate("d_k", size=self.K, dim=-1):
sample("theta", dist.Gamma(theta_shape[d_batch], theta_rate[d_batch]))
def _get_batch(
self, rng: jnp.ndarray, Y: sparse.csr_matrix
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Sample a random mini-batch from the corpus.
Helper function specified exclusively for TBIP and STBS objects.
Parameters
----------
rng : jax.random.PRNGKey
Random number generator key.
Y : scipy.sparse.csr_matrix
The word counts array.
Returns
-------
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Y_batch : Word counts for the batch (batch_size, V).
D_batch : Indices of documents in the batch (batch_size,).
I_batch : Indices of authors for the documents in the batch (batch_size,).
Raises
------
AssertionError
If batch dimensions don't match expected shape.
"""
D_batch = random.choice(rng, jnp.arange(self.D), shape=(self.batch_size,))
Y_batch = jnp.array(Y[D_batch].toarray())
# Ensure the shape of Y_batch is (batch_size, V)
assert Y_batch.shape == (
self.batch_size,
self.V,
), f"Shape mismatch: {Y_batch.shape} != ({self.batch_size}, {self.V})"
I_batch = np.array(self.author_indices[D_batch])
return Y_batch, D_batch, I_batch
[docs]
def train_step(self, num_steps: int, lr: float) -> dict: # type: ignore[override]
"""Train the STBS model using stochastic variational inference.
Custom train function specified exclusively for TBIP and STBS objects.
Parameters
----------
num_steps : int
Number of training steps. Must be > 0.
lr : float
Learning rate for the optimizer. Must be > 0.
Returns
-------
dict
A dictionary containing the estimated parameter values after training.
Raises
------
ValueError
If num_steps or lr are invalid.
"""
if num_steps <= 0:
raise ValueError(f"num_steps must be > 0, got {num_steps}")
if lr <= 0:
raise ValueError(f"lr must be > 0, got {lr}")
svi_batch = SVI(
model=self._model, guide=self._guide, optim=adam(lr), loss=TraceMeanField_ELBO()
)
svi_batch_update = jit(svi_batch.update)
Y_batch, D_batch, I_batch = self._get_batch(random.PRNGKey(1), self.counts)
svi_state = svi_batch.init(
random.PRNGKey(0), Y_batch=Y_batch, d_batch=D_batch, i_batch=I_batch
)
rngs = random.split(random.PRNGKey(2), num_steps)
# losses = list()
pbar = tqdm(range(num_steps))
for step in pbar:
Y_batch, D_batch, I_batch = self._get_batch(rngs[step], self.counts)
svi_state, loss = svi_batch_update(
svi_state, Y_batch=Y_batch, d_batch=D_batch, i_batch=I_batch
)
loss = loss / self.D
self.Metrics.loss.append(float(loss))
# losses.append(loss)
if step % 10 == 0:
pbar.set_description(
"Init loss: "
+ "{:10.4f}".format(self.Metrics.loss[0])
+ f"; Avg loss (last {10} iter): "
+ "{:10.4f}".format(jnp.array(self.Metrics.loss[-10:]).mean())
)
self.estimated_params = svi_batch.get_params(svi_state)
return self.estimated_params
[docs]
def plot_topic_wordclouds( # type: ignore[override]
self,
n_words: int = 50,
figsize: Tuple[int, int] = (16, 12),
ideology_values: Optional[Tuple[float, ...]] = (-1, 0, 1),
topics: Optional[List[int]] = None,
log_corrected: bool = True,
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, np.ndarray]:
"""
Plot wordclouds for each topic, optionally at multiple ideology positions.
When ``ideology_values`` is ``None``, delegates to the base class and
plots one wordcloud per topic using raw beta values. When ``ideology_values``
is set (default: ``(-1, 0, 1)``), produces a grid of shape
``(n_topics, len(ideology_values))``.
Parameters
----------
n_words : int, optional
Maximum number of words per wordcloud (default 50).
figsize : tuple, optional
Figure size ``(width, height)`` (default ``(16, 12)``).
save_path : str, optional
Path to save the figure.
ideology_values : tuple of float or None, optional
Ideal point values for which to draw wordclouds. Default values
are ``(-1, 0, 1)``. Pass ``None`` to fall back to base class
behaviour (raw beta, no ideology).
topics : list of int or None, optional
Subset of topic indices to plot. If None, all K topics are shown.
log_corrected : bool, optional
If True (default), uses log-scale ideology-corrected intensities.
If False, uses the linear approximation ``beta * exp(eta * i)`` instead.
Ignored when ideology_values is None.
Returns
-------
tuple of (plt.Figure, np.ndarray of Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_topic_wordclouds()")
if ideology_values is None:
return super().plot_topic_wordclouds(
n_words=n_words, figsize=figsize or (16, 12), save_path=save_path
)
topic_indices = list(topics) if topics is not None else list(range(self.K))
if not topic_indices:
raise ValueError("topics list is empty.")
if any(t < 0 or t >= self.K for t in topic_indices):
raise ValueError(f"topics must be indices in [0, {self.K - 1}].")
beta_shape = self.estimated_params["beta_shape"]
beta_rate = self.estimated_params["beta_rate"]
mu_eta = self.estimated_params["mu_eta"]
word_scores: List[List[dict]] = []
for k in topic_indices:
topic_scores = []
for i_val in ideology_values:
if log_corrected:
s = digamma(beta_shape[k]) - np.log(beta_rate[k]) + i_val * mu_eta[k]
s = s - s.min() + 0.05 * (s.max() - s.min())
else:
s = (beta_shape[k] / beta_rate[k]) * np.exp(mu_eta[k] * i_val)
word_freq = dict(pd.Series(s, index=self.vocab).nlargest(n_words))
topic_scores.append(word_freq)
word_scores.append(topic_scores)
topic_labels = [f"Topic {k}" for k in topic_indices]
n_rows = len(topic_indices)
n_cols = len(ideology_values)
with plt.rc_context(self._setup_academic_style()):
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
score_label = "log-corrected" if log_corrected else "linear"
for j, i_val in enumerate(ideology_values):
sign = f"{i_val:+.0f}" if i_val != 0 else "0 (neutral)"
axes[0, j].set_title(
f"i = {sign}\n({score_label})",
fontsize=10,
fontweight="bold",
)
for row, label in enumerate(topic_labels):
axes[row, 0].set_ylabel(label, fontsize=9, fontweight="bold", rotation=90, labelpad=4)
for col, word_freq in enumerate(word_scores[row]):
ax = axes[row, col]
if word_freq:
wc = WordCloud(
width=400,
height=300,
background_color="white",
relative_scaling=0.5,
min_font_size=8,
).generate_from_frequencies(word_freq)
ax.imshow(wc, interpolation="bilinear")
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig, axes
def _summary_extra(self) -> str:
"""STBS-specific summary information."""
lines = [
f" Authors (N): {self.N}",
f" Covariates (L): {self.L}",
f" Covariate names: {', '.join(self.covariates)}",
]
if self.estimated_params:
mu_i = np.asarray(self.estimated_params["mu_i"])
lines.append(f" Ideal-point range: [{mu_i.min():.3f}, {mu_i.max():.3f}]")
lines.append(f" Ideal-point std: {mu_i.std():.3f}")
mu_iota = np.asarray(self.estimated_params["mu_iota"])
topic_ranges = mu_iota.max(axis=1) - mu_iota.min(axis=1)
topic_stds = mu_iota.std(axis=1)
lines.append(
f" Iota range (mean over topics): {topic_ranges.mean():.3f} [{topic_ranges.min():.3f}, {topic_ranges.max():.3f}]"
)
lines.append(
f" Iota std (mean over topics): {topic_stds.mean():.3f} [{topic_stds.min():.3f}, {topic_stds.max():.3f}]"
)
return "\n".join(lines)
[docs]
def plot_topic_prevalence( # type: ignore[override]
self,
topic_labels: Optional[dict] = None,
selected_topics: Optional[list] = None,
sort: bool = True,
figsize: tuple = (8, 4),
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""Bar chart of mean normalised topic prevalence across the corpus."""
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_topic_prevalence()")
theta = np.asarray(self.estimated_params["theta_shape"]) / np.asarray(
self.estimated_params["theta_rate"]
)
theta_norm = theta / theta.sum(axis=1, keepdims=True)
mean_prev = theta_norm.mean(axis=0)
K = theta_norm.shape[1]
def _label(k):
return topic_labels[k] if topic_labels and k in topic_labels else f"Topic {k}"
if selected_topics is not None:
indices = np.array(selected_topics)
mean_prev = mean_prev[indices]
labels = [_label(k) for k in indices]
else:
labels = [_label(k) for k in range(K)]
if sort:
order = np.argsort(mean_prev)[::-1]
labels_sorted = [labels[i] for i in order]
prev_sorted = mean_prev[order]
else:
labels_sorted = labels
prev_sorted = mean_prev
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=figsize)
ax.bar(labels_sorted, prev_sorted, color="steelblue", edgecolor="none")
ax.set_xlabel("Topic")
ax.set_ylabel("Mean normalised proportion")
ax.set_title("Corpus-level topic prevalence")
plt.xticks(rotation=45, ha="right", fontsize=7)
sns.despine()
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, ax
[docs]
def plot_author_topic_heatmap(
self,
topic_labels: Optional[dict] = None,
author_labels: Optional[dict] = None,
selected_topics: Optional[list] = None,
figsize: tuple = (16, 12),
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""Heatmap of mean normalised topic proportions per author (topics x authors).
Authors are sorted by their dominant topic so similar authors cluster together.
Parameters
----------
topic_labels : dict or None
{topic_index: "label"}
author_labels : dict or None
{author_index: "label"} — if None, uses raw author indices.
selected_topics : list or None
Integer topic indices to restrict the plot. If None, all topics shown.
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_author_topic_heatmap()")
theta = np.asarray(self.estimated_params["theta_shape"]) / np.asarray(
self.estimated_params["theta_rate"]
)
theta_norm = theta / theta.sum(axis=1, keepdims=True)
K = theta_norm.shape[1]
def _tlabel(k):
return topic_labels[k] if topic_labels and k in topic_labels else f"Topic {k}"
col_labels = [_tlabel(k) for k in range(K)]
author_theta = (
pd.DataFrame(theta_norm, columns=col_labels)
.assign(author=self.author_indices)
.groupby("author")
.mean()
)
if selected_topics is not None:
sel_labels = [_tlabel(k) for k in selected_topics]
author_theta = author_theta[sel_labels]
# Sort authors by dominant topic so similar authors cluster
dominant = author_theta.values.argmax(axis=1)
sort_idx = np.argsort(dominant)
author_theta = author_theta.iloc[sort_idx]
# Author tick labels
inv_map = {v: k for k, v in self.author_map.items()}
if author_labels is not None:
xtick_labels = [author_labels.get(a, str(a)) for a in author_theta.index]
else:
xtick_labels = [inv_map.get(a, str(a)) for a in author_theta.index]
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=figsize)
sns.heatmap(
author_theta.T,
ax=ax,
cmap="YlOrRd",
linewidths=0,
xticklabels=xtick_labels,
yticklabels=author_theta.columns.tolist(),
cbar_kws={"label": "Mean normalised proportion", "shrink": 0.6},
)
ax.set_xlabel("Author")
ax.set_ylabel("Topic")
ax.set_title("Author-topic intensity")
ax.tick_params(axis="x", labelsize=6, rotation=90)
ax.tick_params(axis="y", labelsize=7)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, ax
[docs]
def plot_ideol_points(
self,
group: bool = True,
group_var: Optional[np.ndarray] = None,
group_labels: Optional[dict] = None,
group_palette: Optional[dict] = None,
topic_labels: Optional[dict] = None,
figsize: Tuple[float, float] = (16, 12),
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""Dot plot of topic-specific ideological positions of all authors.
Topics are ordered by the absolute difference between group-weighted
average positions (most polarising topic at the top). Group-weighted
averages are shown as black 'X' markers connected by a horizontal line.
Parameters
----------
group : bool, optional
If True (default), colours dots by group. Falls back to
``self.i_mu_init`` if ``group_var`` is not provided.
If False, all dots are plotted in a single colour.
group_var : np.ndarray of shape (N,) or None
Author-level grouping variable. Overrides ``self.i_mu_init`` when
provided. Unique values are treated as group identifiers.
group_labels : dict or None
Mapping ``{value: "label"}``, e.g. ``{-1: "D", 0: "I", 1: "R"}``.
If None, groups are labelled by their raw value.
group_palette : dict or None
Mapping ``{label: colour}``. If None, uses a default tab10 palette.
topic_labels : dict or None
Optional ``{topic_index: "label"}`` for y-axis tick labels.
selected_topics : list or None
Integer topic indices to restrict the plot. If None, all topics shown.
figsize : tuple, optional
Figure size (default ``(7, 5)``).
save_path : str or None
Path to save the figure.
Returns
-------
tuple of (plt.Figure, plt.Axes)
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_ideology_points()")
# 1. Resolve groups
if not group:
groups = np.array(["all"] * self.N)
group_palette = {"all": "steelblue"}
group_labels = {"all": "all"}
else:
if group_var is None:
if self.i_mu_init is None:
raise ValueError(
"No group_var provided and i_mu_init was not stored on the model."
)
group_var = np.asarray(self.i_mu_init)
else:
group_var = np.asarray(group_var)
if group_var.shape[0] != self.N:
raise ValueError(
f"group_var must have length N={self.N}, got {group_var.shape[0]}."
)
unique_vals = sorted(np.unique(group_var))
if group_labels is None:
group_labels = {v: str(v) for v in unique_vals}
groups = np.array([group_labels[v] for v in group_var])
if group_palette is None:
unique_group_names = [group_labels[v] for v in unique_vals]
group_palette = dict(
zip(unique_group_names, sns.color_palette("tab10", len(unique_group_names)))
)
theta = np.asarray(self.estimated_params["theta_shape"]) / np.asarray(
self.estimated_params["theta_rate"]
)
author_weights = (
pd.DataFrame(theta, columns=[f"x{k}" for k in range(theta.shape[1])])
.assign(author=self.author_indices)
.groupby("author", as_index=False)
.mean()
.melt(id_vars="author", var_name="topic", value_name="weight")
)
author_weights["topic"] = (
author_weights["topic"].str.replace("^x", "", regex=True).astype(int)
)
mu_i = np.asarray(self.estimated_params["mu_i"])
author_ideology = (
pd.DataFrame(mu_i, columns=[f"x{k}" for k in range(mu_i.shape[1])])
.assign(author=list(self.author_map.values()), group=groups)
.melt(id_vars=["author", "group"], var_name="topic", value_name="ideology")
)
author_ideology["topic"] = (
author_ideology["topic"].str.replace("^x", "", regex=True).astype(int)
)
authors_weighted = author_ideology.merge(author_weights, on=["author", "topic"], how="left")
group_ideology = (
authors_weighted.groupby(["group", "topic"], as_index=False)
.apply(
lambda g: pd.Series(
{"ideology": (np.nansum(g["weight"] * g["ideology"]) / np.nansum(g["weight"]))}
),
include_groups=False,
)
.reset_index()
)
top_groups = pd.Series(groups).value_counts().index[:2].tolist()
if len(top_groups) < 2:
K = group_ideology["topic"].nunique()
topic_order = list(range(K))
else:
pivot = group_ideology[group_ideology["group"].isin(top_groups)].pivot(
index="topic", columns="group", values="ideology"
)
pivot["abs_delta"] = (pivot[top_groups[0]] - pivot[top_groups[1]]).abs()
topic_order = (
pivot.sort_values("abs_delta", ascending=False).reset_index()["topic"].tolist()
)
_label = (
lambda t: topic_labels[int(t)]
if topic_labels and int(t) in topic_labels
else f"Topic {t}"
)
label_order = [_label(t) for t in topic_order]
author_ideology["topic_label"] = pd.Categorical(
author_ideology["topic"].map(_label), categories=label_order, ordered=True
)
group_ideology["topic_label"] = pd.Categorical(
group_ideology["topic"].map(_label), categories=label_order, ordered=True
)
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=figsize)
sns.scatterplot(
data=author_ideology,
x="ideology",
y="topic_label",
hue="group",
palette=group_palette,
alpha=0.55,
s=18,
ax=ax,
legend=True,
)
for topic_lbl, grp in group_ideology.groupby("topic_label", observed=True):
grp_top2 = grp[grp["group"].isin(top_groups)]
if len(grp_top2) < 2:
continue
xmin, xmax = grp_top2["ideology"].min(), grp_top2["ideology"].max()
ax.hlines(
y=topic_lbl, xmin=xmin, xmax=xmax, colors="black", linewidth=0.8, zorder=3
)
g0 = grp_top2[grp_top2["group"] == top_groups[0]]
g1 = grp_top2[grp_top2["group"] == top_groups[1]]
ax.scatter(
g0["ideology"],
[topic_lbl] * len(g0),
marker="D",
s=60,
facecolors=group_palette[top_groups[0]],
edgecolors="black",
linewidths=0.8,
zorder=4,
)
ax.scatter(
g1["ideology"],
[topic_lbl] * len(g1),
marker="s",
s=60,
facecolors=group_palette[top_groups[1]],
edgecolors="black",
linewidths=0.8,
zorder=4,
)
for topic_lbl in label_order:
ax.axhline(y=topic_lbl, linestyle="--", color="lightgray", linewidth=0.8, zorder=0)
ax.axvline(0, linestyle="--", color="gray", linewidth=0.8)
ax.set_xlabel("Ideological position")
ax.set_ylabel("Topic (sorted by polarisation)")
ax.legend(
title="",
loc="upper center",
ncol=len(group_palette),
bbox_to_anchor=(0.5, 1.05),
frameon=False,
)
sns.despine()
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, ax
[docs]
def plot_iota_credible_intervals(
self,
topic_labels: Optional[dict] = None,
covariate_labels: Optional[dict] = None,
selected_topics: Optional[list] = None,
selected_covariates: Optional[list] = None,
ci: float = 0.95,
figsize: tuple = (16, 12),
save_path: Optional[str] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""Single CI plot with selected covariates on y-axis and topics as hue."""
from scipy.stats import norm as sp_norm
if not self.estimated_params:
raise ValueError("Model must be trained before calling plot_iota_credible_intervals()")
mu_iota = np.asarray(self.estimated_params["mu_iota"]).T
sigma_iota = np.asarray(self.estimated_params["sigma_iota"]).T
K, P = mu_iota.shape
z = sp_norm.ppf((1 + ci) / 2)
def _tlabel(k):
return topic_labels[k] if topic_labels and k in topic_labels else f"Topic {k}"
def _clabel(p):
if covariate_labels and p in covariate_labels:
return covariate_labels[p]
elif hasattr(self, "covariates") and p < len(self.covariates):
return self.covariates[p]
else:
return f"Cov {p}"
if selected_covariates is not None:
if isinstance(selected_covariates[0], str):
selected_covariates = [list(self.covariates).index(c) for c in selected_covariates]
cov_idx = selected_covariates
else:
cov_idx = list(range(P))
topic_idx = selected_topics if selected_topics is not None else list(range(K))
mu_sub = mu_iota[np.ix_(topic_idx, cov_idx)]
sigma_sub = sigma_iota[np.ix_(topic_idx, cov_idx)]
col_labels = [_clabel(p) for p in cov_idx]
n_covs = len(cov_idx)
n_topics = len(topic_idx)
palette = sns.color_palette("tab10", n_topics)
offsets = np.linspace(-0.3, 0.3, n_topics)
with plt.rc_context(self._setup_academic_style()):
fig, ax = plt.subplots(figsize=figsize)
ax.axvline(0, color="gray", linewidth=0.8, linestyle="--", zorder=0)
for i, k in enumerate(topic_idx):
mu_k = mu_sub[i]
sigma_k = sigma_sub[i]
lo = mu_k - z * sigma_k
hi = mu_k + z * sigma_k
excludes_zero = (lo > 0) | (hi < 0)
color = palette[i]
for j in range(n_covs):
y = j + offsets[i]
ax.plot([lo[j], hi[j]], [y, y], color=color, linewidth=0.8, alpha=0.7, zorder=1)
ax.scatter(
mu_k[j],
y,
color=color,
s=30 if excludes_zero[j] else 15,
zorder=2,
marker="D" if excludes_zero[j] else "o",
)
ax.set_yticks(range(n_covs))
ax.set_yticklabels(col_labels, fontsize=7)
ax.set_xlabel("Iota (ideology coefficient)")
ax.set_title(f"Iota credible intervals ({int(ci * 100)}%)")
ax.tick_params(axis="x", labelsize=7)
handles = [
plt.Line2D([0], [0], color=palette[i], linewidth=1.5, label=_tlabel(topic_idx[i]))
for i in range(n_topics)
]
ax.legend(
handles=handles,
title="Topic",
frameon=False,
fontsize=7,
bbox_to_anchor=(1.01, 1),
loc="upper left",
)
sns.despine()
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, ax
[docs]
def return_ideal_points(self) -> pd.DataFrame:
"""Return ideal point estimates for all authors and topics.
Returns
-------
pd.DataFrame
DataFrame with columns ``['author', 'topic', 'ideal_point', 'std']``
sorted by topic then ideal point.
Raises
------
ValueError
If model has not been trained yet.
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling return_ideal_points()")
mu_i = np.asarray(self.estimated_params["mu_i"])
sigma_i = np.asarray(self.estimated_params["sigma_i"])
rows = []
for author, idx in self.author_map.items():
for k in range(self.K):
rows.append(
{
"author": author,
"topic": k,
"ideal_point": float(mu_i[idx, k]),
"std": float(sigma_i[idx, k]),
}
)
df = pd.DataFrame(rows)
return df.sort_values(["topic", "ideal_point"]).reset_index(drop=True)
[docs]
def return_ideal_covariates(self) -> pd.DataFrame:
"""Return covariate regression coefficient estimates (iota).
Returns
-------
pd.DataFrame
DataFrame with columns ``['covariate', 'topic', 'iota', 'std']``
sorted by topic then covariate.
Raises
------
ValueError
If model has not been trained yet.
"""
if not self.estimated_params:
raise ValueError("Model must be trained before calling return_ideal_covariates()")
mu_iota = np.asarray(self.estimated_params["mu_iota"])
sigma_iota = np.asarray(self.estimated_params["sigma_iota"])
rows = []
for l, covariate in enumerate(self.covariates):
for k in range(self.K):
rows.append(
{
"covariate": covariate,
"topic": k,
"iota": float(mu_iota[l, k]),
"std": float(sigma_iota[l, k]),
}
)
df = pd.DataFrame(rows)
return df.sort_values(["topic", "covariate"]).reset_index(drop=True)