Flash Attention¤
Module: generative_models.core.layers.flash_attention
Source: generative_models/core/layers/flash_attention.py
Overview¤
Flash Attention implementation for Flax NNX with kvax optimizations.
This module provides a Flash Attention implementation designed to serve as a drop-in replacement for Flax NNX's MultiHeadAttention with performance improvements and additional features.
Based on:
- Flash Attention paper: https://arxiv.org/abs/2205.14135
- Flash Attention 2: https://arxiv.org/abs/2307.08691
- kvax implementation: https://github.com/nebius/kvax
Classes¤
AttentionBackend¤
AttentionMask¤
FlashAttentionConfig¤
FlashMultiHeadAttention¤
Functions¤
call¤
init¤
init¤
create_attention_mask¤
flash_attention_forward_kernel¤
flash_attention_triton¤
init_cache¤
make_causal_mask¤
make_segment_mask¤
Module Statistics¤
- Classes: 4
- Functions: 9
- Imports: 20