Getting Started with Bayinx¶
Welcome to Bayinx (Bayesian inference with JAX), a probabilistic programming language embedded in Python. This guide will help you install the package and quickly overview how Bayinx works, but if you would like a more thorough tutorial check out Basic Usage for those of you unfamiliar to probabilistic programming or Coming From Stan for Stan users. If you are unfamiliar with Python check out the official Python tutorial or another resource to get up to speed with programming in Python.
Installation¶
Bayinx requires JAX and a few extra libraries in the JAX ecosystem. The easiest way to get started is by installing from PyPi using your favourite python package manager:
uv¶
This installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:
# Ensure you are in your project environment
uv add 'bayinx[cuda]' # Installs Bayinx with CUDA support
pip¶
This installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:
# Ensure you are in your project environment
pip install 'bayinx[cuda]' # Installs Bayinx with CUDA support
Defining Models In Bayinx¶
You can now get started!
Models are defined by writing a class that inherits from the Model base class.
For example, we can define a simple model that describes a collection of observations derived from a Normal distribution:
from bayinx.dists import Normal, Exponential
from bayinx.nodes import Continuous, Observed
from bayinx import Model, define
from jaxtyping import Array
class SimpleNormalModel(Model):
mu: Continuous = define(shape = ())
std: Continuous[Array] = define(shape = (), lower = 0) # nodes support type hinting
x: Observed[Array] = define(shape = 'n_obs')
def model(self, target):
# Accumulate likelihood
self.x << Normal(self.mu, self.std)
Parameters are attributes annotated with the Continuous class while any data is annotated with Observed, where both are thin wrappers around an internal type that can be type hinted (e.g., Continuous[T] or Observed[T]).
You can then define additional metadata for a node with the define function, for example by assigning shapes define(shape = ...) and bounds define(lower = ..., upper = ...).
Fitting Models With Bayinx¶
Bayinx uses variational inference with normalizing flows (NFs) to approximate the posterior distribution, where the NF architecture can be customized to your preference. We'll simulate some data for demonstration:
import jax.random as jr
n_obs = 100
true_mu = 10.0
true_std = 5.0
# Simulate data
x_data = jr.normal(jr.key(0), (n_obs, )) * true_std + true_mu
The approximation to the posterior can then be created with the Posterior class and optimized later:
from bayinx import Posterior
from bayinx.flows import DiagAffine
# Construct approximation
posterior = Posterior(
SimpleNormalModel,
n_obs = n_obs,
x = x_data
)
posterior.configure(flowspecs = [DiagAffine()]) # Configure the NF architecture
posterior.fit() # Optimize the approximation
Once fitted, you can sample from the approximated posterior distribution to get Monte Carlo estimates for your parameters: