1+ from math import sqrt
2+ import numpy as np
3+ import torch .nn as nn
4+ import torch .nn .functional as F
5+ import torch
6+ from torch import nn , Tensor
7+ from einops import rearrange
8+ from einops .layers .torch import Rearrange
9+ from utils .masking import TriangularCausalMask
10+
11+ class Predict (nn .Module ):
12+ def __init__ (self , individual , c_out , seq_len , pred_len , dropout ):
13+ super (Predict , self ).__init__ ()
14+ self .individual = individual
15+ self .c_out = c_out
16+
17+ if self .individual :
18+ self .seq2pred = nn .ModuleList ()
19+ self .dropout = nn .ModuleList ()
20+ for i in range (self .c_out ):
21+ self .seq2pred .append (nn .Linear (seq_len , pred_len ))
22+ self .dropout .append (nn .Dropout (dropout ))
23+ else :
24+ self .seq2pred = nn .Linear (seq_len , pred_len )
25+ self .dropout = nn .Dropout (dropout )
26+
27+ #(B, c_out , seq)
28+ def forward (self , x ):
29+ if self .individual :
30+ out = []
31+ for i in range (self .c_out ):
32+ per_out = self .seq2pred [i ](x [:,i ,:])
33+ per_out = self .dropout [i ](per_out )
34+ out .append (per_out )
35+ out = torch .stack (out ,dim = 1 )
36+ else :
37+ out = self .seq2pred (x )
38+ out = self .dropout (out )
39+
40+ return out
41+
42+
43+ class Attention_Block (nn .Module ):
44+ def __init__ (self , d_model , d_ff = None , n_heads = 8 , dropout = 0.1 , activation = "relu" ):
45+ super (Attention_Block , self ).__init__ ()
46+ d_ff = d_ff or 4 * d_model
47+ self .attention = self_attention (FullAttention , d_model , n_heads = n_heads )
48+ self .conv1 = nn .Conv1d (in_channels = d_model , out_channels = d_ff , kernel_size = 1 )
49+ self .conv2 = nn .Conv1d (in_channels = d_ff , out_channels = d_model , kernel_size = 1 )
50+ self .norm1 = nn .LayerNorm (d_model )
51+ self .norm2 = nn .LayerNorm (d_model )
52+ self .dropout = nn .Dropout (dropout )
53+ self .activation = F .relu if activation == "relu" else F .gelu
54+
55+ def forward (self , x , attn_mask = None ):
56+ new_x , attn = self .attention (
57+ x , x , x ,
58+ attn_mask = attn_mask
59+ )
60+ x = x + self .dropout (new_x )
61+
62+ y = x = self .norm1 (x )
63+ y = self .dropout (self .activation (self .conv1 (y .transpose (- 1 , 1 ))))
64+ y = self .dropout (self .conv2 (y ).transpose (- 1 , 1 ))
65+
66+ return self .norm2 (x + y )
67+
68+
69+ class self_attention (nn .Module ):
70+ def __init__ (self , attention , d_model ,n_heads ):
71+ super (self_attention , self ).__init__ ()
72+ d_keys = d_model // n_heads
73+ d_values = d_model // n_heads
74+
75+ self .inner_attention = attention ( attention_dropout = 0.1 )
76+ self .query_projection = nn .Linear (d_model , d_keys * n_heads )
77+ self .key_projection = nn .Linear (d_model , d_keys * n_heads )
78+ self .value_projection = nn .Linear (d_model , d_values * n_heads )
79+ self .out_projection = nn .Linear (d_values * n_heads , d_model )
80+ self .n_heads = n_heads
81+
82+
83+ def forward (self , queries ,keys ,values , attn_mask = None ):
84+ B , L , _ = queries .shape
85+ _ , S , _ = keys .shape
86+ H = self .n_heads
87+ queries = self .query_projection (queries ).view (B , L , H , - 1 )
88+ keys = self .key_projection (keys ).view (B , S , H , - 1 )
89+ values = self .value_projection (values ).view (B , S , H , - 1 )
90+
91+ out , attn = self .inner_attention (
92+ queries ,
93+ keys ,
94+ values ,
95+ attn_mask
96+ )
97+ out = out .view (B , L , - 1 )
98+ out = self .out_projection (out )
99+ return out , attn
100+
101+
102+ class FullAttention (nn .Module ):
103+ def __init__ (self , mask_flag = True , factor = 5 , scale = None , attention_dropout = 0.1 , output_attention = False ):
104+ super (FullAttention , self ).__init__ ()
105+ self .scale = scale
106+ self .mask_flag = mask_flag
107+ self .output_attention = output_attention
108+ self .dropout = nn .Dropout (attention_dropout )
109+
110+ def forward (self , queries , keys , values , attn_mask ):
111+ B , L , H , E = queries .shape
112+ _ , S , _ , D = values .shape
113+ scale = self .scale or 1. / sqrt (E )
114+ scores = torch .einsum ("blhe,bshe->bhls" , queries , keys )
115+ if self .mask_flag :
116+ if attn_mask is None :
117+ attn_mask = TriangularCausalMask (B , L , device = queries .device )
118+ scores .masked_fill_ (attn_mask .mask , - np .inf )
119+ A = self .dropout (torch .softmax (scale * scores , dim = - 1 ))
120+ V = torch .einsum ("bhls,bshd->blhd" , A , values )
121+ # return V.contiguous()
122+ if self .output_attention :
123+ return (V .contiguous (), A )
124+ else :
125+ return (V .contiguous (), None )
126+
127+
128+ class GraphBlock (nn .Module ):
129+ def __init__ (self , c_out , d_model , conv_channel , skip_channel ,
130+ gcn_depth , dropout , propalpha ,seq_len , node_dim ):
131+ super (GraphBlock , self ).__init__ ()
132+
133+ self .nodevec1 = nn .Parameter (torch .randn (c_out , node_dim ), requires_grad = True )
134+ self .nodevec2 = nn .Parameter (torch .randn (node_dim , c_out ), requires_grad = True )
135+ self .start_conv = nn .Conv2d (1 , conv_channel , (d_model - c_out + 1 , 1 ))
136+ self .gconv1 = mixprop (conv_channel , skip_channel , gcn_depth , dropout , propalpha )
137+ self .gelu = nn .GELU ()
138+ self .end_conv = nn .Conv2d (skip_channel , seq_len , (1 , seq_len ))
139+ self .linear = nn .Linear (c_out , d_model )
140+ self .norm = nn .LayerNorm (d_model )
141+ # x in (B, T, d_model)
142+ # Here we use a mlp to fit a complex mapping f (x)
143+ def forward (self , x ):
144+ adp = F .softmax (F .relu (torch .mm (self .nodevec1 , self .nodevec2 )), dim = 1 )
145+ out = x .unsqueeze (1 ).transpose (2 , 3 )
146+ out = self .start_conv (out )
147+ out = self .gelu (self .gconv1 (out , adp ))
148+ out = self .end_conv (out ).squeeze (- 1 )
149+ out = self .linear (out )
150+
151+ return self .norm (x + out )
152+
153+
154+ class nconv (nn .Module ):
155+ def __init__ (self ):
156+ super (nconv ,self ).__init__ ()
157+
158+ def forward (self ,x , A ):
159+ x = torch .einsum ('ncwl,vw->ncvl' ,(x ,A ))
160+ # x = torch.einsum('ncwl,wv->nclv',(x,A)
161+ return x .contiguous ()
162+
163+
164+ class linear (nn .Module ):
165+ def __init__ (self ,c_in ,c_out ,bias = True ):
166+ super (linear ,self ).__init__ ()
167+ self .mlp = torch .nn .Conv2d (c_in , c_out , kernel_size = (1 , 1 ), padding = (0 ,0 ), stride = (1 ,1 ), bias = bias )
168+
169+ def forward (self ,x ):
170+ return self .mlp (x )
171+
172+
173+ class mixprop (nn .Module ):
174+ def __init__ (self ,c_in ,c_out ,gdep ,dropout ,alpha ):
175+ super (mixprop , self ).__init__ ()
176+ self .nconv = nconv ()
177+ self .mlp = linear ((gdep + 1 )* c_in ,c_out )
178+ self .gdep = gdep
179+ self .dropout = dropout
180+ self .alpha = alpha
181+
182+ def forward (self , x , adj ):
183+ adj = adj + torch .eye (adj .size (0 )).to (x .device )
184+ d = adj .sum (1 )
185+ h = x
186+ out = [h ]
187+ a = adj / d .view (- 1 , 1 )
188+ for i in range (self .gdep ):
189+ h = self .alpha * x + (1 - self .alpha )* self .nconv (h ,a )
190+ out .append (h )
191+ ho = torch .cat (out ,dim = 1 )
192+ ho = self .mlp (ho )
193+ return ho
194+
195+
196+ class simpleVIT (nn .Module ):
197+ def __init__ (self , in_channels , emb_size , patch_size = 2 , depth = 1 , num_heads = 4 , dropout = 0.1 ,init_weight = True ):
198+ super (simpleVIT , self ).__init__ ()
199+ self .emb_size = emb_size
200+ self .depth = depth
201+ self .to_patch = nn .Sequential (
202+ nn .Conv2d (in_channels , emb_size , 2 * patch_size + 1 , padding = patch_size ),
203+ Rearrange ('b e (h) (w) -> b (h w) e' ),
204+ )
205+ self .layers = nn .ModuleList ([])
206+ for _ in range (self .depth ):
207+ self .layers .append (nn .ModuleList ([
208+ nn .LayerNorm (emb_size ),
209+ MultiHeadAttention (emb_size , num_heads , dropout ),
210+ FeedForward (emb_size , emb_size )
211+ ]))
212+
213+ if init_weight :
214+ self ._initialize_weights ()
215+
216+ def _initialize_weights (self ):
217+ for m in self .modules ():
218+ if isinstance (m , nn .Conv2d ):
219+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
220+ if m .bias is not None :
221+ nn .init .constant_ (m .bias , 0 )
222+
223+ def forward (self ,x ):
224+ B , N ,_ ,P = x .shape
225+ x = self .to_patch (x )
226+ # x = x.permute(0, 2, 3, 1).reshape(B,-1, N)
227+ for norm ,attn , ff in self .layers :
228+ x = attn (norm (x )) + x
229+ x = ff (x ) + x
230+
231+ x = x .transpose (1 ,2 ).reshape (B , self .emb_size ,- 1 , P )
232+ return x
233+
234+ class MultiHeadAttention (nn .Module ):
235+ def __init__ (self , emb_size , num_heads , dropout ):
236+ super ().__init__ ()
237+ self .emb_size = emb_size
238+ self .num_heads = num_heads
239+ self .keys = nn .Linear (emb_size , emb_size )
240+ self .queries = nn .Linear (emb_size , emb_size )
241+ self .values = nn .Linear (emb_size , emb_size )
242+ self .att_drop = nn .Dropout (dropout )
243+ self .projection = nn .Linear (emb_size , emb_size )
244+
245+ def forward (self , x : Tensor , mask : Tensor = None ) -> Tensor :
246+ queries = rearrange (self .queries (x ), "b n (h d) -> b h n d" , h = self .num_heads )
247+ keys = rearrange (self .keys (x ), "b n (h d) -> b h n d" , h = self .num_heads )
248+ values = rearrange (self .values (x ), "b n (h d) -> b h n d" , h = self .num_heads )
249+ energy = torch .einsum ('bhqd, bhkd -> bhqk' , queries , keys )
250+ if mask is not None :
251+ fill_value = torch .finfo (torch .float32 ).min
252+ energy .mask_fill (~ mask , fill_value )
253+
254+ scaling = self .emb_size ** (1 / 2 )
255+ att = F .softmax (energy , dim = - 1 ) / scaling
256+ att = self .att_drop (att )
257+ # sum up over the third axis
258+ out = torch .einsum ('bhal, bhlv -> bhav ' , att , values )
259+ out = rearrange (out , "b h n d -> b n (h d)" )
260+ out = self .projection (out )
261+ return out
262+
263+ class FeedForward (nn .Module ):
264+ def __init__ (self , dim , hidden_dim ):
265+ super ().__init__ ()
266+ self .net = nn .Sequential (
267+ nn .LayerNorm (dim ),
268+ nn .Linear (dim , hidden_dim ),
269+ nn .GELU (),
270+ nn .Linear (hidden_dim , dim ),
271+ )
272+ def forward (self , x ):
273+ return self .net (x )
0 commit comments