lite_boost.ops.rain_fusion_attention

View Source On AtomGit
lite_boost.ops.rain_fusion_attention(query, key, value, select_idx, select_num_idx, block_shape, attn_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, block_table=None, q_input_layout='TND', kv_input_layout='TND', num_key_value_heads=1, mask_type=0, scale_value=1.0, inner_precise=1, block_size=0)[source]

Block-sparse fusion attention forward computation.

Calls the NPU native aclnnRainFusionAttention operator to perform block-level sparse fused attention. Supports flexible non-uniform sparse attention patterns by specifying per query-block KV block indices via select_idx and select_num_idx. Input tensors support float16 and bfloat16 dtypes.

Parameters:
  • query (Tensor) – Query tensor. Layout depends on q_input_layout.

  • key (Tensor) – Key tensor. Layout depends on kv_input_layout.

  • value (Tensor) – Value tensor. Layout depends on kv_input_layout.

  • select_idx (Tensor) – Sparse selection index matrix. Shape [q_blocks, num_heads, kv_blocks], dtype torch.int64. Valid KV block indices are sorted in ascending order per row, with remaining positions filled with -1.

  • select_num_idx (Tensor) – Number of valid KV blocks per query block and head. Shape [q_blocks, num_heads], dtype torch.int64.

  • block_shape (list[int]) – Block tile size in [block_rows, block_cols], typically [128, 128].

  • attn_mask (Tensor, optional) – Attention mask tensor. Default: None.

  • actual_seq_lengths (list[int], optional) – Actual Q sequence length per batch. Required when layout is "TND" to correctly compute sequence boundaries. Default: None.

  • actual_seq_lengths_kv (list[int], optional) – Actual KV sequence length per batch. Default: None.

  • block_table (Tensor, optional) – Block table for PagedAttention scenarios. Default: None.

  • q_input_layout (str, optional) – Query input layout, "TND" or "BNSD". Default: "TND".

  • kv_input_layout (str, optional) – Key/Value input layout, "TND" or "BNSD". Default: "TND".

  • num_key_value_heads (int, optional) – Number of KV heads for GQA/MQA. Default: 1.

  • mask_type (int, optional) – Mask type (0 for causal mask). Default: 0.

  • scale_value (float, optional) – Attention scale factor. Recommended: head_dim ** -0.5. Default: 1.0.

  • inner_precise (int, optional) – Precision mode. 0 for high-precision, 1 for high-performance. Default: 1.

  • block_size (int, optional) – Block size. 0 for automatic inference. Default: 0.

Returns:

tuple[Tensor, Tensor]

  • attention_out (Tensor) — Attention output, same shape as query.

  • softmax_lse (Tensor) — Softmax log-sum-exp values with shape [T, N, H], for debugging and gradient backpropagation.

Raises:

TypeError – If query, key, value, select_idx or select_num_idx is not a Tensor.

Examples

>>> # Build dense attention with rain_fusion_attention
>>> import math
>>> import torch
>>> import lite_boost.ops as lite_ops
>>> device = torch.device("npu:0")
>>> batch_size, num_heads, seq_len, head_dim = 1, 3, 4096, 128
>>> block_size = 128
>>> scale = head_dim ** -0.5
>>> q = torch.randn(batch_size * seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> k = torch.randn(batch_size * seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> v = torch.randn(batch_size * seq_len, num_heads, head_dim,
...                 dtype=torch.float16, device=device)
>>> q_blocks = math.ceil(seq_len / block_size)
>>> kv_blocks = math.ceil(seq_len / block_size)
>>> select_idx = torch.full((q_blocks, num_heads, kv_blocks), -1,
...                         dtype=torch.int64, device=device)
>>> base_indices = torch.arange(kv_blocks, dtype=torch.int64, device=device)
>>> select_idx[...] = base_indices.repeat(q_blocks, num_heads, 1)
>>> select_num_idx = torch.full((q_blocks, num_heads), kv_blocks,
...                             dtype=torch.int64, device=device)
>>> attention_out, softmax_lse = lite_ops.rain_fusion_attention(
...     query=q, key=k, value=v,
...     select_idx=select_idx, select_num_idx=select_num_idx,
...     block_shape=[block_size, block_size],
...     scale_value=scale,
...     actual_seq_lengths=[seq_len],
...     actual_seq_lengths_kv=[seq_len])
>>> print(attention_out.shape)
torch.Size([4096, 3, 128])