Utilities¤
Comprehensive utility modules for JAX operations, logging, visualization, I/O, and development tools.
Overview¤
-
JAX Utilities
Device management, PRNG handling, dtype utilities, and Flax helpers
-
Logging & Metrics
MLflow, Weights & Biases, and custom logging integrations
-
Visualization
Attention maps, latent space plots, image grids, and protein visualization
-
Profiling
Memory profiling, performance tracking, and XProf integration
JAX Utilities¤
Device Management¤
from artifex.utils.jax import get_device, get_available_devices
# Get default device
device = get_device() # Returns GPU if available, else CPU
# List all devices
devices = get_available_devices()
print(f"Available: {devices}")
PRNG Handling¤
from artifex.utils.jax import create_prng_key, split_key
# Create a key
key = create_prng_key(42)
# Split for multiple uses
key1, key2, key3 = split_key(key, num=3)
Data Types¤
from artifex.utils.jax import get_dtype, ensure_dtype
# Get appropriate dtype
dtype = get_dtype("float32")
# Convert array to dtype
array = ensure_dtype(array, "bfloat16")
Shape Utilities¤
from artifex.utils.jax import flatten_batch, unflatten_batch
# Flatten batch dimensions
flat, shape = flatten_batch(tensor, num_batch_dims=2)
# Restore batch dimensions
restored = unflatten_batch(flat, shape)
Flax Utilities¤
from artifex.utils.jax import count_params, get_param_shapes
# Count model parameters
num_params = count_params(model)
# Get parameter shapes
shapes = get_param_shapes(model)
Logging & Metrics¤
Logger¤
from artifex.utils.logging import get_logger
logger = get_logger(__name__)
logger.info("Training started")
logger.debug("Batch size: 128")
Weights & Biases¤
from artifex.utils.logging import WandbLogger
logger = WandbLogger(
project="my-project",
name="experiment-001",
config=config_dict,
)
logger.log_metrics({"loss": 0.5, "accuracy": 0.9}, step=100)
logger.log_image("samples", image_array)
MLflow¤
from artifex.utils.logging import MLflowLogger
logger = MLflowLogger(
experiment_name="vae-experiments",
tracking_uri="http://localhost:5000",
)
logger.log_params({"learning_rate": 1e-3})
logger.log_metrics({"loss": 0.5}, step=100)
Metrics Tracking¤
from artifex.utils.logging import MetricsTracker
tracker = MetricsTracker()
tracker.update("loss", 0.5)
tracker.update("loss", 0.4)
avg = tracker.compute("loss") # Returns average
tracker.reset()
Visualization¤
Image Grids¤
from artifex.utils.visualization import create_image_grid, save_image_grid
# Create grid from batch
grid = create_image_grid(images, nrow=8)
# Save to file
save_image_grid(images, "samples.png", nrow=8)
Latent Space Visualization¤
from artifex.utils.visualization import plot_latent_space
# Plot 2D latent space with labels
plot_latent_space(
latents,
labels=labels,
method="tsne", # or "pca", "umap"
save_path="latent_space.png",
)
Attention Visualization¤
from artifex.utils.visualization import visualize_attention
# Visualize attention weights
visualize_attention(
attention_weights,
tokens=tokens,
save_path="attention.png",
)
Plotting¤
from artifex.utils.visualization import plot_training_curves
# Plot loss curves
plot_training_curves(
train_losses=train_losses,
val_losses=val_losses,
save_path="training_curves.png",
)
Protein Visualization¤
from artifex.utils.visualization import visualize_protein_structure
# Visualize protein backbone
visualize_protein_structure(
coordinates=coords,
sequence=sequence,
save_path="protein.png",
)
I/O Utilities¤
File Operations¤
from artifex.utils.io import save_checkpoint, load_checkpoint
# Save model checkpoint
save_checkpoint(model, optimizer, "checkpoint.ckpt")
# Load checkpoint
model, optimizer = load_checkpoint("checkpoint.ckpt")
Format Conversion¤
from artifex.utils.io import convert_format
# Convert between formats
convert_format(
input_path="model.ckpt",
output_path="model.safetensors",
format="safetensors",
)
Serialization¤
from artifex.utils.io import serialize_config, deserialize_config
# Serialize to YAML
yaml_str = serialize_config(config, format="yaml")
# Deserialize from JSON
config = deserialize_config(json_str, format="json")
Profiling¤
Memory Profiling¤
from artifex.utils.profiling import memory_profiler
with memory_profiler() as prof:
output = model(input)
print(f"Peak memory: {prof.peak_memory_mb:.2f} MB")
Performance Profiling¤
from artifex.utils.profiling import profile_function
@profile_function
def train_step(batch):
return model(batch)
# Profiling results printed automatically
XProf Integration¤
from artifex.utils.profiling import start_xprof, stop_xprof
start_xprof(log_dir="profiles/")
# ... training code ...
stop_xprof()
Image Utilities¤
Color Operations¤
from artifex.utils.image import rgb_to_grayscale, normalize_image
grayscale = rgb_to_grayscale(image)
normalized = normalize_image(image, mean=0.5, std=0.5)
Image Metrics¤
from artifex.utils.image import compute_psnr, compute_ssim
psnr = compute_psnr(generated, reference)
ssim = compute_ssim(generated, reference)
Transforms¤
from artifex.utils.image import resize, center_crop, random_flip
resized = resize(image, size=(256, 256))
cropped = center_crop(image, size=(224, 224))
flipped = random_flip(image, key=prng_key)
Numerical Utilities¤
Math Operations¤
from artifex.utils.numerical import log_sum_exp, softmax_temperature
# Numerically stable log-sum-exp
result = log_sum_exp(logits, axis=-1)
# Softmax with temperature
probs = softmax_temperature(logits, temperature=0.7)
Numerical Stability¤
from artifex.utils.numerical import safe_log, safe_divide
# Safe log with epsilon
log_x = safe_log(x, eps=1e-8)
# Safe division
result = safe_divide(a, b, eps=1e-8)
Statistics¤
from artifex.utils.numerical import running_mean, exponential_moving_average
# Compute running statistics
mean = running_mean(values)
ema = exponential_moving_average(values, decay=0.99)
Text Utilities¤
from artifex.utils.text import compute_bleu, compute_rouge
bleu = compute_bleu(predictions, references)
rouge = compute_rouge(predictions, references)
Development Utilities¤
Timer¤
from artifex.utils import Timer
with Timer("training_step"):
output = train_step(batch)
# Prints: training_step took 0.123s
Registry¤
from artifex.utils import Registry
models = Registry("models")
@models.register("my_model")
class MyModel:
pass
model_class = models.get("my_model")
Environment¤
from artifex.utils import get_env, set_env
# Get environment variable with default
value = get_env("MY_VAR", default="default_value")
Dependency Analyzer¤
from artifex.utils import analyze_dependencies
# Analyze module dependencies
deps = analyze_dependencies("artifex.generative_models")
Module Reference¤
| Category | Modules |
|---|---|
| JAX | device, dtype, flax_utils, prng, shapes |
| Logging | logger, metrics, mlflow, wandb, file_utils |
| Visualization | attention_vis, image_grid, latent_space, plotting, protein |
| I/O | file, formats, serialization |
| Profiling | memory, performance, xprof |
| Image | color, metrics, transforms |
| Numerical | math, stability, stats |
| Text | metrics, postprocessing, processing |
| Utils | env, registry, timer, types |
Related Documentation¤
- Training Guide - Using utilities in training
- Logging & Tracking - Experiment tracking
- Performance Profiling - Profiling guide