Skip to content

Blackjax Samplers¤

Module: generative_models.core.sampling.blackjax_samplers

Source: generative_models/core/sampling/blackjax_samplers.py

Overview¤

BlackJAX integration module.

This module provides integration with BlackJAX, a library of samplers for JAX. It allows using BlackJAX's advanced MCMC samplers with our distribution framework.

Classes¤

BlackJAXHMC¤

class BlackJAXHMC

BlackJAXMALA¤

class BlackJAXMALA

BlackJAXNUTS¤

class BlackJAXNUTS

BlackJAXSamplerState¤

class BlackJAXSamplerState

Functions¤

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

ensure_scalar_log_prob¤

def ensure_scalar_log_prob()

hmc_sampling¤

def hmc_sampling()

init¤

def init()

init¤

def init()

init¤

def init()

mala_sampling¤

def mala_sampling()

nuts_sampling¤

def nuts_sampling()

scalar_log_prob¤

def scalar_log_prob()

step¤

def step()

step¤

def step()

step¤

def step()

Module Statistics¤

  • Classes: 4
  • Functions: 14
  • Imports: 7