Simple Point Cloud ExampleΒ€
Learn how to generate and visualize 3D point clouds using Workshop's PointCloudModel with transformer-based architecture.
FilesΒ€
- Python Script:
simple_point_cloud_example.py - Jupyter Notebook:
simple_point_cloud_example.ipynb
Quick StartΒ€
# Clone and setup
cd workshop
source activate.sh
# Run Python script
python examples/generative_models/geometric/simple_point_cloud_example.py
# Or use Jupyter notebook
jupyter notebook examples/generative_models/geometric/simple_point_cloud_example.ipynb
OverviewΒ€
This tutorial teaches you how to work with point cloudsβthe fundamental representation for 3D data in machine learning. Point clouds are unordered sets of 3D coordinates that represent objects, scenes, or molecular structures.
Learning ObjectivesΒ€
- Understand point cloud representation and properties
- Configure PointCloudModel with transformer architecture
- Generate 3D point clouds from learned distributions
- Visualize point clouds in 3D space
- Control generation diversity with temperature
PrerequisitesΒ€
- Basic understanding of 3D coordinates (x, y, z)
- Familiarity with JAX and Flax NNX
- Basic knowledge of attention mechanisms (helpful but not required)
What Are Point Clouds?Β€
Point clouds are collections of 3D points that represent the shape or structure of objects:
Key PropertiesΒ€
- Unordered (Permutation-Invariant): The order of points doesn't matter
{A, B, C}is the same as{C, A, B}-
This is why transformers work well (they're permutation-invariant)
-
Sparse: Represents surfaces without filling volumes
- A sphere needs only surface points, not interior
-
More efficient than voxels for many tasks
-
Flexible: Can represent arbitrary shapes
- No fixed topology required
- Handles complex, irregular geometries
Common SourcesΒ€
| Source | Example | Resolution |
|---|---|---|
| LiDAR | Autonomous vehicles | 10K-1M points |
| 3D Scanners | Industrial inspection | 100K-10M points |
| Depth Cameras | Robotics, AR/VR | 10K-100K points |
| Molecular | Protein structures | 100-10K atoms |
| Photogrammetry | 3D reconstruction | 100K-10M points |
Transformer Architecture for Point CloudsΒ€
Why Transformers?Β€
Traditional CNNs require regular grids. Point clouds are irregular, so we need architectures that:
- Handle variable-size inputs: Different objects have different numbers of points
- Are permutation-invariant: Point order shouldn't matter
- Model long-range relationships: Distant points may be related
Transformers satisfy all three requirements!
Architecture ComponentsΒ€
Input Points (N, 3)
β
Point Embedding β (N, 128)
β
βββββββββββββββββββ
β Transformer β β Self-Attention Layers (Γ3)
β Layer 1 β Each point attends to all others
βββββββββββββββββββ€
β Transformer β β Multi-Head (Γ4)
β Layer 2 β Different attention patterns
βββββββββββββββββββ€
β Transformer β β Layer Normalization
β Layer 3 β Stable training
βββββββββββββββββββ
β
Output Points (N, 3)
Key ParametersΒ€
num_points: 512 - Number of 3D pointsembed_dim: 128 - Feature dimension for attentionnum_layers: 3 - Transformer depthnum_heads: 4 - Multi-head attention heads
Code WalkthroughΒ€
Step 1: Setup and ImportsΒ€
import jax
import matplotlib.pyplot as plt
import numpy as np
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.geometric import PointCloudModel
We use:
- JAX: Fast numerical computing with GPU support
- Flax NNX: Modern neural network framework
- Matplotlib: 3D visualization
Step 2: Configure the ModelΒ€
config = ModelConfiguration(
name="point_cloud_generator",
model_class="workshop.generative_models.models.geometric.PointCloudModel",
input_dim=(512, 3), # 512 points with 3D coordinates
output_dim=(512, 3), # Output same shape
hidden_dims=[128, 128, 128], # Hidden layer dimensions
dropout_rate=0.1, # Regularization
parameters={
"num_points": 512, # Number of points
"embed_dim": 128, # Feature dimension
"num_layers": 3, # Transformer depth
"num_heads": 4, # Attention heads
},
)
Configuration breakdown:
input_dim&output_dim: Point cloud shape (points Γ dimensions)hidden_dims: Internal processing dimensionsparameters: Model-specific settings
Step 3: Create the ModelΒ€
The model initializes its transformer layers with random weights.
Step 4: Generate Point CloudsΒ€
point_clouds = model.generate(
rngs=rngs,
n_samples=2, # Generate 2 point clouds
temperature=0.8, # Control diversity
)
Temperature effects:
- 0.5-0.7: Focused, consistent samples
- 0.8-1.0: Balanced diversity β Recommended
- 1.0+: High diversity, may be noisy
Step 5: Visualize in 3DΒ€
def plot_point_cloud(points, filename=None):
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
# Color by distance from origin
norm = np.sqrt(np.sum(points**2, axis=1))
norm = (norm - norm.min()) / (norm.max() - norm.min() + 1e-8)
ax.scatter(points[:, 0], points[:, 1], points[:, 2],
c=norm, cmap="viridis", s=20, alpha=0.7)
plt.colorbar(scatter)
return fig
The visualization:
- Projects 3D points onto 2D screen
- Colors points by distance from origin
- Saves plots to
examples_output/directory
Expected OutputΒ€
Creating point cloud model...
Generating point clouds...
Visualizing point clouds...
Example completed! Point clouds saved as PNG files.
Generated files:
examples_output/point_cloud_1.png- First generated point cloudexamples_output/point_cloud_2.png- Second generated point cloud
Each visualization shows a 3D scatter plot with:
- X, Y, Z axes labeled
- Color gradient indicating spatial distribution
- 512 points representing the generated shape
Experiments to TryΒ€
1. Vary Number of PointsΒ€
# Sparse point cloud (faster, less detail)
"num_points": 256
# Dense point cloud (slower, more detail)
"num_points": 1024
Tradeoff: More points = better shape representation but slower generation
2. Adjust TemperatureΒ€
# Low temperature (focused samples)
temperature=0.6
# High temperature (diverse samples)
temperature=1.2
Try it: Generate 5 samples at different temperatures and compare diversity
3. Modify Architecture DepthΒ€
# Shallow network (faster, simpler patterns)
"num_layers": 2
# Deep network (slower, complex patterns)
"num_layers": 6
Note: Deeper networks may need more training data
4. Multi-Head AttentionΒ€
# Fewer heads (simpler attention)
"num_heads": 2
# More heads (richer attention patterns)
"num_heads": 8
Heads must divide embed_dim evenly: e.g., 128 Γ· 4 = 32 β
Understanding Point Cloud ApplicationsΒ€
1. Autonomous VehiclesΒ€
LiDAR sensors generate point clouds for:
- Obstacle detection
- Lane tracking
- 3D scene understanding
Typical setup: 64-128 laser beams β 100K+ points per second
2. RoboticsΒ€
Point clouds enable:
- Object grasping (find grip points)
- Navigation (3D mapping)
- Human-robot interaction (gesture recognition)
3. Molecular ModelingΒ€
Proteins as point clouds:
- Each atom is a 3D point
- Backbone atoms: N, C-alpha, C, O
- Sidechains: variable number of atoms
See protein_point_cloud_example.py for details
4. 3D Content CreationΒ€
Generate 3D models for:
- Video games (procedural generation)
- Movies (digital assets)
- Virtual reality (environments)
Next StepsΒ€
-
Geometric Models Overview
Quick reference for point clouds, meshes, and voxels
-
Geometric Losses
Learn specialized loss functions for point clouds
-
Protein Point Clouds
Apply point clouds to protein structure modeling
-
Geometric Benchmarks
Evaluate on standard geometric datasets
TroubleshootingΒ€
Points look random/unstructuredΒ€
Cause: Model is untrained or poorly trained
Solutions:
- Train on a dataset first (this example uses pretrained weights)
- Increase
num_layersfor more capacity - Adjust
temperaturefor different sampling behavior
"FigureCanvasAgg is non-interactive"Β€
Cause: matplotlib trying to show plots in non-GUI environment
Solution: This is just a warning, plots are still saved. To suppress:
Out of memory errorsΒ€
Solutions:
# Reduce number of points
"num_points": 256 # Instead of 512
# Reduce batch size in generation
n_samples=1 # Instead of 2
# Use CPU instead of GPU
JAX_PLATFORMS=cpu python simple_point_cloud_example.py
Generated point clouds are identicalΒ€
Cause: Not providing fresh random keys
Solution: Split RNG keys for each generation:
for i in range(n_samples):
key, subkey = jax.random.split(key)
rngs_new = nnx.Rngs(params=subkey)
point_cloud = model.generate(rngs=rngs_new, n_samples=1)
Additional ResourcesΒ€
- PointNet: Deep Learning on Point Sets - Original point cloud deep learning paper
- Attention is All You Need - Transformer architecture
- Workshop Geometric Models API - Full API documentation
- JAX Point Cloud Processing - Community discussions