Skip to content

Commit b877f8a

Browse files
committed
add MSGNet and TimeFilter
1 parent d7632da commit b877f8a

File tree

6 files changed

+998
-1
lines changed

6 files changed

+998
-1
lines changed

exp/exp_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
44
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
55
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, \
6-
WPMixer, MultiPatchFormer, KANAD
6+
WPMixer, MultiPatchFormer, KANAD, MSGNet, TimeFilter
77

88

99
class Exp_Basic(object):
@@ -40,6 +40,8 @@ def __init__(self, args):
4040
'WPMixer': WPMixer,
4141
'MultiPatchFormer': MultiPatchFormer,
4242
'KANAD': KANAD,
43+
'MSGNet': MSGNet,
44+
'TimeFilter': TimeFilter
4345
}
4446
if args.model == 'Mamba':
4547
print('Please make sure you have successfully installed mamba_ssm')

layers/MSGBlock.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from math import sqrt
2+
import numpy as np
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torch
6+
from torch import nn, Tensor
7+
from einops import rearrange
8+
from einops.layers.torch import Rearrange
9+
from utils.masking import TriangularCausalMask
10+
11+
class Predict(nn.Module):
12+
def __init__(self, individual, c_out, seq_len, pred_len, dropout):
13+
super(Predict, self).__init__()
14+
self.individual = individual
15+
self.c_out = c_out
16+
17+
if self.individual:
18+
self.seq2pred = nn.ModuleList()
19+
self.dropout = nn.ModuleList()
20+
for i in range(self.c_out):
21+
self.seq2pred.append(nn.Linear(seq_len , pred_len))
22+
self.dropout.append(nn.Dropout(dropout))
23+
else:
24+
self.seq2pred = nn.Linear(seq_len , pred_len)
25+
self.dropout = nn.Dropout(dropout)
26+
27+
#(B, c_out , seq)
28+
def forward(self, x):
29+
if self.individual:
30+
out = []
31+
for i in range(self.c_out):
32+
per_out = self.seq2pred[i](x[:,i,:])
33+
per_out = self.dropout[i](per_out)
34+
out.append(per_out)
35+
out = torch.stack(out,dim=1)
36+
else:
37+
out = self.seq2pred(x)
38+
out = self.dropout(out)
39+
40+
return out
41+
42+
43+
class Attention_Block(nn.Module):
44+
def __init__(self, d_model, d_ff=None, n_heads=8, dropout=0.1, activation="relu"):
45+
super(Attention_Block, self).__init__()
46+
d_ff = d_ff or 4 * d_model
47+
self.attention = self_attention(FullAttention, d_model, n_heads=n_heads)
48+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
49+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
50+
self.norm1 = nn.LayerNorm(d_model)
51+
self.norm2 = nn.LayerNorm(d_model)
52+
self.dropout = nn.Dropout(dropout)
53+
self.activation = F.relu if activation == "relu" else F.gelu
54+
55+
def forward(self, x, attn_mask=None):
56+
new_x, attn = self.attention(
57+
x, x, x,
58+
attn_mask=attn_mask
59+
)
60+
x = x + self.dropout(new_x)
61+
62+
y = x = self.norm1(x)
63+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
64+
y = self.dropout(self.conv2(y).transpose(-1, 1))
65+
66+
return self.norm2(x + y)
67+
68+
69+
class self_attention(nn.Module):
70+
def __init__(self, attention, d_model ,n_heads):
71+
super(self_attention, self).__init__()
72+
d_keys = d_model // n_heads
73+
d_values = d_model // n_heads
74+
75+
self.inner_attention = attention( attention_dropout = 0.1)
76+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
77+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
78+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
79+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
80+
self.n_heads = n_heads
81+
82+
83+
def forward(self, queries ,keys ,values, attn_mask= None):
84+
B, L, _ = queries.shape
85+
_, S, _ = keys.shape
86+
H = self.n_heads
87+
queries = self.query_projection(queries).view(B, L, H, -1)
88+
keys = self.key_projection(keys).view(B, S, H, -1)
89+
values = self.value_projection(values).view(B, S, H, -1)
90+
91+
out, attn = self.inner_attention(
92+
queries,
93+
keys,
94+
values,
95+
attn_mask
96+
)
97+
out = out.view(B, L, -1)
98+
out = self.out_projection(out)
99+
return out , attn
100+
101+
102+
class FullAttention(nn.Module):
103+
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
104+
super(FullAttention, self).__init__()
105+
self.scale = scale
106+
self.mask_flag = mask_flag
107+
self.output_attention = output_attention
108+
self.dropout = nn.Dropout(attention_dropout)
109+
110+
def forward(self, queries, keys, values, attn_mask):
111+
B, L, H, E = queries.shape
112+
_, S, _, D = values.shape
113+
scale = self.scale or 1. / sqrt(E)
114+
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
115+
if self.mask_flag:
116+
if attn_mask is None:
117+
attn_mask = TriangularCausalMask(B, L, device=queries.device)
118+
scores.masked_fill_(attn_mask.mask, -np.inf)
119+
A = self.dropout(torch.softmax(scale * scores, dim=-1))
120+
V = torch.einsum("bhls,bshd->blhd", A, values)
121+
# return V.contiguous()
122+
if self.output_attention:
123+
return (V.contiguous(), A)
124+
else:
125+
return (V.contiguous(), None)
126+
127+
128+
class GraphBlock(nn.Module):
129+
def __init__(self, c_out , d_model , conv_channel, skip_channel,
130+
gcn_depth , dropout, propalpha ,seq_len , node_dim):
131+
super(GraphBlock, self).__init__()
132+
133+
self.nodevec1 = nn.Parameter(torch.randn(c_out, node_dim), requires_grad=True)
134+
self.nodevec2 = nn.Parameter(torch.randn(node_dim, c_out), requires_grad=True)
135+
self.start_conv = nn.Conv2d(1, conv_channel, (d_model - c_out + 1, 1))
136+
self.gconv1 = mixprop(conv_channel, skip_channel, gcn_depth, dropout, propalpha)
137+
self.gelu = nn.GELU()
138+
self.end_conv = nn.Conv2d(skip_channel, seq_len , (1, seq_len ))
139+
self.linear = nn.Linear(c_out, d_model)
140+
self.norm = nn.LayerNorm(d_model)
141+
# x in (B, T, d_model)
142+
# Here we use a mlp to fit a complex mapping f (x)
143+
def forward(self, x):
144+
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
145+
out = x.unsqueeze(1).transpose(2, 3)
146+
out = self.start_conv(out)
147+
out = self.gelu(self.gconv1(out , adp))
148+
out = self.end_conv(out).squeeze(-1)
149+
out = self.linear(out)
150+
151+
return self.norm(x + out)
152+
153+
154+
class nconv(nn.Module):
155+
def __init__(self):
156+
super(nconv,self).__init__()
157+
158+
def forward(self,x, A):
159+
x = torch.einsum('ncwl,vw->ncvl',(x,A))
160+
# x = torch.einsum('ncwl,wv->nclv',(x,A)
161+
return x.contiguous()
162+
163+
164+
class linear(nn.Module):
165+
def __init__(self,c_in,c_out,bias=True):
166+
super(linear,self).__init__()
167+
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)
168+
169+
def forward(self,x):
170+
return self.mlp(x)
171+
172+
173+
class mixprop(nn.Module):
174+
def __init__(self,c_in,c_out,gdep,dropout,alpha):
175+
super(mixprop, self).__init__()
176+
self.nconv = nconv()
177+
self.mlp = linear((gdep+1)*c_in,c_out)
178+
self.gdep = gdep
179+
self.dropout = dropout
180+
self.alpha = alpha
181+
182+
def forward(self, x, adj):
183+
adj = adj + torch.eye(adj.size(0)).to(x.device)
184+
d = adj.sum(1)
185+
h = x
186+
out = [h]
187+
a = adj / d.view(-1, 1)
188+
for i in range(self.gdep):
189+
h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
190+
out.append(h)
191+
ho = torch.cat(out,dim=1)
192+
ho = self.mlp(ho)
193+
return ho
194+
195+
196+
class simpleVIT(nn.Module):
197+
def __init__(self, in_channels, emb_size, patch_size=2, depth=1, num_heads=4, dropout=0.1,init_weight =True):
198+
super(simpleVIT, self).__init__()
199+
self.emb_size = emb_size
200+
self.depth = depth
201+
self.to_patch = nn.Sequential(
202+
nn.Conv2d(in_channels, emb_size, 2 * patch_size + 1, padding= patch_size),
203+
Rearrange('b e (h) (w) -> b (h w) e'),
204+
)
205+
self.layers = nn.ModuleList([])
206+
for _ in range(self.depth):
207+
self.layers.append(nn.ModuleList([
208+
nn.LayerNorm(emb_size),
209+
MultiHeadAttention(emb_size, num_heads, dropout),
210+
FeedForward(emb_size, emb_size)
211+
]))
212+
213+
if init_weight:
214+
self._initialize_weights()
215+
216+
def _initialize_weights(self):
217+
for m in self.modules():
218+
if isinstance(m, nn.Conv2d):
219+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
220+
if m.bias is not None:
221+
nn.init.constant_(m.bias, 0)
222+
223+
def forward(self,x):
224+
B , N ,_ ,P = x.shape
225+
x = self.to_patch(x)
226+
# x = x.permute(0, 2, 3, 1).reshape(B,-1, N)
227+
for norm ,attn, ff in self.layers:
228+
x = attn(norm(x)) + x
229+
x = ff(x) + x
230+
231+
x = x.transpose(1,2).reshape(B, self.emb_size ,-1, P)
232+
return x
233+
234+
class MultiHeadAttention(nn.Module):
235+
def __init__(self, emb_size, num_heads, dropout):
236+
super().__init__()
237+
self.emb_size = emb_size
238+
self.num_heads = num_heads
239+
self.keys = nn.Linear(emb_size, emb_size)
240+
self.queries = nn.Linear(emb_size, emb_size)
241+
self.values = nn.Linear(emb_size, emb_size)
242+
self.att_drop = nn.Dropout(dropout)
243+
self.projection = nn.Linear(emb_size, emb_size)
244+
245+
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
246+
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
247+
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
248+
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
249+
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
250+
if mask is not None:
251+
fill_value = torch.finfo(torch.float32).min
252+
energy.mask_fill(~mask, fill_value)
253+
254+
scaling = self.emb_size ** (1 / 2)
255+
att = F.softmax(energy, dim=-1) / scaling
256+
att = self.att_drop(att)
257+
# sum up over the third axis
258+
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
259+
out = rearrange(out, "b h n d -> b n (h d)")
260+
out = self.projection(out)
261+
return out
262+
263+
class FeedForward(nn.Module):
264+
def __init__(self, dim, hidden_dim):
265+
super().__init__()
266+
self.net = nn.Sequential(
267+
nn.LayerNorm(dim),
268+
nn.Linear(dim, hidden_dim),
269+
nn.GELU(),
270+
nn.Linear(hidden_dim, dim),
271+
)
272+
def forward(self, x):
273+
return self.net(x)

0 commit comments

Comments
 (0)