Protein Diffusion Example¤
Level: Advanced Runtime: ~5 minutes Format: Dual (.py script | .ipynb notebook)
Comprehensive protein diffusion modeling with two approaches: high-level API with extensions and direct model creation, including quality assessment and visualization.
Files¤
- Python Script:
examples/generative_models/protein/protein_diffusion_example.py - Jupyter Notebook:
examples/generative_models/protein/protein_diffusion_example.ipynb
Quick Start¤
# Run the Python script
python examples/generative_models/protein/protein_diffusion_example.py
# Or open the Jupyter notebook
jupyter notebook examples/generative_models/protein/protein_diffusion_example.ipynb
Overview¤
This comprehensive example demonstrates how to build and use protein diffusion models for generating 3D protein structures. You'll learn two distinct approaches to protein modeling, understand protein-specific geometric constraints, and explore quality assessment techniques.
Learning Objectives¤
After completing this example, you will understand:
- How to create protein diffusion models with Workshop's high-level API
- Direct model creation and manipulation for protein structures
- Protein-specific loss functions and geometric constraints
- Quality assessment metrics for generated proteins
- Visualization techniques for 3D protein structures
Prerequisites¤
- Understanding of diffusion models and denoising processes
- Familiarity with protein structure representations
- Knowledge of geometric constraints in biomolecules
- Experience with JAX and Flax NNX
Theory and Key Concepts¤
Protein Structure Representation¤
Proteins are complex biomolecules composed of amino acid residues. Each residue contains multiple atoms with specific 3D coordinates:
Backbone Atoms: The main chain of every protein contains four atoms per residue:
- N (Nitrogen): Backbone nitrogen
- CA (Alpha Carbon): Central carbon atom
- C (Carbonyl Carbon): Carbonyl carbon
- O (Oxygen): Carbonyl oxygen
Representation Approaches:
- Point Cloud: Unordered set of 3D points representing atom positions
- Advantages: Simple, flexible, good for local geometry
-
Use case: Backbone modeling, local structure refinement
-
Graph: Nodes (residues/atoms) connected by edges (bonds)
- Advantages: Captures connectivity, enforces topology
- Use case: Full protein modeling, contact prediction
Geometric Constraints¤
Valid protein structures must satisfy strict geometric constraints:
Bond Lengths: Distance between bonded atoms must fall within specific ranges:
- C-C bonds: ~1.5 Å
- C-N bonds: ~1.3 Å
- C=O bonds: ~1.2 Å
Bond Angles: Angles between consecutive bonds follow specific distributions:
- Tetrahedral angles: ~109.5°
- Planar peptide bonds: ~120°
Dihedral Angles: Rotation around bonds defines protein conformation:
- Phi (φ): Rotation around N-CA bond
- Psi (ψ): Rotation around CA-C bond
- Ramachandran plot: Shows allowed (φ, ψ) combinations
Protein-Specific Loss Functions¤
RMSD (Root Mean Square Deviation): Measures structural similarity between predicted and target structures:
where \(N\) is the number of atoms, \(x_i\) is the predicted position, and \(y_i\) is the target position.
Backbone Loss: Enforces correct backbone geometry:
where \(d(x_i, x_{i+1})\) is the distance between consecutive residues and \(d_{\text{ideal}}\) is the ideal distance.
Composite Loss: Combines multiple geometric constraints:
Code Walkthrough¤
Part 1: High-Level API with Extensions¤
The example demonstrates using Workshop's extension system for protein modeling:
# Create model with extensions
extension_config = {
"name": "protein_diffusion_extensions",
"description": "Extensions for protein diffusion model",
"enabled": True,
"use_backbone_constraints": True,
"use_protein_mixin": True,
}
extensions = create_protein_extensions(extension_config, rngs=rngs)
model = nnx.Module()
model.extensions = extensions
The extension system provides:
- Backbone Constraints: Automatic enforcement of backbone geometry
- Protein Mixin: Domain-specific operations for proteins
- Quality Assessment: Built-in metrics for structure validation
Part 2: Direct Model Creation¤
For full control, create models directly:
from workshop.generative_models.core.configuration import ModelConfiguration
config = ModelConfiguration(
name="protein_point_cloud_model",
model_class="workshop.generative_models.models.protein.point_cloud.ProteinPointCloudModel",
input_dim=(num_residues, 4, 3), # 4 backbone atoms, 3D coordinates
hidden_dims=[128] * 4,
parameters={
"num_residues": 64,
"num_atoms_per_residue": 4,
"backbone_indices": [0, 1, 2, 3], # N, CA, C, O
"embed_dim": 128,
"use_constraints": True,
"constraint_config": {
"backbone_weight": 1.0,
"bond_weight": 1.0,
"angle_weight": 0.5,
},
},
)
model = ProteinPointCloudModel(config, rngs=rngs)
Dataset Preparation¤
Load synthetic or real protein datasets:
# Create synthetic dataset for demonstration
dataset = create_synthetic_protein_dataset(
num_proteins=50,
min_seq_length=32,
max_seq_length=64,
random_seed=42,
)
# Prepare batch
batch = prepare_batch(dataset, batch_size=8, random_seed=42)
# Add noise for diffusion training
noisy_batch = add_noise_to_batch(batch, noise_level=0.1, random_seed=42)
Loss Function Configuration¤
Combine multiple protein-specific losses:
from workshop.generative_models.modalities.protein.losses import (
CompositeLoss,
create_backbone_loss,
create_rmsd_loss,
)
loss_fn = CompositeLoss({
"rmsd": (create_rmsd_loss(), 1.0), # Weight: 1.0
"backbone": (create_backbone_loss(), 0.5), # Weight: 0.5
})
# Calculate losses
outputs = model(noisy_batch)
losses = loss_fn(batch, outputs)
Visualization and Quality Assessment¤
Visualize generated structures and assess quality:
from workshop.visualization.protein_viz import ProteinVisualizer
# Extract positions
target_pos = batch["atom_positions"][0]
pred_pos = outputs["positions"][0]
mask = batch["atom_mask"][0]
# Calculate dihedral angles
target_phi, target_psi = ProteinVisualizer.calculate_dihedral_angles(target_pos, mask)
pred_phi, pred_psi = ProteinVisualizer.calculate_dihedral_angles(pred_pos, mask)
# Plot Ramachandran plots
ProteinVisualizer.plot_ramachandran(target_phi, target_psi, title="Target")
ProteinVisualizer.plot_ramachandran(pred_phi, pred_psi, title="Predicted")
# 3D visualization (requires py3Dmol)
viewer = ProteinVisualizer.visualize_structure(
pred_pos,
mask,
show_sidechains=False,
color_by="chain"
)
viewer.show()
Expected Output¤
The example runs both approaches and displays results:
=== Protein Diffusion Examples ===
This example demonstrates two approaches to protein diffusion:
1. High-level API with extension components
2. Direct model creation and manipulation
=== Running Extensions Example ===
Model structure:
- Type: Module
- Extensions: ['bond_length', 'bond_angle', 'protein_mixin']
Generated 2 protein samples
- Sample shape: (2, 64, 4, 3)
- Atom mask shape: (2, 64, 4)
Quality metrics:
- rmsd: 1.2345
- bond_violations: 0.0234
- angle_violations: 0.0156
=== Running Direct Model Example ===
Creating model...
Loading dataset...
Preparing batch...
Adding noise to batch...
Creating loss function...
Running model...
Calculating losses...
Losses:
rmsd: 0.1234
backbone: 0.0567
total: 0.1801
Displaying results...
The example also generates:
- 2D plots of protein structures
- Ramachandran plots showing dihedral angle distributions
- 3D interactive visualizations (if py3Dmol is installed)
Experiments to Try¤
- Compare Model Types: Test point cloud vs graph representations
point_cloud_model = create_protein_diffusion_model(model_type="point_cloud")
graph_model = create_protein_diffusion_model(model_type="graph")
- Adjust Constraint Weights: Balance different geometric constraints
constraint_config = {
"backbone_weight": 2.0, # Emphasize backbone connectivity
"bond_weight": 1.5, # Strong bond length enforcement
"angle_weight": 1.0, # Moderate angle constraints
"dihedral_weight": 0.5, # Soft dihedral constraints
}
- Larger Proteins: Scale to longer sequences
model = create_protein_diffusion_model(
num_residues=128, # Double the default size
hidden_dim=256, # Increase capacity
)
- Custom Loss Functions: Create domain-specific losses
def create_contact_loss():
"""Enforce protein contact map constraints."""
def loss_fn(batch, outputs):
# Calculate contact map loss
return contact_loss
return loss_fn
loss_fn = CompositeLoss({
"rmsd": (create_rmsd_loss(), 1.0),
"backbone": (create_backbone_loss(), 0.5),
"contact": (create_contact_loss(), 0.3),
})
- Real Datasets: Load actual protein structures
Troubleshooting¤
Size Mismatch Warnings¤
If you see "Target size doesn't match prediction size":
- Check that
num_residuesmatches between model and data - Ensure batch collation handles variable-length sequences
- Use masking to handle different protein lengths
Geometric Constraint Violations¤
If structures have high constraint violations:
- Increase constraint weights in
constraint_config - Add more training epochs for constraint satisfaction
- Use smaller noise levels during training
Visualization Issues¤
If 3D visualization fails:
- Install py3Dmol:
pip install py3Dmol - For Jupyter notebooks, ensure proper widget support
- Fall back to 2D plots if 3D is unavailable
Memory Issues¤
For large proteins:
- Reduce batch size
- Use gradient checkpointing
- Process proteins in chunks
Next Steps¤
-
:material-protein: Advanced Protein Models
Explore AlphaFold-style architectures and multi-scale modeling
-
Point Cloud Models
Learn specialized techniques for point cloud protein representations
-
Diffusion Training
Master advanced diffusion techniques for proteins
-
Protein Benchmarks
Evaluate protein models with standard benchmarks
Additional Resources¤
- Protein Data Bank (PDB) - Repository of 3D protein structures
- AlphaFold Documentation - State-of-the-art protein structure prediction
- Diffusion Models for Proteins - Research paper on protein diffusion
- Workshop Protein Modeling Guide - Comprehensive guide
- Ramachandran Plot - Understanding dihedral angles