Skip to content

Simple Point Cloud ExampleΒ€

Beginner ⚑ 10-15 seconds πŸ““ Dual Format

Learn how to generate and visualize 3D point clouds using Workshop's PointCloudModel with transformer-based architecture.

FilesΒ€

Quick StartΒ€

# Clone and setup
cd workshop
source activate.sh

# Run Python script
python examples/generative_models/geometric/simple_point_cloud_example.py

# Or use Jupyter notebook
jupyter notebook examples/generative_models/geometric/simple_point_cloud_example.ipynb

OverviewΒ€

This tutorial teaches you how to work with point cloudsβ€”the fundamental representation for 3D data in machine learning. Point clouds are unordered sets of 3D coordinates that represent objects, scenes, or molecular structures.

Learning ObjectivesΒ€

  • Understand point cloud representation and properties
  • Configure PointCloudModel with transformer architecture
  • Generate 3D point clouds from learned distributions
  • Visualize point clouds in 3D space
  • Control generation diversity with temperature

PrerequisitesΒ€

  • Basic understanding of 3D coordinates (x, y, z)
  • Familiarity with JAX and Flax NNX
  • Basic knowledge of attention mechanisms (helpful but not required)

What Are Point Clouds?Β€

Point clouds are collections of 3D points that represent the shape or structure of objects:

Point Cloud = {(x₁, y₁, z₁), (xβ‚‚, yβ‚‚, zβ‚‚), ..., (xβ‚™, yβ‚™, zβ‚™)}

Key PropertiesΒ€

  1. Unordered (Permutation-Invariant): The order of points doesn't matter
  2. {A, B, C} is the same as {C, A, B}
  3. This is why transformers work well (they're permutation-invariant)

  4. Sparse: Represents surfaces without filling volumes

  5. A sphere needs only surface points, not interior
  6. More efficient than voxels for many tasks

  7. Flexible: Can represent arbitrary shapes

  8. No fixed topology required
  9. Handles complex, irregular geometries

Common SourcesΒ€

Source Example Resolution
LiDAR Autonomous vehicles 10K-1M points
3D Scanners Industrial inspection 100K-10M points
Depth Cameras Robotics, AR/VR 10K-100K points
Molecular Protein structures 100-10K atoms
Photogrammetry 3D reconstruction 100K-10M points

Transformer Architecture for Point CloudsΒ€

Why Transformers?Β€

Traditional CNNs require regular grids. Point clouds are irregular, so we need architectures that:

  1. Handle variable-size inputs: Different objects have different numbers of points
  2. Are permutation-invariant: Point order shouldn't matter
  3. Model long-range relationships: Distant points may be related

Transformers satisfy all three requirements!

Architecture ComponentsΒ€

Input Points (N, 3)
      ↓
Point Embedding β†’ (N, 128)
      ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Transformer     β”‚ ← Self-Attention Layers (Γ—3)
β”‚  Layer 1        β”‚    Each point attends to all others
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Transformer     β”‚ ← Multi-Head (Γ—4)
β”‚  Layer 2        β”‚    Different attention patterns
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Transformer     β”‚ ← Layer Normalization
β”‚  Layer 3        β”‚    Stable training
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      ↓
Output Points (N, 3)

Key ParametersΒ€

  • num_points: 512 - Number of 3D points
  • embed_dim: 128 - Feature dimension for attention
  • num_layers: 3 - Transformer depth
  • num_heads: 4 - Multi-head attention heads

Code WalkthroughΒ€

Step 1: Setup and ImportsΒ€

import jax
import matplotlib.pyplot as plt
import numpy as np
from flax import nnx

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.geometric import PointCloudModel

We use:

  • JAX: Fast numerical computing with GPU support
  • Flax NNX: Modern neural network framework
  • Matplotlib: 3D visualization

Step 2: Configure the ModelΒ€

config = ModelConfiguration(
    name="point_cloud_generator",
    model_class="workshop.generative_models.models.geometric.PointCloudModel",
    input_dim=(512, 3),  # 512 points with 3D coordinates
    output_dim=(512, 3),  # Output same shape
    hidden_dims=[128, 128, 128],  # Hidden layer dimensions
    dropout_rate=0.1,  # Regularization
    parameters={
        "num_points": 512,    # Number of points
        "embed_dim": 128,     # Feature dimension
        "num_layers": 3,      # Transformer depth
        "num_heads": 4,       # Attention heads
    },
)

Configuration breakdown:

  • input_dim & output_dim: Point cloud shape (points Γ— dimensions)
  • hidden_dims: Internal processing dimensions
  • parameters: Model-specific settings

Step 3: Create the ModelΒ€

rngs = nnx.Rngs(params=jax.random.key(42))
model = PointCloudModel(config=config, rngs=rngs)

The model initializes its transformer layers with random weights.

Step 4: Generate Point CloudsΒ€

point_clouds = model.generate(
    rngs=rngs,
    n_samples=2,        # Generate 2 point clouds
    temperature=0.8,    # Control diversity
)

Temperature effects:

  • 0.5-0.7: Focused, consistent samples
  • 0.8-1.0: Balanced diversity ← Recommended
  • 1.0+: High diversity, may be noisy

Step 5: Visualize in 3DΒ€

def plot_point_cloud(points, filename=None):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection="3d")

    # Color by distance from origin
    norm = np.sqrt(np.sum(points**2, axis=1))
    norm = (norm - norm.min()) / (norm.max() - norm.min() + 1e-8)

    ax.scatter(points[:, 0], points[:, 1], points[:, 2],
               c=norm, cmap="viridis", s=20, alpha=0.7)
    plt.colorbar(scatter)
    return fig

The visualization:

  • Projects 3D points onto 2D screen
  • Colors points by distance from origin
  • Saves plots to examples_output/ directory

Expected OutputΒ€

Creating point cloud model...
Generating point clouds...
Visualizing point clouds...
Example completed! Point clouds saved as PNG files.

Generated files:

  • examples_output/point_cloud_1.png - First generated point cloud
  • examples_output/point_cloud_2.png - Second generated point cloud

Each visualization shows a 3D scatter plot with:

  • X, Y, Z axes labeled
  • Color gradient indicating spatial distribution
  • 512 points representing the generated shape

Experiments to TryΒ€

1. Vary Number of PointsΒ€

# Sparse point cloud (faster, less detail)
"num_points": 256

# Dense point cloud (slower, more detail)
"num_points": 1024

Tradeoff: More points = better shape representation but slower generation

2. Adjust TemperatureΒ€

# Low temperature (focused samples)
temperature=0.6

# High temperature (diverse samples)
temperature=1.2

Try it: Generate 5 samples at different temperatures and compare diversity

3. Modify Architecture DepthΒ€

# Shallow network (faster, simpler patterns)
"num_layers": 2

# Deep network (slower, complex patterns)
"num_layers": 6

Note: Deeper networks may need more training data

4. Multi-Head AttentionΒ€

# Fewer heads (simpler attention)
"num_heads": 2

# More heads (richer attention patterns)
"num_heads": 8

Heads must divide embed_dim evenly: e.g., 128 Γ· 4 = 32 βœ“

Understanding Point Cloud ApplicationsΒ€

1. Autonomous VehiclesΒ€

LiDAR sensors generate point clouds for:

  • Obstacle detection
  • Lane tracking
  • 3D scene understanding

Typical setup: 64-128 laser beams β†’ 100K+ points per second

2. RoboticsΒ€

Point clouds enable:

  • Object grasping (find grip points)
  • Navigation (3D mapping)
  • Human-robot interaction (gesture recognition)

3. Molecular ModelingΒ€

Proteins as point clouds:

  • Each atom is a 3D point
  • Backbone atoms: N, C-alpha, C, O
  • Sidechains: variable number of atoms

See protein_point_cloud_example.py for details

4. 3D Content CreationΒ€

Generate 3D models for:

  • Video games (procedural generation)
  • Movies (digital assets)
  • Virtual reality (environments)

Next StepsΒ€

TroubleshootingΒ€

Points look random/unstructuredΒ€

Cause: Model is untrained or poorly trained

Solutions:

  1. Train on a dataset first (this example uses pretrained weights)
  2. Increase num_layers for more capacity
  3. Adjust temperature for different sampling behavior

"FigureCanvasAgg is non-interactive"Β€

Cause: matplotlib trying to show plots in non-GUI environment

Solution: This is just a warning, plots are still saved. To suppress:

import matplotlib
matplotlib.use('Agg')  # Add before importing pyplot

Out of memory errorsΒ€

Solutions:

# Reduce number of points
"num_points": 256  # Instead of 512

# Reduce batch size in generation
n_samples=1  # Instead of 2

# Use CPU instead of GPU
JAX_PLATFORMS=cpu python simple_point_cloud_example.py

Generated point clouds are identicalΒ€

Cause: Not providing fresh random keys

Solution: Split RNG keys for each generation:

for i in range(n_samples):
    key, subkey = jax.random.split(key)
    rngs_new = nnx.Rngs(params=subkey)
    point_cloud = model.generate(rngs=rngs_new, n_samples=1)

Additional ResourcesΒ€