Skip to content

Flash Attention¤

Module: generative_models.core.layers.flash_attention

Source: generative_models/core/layers/flash_attention.py

Overview¤

Flash Attention implementation for Flax NNX with kvax optimizations.

This module provides a Flash Attention implementation designed to serve as a drop-in replacement for Flax NNX's MultiHeadAttention with performance improvements and additional features.

Based on:

Classes¤

AttentionBackend¤

class AttentionBackend

AttentionMask¤

class AttentionMask

FlashAttentionConfig¤

class FlashAttentionConfig

FlashMultiHeadAttention¤

class FlashMultiHeadAttention

Functions¤

call¤

def __call__()

init¤

def __init__()

init¤

def __init__()

create_attention_mask¤

def create_attention_mask()

flash_attention_forward_kernel¤

def flash_attention_forward_kernel()

flash_attention_triton¤

def flash_attention_triton()

init_cache¤

def init_cache()

make_causal_mask¤

def make_causal_mask()

make_segment_mask¤

def make_segment_mask()

Module Statistics¤

  • Classes: 4
  • Functions: 9
  • Imports: 20