lite_boost.ops.sparse_attention

View Source On AtomGit
lite_boost.ops.sparse_attention(# pylint =unused-argument    q, k, v, attn_mask = None, scale = None, is_causal = False, head_num = 1, input_layout = "BNSD", inner_precise = 0, sparse_type = None, txt_len = 0, block_size = 128, latent_shape_q = None, latent_shape_k = None, keep_sink = True, keep_recent = True, cdf_threshold = 1.0, sparsity = 0.0, **_kwargs)[source]

High-level sparse attention entry point.

Supports two modes:

  • dense (sparse_type=None): Calls torch_npu.npu_fusion_attention directly.

  • block-sparse (sparse_type="rf_v2"): Token rearrangement → block-wise pooling → top-k sparse mask generation → rain_fusion_attention → inverse rearrangement.

Input tensors support float16 and bfloat16 dtypes.

Parameters:
  • q (Tensor) – Query tensor.

  • k (Tensor) – Key tensor.

  • v (Tensor) – Value tensor.

  • attn_mask (Tensor, optional) – Reserved (unused in current rf_v2 path). Default: None.

  • scale (float, optional) – Attention scale factor. None auto-sets to head_dim ** -0.5. Default: None.

  • is_causal (bool, optional) – Whether to apply causal mask. Effective in dense mode (sparse_type=None); unsupported in rf_v2 sparse path. Default: False.

  • head_num (int, optional) – Number of attention heads. Default: 1.

  • input_layout (str, optional) – Input layout, "BSND" or "BNSD". Default: "BNSD".

  • inner_precise (int, optional) – Precision mode. 0 for high-precision, 1 for high-performance. Default: 0.

  • sparse_type (str, optional) – Sparse type. None for dense, "rf_v2" for block-sparse. "rf_v3" and "ada_bsa" are reserved for future support. Default: None.

  • txt_len (int, optional) – Number of text prefix tokens. When >0, text tokens are separated and not rearranged spatially (rf_v2 only). Default: 0.

  • block_size (int, optional) – Block size for pooling and attention (rf_v2 only). Default: 128.

  • latent_shape_q (list[int], optional) – Latent spatial grid (t, h, w) for query. Required for rf_v2 path. For example, (1, 64, 64) for a single 64×64 frame with 4096 tokens. Default: None.

  • latent_shape_k (list[int], optional) – Latent spatial grid (t, h, w) for key/value. Reuses latent_shape_q when not provided (rf_v2 only). Default: None.

  • keep_sink (bool, optional) – Whether to keep sink tokens. Reserved for ada_bsa mode (unused in rf_v2 path). Default: True.

  • keep_recent (bool, optional) – Whether to keep recent tokens. Reserved for ada_bsa mode (unused in rf_v2 path). Default: True.

  • cdf_threshold (float, optional) – CDF threshold. Reserved for ada_bsa mode (unused in rf_v2 path). Default: 1.0.

  • sparsity (float, optional) – Sparsity ratio in [0, 1]. 0.0 for full attention, 0.5 prunes 50% of KV blocks (rf_v2 only). Default: 0.0.

  • **_kwargs – Additional keyword arguments (reserved for future use).

Returns:

Attention output with same shape as input q.

Return type:

Tensor

Raises:
  • ValueError – If input_layout is not "BSND" or "BNSD".

  • ValueError – If sparse_type is not None or "rf_v2".

Examples

>>> # rf_v2 sparse attention (sparsity=0.0 for full dense)
>>> import math
>>> import torch
>>> from lite_boost.ops.sparse_attention import sparse_attention
>>> device = torch.device("npu:0")
>>> batch_size, num_heads, seq_len, head_dim = 1, 3, 4096, 128
>>> scale = head_dim ** -0.5
>>> latent_shape = (1, 64, 64)
>>> q = torch.randn(batch_size, seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> k = torch.randn(batch_size, seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> v = torch.randn(batch_size, seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> out = sparse_attention(
...     q=q, k=k, v=v,
...     scale=scale, head_num=num_heads,
...     input_layout="BSND", inner_precise=0,
...     sparse_type="rf_v2",
...     block_size=128, latent_shape_q=latent_shape,
...     sparsity=0.0)
>>> print(out.shape)
(1, 4096, 3, 128)