Tutorial: GPU Acceleration

This tutorial shows how to leverage GPU acceleration for 10-100x faster training.

Duration: ~10 minutes Prerequisites: NVIDIA/AMD/Apple GPU available and drivers installed

Why GPU?

GPU acceleration is critical for large-scale topic modeling:

  • CPU time: 1000 docs × 10000 words × 100 iters = hours

  • GPU time: Same computation = minutes

JAX automatically handles GPU computations when available.

Checking GPU Availability

First, verify GPU access:

import jax
import jax.numpy as jnp

# List available devices
devices = jax.devices()
print("Available devices:")
for device in devices:
    print(f"  {device}")

Expected output with GPU:

Available devices:
  NVIDIA A100 GPU (cuda:0)
  NVIDIA A100 GPU (cuda:1)

Without GPU:

Available devices:
  cpu

Enabling GPU for poisson-topicmodels

Option 1: Environment variable (recommended)

export JAX_PLATFORMS=gpu
python your_script.py

Option 2: Set in Python before import

import os
os.environ['JAX_PLATFORMS'] = 'gpu'

from poisson_topicmodels import PF
# Now uses GPU

Option 3: Use CUDA directly

# Force CUDA devices
export CUDA_VISIBLE_DEVICES=0,1
python your_script.py

Setting Up GPU Environment

NVIDIA GPU (CUDA):

  1. Install CUDA Toolkit 11.8+ and cuDNN 8.6+

  2. Install GPU-enabled JAX:

pip install jax[cuda12_cudnn]
  1. Verify:

python -c "import jax; print(jax.devices())"

AMD GPU (ROCm):

pip install jax[rocm]

Apple Silicon (Metal):

pip install jax[metal]

Training with GPU

Once GPU is enabled, training automatically uses it:

from poisson_topicmodels import PF

model = PF(counts, vocab, num_topics=20, batch_size=256)

# This automatically uses GPU if available
params = model.train_step(num_steps=200, lr=0.01)

# That's it! No code changes needed.

No explicit GPU calls required—JAX handles it transparently.

Monitoring GPU Usage

Check GPU utilization:

Command line:

# NVIDIA: nvidia-smi shows GPU usage
nvidia-smi -l 1  # Update every second

# AMD: rocm-smi
rocm-smi --watch

Expected during training: 80-95% GPU utilization

In Python:

import subprocess
import time

def monitor_gpu():
    """Print GPU utilization."""
    while True:
        result = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu',
                               '--format=csv,nounits,noheader'],
                              capture_output=True, text=True)
        utilization = result.stdout.strip()
        print(f"GPU utilization: {utilization}%")
        time.sleep(2)

# In another terminal while training runs
monitor_gpu()

Memory Management

GPU Memory Issues:

If you get “out of memory” errors:

# 1. Increase batch size (counterintuitively helps with memory)
model = PF(counts, vocab, num_topics=20, batch_size=512)

# 2. Reduce vocabulary size
# Remove rare words: keep only top 5000 words

# 3. Reduce number of documents
# Sample documents or process in chunks

Memory-efficient training:

# Monitor memory during training
from jax import monitoring

model = PF(counts, vocab, num_topics=20, batch_size=128)
params = model.train_step(num_steps=200, lr=0.01)

# If memory issues: reduce batch_size → 64 or 32

Performance Benchmarking

Compare CPU vs GPU timing:

import time
from poisson_topicmodels import PF

# Small dataset
counts_small = csr_matrix(np.random.poisson(2, (100, 500)).astype(np.float32))
vocab_small = np.array([f'word_{i}' for i in range(500)])

# Time CPU (disable GPU first)
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

model_cpu = PF(counts_small, vocab_small, num_topics=10, batch_size=32)
t0 = time.time()
model_cpu.train_step(num_steps=50, lr=0.01)
cpu_time = time.time() - t0

# Time GPU
os.environ['JAX_PLATFORMS'] = 'gpu'  # Requires restart
# (In practice, use separate scripts or notebooks)

model_gpu = PF(counts_small, vocab_small, num_topics=10, batch_size=32)
t0 = time.time()
model_gpu.train_step(num_steps=50, lr=0.01)
gpu_time = time.time() - t0

print(f"CPU time: {cpu_time:.2f}s")
print(f"GPU time: {gpu_time:.2f}s")
print(f"Speedup: {cpu_time/gpu_time:.1f}x")

Real-world example:

Dataset: 50k documents, 50k vocabulary

CPU (16 cores):      ~2 hours per 100 iterations
GPU (A100):          ~3 minutes per 100 iterations
Speedup:             ~40x

Optimizing for Speed

Tips for maximum performance:

  1. Batch Size: Larger batches = better GPU utilization

    # Start with batch_size=256 or 512 on modern GPUs
    model = PF(counts, vocab, num_topics=20, batch_size=512)
    
  2. Multiple GPUs: Distribute across cards (if supported)

    export CUDA_VISIBLE_DEVICES=0,1,2,3
    python script.py  # Uses all 4 GPUs
    
  3. Mixed Precision: Trade accuracy for speed (advanced)

    # Not currently exposed in poisson-topicmodels
    # Future enhancement
    
  4. Profiling: Identify bottlenecks

    import jax
    
    # Enable profiling
    jax.profiling.pluck_counts()
    
    # Train and profile
    model.train_step(num_steps=10, lr=0.01)
    
    # Analyze results
    # Check if data transfer or computation is bottleneck
    

Troubleshooting GPU

Problem: JAX can’t find GPU

jax._src.lib.xla_extension.XlaRuntimeError: CUDA not found

Solution: - Verify CUDA installation: nvcc --version - Reinstall JAX: pip install --upgrade jax[cuda12_cudnn] - Check CUDA_HOME: echo $CUDA_HOME

Problem: GPU out of memory

Solution: - Reduce batch_size: batch_size=64 instead of 256 - Reduce num_topics - Reduce vocabulary size - Process data in chunks

Problem: GPU slower than CPU (!?)

Solution: - GPU overhead for small datasets (< 10k docs) - GPU shines with 100k+ documents - Check GPU utilization (should be >80%) - Increase batch_size to improve utilization

Problem: Training hangs on GPU

Solution: - Timeout issue with GPU - Reduce batch_size or num_topics - Update JAX: pip install --upgrade jax - Check GPU memory: nvidia-smi

Best Practices

Development:

  • Start on CPU for quick iterations

  • Verify results make sense

  • Switch to GPU for final large-scale runs

Production:

  • Always use GPU for meaningful datasets (100k+ docs)

  • Monitor GPU utilization

  • Use optimal batch size (64-512 depending on GPU memory)

Research reproducibility:

  • Document which GPU was used

  • Set random seed (results consistent across runs)

  • GPU results may slightly differ from CPU due to floating-point precision

Scaling to Large Datasets

With GPU, you can now handle:

  • 100k documents: ~10 minutes (vs hours on CPU)

  • 500k documents: ~50 minutes

  • 1M documents: ~2 hours

Example:

# Large dataset
num_docs, num_words = 500_000, 100_000

# GPU can handle this
model = PF(
    counts=counts,
    vocab=vocab,
    num_topics=50,
    batch_size=1024  # Can use large batch on GPU
)

# Train efficiently
params = model.train_step(
    num_steps=200,
    lr=0.01,
)

# Takes ~1 hour instead of whole day

Next Steps

Key Takeaway

GPU acceleration requires zero code changes—just enable it!

Once enabled, training automatically uses GPU for massive speedups.