lite_boost.ops.recurrent_gated_delta_rule

View Source On AtomGit
lite_boost.ops.recurrent_gated_delta_rule(query, key, value, beta, state, actual_seq_lengths, ssm_state_indices, g, gk, num_accepted_tokens, scale_value=1.0)[source]

Recurrent GatedDeltaRule operator — CANN aclnn-backed recurrent linear attention decode.

Implements the token-by-token recurrent forward pass of the Gated Delta Rule, updating the recurrent state matrix and producing the attention output. Primarily used for decode-phase inference acceleration in hybrid linear attention models such as Qwen3.5.

Algorithm flow (executed sequentially for each token in each batch):

  1. State decay: S = S * exp(g) * exp(gk)

  2. Memory retrieval: kv_mem = S^T @ k

  3. Delta update: S = S + k^T @ ((v - kv_mem) * beta)

  4. Output: o = S^T @ q

where S is the recurrent state matrix [H, D_k, D_v] storing the key-value associations of linear attention.

Parameters:
  • query (torch.Tensor) – Query tensor of shape [B, H_q, T, D_k], dtype=bfloat16. Must be L2-normalized (L2 norm of each head vector is 1, value range [0, 1]). B=batch_size, H_q=num_query_heads, T=seq_len, D_k=key_dim.

  • key (torch.Tensor) – Key tensor of shape [B, H_q, T, D_k], dtype=bfloat16. Must be L2-normalized (same as query). In GQA/MQA scenarios, multiple query heads share the same set of keys.

  • value (torch.Tensor) – Value tensor of shape [B, H_v, T, D_v], dtype=bfloat16. H_v=num_value_heads (can differ from H_q to support GQA/MQA), D_v=value_dim.

  • beta (torch.Tensor) – Delta update step size of shape [B, H_v, T], dtype=bfloat16. Value range (0, 1). Controls the magnitude of each delta update: a larger beta causes new information to overwrite old memory more aggressively; a smaller beta tends to preserve existing memory.

  • state (torch.Tensor) – Recurrent state matrix of shape [B, H_v, D_k, D_v], dtype=bfloat16. Stores the cumulative key-value associations for linear attention. D_k is the key dimension (rows), D_v is the value dimension (columns). Can be initialized to zeros for the first call.

  • actual_seq_lengths (torch.Tensor) – Actual sequence lengths of shape [B], dtype=int32. Used for variable-length sequence inference. Each element represents the number of valid tokens in the corresponding batch. E.g., [4, 3, 5] means 3 batches with sequence lengths 4, 3, and 5.

  • ssm_state_indices (torch.Tensor) – SSM state indices of shape [B], dtype=int32. Indicates the index position of each batch item in the global state pool, used for state management in multi-batch scenarios. Typically [0, 1, 2, ..., B-1].

  • g (torch.Tensor) – Global decay gate of shape [B, H_v, T], dtype=float32. Must be negative. exp(g) serves as the state decay factor with range (0, 1). The more negative g is, the faster historical information is forgotten. E.g., when g=-1, approximately 37% of the historical state is retained per step.

  • gk (torch.Tensor) – Key-dimension gate of shape [B, H_v, T, D_k], dtype=float32. Must be negative. exp(gk) applies per-dimension decay independently along the key dimension, enabling finer-grained memory control. Unlike the global gate g, gk operates element-wise along the D_k dimension.

  • num_accepted_tokens (torch.Tensor) – Number of accepted tokens of shape [B], dtype=int32. Used in speculative decoding and similar scenarios to mark the number of actually accepted (non-rejected) tokens. For standard inference, this is the same as actual_seq_lengths.

  • scale_value (float, optional) – Attention scale factor, default 1.0. Typically set to 1.0 / sqrt(D_k), consistent with standard attention scaling. The query is multiplied by this scale factor before computation.

Returns:

  • out (torch.Tensor): Attention output of shape [B, H_v, T, D_v], dtype=bfloat16. The linear attention result at each token position.

  • state_out (torch.Tensor): Updated recurrent state of shape [B, H_v, D_k, D_v], dtype=bfloat16. Must be passed as state input in the next recurrent step to form a state-passing chain.

Return type:

tuple[torch.Tensor, torch.Tensor]

Raises:

RuntimeError – If input tensor shapes, dtypes, or devices are invalid, or if the CANN operator execution fails.

Note

  • This operator only supports the decode phase (token-by-token inference), with sequence length T not exceeding 8. For parallel prefill computation, use the chunk-level operator.

  • Supports GQA (Grouped Query Attention) / MQA (Multi-Query Attention) modes, i.e., H_q can be greater than H_v, with multiple query heads sharing the same set of key/value heads.

  • All input tensors must reside on the same NPU device.

  • The CANN operator stores state internally as [B, H_v, D_v, D_k] layout (value dimension first). This function automatically performs the layout conversion.

Example:

import torch
import lite_boost.ops as lite_ops

# Qwen3.5-2B decode configuration
B, H, T, Dk, Dv = 1, 64, 1, 64, 512

# Initialize inputs (must satisfy CANN operator constraints)
query  = torch.randn(B, H, T, Dk, device="npu:0", dtype=torch.bfloat16)
key    = torch.randn(B, H, T, Dk, device="npu:0", dtype=torch.bfloat16)
value  = torch.randn(B, H, T, Dv, device="npu:0", dtype=torch.bfloat16)
beta   = torch.rand(B, H, T, device="npu:0", dtype=torch.bfloat16) * 0.9 + 0.05
state  = torch.zeros(B, H, Dk, Dv, device="npu:0", dtype=torch.bfloat16)
g      = -(torch.rand(B, H, T, device="npu:0") + 0.01)       # negative
gk     = -(torch.rand(B, H, T, Dk, device="npu:0") + 0.01)   # negative

actual_seq_lengths    = torch.tensor([T], dtype=torch.int32, device="npu:0")
ssm_state_indices     = torch.tensor([0], dtype=torch.int32, device="npu:0")
num_accepted_tokens   = torch.tensor([T], dtype=torch.int32, device="npu:0")

# Execute recurrent inference
output, state_out = lite_ops.recurrent_gated_delta_rule(
    query, key, value, beta, state,
    actual_seq_lengths, ssm_state_indices,
    g, gk, num_accepted_tokens,
    scale_value=1.0 / (Dk ** 0.5),
)
# output:    [1, 64, 1, 512]  -- attention output for the current token
# state_out: [1, 64, 64, 512] -- updated recurrent state, passed to next step