# Copyright 2026 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Sparse flash attention (rf_v2).
Implements block-sparse attention with token rearrangement, block-wise pooling,
top-k sparse mask generation, and inverse rearrangement for video generation.
"""
import math
from typing import Optional
import torch
# ---------------------------------------------------------------------------
# avgpool — block-wise mean pooling with aligned fast path
# ---------------------------------------------------------------------------
def avgpool(input_tensor, pool_size=128, input_layout='BNSD'):
"""Average pool input_tensor in blocks of pool_size along the sequence dim.
Args:
input_tensor: BSND [B,S,N,D] or BNSD [B,N,S,D]
pool_size: block size (default 128)
input_layout: 'BSND' or 'BNSD'
Returns:
Pooled tensor with same layout, sequence dim reduced by factor pool_size.
"""
if input_layout == "BSND":
batch, seqlen, headnum, dim = input_tensor.shape
num_full_blocks = seqlen // pool_size
tail_size = seqlen % pool_size
# Fast path: aligned sequence (no tail)
if tail_size == 0:
return input_tensor.view(
batch, num_full_blocks, pool_size, headnum, dim
).mean(dim=2)
# Slow path: handle tail
full_blocks = input_tensor[:, :num_full_blocks * pool_size, :, :]
full_pooled = full_blocks.view(
batch, num_full_blocks, pool_size, headnum, dim
).mean(dim=2)
tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :]
tail_pooled = tail_block.view(
batch, 1, tail_size, headnum, dim
).mean(dim=2)
return torch.cat([full_pooled, tail_pooled], dim=1)
batch, headnum, seqlen, dim = input_tensor.shape
num_full_blocks = seqlen // pool_size
tail_size = seqlen % pool_size
# Fast path: aligned sequence
if tail_size == 0:
return input_tensor.view(
batch, headnum, num_full_blocks, pool_size, dim
).mean(dim=3)
# Slow path: handle tail
full_blocks = input_tensor[:, :, :num_full_blocks * pool_size, :]
full_pooled = full_blocks.view(
batch, headnum, num_full_blocks, pool_size, dim
).mean(dim=3)
tail_block = input_tensor[:, :, num_full_blocks * pool_size:, :]
tail_pooled = tail_block.view(
batch, headnum, 1, tail_size, dim
).mean(dim=3)
return torch.cat([full_pooled, tail_pooled], dim=2)
# ---------------------------------------------------------------------------
# get_mask_index — convert binary mask to select_idx format for rain_fusion op
# ---------------------------------------------------------------------------
def get_mask_index(mask):
"""Convert boolean block-sparse mask [B,N,S,S] to index matrix [B,N,S,S].
For each row, true positions are sorted to the front; remaining filled with -1.
"""
b, n, s, _ = mask.shape
device = mask.device
mask_reshaped = mask.reshape(-1, s, s)
batch_size = mask_reshaped.shape[0]
# Row indices [batch_size, s, s]
row_indices = torch.arange(s, device=device).expand(batch_size, s, -1)
# Replace False with large sentinel, sort to push true indices to front
sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32)
sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
# Mask out excess positions
valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
keep_mask = row_indices < valid_count
result = torch.where(keep_mask, sorted_vals, -1)
return result.reshape(b, n, s, s).to(torch.int64)
# ---------------------------------------------------------------------------
# get_blockwise_mask — generate block-sparse mask from pooled QKV
# ---------------------------------------------------------------------------
def get_blockwise_mask( # pylint: disable=unused-argument
qkv_pool,
txt_len, sparsity, scale, pool_size,
latent_shape_q, latent_shape_k=None,
input_layout=None, return_binary=False,
protect_first_frame=True):
"""Generate block-sparse attention mask via top-k thresholding on softmax scores.
Args:
qkv_pool: concatenated pooled Q,K,V along dim=0
txt_len: number of text prefix tokens
sparsity: fraction of KV blocks to prune [0, 1)
scale: attention scale factor
pool_size: block size used for pooling
latent_shape_q: (t, h, w) for query
input_layout: 'BSND' or 'BNSD'
return_binary: if True, return int8 mask; else return select_idx/select_num_idx
protect_first_frame: whether to protect first frame blocks
Returns:
If return_binary: int8 mask [B, N, q_blocks, kv_blocks]
Else: (select_idx, select_num_idx) for rain_fusion_attention op
"""
_, hq, wq = latent_shape_q
first_frame_len = hq * wq
query_pool, key_pool, _ = torch.chunk(qkv_pool, 3, dim=0)
# Compute attention scores on pooled tokens
if input_layout == "BSND":
attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale
else:
attn_scores_head = torch.einsum("bnld,bnsd->bnls", query_pool, key_pool) * scale
score_matrix = torch.nn.functional.softmax(attn_scores_head, dim=-1)
cols = score_matrix.shape[-1]
# Top-k thresholding for sparsity
keep_len = math.ceil(cols * (1 - sparsity))
topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1)
thresholds = topk_values[..., -1:]
mask = score_matrix >= thresholds
# Protect text blocks (bidirectional attention with text)
text_block_num = (txt_len + pool_size - 1) // pool_size
if text_block_num > 0:
mask[:, :, -text_block_num:, :] = True
mask[:, :, :, -text_block_num:] = True
# Protect first frame blocks
if protect_first_frame:
firstframe_block_num = (first_frame_len + pool_size - 1) // pool_size
if firstframe_block_num > 0:
mask[:, :, :firstframe_block_num, :] = True
mask[:, :, :, :firstframe_block_num] = True
if return_binary:
return mask.to(torch.int8)
select_idx = get_mask_index(mask)
select_idx = select_idx[0].transpose(0, 1)
select_num_idx = mask[0].transpose(0, 1).sum(dim=-1)
return select_idx, select_num_idx
# ---------------------------------------------------------------------------
# rearrange_with_remaining — spatial token reordering (native torch, no einops)
# ---------------------------------------------------------------------------
def rearrange_with_remaining(tensor, latent_shape_q, latent_shape_k=None, input_layout=None):
# pylint: disable=unused-argument
"""Rearrange tokens from (frame, h, w) order to block-interleaved (hn, wn, hb, wb).
Transforms:
BSND: b (f h w) n d -> b (f hn wn hb wb) n d
BNSD: b n (f h w) d -> b n (f hn wn hb wb) d
where hb=wb=8, hn=h//8, wn=w//8.
"""
tq, hq, wq = latent_shape_q
first_frame_len = hq * wq
frame_num = tq
hn, wn = hq // 8, wq // 8
aligned = (hq % 8 == 0) and (wq % 8 == 0)
if input_layout == "BSND":
b, _, n, d = tensor.shape
if aligned:
# Fast path: single reshape+permute+reshape
return (tensor
.reshape(b, frame_num, hn, 8, wn, 8, n, d)
.permute(0, 1, 2, 4, 3, 5, 6, 7)
.contiguous()
.reshape(b, frame_num * hn * wn * 64, n, d))
# Remainder path: split first frame from rest
tensor_first = tensor[:, :first_frame_len, :, :]
tensor_rest = tensor[:, first_frame_len:, :, :]
f_rest = frame_num - 1
# Reshape to (b, f, h, w, n, d)
tensor_hwt = tensor_rest.reshape(b, f_rest, hq, wq, n, d)
# Split h dimension at 8-aligned boundary
hq_block = (hq // 8) * 8
hq_rem = hq % 8
if hq_rem != 0:
t_block = tensor_hwt[:, :, :hq_block, :, :, :]
t_h_r = tensor_hwt[:, :, hq_block:, :, :, :].reshape(b, f_rest, -1, n, d)
else:
t_block = tensor_hwt
t_h_r = None
# Split w dimension at 8-aligned boundary
wq_block = (wq // 8) * 8
wq_rem = wq % 8
if wq_rem != 0:
t_main = t_block[:, :, :, :wq_block, :, :]
t_w_r = t_block[:, :, :, wq_block:, :, :].reshape(b, f_rest, -1, n, d)
else:
t_main = t_block
t_w_r = None
# Block-rearrange the aligned portion
t_main = (t_main
.reshape(b, f_rest, hn, 8, wn, 8, n, d)
.permute(0, 1, 2, 4, 3, 5, 6, 7)
.contiguous()
.reshape(b, f_rest, -1, n, d))
# Concatenate remainder blocks
if t_h_r is not None:
t_main = torch.cat([t_main, t_h_r], dim=2)
if t_w_r is not None:
t_main = torch.cat([t_main, t_w_r], dim=2)
t_main = t_main.reshape(b, -1, n, d)
return torch.cat([tensor_first, t_main], dim=1)
b, n, _, d = tensor.shape
if aligned:
# Fast path: single reshape+permute+reshape
return (tensor
.reshape(b, n, frame_num, hn, 8, wn, 8, d)
.permute(0, 1, 2, 3, 5, 4, 6, 7)
.contiguous()
.reshape(b, n, frame_num * hn * wn * 64, d))
# Remainder path
tensor_first = tensor[:, :, :first_frame_len, :]
tensor_rest = tensor[:, :, first_frame_len:, :]
f_rest = frame_num - 1
tensor_hwt = tensor_rest.reshape(b, n, f_rest, hq, wq, d)
hq_block = (hq // 8) * 8
hq_rem = hq % 8
if hq_rem != 0:
t_block = tensor_hwt[:, :, :, :hq_block, :, :]
t_h_r = tensor_hwt[:, :, :, hq_block:, :, :].reshape(b, n, f_rest, -1, d)
else:
t_block = tensor_hwt
t_h_r = None
wq_block = (wq // 8) * 8
wq_rem = wq % 8
if wq_rem != 0:
t_main = t_block[:, :, :, :, :wq_block, :]
t_w_r = t_block[:, :, :, :, wq_block:, :].reshape(b, n, f_rest, -1, d)
else:
t_main = t_block
t_w_r = None
t_main = (t_main
.reshape(b, n, f_rest, hn, 8, wn, 8, d)
.permute(0, 1, 2, 3, 5, 4, 6, 7)
.contiguous()
.reshape(b, n, f_rest, -1, d))
if t_h_r is not None:
t_main = torch.cat([t_main, t_h_r], dim=3)
if t_w_r is not None:
t_main = torch.cat([t_main, t_w_r], dim=3)
t_main = t_main.reshape(b, n, -1, d)
return torch.cat([tensor_first, t_main], dim=2)
# ---------------------------------------------------------------------------
# inv_rearrange_with_remaining — inverse spatial token reordering (v3-style)
# ---------------------------------------------------------------------------
def inv_rearrange_with_remaining(tensor, latent_shape_q, latent_shape_k=None, input_layout=None):
# pylint: disable=unused-argument
"""Inverse of rearrange_with_remaining: block-interleaved -> (frame, h, w) order.
Transforms:
BSND: b (f hn wn hb wb) n d -> b (f h w) n d
BNSD: b n (f hn wn hb wb) d -> b n (f h w) d
where hb=wb=8, hn=h//8, wn=w//8.
"""
tq, hq, wq = latent_shape_q
first_frame_len = hq * wq
frame_num = tq
hn, wn = hq // 8, wq // 8
aligned = (hq % 8 == 0) and (wq % 8 == 0)
if input_layout == "BSND":
b, _, n, d = tensor.shape
if aligned:
# Fast path: reshape -> permute -> reshape
return (tensor
.reshape(b, frame_num, hn, wn, 8, 8, n, d)
.permute(0, 1, 2, 4, 3, 5, 6, 7)
.contiguous()
.reshape(b, frame_num * hq * wq, n, d))
# Remainder path
tensor_first = tensor[:, :first_frame_len, :, :]
tensor_rest = tensor[:, first_frame_len:, :, :]
f_rest = frame_num - 1
tensor_rest = tensor_rest.reshape(b, f_rest, hq * wq, n, d)
# Explicit slicing by known block sizes (avoids torch.split)
hq_block = (hq // 8) * 8
wq_block = (wq // 8) * 8
hq_rem = hq % 8
wq_rem = wq % 8
block_size = hn * wn * 64
h_rem_size = hq_rem * wq
w_rem_size = hq_block * wq_rem
t_block = tensor_rest[:, :, :block_size, :, :]
t_h_r = tensor_rest[:, :, block_size:block_size + h_rem_size, :, :] if hq_rem > 0 else None
t_w_r = (tensor_rest[:, :, block_size + h_rem_size:block_size + h_rem_size + w_rem_size, :, :]
if wq_rem > 0 else None)
# Un-block-rearrange the aligned portion
t_block = (t_block
.reshape(b, f_rest, hn, wn, 8, 8, n, d)
.permute(0, 1, 2, 4, 3, 5, 6, 7)
.contiguous()
.reshape(b, f_rest, hq_block, wq_block, n, d))
if wq_rem > 0:
t_block = torch.cat([t_block, t_w_r.reshape(b, f_rest, hq_block, wq_rem, n, d)], dim=3)
if hq_rem > 0:
t_block = torch.cat([t_block, t_h_r.reshape(b, f_rest, hq_rem, wq, n, d)], dim=2)
tensor_rest = t_block.reshape(b, -1, n, d)
return torch.cat([tensor_first, tensor_rest], dim=1)
b, n, _, d = tensor.shape
if aligned:
# Fast path
return (tensor
.reshape(b, n, frame_num, hn, wn, 8, 8, d)
.permute(0, 1, 2, 3, 5, 4, 6, 7)
.contiguous()
.reshape(b, n, frame_num * hq * wq, d))
# Remainder path
tensor_first = tensor[:, :, :first_frame_len, :]
tensor_rest = tensor[:, :, first_frame_len:, :]
f_rest = frame_num - 1
tensor_rest = tensor_rest.reshape(b, n, f_rest, hq * wq, d)
hq_block = (hq // 8) * 8
wq_block = (wq // 8) * 8
hq_rem = hq % 8
wq_rem = wq % 8
block_size = hn * wn * 64
h_rem_size = hq_rem * wq
w_rem_size = hq_block * wq_rem
t_block = tensor_rest[:, :, :, :block_size, :]
t_h_r = tensor_rest[:, :, :, block_size:block_size + h_rem_size, :] if hq_rem > 0 else None
t_w_r = (tensor_rest[:, :, :, block_size + h_rem_size:block_size + h_rem_size + w_rem_size, :]
if wq_rem > 0 else None)
t_block = (t_block
.reshape(b, n, f_rest, hn, wn, 8, 8, d)
.permute(0, 1, 2, 3, 5, 4, 6, 7)
.contiguous()
.reshape(b, n, f_rest, hq_block, wq_block, d))
if wq_rem > 0:
t_block = torch.cat([t_block, t_w_r.reshape(b, n, f_rest, hq_block, wq_rem, d)], dim=4)
if hq_rem > 0:
t_block = torch.cat([t_block, t_h_r.reshape(b, n, f_rest, hq_rem, wq, d)], dim=3)
tensor_rest = t_block.reshape(b, n, -1, d)
return torch.cat([tensor_first, tensor_rest], dim=2)
# ---------------------------------------------------------------------------
# do_tensor_rearrange_pooling — rearrange + pool QKV, returning pooled for mask gen
# ---------------------------------------------------------------------------
def do_tensor_rearrange_pooling(query, key, value, text_len, pool_size,
latent_shape_q, latent_shape_k, input_layout):
"""Rearrange Q/K/V tokens and compute block-wise pooled representations.
When text_len > 0, text tokens are separated before rearrange and re-attached
afterwards. When text_len == 0, Q/K/V are processed independently to avoid
the cat+chunk roundtrip.
Returns:
query_, key_, value_: rearranged Q/K/V
tensor_pool: pooled QKV for mask generation
"""
if text_len != 0:
tensor = torch.cat((query, key, value), dim=0)
if input_layout == "BSND":
tensor_t = tensor[:, :text_len, :, :]
tensor_i = tensor[:, text_len:, :, :]
else:
tensor_t = tensor[:, :, :text_len, :]
tensor_i = tensor[:, :, text_len:, :]
tensor_i_2 = rearrange_with_remaining(tensor_i, latent_shape_q, latent_shape_k, input_layout)
tensor_i_pool = avgpool(tensor_i_2, pool_size, input_layout)
tensor_t_pool = avgpool(tensor_t, pool_size, input_layout)
if input_layout == "BSND":
tensor = torch.cat((tensor_i_2, tensor_t), dim=1)
tensor_pool = torch.cat((tensor_i_pool, tensor_t_pool), dim=1)
else:
tensor = torch.cat((tensor_i_2, tensor_t), dim=2)
tensor_pool = torch.cat((tensor_i_pool, tensor_t_pool), dim=2)
query_, key_, value_ = torch.chunk(tensor, 3, dim=0)
return query_, key_, value_, tensor_pool
# No text: avoid cat+chunk roundtrip — process Q/K/V independently
q2 = rearrange_with_remaining(query, latent_shape_q, latent_shape_k, input_layout)
k2 = rearrange_with_remaining(key, latent_shape_q, latent_shape_k, input_layout)
v2 = rearrange_with_remaining(value, latent_shape_q, latent_shape_k, input_layout)
q_pool = avgpool(q2, pool_size, input_layout)
k_pool = avgpool(k2, pool_size, input_layout)
v_pool = avgpool(v2, pool_size, input_layout)
tensor_pool = torch.cat([q_pool, k_pool, v_pool], dim=0)
return q2, k2, v2, tensor_pool
# ---------------------------------------------------------------------------
# do_tensor_inv_rearrange — inverse rearrange for attention output
# ---------------------------------------------------------------------------
def do_tensor_inv_rearrange(tensor, text_len, latent_shape_q, latent_shape_k, input_layout):
"""Inverse rearrange: restore original spatial token order from block-interleaved.
When text_len > 0, text tokens (at the end) are separated and re-attached
without spatial rearrangement.
"""
if text_len != 0:
if input_layout == "BSND":
tensor_t = tensor[:, -text_len:, :, :]
tensor_i = tensor[:, :-text_len, :, :]
tensor_i = inv_rearrange_with_remaining(tensor_i, latent_shape_q, latent_shape_k, input_layout)
return torch.cat((tensor_t, tensor_i), dim=1)
# BNSD
tensor_t = tensor[:, :, -text_len:, :]
tensor_i = tensor[:, :, :-text_len, :]
tensor_i = inv_rearrange_with_remaining(tensor_i, latent_shape_q, latent_shape_k, input_layout)
return torch.cat((tensor_t, tensor_i), dim=2)
return inv_rearrange_with_remaining(tensor, latent_shape_q, latent_shape_k, input_layout)
# ---------------------------------------------------------------------------
# do_tensor_pooling — pool-only (no rearrange), for legacy use
# ---------------------------------------------------------------------------
def do_tensor_pooling(tensor, text_len):
"""Pool image and text token regions separately and concatenate."""
tensor_t = tensor[:, :text_len, :, :]
tensor_i = tensor[:, text_len:, :, :]
tensor_i_pool = avgpool(tensor_i, pool_size=128)
tensor_t_pool = avgpool(tensor_t, pool_size=128)
return torch.cat((tensor_t_pool, tensor_i_pool), dim=1)
# ---------------------------------------------------------------------------
# check_params
# ---------------------------------------------------------------------------
MAX_TOKEN = 2147483647
def check_params(input_layout, sparse_type):
"""Validate input_layout and sparse_type parameters."""
if input_layout not in ['BSND', 'BNSD']:
raise ValueError(f"The input_layout must be in ['BSND', 'BNSD'], but got {input_layout}.")
if sparse_type not in [None, 'rf_v2']:
raise ValueError(f"sparse_type must be None or 'rf_v2', but got {sparse_type}.")
# ---------------------------------------------------------------------------
# sparse_attention — main entry point
# ---------------------------------------------------------------------------
[docs]
def sparse_attention( # pylint: disable=unused-argument
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: Optional[bool] = False,
head_num: int = 1,
input_layout: str = "BNSD",
inner_precise: int = 0,
sparse_type: Optional[str] = None,
txt_len: int = 0,
block_size: int = 128,
latent_shape_q: Optional[list[int]] = None,
latent_shape_k: Optional[list[int]] = None,
keep_sink: bool = True,
keep_recent: bool = True,
cdf_threshold: float = 1.0,
sparsity: float = 0.0,
**_kwargs
):
r"""
High-level sparse attention entry point.
Supports two modes:
- **dense** (``sparse_type=None``): Calls ``torch_npu.npu_fusion_attention`` directly.
- **block-sparse** (``sparse_type="rf_v2"``): Token rearrangement → block-wise pooling →
top-k sparse mask generation → ``rain_fusion_attention`` → inverse rearrangement.
Input tensors support ``float16`` and ``bfloat16`` dtypes.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
v (Tensor): Value tensor.
attn_mask (Tensor, optional): Reserved (unused in current rf_v2 path). Default: ``None``.
scale (float, optional): Attention scale factor. ``None`` auto-sets to ``head_dim ** -0.5``.
Default: ``None``.
is_causal (bool, optional): Whether to apply causal mask. Effective in dense mode
(``sparse_type=None``); unsupported in rf_v2 sparse path. Default: ``False``.
head_num (int, optional): Number of attention heads. Default: ``1``.
input_layout (str, optional): Input layout, ``"BSND"`` or ``"BNSD"``. Default: ``"BNSD"``.
inner_precise (int, optional): Precision mode. ``0`` for high-precision, ``1`` for
high-performance. Default: ``0``.
sparse_type (str, optional): Sparse type. ``None`` for dense, ``"rf_v2"`` for block-sparse.
``"rf_v3"`` and ``"ada_bsa"`` are reserved for future support.
Default: ``None``.
txt_len (int, optional): Number of text prefix tokens. When ``>0``, text tokens are
separated and not rearranged spatially (rf_v2 only). Default: ``0``.
block_size (int, optional): Block size for pooling and attention (rf_v2 only).
Default: ``128``.
latent_shape_q (list[int], optional): Latent spatial grid ``(t, h, w)`` for query.
Required for rf_v2 path. For example, ``(1, 64, 64)`` for a single 64×64 frame
with 4096 tokens. Default: ``None``.
latent_shape_k (list[int], optional): Latent spatial grid ``(t, h, w)`` for key/value.
Reuses ``latent_shape_q`` when not provided (rf_v2 only). Default: ``None``.
keep_sink (bool, optional): Whether to keep sink tokens. Reserved for ``ada_bsa`` mode
(unused in rf_v2 path). Default: ``True``.
keep_recent (bool, optional): Whether to keep recent tokens. Reserved for ``ada_bsa`` mode
(unused in rf_v2 path). Default: ``True``.
cdf_threshold (float, optional): CDF threshold. Reserved for ``ada_bsa`` mode
(unused in rf_v2 path). Default: ``1.0``.
sparsity (float, optional): Sparsity ratio in ``[0, 1]``. ``0.0`` for full attention,
``0.5`` prunes 50% of KV blocks (rf_v2 only). Default: ``0.0``.
**_kwargs: Additional keyword arguments (reserved for future use).
Returns:
Tensor: Attention output with same shape as input ``q``.
Raises:
ValueError: If ``input_layout`` is not ``"BSND"`` or ``"BNSD"``.
ValueError: If ``sparse_type`` is not ``None`` or ``"rf_v2"``.
Examples:
>>> # rf_v2 sparse attention (sparsity=0.0 for full dense)
>>> import math
>>> import torch
>>> from lite_boost.ops.sparse_attention import sparse_attention
>>> device = torch.device("npu:0")
>>> batch_size, num_heads, seq_len, head_dim = 1, 3, 4096, 128
>>> scale = head_dim ** -0.5
>>> latent_shape = (1, 64, 64)
>>> q = torch.randn(batch_size, seq_len, num_heads, head_dim,
... dtype=torch.float16, device=device)
>>> k = torch.randn(batch_size, seq_len, num_heads, head_dim,
... dtype=torch.float16, device=device)
>>> v = torch.randn(batch_size, seq_len, num_heads, head_dim,
... dtype=torch.float16, device=device)
>>> out = sparse_attention(
... q=q, k=k, v=v,
... scale=scale, head_num=num_heads,
... input_layout="BSND", inner_precise=0,
... sparse_type="rf_v2",
... block_size=128, latent_shape_q=latent_shape,
... sparsity=0.0)
>>> print(out.shape)
(1, 4096, 3, 128)
"""
check_params(input_layout, sparse_type)
batch, head_dim = q.shape[0], q.shape[-1]
scale = head_dim ** -0.5 if scale is None else scale
if sparse_type == "rf_v2":
q_rf, k_rf, v_rf, qkv_pool = do_tensor_rearrange_pooling(
q, k, v, txt_len, block_size, latent_shape_q, latent_shape_k, input_layout
)
select_idx, select_num_idx = get_blockwise_mask(
qkv_pool, txt_len, sparsity, scale, block_size, latent_shape_q, latent_shape_k, input_layout)
if input_layout == "BSND":
q_seq, kv_seq = q_rf.shape[1], k_rf.shape[1]
layout = "TND"
q_rf = q_rf.reshape(-1, head_num, head_dim)
k_rf = k_rf.reshape(-1, head_num, head_dim)
v_rf = v_rf.reshape(-1, head_num, head_dim)
else:
q_seq, kv_seq = q_rf.shape[2], k_rf.shape[2]
layout = input_layout
actual_seq_lengths = [q_seq for _ in range(batch)]
actual_seq_lengths_kv = [kv_seq for _ in range(batch)]
from . import rain_fusion_attention
out, _ = rain_fusion_attention(
q_rf, k_rf, v_rf,
select_idx, select_num_idx,
block_shape=[block_size, block_size],
attn_mask=None,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
block_table=None,
q_input_layout=layout,
kv_input_layout=layout,
num_key_value_heads=head_num,
mask_type=0,
scale_value=scale,
inner_precise=inner_precise,
block_size=0,
)
if layout == "TND":
out = out.reshape(batch, q_seq, head_num, head_dim)
out = do_tensor_inv_rearrange(out, txt_len, latent_shape_q, latent_shape_k, input_layout)
elif sparse_type is None:
import torch_npu
out = torch_npu.npu_fusion_attention(
q, k, v,
input_layout=input_layout,
scale=scale,
pre_tockens=MAX_TOKEN,
next_tockens=MAX_TOKEN,
head_num=head_num)[0]
else:
raise ValueError(f"sparse_type must be None or 'rf_v2', but got {sparse_type}.")
return out