Skip to content

Commit 4402dee

Browse files
rkazantsCopilotCopilotmitruska
authored
Specify GatedDeltaNet operation for Stateful mode (#34529)
**Details:** Specify GatedDeltaNet operation for Stateful mode. This operation is needed to optimize performance for Qwen3-next model. **Ticket:** 181474 ### AI Assistance: - *AI assistance used: yes* - *If yes, summarize how AI was used and what human validation was performed (build/tests/manual checks).* It helps to design specification with my guidance --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
1 parent 1697c30 commit 4402dee

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

docs/articles_en/documentation/openvino-ir-format/operation-sets/operation-specs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Operation Specifications
8787
GRN-1 <operation-specs/normalization/grn-1>
8888
GRUCell-3 <operation-specs/sequence/gru-cell-3>
8989
GRUSequence-5 <operation-specs/sequence/gru-sequence-5>
90+
GatedDeltaNet <operation-specs/internal/gated-delta-net>
9091
GatherTree-1 <operation-specs/movement/gather-tree-1>
9192
Gather-1 <operation-specs/movement/gather-1>
9293
Gather-7 <operation-specs/movement/gather-7>
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
.. {#openvino_docs_ops_internal_GatedDeltaNet}
2+
3+
GatedDeltaNet
4+
=============
5+
6+
7+
.. meta::
8+
:description: Learn about GatedDeltaNet - a linear recurrent sequence processing
9+
operation based on the delta rule with a gating mechanism.
10+
11+
**Versioned name**: *GatedDeltaNet*
12+
13+
**Category**: *Sequence processing*
14+
15+
**Short description**: *GatedDeltaNet* represents a linear recurrent sequence model
16+
that combines the delta rule memory update with a gating mechanism.
17+
18+
**Detailed description**: *GatedDeltaNet* implements the recurrence from the paper
19+
`arXiv:2412.06464 <https://arxiv.org/abs/2412.06464>`__. It processes a sequence of
20+
query, key, and value vectors using the delta rule to update a hidden state matrix,
21+
controlled by a per-token forget ``gate`` (applied as ``exp(g)``) and a per-token
22+
write gate ``beta``. Queries are scaled by ``1 / sqrt(key_head_dim)`` before being used
23+
to compute the output. The following PyTorch-equivalent code illustrates the full
24+
computation:
25+
26+
.. code-block:: py
27+
28+
def torch_recurrent_gated_delta_rule(
29+
query, key, value, recurrent_state, gate, beta,
30+
):
31+
batch_size, sequence_length, num_heads, k_head_dim = key.shape
32+
v_head_dim = value.shape[-1]
33+
scale = 1 / (query.shape[-1] ** 0.5)
34+
query = query * scale
35+
36+
output_attn = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
37+
output_recurrent_state = recurrent_state
38+
39+
for i in range(sequence_length):
40+
q_t = query[:, i]
41+
k_t = key[:, i]
42+
v_t = value[:, i]
43+
g_t = gate[:, i].exp().unsqueeze(-1).unsqueeze(-1)
44+
beta_t = beta[:, i].unsqueeze(-1)
45+
46+
output_recurrent_state = output_recurrent_state * g_t
47+
kv_mem = (output_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
48+
delta = (v_t - kv_mem) * beta_t
49+
output_recurrent_state = output_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
50+
output_attn[:, i] = (output_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
51+
52+
return output_attn, output_recurrent_state
53+
54+
55+
**Inputs**
56+
57+
* **1**: ``query`` - 4D tensor of type *T* and shape ``[batch_size, seq_len, num_heads, key_head_dim]``,
58+
the query vectors for each token and head. Scaled internally by ``1 / sqrt(key_head_dim)``
59+
before computing the output. **Required.**
60+
61+
* **2**: ``key`` - 4D tensor of type *T* and shape ``[batch_size, seq_len, num_heads, key_head_dim]``,
62+
the key vectors for each token and head. **Required.**
63+
64+
* **3**: ``value`` - 4D tensor of type *T* and shape ``[batch_size, seq_len, num_heads, value_head_dim]``,
65+
the value vectors for each token and head. **Required.**
66+
67+
* **4**: ``recurrent_state`` - 4D tensor of type *T* and shape
68+
``[batch_size, num_heads, key_head_dim, value_head_dim]``, the recurrent (initially all-zeros) hidden state matrix. **Required.**
69+
70+
* **5**: ``gate`` - 3D tensor of type *T* and shape ``[batch_size, seq_len, num_heads]``,
71+
the forget gate in log-space. Applied as ``exp(g)`` at each time step to decay the
72+
hidden state before the delta update. **Required.**
73+
74+
* **6**: ``beta`` - 3D tensor of type *T* and shape ``[batch_size, seq_len, num_heads]``,
75+
the write gate controlling how much of the delta correction is applied to the hidden
76+
state. **Required.**
77+
78+
79+
**Outputs**
80+
81+
* **1**: ``output_attn`` - 4D tensor of type *T* and shape
82+
``[batch_size, seq_len, num_heads, value_head_dim]``, the output vectors at each time step
83+
produced by applying the state matrix to the (scaled) query.
84+
85+
* **2**: ``output_recurrent_state`` - 4D tensor of type *T* and shape
86+
``[batch_size, num_heads, key_head_dim, value_head_dim]``, the hidden state matrix
87+
after processing the last token in the sequence.
88+
89+
90+
**Types**
91+
92+
* *T*: any supported floating-point type.
93+
94+
95+
**Example**
96+
97+
.. code-block:: xml
98+
:force:
99+
100+
<layer ... type="GatedDeltaNet" ...>
101+
<input>
102+
<port id="0"> <!-- `query` -->
103+
<dim>1</dim>
104+
<dim>16</dim>
105+
<dim>8</dim>
106+
<dim>64</dim>
107+
</port>
108+
<port id="1"> <!-- `key` -->
109+
<dim>1</dim>
110+
<dim>16</dim>
111+
<dim>8</dim>
112+
<dim>64</dim>
113+
</port>
114+
<port id="2"> <!-- `value` -->
115+
<dim>1</dim>
116+
<dim>16</dim>
117+
<dim>8</dim>
118+
<dim>128</dim>
119+
</port>
120+
<port id="3"> <!-- `recurrent_state` -->
121+
<dim>1</dim>
122+
<dim>8</dim>
123+
<dim>64</dim>
124+
<dim>128</dim>
125+
</port>
126+
<port id="4"> <!-- `gate` -->
127+
<dim>1</dim>
128+
<dim>16</dim>
129+
<dim>8</dim>
130+
</port>
131+
<port id="5"> <!-- `beta` -->
132+
<dim>1</dim>
133+
<dim>16</dim>
134+
<dim>8</dim>
135+
</port>
136+
</input>
137+
<output>
138+
<port id="6"> <!-- `output_attn` -->
139+
<dim>1</dim>
140+
<dim>16</dim>
141+
<dim>8</dim>
142+
<dim>128</dim>
143+
</port>
144+
<port id="7"> <!-- `output_recurrent_state` -->
145+
<dim>1</dim>
146+
<dim>8</dim>
147+
<dim>64</dim>
148+
<dim>128</dim>
149+
</port>
150+
</output>
151+
</layer>

0 commit comments

Comments
 (0)