Skip to content

Getting Started with Bayinx

Welcome to Bayinx (Bayesian inference with JAX), a probabilistic programming language embedded in Python. This guide will help you install the package and quickly overview how Bayinx works, but if you would like a more thorough tutorial check out Basic Usage for those of you unfamiliar to probabilistic programming or Coming From Stan for Stan users. If you are unfamiliar with Python check out the official Python tutorial or another resource to get up to speed with programming in Python.

Installation

Bayinx requires JAX and a few extra libraries in the JAX ecosystem. The easiest way to get started is by installing from PyPi using your favourite python package manager:

uv

# Ensure you are in your project environment
uv add bayinx

This installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:

# Ensure you are in your project environment
uv add 'bayinx[cuda]' # Installs Bayinx with CUDA support

pip

# Ensure you are in your project environment
pip install bayinx

This installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:

# Ensure you are in your project environment
pip install 'bayinx[cuda]' # Installs Bayinx with CUDA support

Defining Models In Bayinx

You can now get started!

Models are defined by writing a class that inherits from the Model base class. For example, we can define a simple model that describes a collection of observations derived from a Normal distribution:

from bayinx.dists import Normal, Exponential
from bayinx.nodes import Continuous, Observed
from bayinx import Model, define
from jaxtyping import Array

class SimpleNormalModel(Model):
    mu: Continuous = define(shape = ())
    std: Continuous[Array] = define(shape = (), lower = 0) # nodes support type hinting

    x: Observed[Array] = define(shape = 'n_obs')

    def model(self, target):
        # Accumulate likelihood
        self.x << Normal(self.mu, self.std)

Parameters are attributes annotated with the Continuous class while any data is annotated with Observed, where both are thin wrappers around an internal type that can be type hinted (e.g., Continuous[T] or Observed[T]). You can then define additional metadata for a node with the define function, for example by assigning shapes define(shape = ...) and bounds define(lower = ..., upper = ...).

Fitting Models With Bayinx

Bayinx uses variational inference with normalizing flows (NFs) to approximate the posterior distribution, where the NF architecture can be customized to your preference. We'll simulate some data for demonstration:

import jax.random as jr

n_obs = 100
true_mu = 10.0
true_std = 5.0

# Simulate data
x_data = jr.normal(jr.key(0), (n_obs, )) * true_std + true_mu

The approximation to the posterior can then be created with the Posterior class and optimized later:

from bayinx import Posterior
from bayinx.flows import DiagAffine

# Construct approximation
posterior = Posterior(
    SimpleNormalModel,
    n_obs = n_obs,
    x = x_data
)
posterior.configure(flowspecs = [DiagAffine()]) # Configure the NF architecture
posterior.fit() # Optimize the approximation

Once fitted, you can sample from the approximated posterior distribution to get Monte Carlo estimates for your parameters:

# Sample the posterior distribution for 'mean'
mu_draws = posterior.sample('mu', int(1e6))

print(f"Analytic Posterior Mean for 'mu': {x_data.mean():.4f}")
print(f"Posterior Mean Estimate for 'mu': {mu_draws.mean():.4f} ± {mu_draws.std()/1000:.4f}")
Analytic Posterior Mean for 'mu': 10.5465
Posterior Mean Estimate for 'mu': 10.5467 ± 0.0005