Skip to content

Geometric Models Demo¤

Beginner ⚡ 5-10 seconds 📓 Dual Format

A quick reference guide demonstrating how to configure and instantiate three types of geometric models in Workshop: point clouds, meshes, and voxels.

Files¤

Quick Start¤

# Clone and setup
cd workshop
source activate.sh

# Run Python script
python examples/generative_models/geometric/geometric_models_demo.py

# Or use Jupyter notebook
jupyter notebook examples/generative_models/geometric/geometric_models_demo.ipynb

Overview¤

This example provides a concise demonstration of Workshop's three main geometric representations:

Learning Objectives¤

  • Understand point cloud, mesh, and voxel representations
  • Configure models using ModelConfiguration
  • Use the unified factory pattern with create_model()
  • Understand model-specific parameters for each geometry type

Prerequisites¤

  • Basic understanding of 3D geometric representations
  • Familiarity with JAX and Flax NNX
  • Workshop installed and activated

Geometric Representations Explained¤

1. Point Clouds¤

Unordered sets of 3D points - flexible, permutation-invariant

point_cloud_config = ModelConfiguration(
    name="demo_point_cloud",
    model_class="workshop.generative_models.models.geometric.PointCloudModel",
    input_dim=(512, 3),
    output_dim=(512, 3),
    parameters={
        "num_points": 512,
        "embed_dim": 128,
        "num_layers": 4,
        "loss_type": "chamfer",
    },
)
point_cloud_model = create_model(point_cloud_config, rngs=rngs)

Use cases:

  • LiDAR data processing
  • 3D object detection
  • Molecular structure modeling (proteins, molecules)

Key parameters:

  • num_points: Number of points in the cloud
  • embed_dim: Embedding dimension for features
  • num_layers: Network depth
  • loss_type: Distance metric (chamfer, earth mover's distance)

2. Mesh Models¤

Connected vertex structures with topology - surface-oriented

mesh_config = ModelConfiguration(
    name="demo_mesh",
    model_class="workshop.generative_models.models.geometric.MeshModel",
    input_dim=(512, 3),
    output_dim=(512, 3),
    hidden_dims=[256, 128, 64],
    parameters={
        "num_vertices": 512,
        "template_type": "sphere",
        "vertex_loss_weight": 1.0,
        "normal_loss_weight": 0.2,
        "edge_loss_weight": 0.1,
    },
)
mesh_model = create_model(mesh_config, rngs=rngs)

Use cases:

  • 3D graphics and rendering
  • Shape analysis and generation
  • Surface reconstruction

Key parameters:

  • num_vertices: Number of mesh vertices
  • template_type: Initial mesh template (sphere, cube, etc.)
  • Loss weights balance geometric properties:
  • vertex_loss_weight: Vertex position accuracy
  • normal_loss_weight: Surface normal consistency
  • edge_loss_weight: Edge length regularization

3. Voxel Models¤

Regular 3D grids - easy to process with CNNs

voxel_config = ModelConfiguration(
    name="demo_voxel",
    model_class="workshop.generative_models.models.geometric.VoxelModel",
    input_dim=(16, 16, 16, 1),
    output_dim=(16, 16, 16, 1),
    parameters={
        "resolution": 16,
        "channels": [128, 64, 32, 16, 8, 1],
        "use_conditioning": True,
        "conditioning_dim": 10,
        "loss_type": "focal",
        "focal_gamma": 2.0,
    },
)
voxel_model = create_model(voxel_config, rngs=rngs)

Use cases:

  • Medical imaging (CT, MRI scans)
  • 3D scene understanding
  • Volumetric shape generation

Key parameters:

  • resolution: Grid resolution (16³ = 4,096 voxels)
  • channels: Multi-scale architecture layers
  • use_conditioning: Enable class-conditional generation
  • loss_type: "focal" handles sparse voxel data
  • focal_gamma: Focus on hard-to-classify voxels (2.0 is standard)

Expected Output¤

Creating point cloud model...
Created model: PointCloudModel
Sample shape: (1, 512, 3)

Creating mesh model...
Created model: MeshModel

Creating voxel model with conditioning...
Created model: VoxelModel

Demo completed successfully!

Code Walkthrough¤

Step 1: Setup Random Number Generation¤

rng = jax.random.PRNGKey(42)
rngs = nnx.Rngs(params=rng)

Initialize RNG using Flax NNX patterns for reproducible model creation.

Step 2: Create Models via Factory¤

All three models use the unified create_model() factory pattern:

model = create_model(config, rngs=rngs)

This abstracts away model-specific initialization details and provides a consistent API.

Step 3: Generate Samples (Optional)¤

sample = point_cloud_model.sample(1, rngs=rngs)
# shape: (1, 512, 3) - batch_size=1, num_points=512, xyz=3

Choosing the Right Representation¤

Representation When to Use Strengths Limitations
Point Cloud Raw sensor data, irregular shapes, molecular structures Flexible, no topology required, permutation-invariant No explicit surface, harder to render
Mesh Graphics, animation, smooth surfaces Explicit topology, efficient rendering, smooth surfaces Requires consistent topology, harder to optimize
Voxel Medical imaging, volumetric data, regular grids Easy CNN processing, regular structure Memory-intensive, discretization artifacts

Experiments to Try¤

  1. Change loss types - Try different distance metrics for point clouds:
"loss_type": "emd"  # Earth Mover's Distance (slower but more accurate)
  1. Adjust mesh templates - Experiment with different starting shapes:
"template_type": "cube"  # or "icosahedron", "octahedron"
  1. Scale voxel resolution - Balance memory vs. detail:
"resolution": 32  # 32³ = 32,768 voxels (8x more memory)
  1. Conditional generation - Create class-specific voxel shapes:
labels = jnp.array([0, 1, 5, 9])  # Different classes
voxel_model.generate(labels, rngs=rngs)

Next Steps¤

Troubleshooting¤

"Backend 'cuda' is not in the list of known backends"¤

Solution: JAX is looking for CUDA but can't find it. Run with CPU:

JAX_PLATFORMS=cpu python examples/generative_models/geometric/geometric_models_demo.py

"ModuleNotFoundError: No module named 'workshop'"¤

Solution: Activate the environment first:

source activate.sh
python examples/generative_models/geometric/geometric_models_demo.py

Memory errors with high resolution voxels¤

Solution: Reduce voxel resolution or use gradient checkpointing:

"resolution": 8,  # Lower resolution (8³ = 512 voxels)

Additional Resources¤