lite_boost.ops.rain_fusion_attention
- lite_boost.ops.rain_fusion_attention(query, key, value, select_idx, select_num_idx, block_shape, attn_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, block_table=None, q_input_layout='TND', kv_input_layout='TND', num_key_value_heads=1, mask_type=0, scale_value=1.0, inner_precise=1, block_size=0)[source]
Block-sparse fusion attention forward computation.
Calls the NPU native
aclnnRainFusionAttentionoperator to perform block-level sparse fused attention. Supports flexible non-uniform sparse attention patterns by specifying per query-block KV block indices viaselect_idxandselect_num_idx. Input tensors supportfloat16andbfloat16dtypes.- Parameters:
query (Tensor) – Query tensor. Layout depends on
q_input_layout.key (Tensor) – Key tensor. Layout depends on
kv_input_layout.value (Tensor) – Value tensor. Layout depends on
kv_input_layout.select_idx (Tensor) – Sparse selection index matrix. Shape
[q_blocks, num_heads, kv_blocks], dtypetorch.int64. Valid KV block indices are sorted in ascending order per row, with remaining positions filled with-1.select_num_idx (Tensor) – Number of valid KV blocks per query block and head. Shape
[q_blocks, num_heads], dtypetorch.int64.block_shape (list[int]) – Block tile size in
[block_rows, block_cols], typically[128, 128].attn_mask (Tensor, optional) – Attention mask tensor. Default:
None.actual_seq_lengths (list[int], optional) – Actual Q sequence length per batch. Required when layout is
"TND"to correctly compute sequence boundaries. Default:None.actual_seq_lengths_kv (list[int], optional) – Actual KV sequence length per batch. Default:
None.block_table (Tensor, optional) – Block table for PagedAttention scenarios. Default:
None.q_input_layout (str, optional) – Query input layout,
"TND"or"BNSD". Default:"TND".kv_input_layout (str, optional) – Key/Value input layout,
"TND"or"BNSD". Default:"TND".num_key_value_heads (int, optional) – Number of KV heads for GQA/MQA. Default:
1.mask_type (int, optional) – Mask type (
0for causal mask). Default:0.scale_value (float, optional) – Attention scale factor. Recommended:
head_dim ** -0.5. Default:1.0.inner_precise (int, optional) – Precision mode.
0for high-precision,1for high-performance. Default:1.block_size (int, optional) – Block size.
0for automatic inference. Default:0.
- Returns:
tuple[Tensor, Tensor]
attention_out (Tensor) — Attention output, same shape as
query.softmax_lse (Tensor) — Softmax log-sum-exp values with shape
[T, N, H], for debugging and gradient backpropagation.
- Raises:
TypeError – If
query,key,value,select_idxorselect_num_idxis not a Tensor.
Examples
>>> # Build dense attention with rain_fusion_attention >>> import math >>> import torch >>> import lite_boost.ops as lite_ops >>> device = torch.device("npu:0") >>> batch_size, num_heads, seq_len, head_dim = 1, 3, 4096, 128 >>> block_size = 128 >>> scale = head_dim ** -0.5 >>> q = torch.randn(batch_size * seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> k = torch.randn(batch_size * seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> v = torch.randn(batch_size * seq_len, num_heads, head_dim, ... dtype=torch.float16, device=device) >>> q_blocks = math.ceil(seq_len / block_size) >>> kv_blocks = math.ceil(seq_len / block_size) >>> select_idx = torch.full((q_blocks, num_heads, kv_blocks), -1, ... dtype=torch.int64, device=device) >>> base_indices = torch.arange(kv_blocks, dtype=torch.int64, device=device) >>> select_idx[...] = base_indices.repeat(q_blocks, num_heads, 1) >>> select_num_idx = torch.full((q_blocks, num_heads), kv_blocks, ... dtype=torch.int64, device=device) >>> attention_out, softmax_lse = lite_ops.rain_fusion_attention( ... query=q, key=k, value=v, ... select_idx=select_idx, select_num_idx=select_num_idx, ... block_shape=[block_size, block_size], ... scale_value=scale, ... actual_seq_lengths=[seq_len], ... actual_seq_lengths_kv=[seq_len]) >>> print(attention_out.shape) torch.Size([4096, 3, 128])