Protein Model Extensions Example¤
Demonstrate how to use protein-specific extensions with Workshop's geometric model framework, combining domain knowledge with general-purpose geometric models.
Files¤
- Python Script:
protein_model_extension.py - Jupyter Notebook:
protein_model_extension.ipynb
Quick Start¤
# Run the Python script
source activate.sh
python examples/generative_models/protein/protein_model_extension.py
# Or use Jupyter notebook
jupyter notebook examples/generative_models/protein/protein_model_extension.ipynb
Overview¤
This example shows how to enhance a point cloud model with protein-specific extensions. Extensions add domain knowledge about protein structure, chemistry, and geometry to improve model predictions and learning.
Learning Objectives¤
- Understand the extension system in Workshop
- Learn how to create and configure protein-specific extensions
- See how extensions enhance model outputs with domain knowledge
- Understand extension contribution to loss calculations
- Learn to combine multiple extensions for complex domain knowledge
Prerequisites¤
- Basic understanding of protein structure (residues, backbone atoms)
- Familiarity with point cloud models
- Knowledge of Flax NNX modules
- Understanding of JAX random number generation
Background: Protein Structure and Extensions¤
Protein Basics¤
Proteins are polymers composed of amino acid residues connected by peptide bonds. The backbone consists of repeating units with four key atoms per residue:
- N (Nitrogen): Backbone nitrogen
- Cα (Alpha Carbon): Central carbon with side chain
- C (Carbonyl Carbon): Carbon of carbonyl group
- O (Oxygen): Carbonyl oxygen
Geometric Constraints¤
Protein structures follow specific geometric constraints due to chemical bonding:
Bond Lengths: Relatively fixed distances between bonded atoms
Bond Angles: Preferred angles between consecutive bonds
Torsion Angles: Backbone flexibility through φ (phi) and ψ (psi) angles
These constraints make protein structure prediction a constrained optimization problem, which extensions help enforce.
Extension System¤
Workshop's extension system allows adding domain-specific knowledge to base models. For proteins, we have:
- Protein Mixin Extension: Integrates amino acid type information
- Protein Constraints Extension: Enforces backbone geometry
- Bond Length Extension: Monitors and penalizes bond violations
- Bond Angle Extension: Monitors and penalizes angle violations
Code Walkthrough¤
1. Import Required Modules¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.extensions.base.extensions import ExtensionConfig
from workshop.generative_models.extensions.protein import (
BondAngleExtension,
BondLengthExtension,
ProteinMixinExtension,
)
from workshop.generative_models.extensions.protein.constraints import (
ProteinBackboneConstraint,
)
from workshop.generative_models.models.geometric.point_cloud import (
PointCloudModel,
)
We import:
- JAX for array operations and random number generation
- Flax NNX for neural network modules
- Workshop configuration and extension classes
- Protein-specific extensions
- Point cloud model
2. Configure Protein Structure¤
num_residues = 10
atoms_per_residue = 4 # N, CA, C, O
num_points = num_residues * atoms_per_residue
embedding_dim = 64
model_config = ModelConfiguration(
name="protein_point_cloud",
model_class="PointCloudModel",
input_dim=num_points,
output_dim=num_points,
hidden_dims=[embedding_dim] * 2,
parameters={
"num_points": num_points,
"embed_dim": embedding_dim,
"num_layers": 2,
"num_heads": 4,
},
dropout_rate=0.1,
)
Key configuration points:
- num_points: Must be specified to override default (1024)
- hidden_dims: List of hidden dimensions for each layer
- parameters: Additional model-specific parameters
- dropout_rate: Regularization during training
3. Create Protein-Specific Extensions¤
3.1 Protein Mixin Extension¤
mixin_config = ExtensionConfig(
name="protein_mixin",
weight=1.0,
enabled=True,
extensions={
"embedding_dim": embedding_dim,
"num_aa_types": 20,
},
)
extensions_dict["protein_mixin"] = ProteinMixinExtension(
config=mixin_config,
rngs=nnx.Rngs(params=mixin_key),
)
The mixin extension learns embeddings for all 20 standard amino acid types, allowing the model to incorporate sequence information.
3.2 Protein Constraints Extension¤
constraint_config = ExtensionConfig(
name="protein_constraints",
weight=1.0,
enabled=True,
extensions={
"num_residues": num_residues,
"backbone_indices": [0, 1, 2, 3],
},
)
extensions_dict["protein_constraints"] = ProteinBackboneConstraint(
config=constraint_config,
rngs=nnx.Rngs(params=constraint_key),
)
This extension enforces geometric constraints on backbone atoms during generation.
3.3 Bond Length Extension¤
bond_length_config = ExtensionConfig(
name="bond_length",
weight=1.0,
enabled=True,
extensions={
"num_residues": num_residues,
"backbone_indices": [0, 1, 2, 3],
},
)
extensions_dict["bond_length"] = BondLengthExtension(
config=bond_length_config,
rngs=nnx.Rngs(params=constraint_key),
)
Monitors bond lengths and calculates violations for use in the loss function.
3.4 Bond Angle Extension¤
bond_angle_config = ExtensionConfig(
name="bond_angle",
weight=0.5, # Lower weight - angles are more flexible
enabled=True,
extensions={
"num_residues": num_residues,
"backbone_indices": [0, 1, 2, 3],
},
)
extensions_dict["bond_angle"] = BondAngleExtension(
config=bond_angle_config,
rngs=nnx.Rngs(params=backbone_key),
)
Note the lower weight (0.5) compared to bond lengths, reflecting the fact that bond angles are more flexible than bond lengths in real proteins.
4. Wrap Extensions in nnx.Dict¤
Flax NNX 0.12.0+ requires extensions to be wrapped in nnx.Dict for proper parameter tracking and serialization.
5. Create Model with Extensions¤
The point cloud model now has access to all four extensions and will use them during forward passes and loss calculation.
6. Create Test Batch¤
batch = {
"aatype": aatype, # Shape: (batch_size, num_residues)
"positions": coords, # Shape: (batch_size, num_points, 3)
"mask": mask, # Shape: (batch_size, num_points)
}
The batch contains:
- aatype: Amino acid types (integers 0-19 for 20 amino acids)
- positions: 3D coordinates of all atoms
- mask: Binary mask indicating valid atoms
7. Forward Pass with Extensions¤
During the forward pass:
- Model processes input through transformer layers
- Each enabled extension runs on the intermediate representations
- Extension outputs are collected and returned alongside main output
Extension outputs might include:
- Amino acid embeddings (from mixin)
- Constraint violation metrics
- Bond statistics
- Angle statistics
8. Calculate Loss with Extensions¤
The loss function combines:
Where:
- \(\mathcal{L}_{\text{MSE}}\): Main reconstruction loss
- \(w_i\): Extension weight
- \(\mathcal{L}_{\text{ext}_i}\): Extension-specific loss
This multi-objective loss encourages the model to:
- Reconstruct input positions accurately
- Respect bond length constraints
- Maintain realistic bond angles
- Utilize amino acid type information
Expected Output¤
Created extensions: protein_mixin, protein_constraints, bond_length, bond_angle
Created model: PointCloudModel
Model outputs:
- Main output shape: (2, 40, 3)
- Extension outputs:
- protein_mixin
- protein_constraints
- bond_length
- bond_angle
Loss calculation:
- Available loss keys: ['total_loss', 'mse_loss', 'protein_mixin', 'protein_constraints', 'bond_length', 'bond_angle']
- Total loss: 89.77
- total_loss: 89.77
- mse_loss: 87.01
Protein model extension demo completed successfully!
Key Concepts¤
Extension Configuration¤
Extensions use ExtensionConfig with:
- name: Identifier for the extension
- weight: Contribution to total loss (0-1 or higher)
- enabled: Whether extension is active
- extensions: Extension-specific parameters dict
Extension Weights¤
Weights control the relative importance of different constraints:
- Bond lengths: Weight 1.0 (strict constraint)
- Bond angles: Weight 0.5 (more flexible)
- Constraints: Weight 1.0 (enforce geometry)
- Mixin: Weight 1.0 (sequence information)
Adjust weights based on your application's priorities.
Flax NNX 0.12.0+ Compatibility¤
Always wrap extension dictionaries in nnx.Dict:
# CORRECT
extensions = nnx.Dict(extensions_dict)
# WRONG (will fail in NNX 0.12.0+)
extensions = extensions_dict
Random Number Generation¤
Each extension receives its own RNG for parameter initialization:
key, mixin_key, constraint_key, backbone_key = jax.random.split(key, 4)
extensions_dict["protein_mixin"] = ProteinMixinExtension(
config=mixin_config,
rngs=nnx.Rngs(params=mixin_key), # Separate key
)
This ensures independent randomness across extensions.
Experiments to Try¤
- Adjust Extension Weights
- Disable Specific Extensions
- Increase Protein Size
- Add Custom Extensions
Create your own extension for other properties (e.g., secondary structure, hydrophobicity).
- Visualize Bond Statistics
Extract and plot bond length/angle distributions from extension outputs.
- Compare With/Without Extensions
Train two models (with and without extensions) and compare structure quality.
Next Steps¤
Explore related examples to deepen your understanding:
-
Protein Extensions Deep Dive
Learn more about individual protein extensions and their implementation.
-
Protein Point Cloud Model
Explore the ProteinPointCloudModel that combines point clouds with protein constraints.
-
Protein Diffusion
Use diffusion models for protein structure generation with extensions.
-
Protein Benchmarks
Evaluate protein models with domain-specific metrics.
Troubleshooting¤
Extension Not Contributing to Loss¤
Problem: Extension appears in outputs but not in loss.
Solution: Check that:
- Extension weight is non-zero
- Extension is enabled (
enabled=True) - Extension implements
compute_loss()method
nnx.Dict Error¤
Problem: TypeError: extensions must be nnx.Dict
Solution: Wrap your extensions dictionary:
Bond Violations Too High¤
Problem: Bond length/angle violations are unreasonably large.
Solution:
- Check input coordinates are in correct units (Angstroms)
- Verify backbone indices match your atom ordering
- Increase extension weights to penalize violations more
Out of Memory¤
Problem: GPU runs out of memory with extensions.
Solution:
- Reduce batch size
- Reduce number of residues
- Reduce embedding dimension
- Disable less critical extensions
Additional Resources¤
- Extension System Documentation
- Point Cloud Models Guide
- Protein Modeling Tutorial
- Flax NNX Documentation
- JAX Documentation
Citation¤
If you use this example in your research, please cite: