Skip to content

Utilities¤

Comprehensive utility modules for JAX operations, logging, visualization, I/O, and development tools.

Overview¤

  • JAX Utilities


    Device management, PRNG handling, dtype utilities, and Flax helpers

  • Logging & Metrics


    MLflow, Weights & Biases, and custom logging integrations

  • Visualization


    Attention maps, latent space plots, image grids, and protein visualization

  • Profiling


    Memory profiling, performance tracking, and XProf integration

JAX Utilities¤

Device Management¤

from artifex.utils.jax import get_device, get_available_devices

# Get default device
device = get_device()  # Returns GPU if available, else CPU

# List all devices
devices = get_available_devices()
print(f"Available: {devices}")

Device Utilities

PRNG Handling¤

from artifex.utils.jax import create_prng_key, split_key

# Create a key
key = create_prng_key(42)

# Split for multiple uses
key1, key2, key3 = split_key(key, num=3)

PRNG Utilities

Data Types¤

from artifex.utils.jax import get_dtype, ensure_dtype

# Get appropriate dtype
dtype = get_dtype("float32")

# Convert array to dtype
array = ensure_dtype(array, "bfloat16")

Dtype Utilities

Shape Utilities¤

from artifex.utils.jax import flatten_batch, unflatten_batch

# Flatten batch dimensions
flat, shape = flatten_batch(tensor, num_batch_dims=2)

# Restore batch dimensions
restored = unflatten_batch(flat, shape)

Shape Utilities

Flax Utilities¤

from artifex.utils.jax import count_params, get_param_shapes

# Count model parameters
num_params = count_params(model)

# Get parameter shapes
shapes = get_param_shapes(model)

Flax Utilities

Logging & Metrics¤

Logger¤

from artifex.utils.logging import get_logger

logger = get_logger(__name__)
logger.info("Training started")
logger.debug("Batch size: 128")

Logger Reference

Weights & Biases¤

from artifex.utils.logging import WandbLogger

logger = WandbLogger(
    project="my-project",
    name="experiment-001",
    config=config_dict,
)

logger.log_metrics({"loss": 0.5, "accuracy": 0.9}, step=100)
logger.log_image("samples", image_array)

W&B Integration

MLflow¤

from artifex.utils.logging import MLflowLogger

logger = MLflowLogger(
    experiment_name="vae-experiments",
    tracking_uri="http://localhost:5000",
)

logger.log_params({"learning_rate": 1e-3})
logger.log_metrics({"loss": 0.5}, step=100)

MLflow Integration

Metrics Tracking¤

from artifex.utils.logging import MetricsTracker

tracker = MetricsTracker()
tracker.update("loss", 0.5)
tracker.update("loss", 0.4)

avg = tracker.compute("loss")  # Returns average
tracker.reset()

Metrics Tracking

Visualization¤

Image Grids¤

from artifex.utils.visualization import create_image_grid, save_image_grid

# Create grid from batch
grid = create_image_grid(images, nrow=8)

# Save to file
save_image_grid(images, "samples.png", nrow=8)

Image Grid

Latent Space Visualization¤

from artifex.utils.visualization import plot_latent_space

# Plot 2D latent space with labels
plot_latent_space(
    latents,
    labels=labels,
    method="tsne",  # or "pca", "umap"
    save_path="latent_space.png",
)

Latent Space

Attention Visualization¤

from artifex.utils.visualization import visualize_attention

# Visualize attention weights
visualize_attention(
    attention_weights,
    tokens=tokens,
    save_path="attention.png",
)

Attention Visualization

Plotting¤

from artifex.utils.visualization import plot_training_curves

# Plot loss curves
plot_training_curves(
    train_losses=train_losses,
    val_losses=val_losses,
    save_path="training_curves.png",
)

Plotting Utilities

Protein Visualization¤

from artifex.utils.visualization import visualize_protein_structure

# Visualize protein backbone
visualize_protein_structure(
    coordinates=coords,
    sequence=sequence,
    save_path="protein.png",
)

Protein Visualization

I/O Utilities¤

File Operations¤

from artifex.utils.io import save_checkpoint, load_checkpoint

# Save model checkpoint
save_checkpoint(model, optimizer, "checkpoint.ckpt")

# Load checkpoint
model, optimizer = load_checkpoint("checkpoint.ckpt")

File Utilities

Format Conversion¤

from artifex.utils.io import convert_format

# Convert between formats
convert_format(
    input_path="model.ckpt",
    output_path="model.safetensors",
    format="safetensors",
)

Format Utilities

Serialization¤

from artifex.utils.io import serialize_config, deserialize_config

# Serialize to YAML
yaml_str = serialize_config(config, format="yaml")

# Deserialize from JSON
config = deserialize_config(json_str, format="json")

Serialization

Profiling¤

Memory Profiling¤

from artifex.utils.profiling import memory_profiler

with memory_profiler() as prof:
    output = model(input)

print(f"Peak memory: {prof.peak_memory_mb:.2f} MB")

Memory Profiling

Performance Profiling¤

from artifex.utils.profiling import profile_function

@profile_function
def train_step(batch):
    return model(batch)

# Profiling results printed automatically

Performance Profiling

XProf Integration¤

from artifex.utils.profiling import start_xprof, stop_xprof

start_xprof(log_dir="profiles/")
# ... training code ...
stop_xprof()

XProf Integration

Image Utilities¤

Color Operations¤

from artifex.utils.image import rgb_to_grayscale, normalize_image

grayscale = rgb_to_grayscale(image)
normalized = normalize_image(image, mean=0.5, std=0.5)

Color Utilities

Image Metrics¤

from artifex.utils.image import compute_psnr, compute_ssim

psnr = compute_psnr(generated, reference)
ssim = compute_ssim(generated, reference)

Image Metrics

Transforms¤

from artifex.utils.image import resize, center_crop, random_flip

resized = resize(image, size=(256, 256))
cropped = center_crop(image, size=(224, 224))
flipped = random_flip(image, key=prng_key)

Image Transforms

Numerical Utilities¤

Math Operations¤

from artifex.utils.numerical import log_sum_exp, softmax_temperature

# Numerically stable log-sum-exp
result = log_sum_exp(logits, axis=-1)

# Softmax with temperature
probs = softmax_temperature(logits, temperature=0.7)

Math Utilities

Numerical Stability¤

from artifex.utils.numerical import safe_log, safe_divide

# Safe log with epsilon
log_x = safe_log(x, eps=1e-8)

# Safe division
result = safe_divide(a, b, eps=1e-8)

Stability Utilities

Statistics¤

from artifex.utils.numerical import running_mean, exponential_moving_average

# Compute running statistics
mean = running_mean(values)
ema = exponential_moving_average(values, decay=0.99)

Statistics Utilities

Text Utilities¤

from artifex.utils.text import compute_bleu, compute_rouge

bleu = compute_bleu(predictions, references)
rouge = compute_rouge(predictions, references)

Text Metrics

Development Utilities¤

Timer¤

from artifex.utils import Timer

with Timer("training_step"):
    output = train_step(batch)
# Prints: training_step took 0.123s

Timer

Registry¤

from artifex.utils import Registry

models = Registry("models")

@models.register("my_model")
class MyModel:
    pass

model_class = models.get("my_model")

Registry

Environment¤

from artifex.utils import get_env, set_env

# Get environment variable with default
value = get_env("MY_VAR", default="default_value")

Environment

Dependency Analyzer¤

from artifex.utils import analyze_dependencies

# Analyze module dependencies
deps = analyze_dependencies("artifex.generative_models")

Dependency Analyzer

Module Reference¤

Category Modules
JAX device, dtype, flax_utils, prng, shapes
Logging logger, metrics, mlflow, wandb, file_utils
Visualization attention_vis, image_grid, latent_space, plotting, protein
I/O file, formats, serialization
Profiling memory, performance, xprof
Image color, metrics, transforms
Numerical math, stability, stats
Text metrics, postprocessing, processing
Utils env, registry, timer, types