lite_boost.ops.sparse_attention
- lite_boost.ops.sparse_attention(# pylint =unused-argument q, k, v, attn_mask = None, scale = None, is_causal = False, head_num = 1, input_layout = "BNSD", inner_precise = 0, sparse_type = None, txt_len = 0, block_size = 128, latent_shape_q = None, latent_shape_k = None, keep_sink = True, keep_recent = True, cdf_threshold = 1.0, sparsity = 0.0, **_kwargs)[source]
High-level sparse attention entry point.
Supports two modes:
dense (
sparse_type=None): Callstorch_npu.npu_fusion_attentiondirectly.block-sparse (
sparse_type="rf_v2"): Token rearrangement → block-wise pooling → top-k sparse mask generation →rain_fusion_attention→ inverse rearrangement.
Input tensors support
float16andbfloat16dtypes.- Parameters:
q (Tensor) – Query tensor.
k (Tensor) – Key tensor.
v (Tensor) – Value tensor.
attn_mask (Tensor, optional) – Reserved (unused in current rf_v2 path). Default:
None.scale (float, optional) – Attention scale factor.
Noneauto-sets tohead_dim ** -0.5. Default:None.is_causal (bool, optional) – Whether to apply causal mask. Effective in dense mode (
sparse_type=None); unsupported in rf_v2 sparse path. Default:False.head_num (int, optional) – Number of attention heads. Default:
1.input_layout (str, optional) – Input layout,
"BSND"or"BNSD". Default:"BNSD".inner_precise (int, optional) – Precision mode.
0for high-precision,1for high-performance. Default:0.sparse_type (str, optional) – Sparse type.
Nonefor dense,"rf_v2"for block-sparse."rf_v3"and"ada_bsa"are reserved for future support. Default:None.txt_len (int, optional) – Number of text prefix tokens. When
>0, text tokens are separated and not rearranged spatially (rf_v2 only). Default:0.block_size (int, optional) – Block size for pooling and attention (rf_v2 only). Default:
128.latent_shape_q (list[int], optional) – Latent spatial grid
(t, h, w)for query. Required for rf_v2 path. For example,(1, 64, 64)for a single 64×64 frame with 4096 tokens. Default:None.latent_shape_k (list[int], optional) – Latent spatial grid
(t, h, w)for key/value. Reuseslatent_shape_qwhen not provided (rf_v2 only). Default:None.keep_sink (bool, optional) – Whether to keep sink tokens. Reserved for
ada_bsamode (unused in rf_v2 path). Default:True.keep_recent (bool, optional) – Whether to keep recent tokens. Reserved for
ada_bsamode (unused in rf_v2 path). Default:True.cdf_threshold (float, optional) – CDF threshold. Reserved for
ada_bsamode (unused in rf_v2 path). Default:1.0.sparsity (float, optional) – Sparsity ratio in
[0, 1].0.0for full attention,0.5prunes 50% of KV blocks (rf_v2 only). Default:0.0.**_kwargs – Additional keyword arguments (reserved for future use).
- Returns:
Attention output with same shape as input
q.- Return type:
- Raises:
ValueError – If
input_layoutis not"BSND"or"BNSD".ValueError – If
sparse_typeis notNoneor"rf_v2".
Examples
>>> # rf_v2 sparse attention (sparsity=0.0 for full dense) >>> import math >>> import torch >>> from lite_boost.ops.sparse_attention import sparse_attention >>> device = torch.device("npu:0") >>> batch_size, num_heads, seq_len, head_dim = 1, 3, 4096, 128 >>> scale = head_dim ** -0.5 >>> latent_shape = (1, 64, 64) >>> q = torch.randn(batch_size, seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> k = torch.randn(batch_size, seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> v = torch.randn(batch_size, seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> out = sparse_attention( ... q=q, k=k, v=v, ... scale=scale, head_num=num_heads, ... input_layout="BSND", inner_precise=0, ... sparse_type="rf_v2", ... block_size=128, latent_shape_q=latent_shape, ... sparsity=0.0) >>> print(out.shape) (1, 4096, 3, 128)