Geometric Model Loss Functions Demo¤
Comprehensive demonstration of loss functions for geometric models including point clouds, meshes, and voxel grids.
Files¤
- Python Script:
geometric_losses_demo.py - Jupyter Notebook:
geometric_losses_demo.ipynb
Quick Start¤
# Run the Python script
python examples/generative_models/geometric/geometric_losses_demo.py
# Or use Jupyter notebook
jupyter notebook examples/generative_models/geometric/geometric_losses_demo.ipynb
Overview¤
This example provides a comprehensive tour of loss functions used for different 3D geometric representations. Understanding these losses is crucial for training effective generative models for 3D data.
Learning Objectives¤
- Understand permutation-invariant losses for point clouds
- Learn composite loss functions for meshes
- Explore specialized losses for voxel grids
- See how to configure and balance different loss components
- Compare different loss types for the same representation
Prerequisites¤
- Basic understanding of 3D representations (point clouds, meshes, voxels)
- Familiarity with loss functions in machine learning
- Knowledge of JAX and Artifex's configuration system
Background: 3D Representations and Their Losses¤
Why Geometric Losses Matter¤
Standard image losses (MSE, L1) don't work well for 3D data because:
- Point clouds are unordered: Permuting points shouldn't change the loss
- Meshes have topology: Need to preserve connectivity and smoothness
- Voxels are sparse: Most voxels are empty, causing class imbalance
The Three Representations¤
Point Clouds¤
Unordered sets of 3D points: \({(x_i, y_i, z_i)}_{i=1}^N\)
- Compact representation
- Permutation invariance required
- Use Chamfer Distance or Earth Mover's Distance
Meshes¤
Vertices connected by edges and faces
- Explicit topology
- Need vertex, normal, and edge losses
- Balancing multiple objectives
Voxels¤
Regular 3D grids with occupancy values
- Like 3D images
- Sparse (mostly empty)
- Use BCE, Focal, or Dice loss
Code Walkthrough¤
1. Point Cloud Losses¤
Chamfer Distance¤
The Chamfer Distance is the workhorse of point cloud generation. It measures how well two point sets match by finding nearest neighbors:
Key properties:
- Permutation invariant
- Fast to compute (O(N² ) with optimizations)
- Good for most applications
from artifex.generative_models.core.configuration import (
PointCloudConfig,
PointCloudNetworkConfig,
)
network_config = PointCloudNetworkConfig(
name="chamfer_network",
hidden_dims=(64,), # Tuple for frozen dataclass
activation="gelu",
embed_dim=64,
num_heads=4,
num_layers=2,
dropout_rate=0.1,
)
chamfer_config = PointCloudConfig(
name="chamfer_point_cloud",
network=network_config,
num_points=125,
loss_type="chamfer", # Chamfer distance loss
dropout_rate=0.1,
)
Earth Mover's Distance (EMD)¤
EMD finds the optimal transport plan between point sets. More accurate but slower:
Where \(\phi\) is a bijection between X and Y.
When to use:
- Quality is more important than speed
- Small point clouds (<1000 points)
- Fine geometric details matter
# Same network config, different loss type
earth_mover_config = PointCloudConfig(
name="earth_mover_point_cloud",
network=network_config,
num_points=125,
loss_type="earth_mover", # EMD loss
dropout_rate=0.1,
)
2. Mesh Losses¤
Meshes require balancing multiple geometric properties:
Vertex Loss: L2 distance between vertex positions
Normal Loss: Ensures surface smoothness
Edge Loss: Preserves edge lengths
Configuring Weights¤
from artifex.generative_models.core.configuration import (
MeshConfig,
MeshNetworkConfig,
)
mesh_network = MeshNetworkConfig(
name="mesh_network",
hidden_dims=(128, 64), # Tuple for frozen dataclass
activation="gelu",
)
# Smooth surfaces (e.g., CAD models)
normal_config = MeshConfig(
name="smooth_mesh",
network=mesh_network,
num_vertices=512,
vertex_loss_weight=0.5, # Reduce vertex constraint
normal_loss_weight=1.0, # Emphasize smoothness
edge_loss_weight=0.1, # Light edge preservation
)
# Sharp edges (e.g., furniture)
edge_config = MeshConfig(
name="sharp_mesh",
network=mesh_network,
num_vertices=512,
vertex_loss_weight=0.5,
normal_loss_weight=0.1, # Less smoothing
edge_loss_weight=1.0, # Strong edge preservation
)
3. Voxel Losses¤
Voxel grids can use image-like losses, but some are better for 3D:
Binary Cross-Entropy (BCE)¤
Standard loss for binary voxels:
Best for:
- Balanced datasets (50% occupied voxels)
- Dense 3D shapes
Focal Loss¤
Down-weights easy examples, focuses on hard ones:
Where \(p_t = \hat{y}_i\) if \(y_i=1\), else \(1-\hat{y}_i\).
Best for:
- Imbalanced data (sparse objects)
- \(\gamma=2.0\) is typical
- Higher \(\gamma\) → more focus on hard examples
Dice Loss¤
Directly optimizes overlap (similar to IoU):
Best for:
- Segmentation-like tasks
- Maximizing shape overlap
- Handles class imbalance well
Comparison¤
from artifex.generative_models.core.configuration import (
VoxelConfig,
VoxelNetworkConfig,
)
voxel_network = VoxelNetworkConfig(
name="voxel_network",
hidden_dims=(64, 32), # Required base config
activation="relu",
base_channels=32, # Base number of 3D CNN channels
num_layers=4, # Number of 3D convolutional layers
)
# Dense shapes → BCE
bce_config = VoxelConfig(
name="dense_voxel",
network=voxel_network,
resolution=16,
loss_type="bce",
)
# Sparse shapes → Focal
focal_config = VoxelConfig(
name="sparse_voxel",
network=voxel_network,
resolution=16,
loss_type="focal",
focal_gamma=2.0, # Adjust based on sparsity
)
# Overlap optimization → Dice
dice_config = VoxelConfig(
name="overlap_voxel",
network=voxel_network,
resolution=16,
loss_type="dice",
)
Expected Output¤
===== Point Cloud Loss Functions Demo =====
Chamfer distance loss: {'total_loss': 2.92, 'mse_loss': 2.92}
Earth Mover distance loss: {'total_loss': 3.61, 'mse_loss': 3.61}
===== Mesh Loss Functions Demo =====
Default model vertex weight: 1.0
Default model normal weight: 1.0
Default model edge weight: 1.0
Normal-focused model vertex weight: 0.5
Normal-focused model normal weight: 1.0
Normal-focused model edge weight: 0.1
===== Voxel Loss Functions Demo =====
Binary cross-entropy loss: {'total_loss': 0.68, ...}
Focal loss (gamma=2.0): {'total_loss': 0.15, ...}
Dice loss: {'total_loss': 0.42, ...}
Loss function demos completed!
Key Concepts¤
Permutation Invariance¤
Point cloud losses must be invariant to point ordering:
# These should have the same loss
points_A = [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
points_B = [[2, 2, 2], [0, 0, 0], [1, 1, 1]] # Same points, different order
loss(points_A, target) == loss(points_B, target) # Must be true
Loss Component Balancing¤
For composite losses (meshes), balance is key:
- Start with equal weights (1.0, 1.0, 1.0)
- Identify the most important property (smoothness vs sharp edges)
- Increase weight for that component
- Reduce others proportionally
- Validate on test shapes
Class Imbalance in Voxels¤
Voxel grids are typically 90-99% empty:
# Sparse object (5% occupied)
occupancy_ratio = 0.05
# BCE: Treats all voxels equally → biased toward empty
# Focal (γ=2): Down-weights easy empties → balanced
# Dice: Focuses on overlap → invariant to sparsity
Experiments to Try¤
- Compare Chamfer vs EMD
Generate the same point cloud with both losses and compare quality/speed
- Mesh Weight Tuning
Try different weight combinations for different mesh types (organic vs geometric)
- Voxel Sparsity Study
Compare BCE, Focal, Dice on grids with 1%, 10%, 50% occupancy
- Focal Gamma Sweep
Test \(\gamma \in [0.5, 1.0, 2.0, 5.0]\) on sparse voxels
- Visualization
Plot generated shapes with different losses to see visual differences
Next Steps¤
Explore related examples to deepen your understanding:
-
Geometric Models Overview
Learn about the three geometric representations and when to use each.
-
Point Cloud Generation
Generate and visualize 3D point clouds with transformers.
-
Geometric Benchmarks
Evaluate geometric models with specialized metrics.
-
Protein Modeling
Apply geometric models to protein structure prediction.
Troubleshooting¤
High Chamfer Distance¤
Problem: Chamfer loss is unexpectedly high.
Solutions:
- Check point cloud normalization (scale to [-1, 1])
- Verify number of points matches between pred and target
- Ensure points are in same coordinate system
Mesh Loss Imbalance¤
Problem: One loss component dominates others.
Solutions:
- Normalize each loss to [0, 1] range before weighting
- Use relative weights (sum to 1.0)
- Monitor individual losses during training
Voxel Loss Not Decreasing¤
Problem: Loss plateaus early in training.
Solutions:
- Switch from BCE to Focal for sparse grids
- Adjust focal gamma (try 2.0 → 3.0)
- Check for label imbalance (>95% empty → use Dice)
Out of Memory¤
Problem: Voxel models run out of GPU memory.
Solutions:
- Reduce voxel resolution (32³ → 16³)
- Reduce batch size
- Use gradient checkpointing
- Consider point cloud representation instead
Additional Resources¤
Citation¤
If you use these loss functions in your research, please cite:
@software{artifex2025,
title={Artifex: Modular Generative Modeling Library},
author={Artifex Contributors},
year={2025},
url={https://github.com/avitai/artifex}
}
References¤
- Chamfer Distance: Fan et al., "A Point Set Generation Network for 3D Object Reconstruction from a Single Image", CVPR 2017
- Earth Mover's Distance: Rubner et al., "The Earth Mover's Distance as a Metric for Image Retrieval", IJCV 2000
- Focal Loss: Lin et al., "Focal Loss for Dense Object Detection", ICCV 2017
- Dice Loss: Milletari et al., "V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation", 3DV 2016