Audio Modality Guide¤
This guide covers working with audio data in Workshop, including audio representations, waveform processing, spectrograms, and best practices for audio-based generative models.
Overview¤
Workshop's audio modality provides a unified interface for working with audio data in different representations: raw waveforms, mel-spectrograms, and STFTs (Short-Time Fourier Transforms).
-
Multiple Representations
Support for raw waveforms, mel-spectrograms, and STFT representations
-
Sample Rate Control
Work with any sample rate from 8kHz to 48kHz and beyond
-
Spectrogram Processing
Built-in mel-spectrogram and STFT computation with configurable parameters
-
Synthetic Datasets
Ready-to-use synthetic audio datasets (sine waves, noise, chirps)
-
Augmentation
Time-domain and frequency-domain audio augmentation techniques
-
JAX-Native
Full JAX compatibility with JIT compilation and GPU acceleration
Audio Representations¤
Supported Formats¤
Workshop supports three audio representations:
from workshop.generative_models.modalities.audio.base import AudioRepresentation
# Raw waveform (time-domain signal)
AudioRepresentation.RAW_WAVEFORM
# Mel-spectrogram (perceptually-scaled frequency representation)
AudioRepresentation.MEL_SPECTROGRAM
# STFT (Short-Time Fourier Transform)
AudioRepresentation.STFT
Configuring Audio Modality¤
from workshop.generative_models.modalities.audio import (
AudioModality,
AudioModalityConfig,
AudioRepresentation
)
from flax import nnx
# Initialize RNG
rngs = nnx.Rngs(0)
# Raw waveform configuration
waveform_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=16000, # 16 kHz
duration=2.0, # 2 seconds
normalize=True
)
waveform_modality = AudioModality(config=waveform_config, rngs=rngs)
# Mel-spectrogram configuration
mel_config = AudioModalityConfig(
representation=AudioRepresentation.MEL_SPECTROGRAM,
sample_rate=16000,
n_mel_channels=80, # Number of mel bins
hop_length=256, # Hop size for STFT
n_fft=1024, # FFT size
duration=2.0,
normalize=True
)
mel_modality = AudioModality(config=mel_config, rngs=rngs)
# STFT configuration
stft_config = AudioModalityConfig(
representation=AudioRepresentation.STFT,
sample_rate=22050, # CD-quality / 2
n_fft=2048,
hop_length=512,
duration=3.0,
normalize=True
)
stft_modality = AudioModality(config=stft_config, rngs=rngs)
Audio Shape Properties¤
# Raw waveform
print(f"Time steps: {waveform_modality.n_time_steps}") # 32000 (16000 * 2.0)
print(f"Output shape: {waveform_modality.output_shape}") # (32000,)
# Mel-spectrogram
print(f"Time frames: {mel_modality.n_time_frames}") # 125 (32000 // 256)
print(f"Output shape: {mel_modality.output_shape}") # (80, 125)
# STFT
print(f"Frequency bins: {stft_config.n_fft // 2 + 1}") # 1025
print(f"Output shape: {stft_modality.output_shape}") # (1025, time_frames)
Audio Datasets¤
Synthetic Audio Datasets¤
Workshop provides synthetic audio datasets for testing and development:
from workshop.generative_models.modalities.audio.datasets import (
SyntheticAudioDataset,
create_audio_dataset
)
# Configure audio
audio_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=16000,
duration=1.0,
normalize=True
)
# Create synthetic dataset
audio_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=1000,
audio_types=["sine", "noise", "chirp"],
name="SyntheticAudio"
)
# Access sample
sample = audio_dataset[0]
print(sample["audio"].shape) # (16000,) - 1 second at 16kHz
print(sample["audio_type"]) # "sine" or "noise" or "chirp"
print(sample["sample_rate"]) # 16000
print(sample["duration"]) # 1.0
Audio Types¤
Sine Waves - Pure tones at random frequencies:
audio_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=1000,
audio_types=["sine"]
)
# Sine waves have:
# - Random frequencies (200-800 Hz)
# - Fixed amplitude (0.5)
# - Clean sinusoidal waveform
White Noise - Random Gaussian noise:
audio_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=1000,
audio_types=["noise"]
)
# White noise has:
# - Gaussian distribution
# - Fixed amplitude (0.3)
# - Flat frequency spectrum
Chirp Signals - Linear frequency sweeps:
audio_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=1000,
audio_types=["chirp"]
)
# Chirps have:
# - Linear frequency sweep (200-800 Hz)
# - Smooth transitions
# - Useful for testing time-frequency representations
Mixed Audio Types¤
# Mix multiple audio types
mixed_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=3000,
audio_types=["sine", "noise", "chirp"]
)
# Dataset will cycle through types
# Sample 0: sine
# Sample 1: noise
# Sample 2: chirp
# Sample 3: sine (cycles back)
Batching Audio Data¤
# Get batch using collate function
batch = audio_dataset.collate_fn([
audio_dataset[0],
audio_dataset[1],
audio_dataset[2],
audio_dataset[3]
])
print(batch["audio"].shape) # (4, 16000)
print(len(batch["audio_type"])) # 4
print(batch["sample_rate"]) # [16000, 16000, 16000, 16000]
# Iterate through dataset
for i, sample in enumerate(audio_dataset):
if i >= 5:
break
print(f"Sample {i}: {sample['audio_type']}, shape: {sample['audio'].shape}")
Factory Function¤
# Create dataset using factory
dataset = create_audio_dataset(
dataset_type="synthetic",
config=audio_config,
n_samples=5000,
audio_types=["sine", "noise"]
)
# With custom parameters
custom_dataset = create_audio_dataset(
dataset_type="synthetic",
config=None, # Will use defaults
n_samples=1000,
audio_types=["chirp"]
)
Audio Preprocessing¤
Normalization¤
import jax.numpy as jnp
def normalize_audio(audio: jax.Array) -> jax.Array:
"""Normalize audio to [-1, 1] range.
Args:
audio: Input audio waveform
Returns:
Normalized audio
"""
max_val = jnp.max(jnp.abs(audio))
return jnp.where(max_val > 0, audio / max_val, audio)
def rms_normalize(audio: jax.Array, target_rms: float = 0.1) -> jax.Array:
"""Normalize audio by RMS (root mean square) energy.
Args:
audio: Input audio waveform
target_rms: Target RMS value
Returns:
RMS-normalized audio
"""
rms = jnp.sqrt(jnp.mean(audio ** 2))
return audio * (target_rms / (rms + 1e-8))
# Usage
audio = jnp.array([...]) # Raw audio
normalized = normalize_audio(audio)
rms_normalized = rms_normalize(audio, target_rms=0.1)
Resampling¤
import jax
import jax.numpy as jnp
def resample_audio(
audio: jax.Array,
orig_sr: int,
target_sr: int
) -> jax.Array:
"""Resample audio to different sample rate.
Args:
audio: Input audio waveform
orig_sr: Original sample rate
target_sr: Target sample rate
Returns:
Resampled audio
"""
# Simple linear interpolation resampling
orig_length = len(audio)
target_length = int(orig_length * target_sr / orig_sr)
# Create interpolation indices
orig_indices = jnp.linspace(0, orig_length - 1, target_length)
# Interpolate
resampled = jnp.interp(orig_indices, jnp.arange(orig_length), audio)
return resampled
# Usage
audio = jnp.array([...]) # Audio at 16kHz
audio_22k = resample_audio(audio, orig_sr=16000, target_sr=22050)
audio_8k = resample_audio(audio, orig_sr=16000, target_sr=8000)
Duration Adjustment¤
def trim_or_pad(
audio: jax.Array,
target_length: int,
pad_value: float = 0.0
) -> jax.Array:
"""Trim or pad audio to target length.
Args:
audio: Input audio waveform
target_length: Target number of samples
pad_value: Value to use for padding
Returns:
Trimmed or padded audio
"""
current_length = len(audio)
if current_length >= target_length:
# Trim
return audio[:target_length]
else:
# Pad
padding = jnp.full((target_length - current_length,), pad_value)
return jnp.concatenate([audio, padding])
# Usage
audio = jnp.array([...]) # Variable length audio
fixed_length = trim_or_pad(audio, target_length=16000)
print(fixed_length.shape) # (16000,)
Spectrogram Processing¤
Computing Spectrograms¤
import jax.numpy as jnp
def compute_stft(
audio: jax.Array,
n_fft: int = 1024,
hop_length: int = 256,
window: str = "hann"
) -> jax.Array:
"""Compute Short-Time Fourier Transform.
Args:
audio: Input audio waveform
n_fft: FFT size
hop_length: Number of samples between frames
window: Window function
Returns:
STFT (n_freq_bins, n_frames, 2) for real and imaginary parts
"""
# Create window
if window == "hann":
win = jnp.hanning(n_fft)
else:
win = jnp.ones(n_fft)
# Number of frames
n_frames = (len(audio) - n_fft) // hop_length + 1
# Compute STFT
stft = []
for i in range(n_frames):
start = i * hop_length
frame = audio[start:start + n_fft] * win
# Apply FFT
fft_result = jnp.fft.rfft(frame)
stft.append(fft_result)
stft = jnp.stack(stft, axis=1) # (n_freq_bins, n_frames)
return stft
def compute_magnitude_spectrogram(
audio: jax.Array,
n_fft: int = 1024,
hop_length: int = 256
) -> jax.Array:
"""Compute magnitude spectrogram.
Args:
audio: Input audio waveform
n_fft: FFT size
hop_length: Hop size
Returns:
Magnitude spectrogram (n_freq_bins, n_frames)
"""
stft = compute_stft(audio, n_fft, hop_length)
return jnp.abs(stft)
def compute_power_spectrogram(
audio: jax.Array,
n_fft: int = 1024,
hop_length: int = 256
) -> jax.Array:
"""Compute power spectrogram.
Args:
audio: Input audio waveform
n_fft: FFT size
hop_length: Hop size
Returns:
Power spectrogram (n_freq_bins, n_frames)
"""
mag_spec = compute_magnitude_spectrogram(audio, n_fft, hop_length)
return mag_spec ** 2
# Usage
audio = jnp.array([...]) # Audio waveform
# Compute spectrograms
stft = compute_stft(audio, n_fft=1024, hop_length=256)
mag_spec = compute_magnitude_spectrogram(audio, n_fft=1024, hop_length=256)
power_spec = compute_power_spectrogram(audio, n_fft=1024, hop_length=256)
print(f"STFT shape: {stft.shape}") # (513, n_frames)
print(f"Magnitude spectrogram: {mag_spec.shape}") # (513, n_frames)
print(f"Power spectrogram: {power_spec.shape}") # (513, n_frames)
Mel-Spectrogram¤
def hz_to_mel(hz: float) -> float:
"""Convert Hz to mel scale."""
return 2595 * jnp.log10(1 + hz / 700)
def mel_to_hz(mel: float) -> float:
"""Convert mel scale to Hz."""
return 700 * (10 ** (mel / 2595) - 1)
def create_mel_filterbank(
n_fft: int,
n_mels: int,
sample_rate: int,
fmin: float = 0.0,
fmax: float = None
) -> jax.Array:
"""Create mel-scale filterbank.
Args:
n_fft: FFT size
n_mels: Number of mel bins
sample_rate: Audio sample rate
fmin: Minimum frequency
fmax: Maximum frequency (defaults to sample_rate / 2)
Returns:
Mel filterbank matrix (n_mels, n_freq_bins)
"""
if fmax is None:
fmax = sample_rate / 2
# Convert to mel scale
mel_min = hz_to_mel(fmin)
mel_max = hz_to_mel(fmax)
# Create mel points
mel_points = jnp.linspace(mel_min, mel_max, n_mels + 2)
hz_points = mel_to_hz(mel_points)
# Create frequency bins
n_freq_bins = n_fft // 2 + 1
freq_bins = jnp.linspace(0, sample_rate / 2, n_freq_bins)
# Create filterbank
filterbank = jnp.zeros((n_mels, n_freq_bins))
for i in range(n_mels):
left = hz_points[i]
center = hz_points[i + 1]
right = hz_points[i + 2]
# Triangular filter
for j, freq in enumerate(freq_bins):
if left <= freq <= center:
filterbank = filterbank.at[i, j].set((freq - left) / (center - left))
elif center <= freq <= right:
filterbank = filterbank.at[i, j].set((right - freq) / (right - center))
return filterbank
def compute_mel_spectrogram(
audio: jax.Array,
sample_rate: int = 16000,
n_fft: int = 1024,
hop_length: int = 256,
n_mels: int = 80
) -> jax.Array:
"""Compute mel-spectrogram.
Args:
audio: Input audio waveform
sample_rate: Sample rate
n_fft: FFT size
hop_length: Hop size
n_mels: Number of mel bins
Returns:
Mel-spectrogram (n_mels, n_frames)
"""
# Compute power spectrogram
power_spec = compute_power_spectrogram(audio, n_fft, hop_length)
# Create mel filterbank
mel_fb = create_mel_filterbank(n_fft, n_mels, sample_rate)
# Apply filterbank
mel_spec = jnp.dot(mel_fb, power_spec)
# Convert to log scale (dB)
log_mel_spec = 10 * jnp.log10(mel_spec + 1e-10)
return log_mel_spec
# Usage
audio = jnp.array([...]) # Audio waveform
mel_spec = compute_mel_spectrogram(
audio,
sample_rate=16000,
n_fft=1024,
hop_length=256,
n_mels=80
)
print(f"Mel-spectrogram shape: {mel_spec.shape}") # (80, n_frames)
Audio Augmentation¤
Time-Domain Augmentations¤
import jax
import jax.numpy as jnp
def time_shift(audio: jax.Array, key, max_shift: int = 1600):
"""Randomly shift audio in time.
Args:
audio: Input audio waveform
key: Random key
max_shift: Maximum shift in samples
Returns:
Time-shifted audio
"""
shift = jax.random.randint(key, (), -max_shift, max_shift)
return jnp.roll(audio, int(shift))
def time_stretch(audio: jax.Array, key, rate_range=(0.8, 1.2)):
"""Randomly stretch or compress audio in time.
Args:
audio: Input audio waveform
key: Random key
rate_range: (min_rate, max_rate) for stretching
Returns:
Time-stretched audio
"""
rate = jax.random.uniform(key, minval=rate_range[0], maxval=rate_range[1])
# Simple resampling for time stretching
orig_length = len(audio)
new_length = int(orig_length / rate)
indices = jnp.linspace(0, orig_length - 1, new_length)
stretched = jnp.interp(indices, jnp.arange(orig_length), audio)
# Pad or trim to original length
if len(stretched) < orig_length:
padding = jnp.zeros(orig_length - len(stretched))
stretched = jnp.concatenate([stretched, padding])
else:
stretched = stretched[:orig_length]
return stretched
def add_gaussian_noise(audio: jax.Array, key, noise_level: float = 0.005):
"""Add Gaussian noise to audio.
Args:
audio: Input audio waveform
key: Random key
noise_level: Standard deviation of noise
Returns:
Noisy audio
"""
noise = noise_level * jax.random.normal(key, audio.shape)
return audio + noise
def random_gain(audio: jax.Array, key, gain_range=(0.7, 1.3)):
"""Apply random gain (amplitude scaling).
Args:
audio: Input audio waveform
key: Random key
gain_range: (min_gain, max_gain)
Returns:
Gain-adjusted audio
"""
gain = jax.random.uniform(key, minval=gain_range[0], maxval=gain_range[1])
return audio * gain
# Usage
audio = jnp.array([...]) # Audio waveform
key = jax.random.key(0)
keys = jax.random.split(key, 4)
shifted = time_shift(audio, keys[0], max_shift=1600)
stretched = time_stretch(audio, keys[1], rate_range=(0.9, 1.1))
noisy = add_gaussian_noise(audio, keys[2], noise_level=0.005)
gained = random_gain(audio, keys[3], gain_range=(0.8, 1.2))
Frequency-Domain Augmentations¤
def frequency_mask(
spectrogram: jax.Array,
key,
num_masks: int = 1,
mask_param: int = 10
):
"""Apply frequency masking to spectrogram (SpecAugment).
Args:
spectrogram: Input spectrogram (freq_bins, time_frames)
key: Random key
num_masks: Number of masks to apply
mask_param: Maximum width of mask
Returns:
Masked spectrogram
"""
masked_spec = spectrogram.copy()
n_freq_bins = spectrogram.shape[0]
keys = jax.random.split(key, num_masks * 2)
for i in range(num_masks):
# Random mask width
mask_width = jax.random.randint(keys[2*i], (), 0, mask_param)
# Random starting frequency
start_freq = jax.random.randint(
keys[2*i + 1], (), 0, n_freq_bins - mask_width
)
# Apply mask
masked_spec = masked_spec.at[start_freq:start_freq+mask_width, :].set(0)
return masked_spec
def time_mask(
spectrogram: jax.Array,
key,
num_masks: int = 1,
mask_param: int = 10
):
"""Apply time masking to spectrogram (SpecAugment).
Args:
spectrogram: Input spectrogram (freq_bins, time_frames)
key: Random key
num_masks: Number of masks to apply
mask_param: Maximum width of mask
Returns:
Masked spectrogram
"""
masked_spec = spectrogram.copy()
n_time_frames = spectrogram.shape[1]
keys = jax.random.split(key, num_masks * 2)
for i in range(num_masks):
# Random mask width
mask_width = jax.random.randint(keys[2*i], (), 0, mask_param)
# Random starting time
start_time = jax.random.randint(
keys[2*i + 1], (), 0, n_time_frames - mask_width
)
# Apply mask
masked_spec = masked_spec.at[:, start_time:start_time+mask_width].set(0)
return masked_spec
def spec_augment(
spectrogram: jax.Array,
key,
freq_masks: int = 2,
time_masks: int = 2,
freq_mask_param: int = 15,
time_mask_param: int = 20
):
"""Apply SpecAugment (frequency + time masking).
Args:
spectrogram: Input spectrogram
key: Random key
freq_masks: Number of frequency masks
time_masks: Number of time masks
freq_mask_param: Max frequency mask width
time_mask_param: Max time mask width
Returns:
Augmented spectrogram
"""
keys = jax.random.split(key, 2)
# Apply frequency masking
spec = frequency_mask(spectrogram, keys[0], freq_masks, freq_mask_param)
# Apply time masking
spec = time_mask(spec, keys[1], time_masks, time_mask_param)
return spec
# Usage
mel_spec = compute_mel_spectrogram(audio)
key = jax.random.key(0)
augmented_spec = spec_augment(
mel_spec,
key,
freq_masks=2,
time_masks=2,
freq_mask_param=15,
time_mask_param=20
)
Complete Augmentation Pipeline¤
@jax.jit
def augment_audio(audio: jax.Array, key):
"""Apply comprehensive audio augmentation pipeline.
Args:
audio: Input audio waveform
key: Random key
Returns:
Augmented audio
"""
keys = jax.random.split(key, 5)
# Time-domain augmentations
audio = time_shift(audio, keys[0], max_shift=800)
audio = add_gaussian_noise(audio, keys[1], noise_level=0.005)
audio = random_gain(audio, keys[2], gain_range=(0.8, 1.2))
audio = time_stretch(audio, keys[3], rate_range=(0.95, 1.05))
# Normalize after augmentation
audio = normalize_audio(audio)
return audio
def augment_audio_batch(audio_batch: jax.Array, key):
"""Augment batch of audio waveforms.
Args:
audio_batch: Batch of audio (N, n_samples)
key: Random key
Returns:
Augmented batch
"""
batch_size = audio_batch.shape[0]
keys = jax.random.split(key, batch_size)
# Vectorize over batch
augmented = jax.vmap(augment_audio)(audio_batch, keys)
return augmented
# Usage in training
key = jax.random.key(0)
for batch in data_loader:
key, subkey = jax.random.split(key)
augmented_audio = augment_audio_batch(batch["audio"], subkey)
# Use augmented_audio for training
Working with Different Sample Rates¤
Common Sample Rates¤
# Telephone quality (8 kHz)
telephone_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=8000,
duration=2.0
)
# Standard speech (16 kHz)
speech_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=16000,
duration=2.0
)
# High-quality speech (22.05 kHz)
hq_speech_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=22050,
duration=2.0
)
# CD quality (44.1 kHz)
cd_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=44100,
duration=2.0
)
# Studio quality (48 kHz)
studio_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=48000,
duration=2.0
)
Sample Rate Conversion¤
def convert_sample_rate(
audio: jax.Array,
orig_sr: int,
target_sr: int
) -> jax.Array:
"""Convert audio to different sample rate.
Args:
audio: Input audio
orig_sr: Original sample rate
target_sr: Target sample rate
Returns:
Resampled audio
"""
return resample_audio(audio, orig_sr, target_sr)
# Usage
audio_16k = jnp.array([...]) # Audio at 16 kHz
# Upsample to 22.05 kHz
audio_22k = convert_sample_rate(audio_16k, 16000, 22050)
# Downsample to 8 kHz
audio_8k = convert_sample_rate(audio_16k, 16000, 8000)
print(f"16 kHz: {len(audio_16k)} samples")
print(f"22.05 kHz: {len(audio_22k)} samples")
print(f"8 kHz: {len(audio_8k)} samples")
Complete Examples¤
Example 1: Audio Generation Dataset¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.modalities.audio import (
AudioModality,
AudioModalityConfig,
AudioRepresentation
)
from workshop.generative_models.modalities.audio.datasets import (
SyntheticAudioDataset
)
# Setup
rngs = nnx.Rngs(0)
# Configure
audio_config = AudioModalityConfig(
representation=AudioRepresentation.RAW_WAVEFORM,
sample_rate=16000,
duration=1.0,
normalize=True
)
modality = AudioModality(config=audio_config, rngs=rngs)
# Create datasets
train_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=10000,
audio_types=["sine", "noise", "chirp"]
)
val_dataset = SyntheticAudioDataset(
config=audio_config,
n_samples=1000,
audio_types=["sine", "noise", "chirp"]
)
# Training loop
batch_size = 32
num_epochs = 10
key = jax.random.key(42)
for epoch in range(num_epochs):
num_batches = len(train_dataset) // batch_size
for i in range(num_batches):
# Get batch
batch_samples = [train_dataset[j] for j in range(i*batch_size, (i+1)*batch_size)]
batch = train_dataset.collate_fn(batch_samples)
audio_batch = batch["audio"]
# Apply augmentation
key, subkey = jax.random.split(key)
augmented = augment_audio_batch(audio_batch, subkey)
# Training step
# loss = train_step(model, augmented)
# Validation (no augmentation)
val_batches = len(val_dataset) // batch_size
for i in range(val_batches):
val_samples = [val_dataset[j] for j in range(i*batch_size, (i+1)*batch_size)]
val_batch = val_dataset.collate_fn(val_samples)
# val_loss = validate_step(model, val_batch["audio"])
print(f"Epoch {epoch + 1}/{num_epochs} complete")
Example 2: Mel-Spectrogram Training¤
# Configure for mel-spectrograms
mel_config = AudioModalityConfig(
representation=AudioRepresentation.MEL_SPECTROGRAM,
sample_rate=16000,
n_mel_channels=80,
hop_length=256,
n_fft=1024,
duration=2.0,
normalize=True
)
mel_modality = AudioModality(config=mel_config, rngs=rngs)
# Create dataset
mel_dataset = SyntheticAudioDataset(
config=mel_config,
n_samples=5000,
audio_types=["sine", "chirp"]
)
# Training with spectrograms
for epoch in range(num_epochs):
for i in range(len(mel_dataset) // batch_size):
# Get audio batch
batch_samples = [mel_dataset[j] for j in range(i*batch_size, (i+1)*batch_size)]
batch = mel_dataset.collate_fn(batch_samples)
audio_batch = batch["audio"]
# Compute mel-spectrograms
mel_specs = []
for audio in audio_batch:
mel_spec = compute_mel_spectrogram(
audio,
sample_rate=16000,
n_fft=1024,
hop_length=256,
n_mels=80
)
mel_specs.append(mel_spec)
mel_batch = jnp.stack(mel_specs)
# Apply SpecAugment
key, subkey = jax.random.split(key)
augmented_specs = jax.vmap(lambda spec, k: spec_augment(spec, k))(
mel_batch,
jax.random.split(subkey, batch_size)
)
# Training step
# loss = train_step(model, augmented_specs)
Example 3: Custom Audio Dataset¤
from typing import Iterator
from workshop.generative_models.modalities.base import BaseDataset
class CustomAudioDataset(BaseDataset):
"""Custom audio dataset loading from files."""
def __init__(
self,
config: AudioModalityConfig,
audio_paths: list[str],
split: str = "train",
*,
rngs: nnx.Rngs,
):
super().__init__(config, split, rngs=rngs)
self.audio_paths = audio_paths
self.audio_samples = self._load_audio_files()
def _load_audio_files(self):
"""Load audio files."""
samples = []
for path in self.audio_paths:
# In practice: use librosa, soundfile, etc.
# For demo, generate synthetic
audio = jax.random.uniform(
jax.random.key(hash(path)),
(16000,)
)
samples.append(audio)
return samples
def __len__(self) -> int:
return len(self.audio_samples)
def __iter__(self) -> Iterator[dict[str, jax.Array]]:
for i, audio in enumerate(self.audio_samples):
yield {
"audio": audio,
"index": jnp.array(i),
"path": self.audio_paths[i]
}
def get_batch(self, batch_size: int) -> dict[str, jax.Array]:
key = self.rngs.sample() if "sample" in self.rngs else jax.random.key(0)
indices = jax.random.randint(key, (batch_size,), 0, len(self))
batch_audio = [self.audio_samples[int(idx)] for idx in indices]
batch_paths = [self.audio_paths[int(idx)] for idx in indices]
return {
"audio": jnp.stack(batch_audio),
"indices": indices,
"paths": batch_paths
}
# Usage
audio_paths = ["/path/to/audio1.wav", "/path/to/audio2.wav", ...]
custom_dataset = CustomAudioDataset(
config=audio_config,
audio_paths=audio_paths,
rngs=rngs
)
Best Practices¤
DO¤
Audio Loading
- Use appropriate sample rate for your task
- Normalize audio to consistent amplitude
- Trim or pad to consistent duration for batching
- Validate audio shapes before training
- Cache processed audio when possible
- Use synthetic datasets for testing pipelines
Spectrograms
- Choose appropriate n_fft and hop_length for task
- Use mel-spectrograms for perceptual tasks
- Apply log scaling to spectrograms
- Consider phase information for reconstruction tasks
- Use appropriate number of mel bins (typically 40-128)
Augmentation
- Apply augmentation only during training
- Balance augmentation strength with audio quality
- Use SpecAugment for spectrogram-based models
- Test augmentations by listening to samples
- Use JIT compilation for performance
DON'T¤
Common Mistakes
- Mix different sample rates in same batch
- Forget to normalize audio amplitudes
- Apply augmentation during validation
- Use non-JAX operations in data pipeline
- Load full audio files if working with clips
- Ignore clipping (values > 1.0 or < -1.0)
Performance Issues
- Load audio from disk during training
- Compute spectrograms on-the-fly every epoch
- Use Python loops for audio processing
- Keep multiple copies of audio in memory
- Use very long audio clips on limited GPU memory
Quality Issues
- Over-augment audio (too much distortion)
- Use inappropriate sample rates
- Mix time-domain and frequency-domain representations
- Ignore audio phase for waveform reconstruction
- Apply excessive noise
Summary¤
This guide covered:
- Audio representations - Waveforms, mel-spectrograms, STFT
- Audio datasets - Synthetic datasets with various audio types
- Preprocessing - Normalization, resampling, duration adjustment
- Spectrograms - Computing STFT, magnitude, power, and mel-spectrograms
- Augmentation - Time-domain and frequency-domain techniques
- Sample rates - Working with different sample rates
- Complete examples - Training pipelines and custom datasets
- Best practices - DOs and DON'Ts for audio data
Next Steps¤
-
Working with multiple modalities and aligned multi-modal datasets
-
Deep dive into image datasets, preprocessing, and augmentation
-
Learn about text tokenization, vocabulary management, and sequences
-
Complete API documentation for all dataset classes and functions