lite_boost.ops.recurrent_gated_delta_rule
- lite_boost.ops.recurrent_gated_delta_rule(query, key, value, beta, state, actual_seq_lengths, ssm_state_indices, g, gk, num_accepted_tokens, scale_value=1.0)[source]
Recurrent GatedDeltaRule operator — CANN aclnn-backed recurrent linear attention decode.
Implements the token-by-token recurrent forward pass of the Gated Delta Rule, updating the recurrent state matrix and producing the attention output. Primarily used for decode-phase inference acceleration in hybrid linear attention models such as Qwen3.5.
Algorithm flow (executed sequentially for each token in each batch):
State decay: S = S * exp(g) * exp(gk)
Memory retrieval: kv_mem = S^T @ k
Delta update: S = S + k^T @ ((v - kv_mem) * beta)
Output: o = S^T @ q
where S is the recurrent state matrix
[H, D_k, D_v]storing the key-value associations of linear attention.- Parameters:
query (torch.Tensor) – Query tensor of shape
[B, H_q, T, D_k], dtype=bfloat16. Must be L2-normalized (L2 norm of each head vector is 1, value range [0, 1]). B=batch_size, H_q=num_query_heads, T=seq_len, D_k=key_dim.key (torch.Tensor) – Key tensor of shape
[B, H_q, T, D_k], dtype=bfloat16. Must be L2-normalized (same as query). In GQA/MQA scenarios, multiple query heads share the same set of keys.value (torch.Tensor) – Value tensor of shape
[B, H_v, T, D_v], dtype=bfloat16. H_v=num_value_heads (can differ from H_q to support GQA/MQA), D_v=value_dim.beta (torch.Tensor) – Delta update step size of shape
[B, H_v, T], dtype=bfloat16. Value range (0, 1). Controls the magnitude of each delta update: a larger beta causes new information to overwrite old memory more aggressively; a smaller beta tends to preserve existing memory.state (torch.Tensor) – Recurrent state matrix of shape
[B, H_v, D_k, D_v], dtype=bfloat16. Stores the cumulative key-value associations for linear attention. D_k is the key dimension (rows), D_v is the value dimension (columns). Can be initialized to zeros for the first call.actual_seq_lengths (torch.Tensor) – Actual sequence lengths of shape
[B], dtype=int32. Used for variable-length sequence inference. Each element represents the number of valid tokens in the corresponding batch. E.g.,[4, 3, 5]means 3 batches with sequence lengths 4, 3, and 5.ssm_state_indices (torch.Tensor) – SSM state indices of shape
[B], dtype=int32. Indicates the index position of each batch item in the global state pool, used for state management in multi-batch scenarios. Typically[0, 1, 2, ..., B-1].g (torch.Tensor) – Global decay gate of shape
[B, H_v, T], dtype=float32. Must be negative.exp(g)serves as the state decay factor with range (0, 1). The more negativegis, the faster historical information is forgotten. E.g., when g=-1, approximately 37% of the historical state is retained per step.gk (torch.Tensor) – Key-dimension gate of shape
[B, H_v, T, D_k], dtype=float32. Must be negative.exp(gk)applies per-dimension decay independently along the key dimension, enabling finer-grained memory control. Unlike the global gate g, gk operates element-wise along the D_k dimension.num_accepted_tokens (torch.Tensor) – Number of accepted tokens of shape
[B], dtype=int32. Used in speculative decoding and similar scenarios to mark the number of actually accepted (non-rejected) tokens. For standard inference, this is the same asactual_seq_lengths.scale_value (float, optional) – Attention scale factor, default 1.0. Typically set to
1.0 / sqrt(D_k), consistent with standard attention scaling. The query is multiplied by this scale factor before computation.
- Returns:
out (torch.Tensor): Attention output of shape
[B, H_v, T, D_v], dtype=bfloat16. The linear attention result at each token position.state_out (torch.Tensor): Updated recurrent state of shape
[B, H_v, D_k, D_v], dtype=bfloat16. Must be passed asstateinput in the next recurrent step to form a state-passing chain.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- Raises:
RuntimeError – If input tensor shapes, dtypes, or devices are invalid, or if the CANN operator execution fails.
Note
This operator only supports the decode phase (token-by-token inference), with sequence length T not exceeding 8. For parallel prefill computation, use the chunk-level operator.
Supports GQA (Grouped Query Attention) / MQA (Multi-Query Attention) modes, i.e., H_q can be greater than H_v, with multiple query heads sharing the same set of key/value heads.
All input tensors must reside on the same NPU device.
The CANN operator stores state internally as
[B, H_v, D_v, D_k]layout (value dimension first). This function automatically performs the layout conversion.
Example:
import torch import lite_boost.ops as lite_ops # Qwen3.5-2B decode configuration B, H, T, Dk, Dv = 1, 64, 1, 64, 512 # Initialize inputs (must satisfy CANN operator constraints) query = torch.randn(B, H, T, Dk, device="npu:0", dtype=torch.bfloat16) key = torch.randn(B, H, T, Dk, device="npu:0", dtype=torch.bfloat16) value = torch.randn(B, H, T, Dv, device="npu:0", dtype=torch.bfloat16) beta = torch.rand(B, H, T, device="npu:0", dtype=torch.bfloat16) * 0.9 + 0.05 state = torch.zeros(B, H, Dk, Dv, device="npu:0", dtype=torch.bfloat16) g = -(torch.rand(B, H, T, device="npu:0") + 0.01) # negative gk = -(torch.rand(B, H, T, Dk, device="npu:0") + 0.01) # negative actual_seq_lengths = torch.tensor([T], dtype=torch.int32, device="npu:0") ssm_state_indices = torch.tensor([0], dtype=torch.int32, device="npu:0") num_accepted_tokens = torch.tensor([T], dtype=torch.int32, device="npu:0") # Execute recurrent inference output, state_out = lite_ops.recurrent_gated_delta_rule( query, key, value, beta, state, actual_seq_lengths, ssm_state_indices, g, gk, num_accepted_tokens, scale_value=1.0 / (Dk ** 0.5), ) # output: [1, 64, 1, 512] -- attention output for the current token # state_out: [1, 64, 64, 512] -- updated recurrent state, passed to next step