1
+ # SPDX-License-Identifier: GPL-2.0-or-later
2
+ # This file is part of Scapy
3
+ # See https://scapy.net/ for more information
4
+ # Copyright (C) 2024 Lucas Drufva <[email protected] >
5
+
6
+ # scapy.contrib.description = WebSocket
7
+ # scapy.contrib.status = loads
8
+
9
+ # Based on rfc6455
10
+
11
+ import struct
12
+ import base64
13
+ import zlib
14
+ from hashlib import sha1
15
+ from scapy .fields import (BitFieldLenField , Field , BitField , BitEnumField , ConditionalField , XIntField , FieldLenField , XNBytesField )
16
+ from scapy .layers .http import HTTPRequest , HTTPResponse
17
+ from scapy .layers .inet import TCP
18
+ from scapy .packet import Packet
19
+ from scapy .error import Scapy_Exception
20
+
21
+
22
+ class PayloadLenField (BitFieldLenField ):
23
+
24
+ def __init__ (self , name , default , length_of , size = 0 , tot_size = 0 , end_tot_size = 0 ):
25
+ # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
26
+ super ().__init__ (name , default , size , length_of = length_of , tot_size = tot_size , end_tot_size = end_tot_size )
27
+
28
+ def getfield (self , pkt , s ):
29
+ s , _ = s
30
+ # Get the 7-bit field (first byte)
31
+ length_byte = s [0 ] & 0x7F
32
+ s = s [1 :]
33
+
34
+ if length_byte <= 125 :
35
+ # 7-bit length
36
+ return s , length_byte
37
+ elif length_byte == 126 :
38
+ # 16-bit length
39
+ length = struct .unpack ("!H" , s [:2 ])[0 ] # Read 2 bytes
40
+ s = s [2 :]
41
+ return s , length
42
+ elif length_byte == 127 :
43
+ # 64-bit length
44
+ length = struct .unpack ("!Q" , s [:8 ])[0 ] # Read 8 bytes
45
+ s = s [8 :]
46
+ return s , length
47
+
48
+ def addfield (self , pkt , s , val ):
49
+ p_field , p_val = pkt .getfield_and_val (self .length_of )
50
+ val = p_field .i2len (pkt , p_val )
51
+
52
+ if val <= 125 :
53
+ self .size = 7
54
+ return super ().addfield (pkt , s , val )
55
+ elif val <= 0xFFFF :
56
+ self .size = 7 + 16
57
+ s , _ , masked = s
58
+ return s + struct .pack ("!BH" , 126 | masked , val )
59
+ elif val <= 0xFFFFFFFFFFFFFFFF :
60
+ self .size = 7 + 64
61
+ s , _ , masked = s
62
+ return s + struct .pack ("!BQ" , 127 | masked , val )
63
+ else :
64
+ raise Scapy_Exception ("%s: Payload length too large" %
65
+ self .__class__ .__name__ )
66
+
67
+
68
+
69
+ class PayloadField (Field ):
70
+ """
71
+ Field for handling raw byte payloads with dynamic size.
72
+ The length of the payload is described by a preceding PayloadLenField.
73
+ """
74
+ __slots__ = ["lengthFrom" ]
75
+
76
+ def __init__ (self , name , lengthFrom ):
77
+ """
78
+ :param name: Field name
79
+ :param lengthFrom: Field name that provides the length of the payload
80
+ """
81
+ super (PayloadField , self ).__init__ (name , None )
82
+ self .lengthFrom = lengthFrom
83
+
84
+ def getfield (self , pkt , s ):
85
+ # Fetch the length from the field that specifies the length
86
+ length = getattr (pkt , self .lengthFrom )
87
+ payloadData = s [:length ]
88
+
89
+ if pkt .mask :
90
+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
91
+ data_int = int .from_bytes (payloadData , 'big' )
92
+ mask_repeated = key * (len (payloadData ) // 4 ) + key [: len (payloadData ) % 4 ]
93
+ mask_int = int .from_bytes (mask_repeated , 'big' )
94
+ payloadData = (data_int ^ mask_int ).to_bytes (len (payloadData ), 'big' )
95
+
96
+ if ("permessage-deflate" in pkt .extensions ):
97
+ try :
98
+ payloadData = pkt .decoder [0 ](payloadData + b"\x00 \x00 \xff \xff " )
99
+ except Exception :
100
+ # Failed to decompress payload
101
+ pass
102
+
103
+ return s [length :], payloadData
104
+
105
+ def addfield (self , pkt , s , val ):
106
+ # Ensure val is bytes and append the data to the packet
107
+ return s + bytes (val )
108
+
109
+ def i2len (self , pkt , val ):
110
+ # Length of the payload in bytes
111
+ return len (val )
112
+
113
+ class WebSocket (Packet ):
114
+ __slots__ = ["extensions" , "decoder" ]
115
+
116
+ name = "WebSocket"
117
+ fields_desc = [
118
+ BitField ("fin" , 0 , 1 ),
119
+ BitField ("rsv" , 0 , 3 ),
120
+ BitEnumField ("opcode" , 0 , 4 ,
121
+ {
122
+ 0x0 : "none" ,
123
+ 0x1 : "text" ,
124
+ 0x2 : "binary" ,
125
+ 0x8 : "close" ,
126
+ 0x9 : "ping" ,
127
+ 0xA : "pong" ,
128
+ }),
129
+ BitField ("mask" , 0 , 1 ),
130
+ PayloadLenField ("payloadLen" , 0 , length_of = "wsPayload" , size = 1 ),
131
+ ConditionalField (XNBytesField ("maskingKey" , 0 , sz = 4 ), lambda pkt : pkt .mask == 1 ),
132
+ PayloadField ("wsPayload" , lengthFrom = "payloadLen" )
133
+ ]
134
+
135
+ def __init__ (self , pkt = None , extensions = [], decoder = None , * args , ** fields ):
136
+ self .extensions = extensions
137
+ self .decoder = decoder
138
+ super ().__init__ (_pkt = pkt , * args , ** fields )
139
+
140
+ def extract_padding (self , s ):
141
+ return '' , s
142
+
143
+ @classmethod
144
+ def tcp_reassemble (cls , data , metadata , session ):
145
+ # data = the reassembled data from the same request/flow
146
+ # metadata = empty dictionary, that can be used to store data
147
+ # during TCP reassembly
148
+ # session = a dictionary proper to the bidirectional TCP session,
149
+ # that can be used to store anything
150
+ # [...]
151
+ # If the packet is available, return it. Otherwise don't.
152
+ # Whenever you return a packet, the buffer will be discarded.
153
+
154
+
155
+ HANDSHAKE_STATE_CLIENT_OPEN = 0
156
+ HANDSHAKE_STATE_SERVER_OPEN = 1
157
+ HANDSHAKE_STATE_OPEN = 2
158
+
159
+ if "handshake-state" not in session :
160
+ session ["handshake-state" ] = HANDSHAKE_STATE_CLIENT_OPEN
161
+
162
+ if "extensions" not in session :
163
+ session ["extensions" ] = {}
164
+
165
+
166
+ if session ["handshake-state" ] == HANDSHAKE_STATE_CLIENT_OPEN :
167
+ ht = HTTPRequest (data )
168
+
169
+ if ht .Method != b"GET" :
170
+ return None
171
+
172
+ if not ht .Upgrade or ht .Upgrade .lower () != b"websocket" :
173
+ return None
174
+
175
+ if b"Sec-WebSocket-Key" not in ht .Unknown_Headers :
176
+ return None
177
+
178
+
179
+ session ["handshake-key" ] = ht .Unknown_Headers [b"Sec-WebSocket-Key" ]
180
+
181
+ if "original" in metadata :
182
+ session ["server-port" ] = metadata ["original" ][TCP ].dport
183
+ else :
184
+ print ("No original packet" )
185
+
186
+ session ["handshake-state" ] = HANDSHAKE_STATE_SERVER_OPEN
187
+
188
+ return ht
189
+
190
+ elif session ["handshake-state" ] == HANDSHAKE_STATE_SERVER_OPEN :
191
+ ht = HTTPResponse (data )
192
+
193
+ if not ht .Upgrade .lower () == b"websocket" :
194
+ return None
195
+
196
+ # Verify key-accept handshake:
197
+ correct_accept = base64 .b64encode (sha1 (session ["handshake-key" ] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" .encode ()).digest ())
198
+ if ht .Unknown_Headers [b"Sec-WebSocket-Accept" ] != correct_accept :
199
+ #TODO handle or Logg wrong accept key
200
+ pass
201
+
202
+ if b"Sec-WebSocket-Extensions" in ht .Unknown_Headers :
203
+ session ["extensions" ] = {}
204
+ for extension in ht .Unknown_Headers [b"Sec-WebSocket-Extensions" ].decode ().strip ().split (";" ):
205
+ key_value_pair = extension .split ("=" , 1 ) + [None ]
206
+ session ["extensions" ][key_value_pair [0 ].strip ()] = key_value_pair [1 ]
207
+
208
+ if "permessage-deflate" in session ["extensions" ]:
209
+ def create_decompressor (window_bits ):
210
+ decoder = zlib .decompressobj (wbits = - window_bits )
211
+ def decomp (data ):
212
+ nonlocal decoder
213
+ return decoder .decompress (data , 0 )
214
+
215
+ def reset ():
216
+ nonlocal decoder
217
+ nonlocal window_bits
218
+ decoder = zlib .decompressobj (wbits = - window_bits )
219
+
220
+ return (decomp , reset )
221
+
222
+ # Default values
223
+ client_wb = 12
224
+ server_wb = 15
225
+
226
+ # Check for new values in extensions header
227
+ if "client_max_window_bits" in session ["extensions" ]:
228
+ client_wb = int (session ["extensions" ]["client_max_window_bits" ])
229
+
230
+ if "server_max_window_bits" in session ["extensions" ]:
231
+ server_wb = int (session ["extensions" ]["server_max_window_bits" ])
232
+
233
+
234
+ session ["server-decoder" ] = create_decompressor (client_wb )
235
+ session ["client-decoder" ] = create_decompressor (server_wb )
236
+
237
+
238
+ session ["handshake-state" ] = HANDSHAKE_STATE_OPEN
239
+
240
+ return ht
241
+
242
+
243
+ # Handshake is done:
244
+ if "original" not in metadata :
245
+ return
246
+
247
+ if "permessage-deflate" in session ["extensions" ]:
248
+ is_server = True if metadata ["original" ][TCP ].sport == session ["server-port" ] else False
249
+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ], decoder = session ["server-decoder" ] if is_server else session ["client-decoder" ])
250
+ return ws
251
+ else :
252
+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ])
253
+ return ws
0 commit comments