Skip to content

Commit ab975bf

Browse files
committed
Add support for websockets (#4578)
1 parent a9eed2d commit ab975bf

File tree

2 files changed

+261
-8
lines changed

2 files changed

+261
-8
lines changed

scapy/contrib/websocket.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# SPDX-License-Identifier: GPL-2.0-or-later
2+
# This file is part of Scapy
3+
# See https://scapy.net/ for more information
4+
# Copyright (C) 2024 Lucas Drufva <[email protected]>
5+
6+
# scapy.contrib.description = WebSocket
7+
# scapy.contrib.status = loads
8+
9+
# Based on rfc6455
10+
11+
import struct
12+
import base64
13+
import zlib
14+
from hashlib import sha1
15+
from scapy.fields import (BitFieldLenField, Field, BitField, BitEnumField, ConditionalField, XIntField, FieldLenField, XNBytesField)
16+
from scapy.layers.http import HTTPRequest, HTTPResponse
17+
from scapy.layers.inet import TCP
18+
from scapy.packet import Packet
19+
from scapy.error import Scapy_Exception
20+
21+
22+
class PayloadLenField(BitFieldLenField):
23+
24+
def __init__(self, name, default, length_of, size=0, tot_size=0, end_tot_size=0):
25+
# Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
26+
super().__init__(name, default, size, length_of=length_of, tot_size=tot_size, end_tot_size=end_tot_size)
27+
28+
def getfield(self, pkt, s):
29+
s, _ = s
30+
# Get the 7-bit field (first byte)
31+
length_byte = s[0] & 0x7F
32+
s = s[1:]
33+
34+
if length_byte <= 125:
35+
# 7-bit length
36+
return s, length_byte
37+
elif length_byte == 126:
38+
# 16-bit length
39+
length = struct.unpack("!H", s[:2])[0] # Read 2 bytes
40+
s = s[2:]
41+
return s, length
42+
elif length_byte == 127:
43+
# 64-bit length
44+
length = struct.unpack("!Q", s[:8])[0] # Read 8 bytes
45+
s = s[8:]
46+
return s, length
47+
48+
def addfield(self, pkt, s, val):
49+
p_field, p_val = pkt.getfield_and_val(self.length_of)
50+
val = p_field.i2len(pkt, p_val)
51+
52+
if val <= 125:
53+
self.size = 7
54+
return super().addfield(pkt, s, val)
55+
elif val <= 0xFFFF:
56+
self.size = 7+16
57+
s, _, masked = s
58+
return s + struct.pack("!BH", 126 | masked, val)
59+
elif val <= 0xFFFFFFFFFFFFFFFF:
60+
self.size = 7+64
61+
s, _, masked = s
62+
return s + struct.pack("!BQ", 127 | masked, val)
63+
else:
64+
raise Scapy_Exception("%s: Payload length too large" %
65+
self.__class__.__name__)
66+
67+
68+
69+
class PayloadField(Field):
70+
"""
71+
Field for handling raw byte payloads with dynamic size.
72+
The length of the payload is described by a preceding PayloadLenField.
73+
"""
74+
__slots__ = ["lengthFrom"]
75+
76+
def __init__(self, name, lengthFrom):
77+
"""
78+
:param name: Field name
79+
:param lengthFrom: Field name that provides the length of the payload
80+
"""
81+
super(PayloadField, self).__init__(name, None)
82+
self.lengthFrom = lengthFrom
83+
84+
def getfield(self, pkt, s):
85+
# Fetch the length from the field that specifies the length
86+
length = getattr(pkt, self.lengthFrom)
87+
payloadData = s[:length]
88+
89+
if pkt.mask:
90+
key = struct.pack("I", pkt.maskingKey)[::-1]
91+
data_int = int.from_bytes(payloadData, 'big')
92+
mask_repeated = key * (len(payloadData) // 4) + key[: len(payloadData) % 4]
93+
mask_int = int.from_bytes(mask_repeated, 'big')
94+
payloadData = (data_int ^ mask_int).to_bytes(len(payloadData), 'big')
95+
96+
if("permessage-deflate" in pkt.extensions):
97+
try:
98+
payloadData = pkt.decoder[0](payloadData + b"\x00\x00\xff\xff")
99+
except Exception:
100+
# Failed to decompress payload
101+
pass
102+
103+
return s[length:], payloadData
104+
105+
def addfield(self, pkt, s, val):
106+
# Ensure val is bytes and append the data to the packet
107+
return s + bytes(val)
108+
109+
def i2len(self, pkt, val):
110+
# Length of the payload in bytes
111+
return len(val)
112+
113+
class WebSocket(Packet):
114+
__slots__ = ["extensions", "decoder"]
115+
116+
name = "WebSocket"
117+
fields_desc = [
118+
BitField("fin", 0, 1),
119+
BitField("rsv", 0, 3),
120+
BitEnumField("opcode", 0, 4,
121+
{
122+
0x0: "none",
123+
0x1: "text",
124+
0x2: "binary",
125+
0x8: "close",
126+
0x9: "ping",
127+
0xA: "pong",
128+
}),
129+
BitField("mask", 0, 1),
130+
PayloadLenField("payloadLen", 0, length_of="wsPayload", size=1),
131+
ConditionalField(XNBytesField("maskingKey", 0, sz=4), lambda pkt: pkt.mask == 1),
132+
PayloadField("wsPayload", lengthFrom="payloadLen")
133+
]
134+
135+
def __init__(self, pkt=None, extensions=[], decoder=None, *args, **fields):
136+
self.extensions = extensions
137+
self.decoder = decoder
138+
super().__init__(_pkt=pkt, *args, **fields)
139+
140+
def extract_padding(self, s):
141+
return '', s
142+
143+
@classmethod
144+
def tcp_reassemble(cls, data, metadata, session):
145+
# data = the reassembled data from the same request/flow
146+
# metadata = empty dictionary, that can be used to store data
147+
# during TCP reassembly
148+
# session = a dictionary proper to the bidirectional TCP session,
149+
# that can be used to store anything
150+
# [...]
151+
# If the packet is available, return it. Otherwise don't.
152+
# Whenever you return a packet, the buffer will be discarded.
153+
154+
155+
HANDSHAKE_STATE_CLIENT_OPEN = 0
156+
HANDSHAKE_STATE_SERVER_OPEN = 1
157+
HANDSHAKE_STATE_OPEN = 2
158+
159+
if "handshake-state" not in session:
160+
session["handshake-state"] = HANDSHAKE_STATE_CLIENT_OPEN
161+
162+
if "extensions" not in session:
163+
session["extensions"] = {}
164+
165+
166+
if session["handshake-state"] == HANDSHAKE_STATE_CLIENT_OPEN:
167+
ht = HTTPRequest(data)
168+
169+
if ht.Method != b"GET":
170+
return None
171+
172+
if not ht.Upgrade or ht.Upgrade.lower() != b"websocket":
173+
return None
174+
175+
if b"Sec-WebSocket-Key" not in ht.Unknown_Headers:
176+
return None
177+
178+
179+
session["handshake-key"] = ht.Unknown_Headers[b"Sec-WebSocket-Key"]
180+
181+
if "original" in metadata:
182+
session["server-port"] = metadata["original"][TCP].dport
183+
else:
184+
print("No original packet")
185+
186+
session["handshake-state"] = HANDSHAKE_STATE_SERVER_OPEN
187+
188+
return ht
189+
190+
elif session["handshake-state"] == HANDSHAKE_STATE_SERVER_OPEN:
191+
ht = HTTPResponse(data)
192+
193+
if not ht.Upgrade.lower() == b"websocket":
194+
return None
195+
196+
# Verify key-accept handshake:
197+
correct_accept = base64.b64encode(sha1(session["handshake-key"] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode()).digest())
198+
if ht.Unknown_Headers[b"Sec-WebSocket-Accept"] != correct_accept:
199+
#TODO handle or Logg wrong accept key
200+
pass
201+
202+
if b"Sec-WebSocket-Extensions" in ht.Unknown_Headers:
203+
session["extensions"] = {}
204+
for extension in ht.Unknown_Headers[b"Sec-WebSocket-Extensions"].decode().strip().split(";"):
205+
key_value_pair = extension.split("=", 1) + [None]
206+
session["extensions"][key_value_pair[0].strip()] = key_value_pair[1]
207+
208+
if "permessage-deflate" in session["extensions"]:
209+
def create_decompressor(window_bits):
210+
decoder = zlib.decompressobj(wbits=-window_bits)
211+
def decomp(data):
212+
nonlocal decoder
213+
return decoder.decompress(data, 0)
214+
215+
def reset():
216+
nonlocal decoder
217+
nonlocal window_bits
218+
decoder = zlib.decompressobj(wbits=-window_bits)
219+
220+
return (decomp, reset)
221+
222+
# Default values
223+
client_wb = 12
224+
server_wb = 15
225+
226+
# Check for new values in extensions header
227+
if "client_max_window_bits" in session["extensions"]:
228+
client_wb = int(session["extensions"]["client_max_window_bits"])
229+
230+
if "server_max_window_bits" in session["extensions"]:
231+
server_wb = int(session["extensions"]["server_max_window_bits"])
232+
233+
234+
session["server-decoder"] = create_decompressor(client_wb)
235+
session["client-decoder"] = create_decompressor(server_wb)
236+
237+
238+
session["handshake-state"] = HANDSHAKE_STATE_OPEN
239+
240+
return ht
241+
242+
243+
# Handshake is done:
244+
if "original" not in metadata:
245+
return
246+
247+
if "permessage-deflate" in session["extensions"]:
248+
is_server = True if metadata["original"][TCP].sport == session["server-port"] else False
249+
ws = WebSocket(bytes(data), extensions=session["extensions"], decoder = session["server-decoder"] if is_server else session["client-decoder"])
250+
return ws
251+
else:
252+
ws = WebSocket(bytes(data), extensions=session["extensions"])
253+
return ws

scapy/layers/http.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,10 @@ def do_dissect(self, s):
529529
"""From the HTTP packet string, populate the scapy object"""
530530
first_line, body = _dissect_headers(self, s)
531531
try:
532-
Method, Path, HTTPVersion = re.split(br"\s+", first_line, maxsplit=2)
533-
self.setfieldval('Method', Method)
534-
self.setfieldval('Path', Path)
535-
self.setfieldval('Http_Version', HTTPVersion)
532+
version_status_reason = re.split(br"\s+", first_line, maxsplit=2) + [None]
533+
self.setfieldval('Http_Version', version_status_reason[0])
534+
self.setfieldval('Status_Code', version_status_reason[1])
535+
self.setfieldval('Reason_Phrase', version_status_reason[2])
536536
except ValueError:
537537
pass
538538
if body:
@@ -573,10 +573,10 @@ def do_dissect(self, s):
573573
''' From the HTTP packet string, populate the scapy object '''
574574
first_line, body = _dissect_headers(self, s)
575575
try:
576-
HTTPVersion, Status, Reason = re.split(br"\s+", first_line, maxsplit=2)
577-
self.setfieldval('Http_Version', HTTPVersion)
578-
self.setfieldval('Status_Code', Status)
579-
self.setfieldval('Reason_Phrase', Reason)
576+
version_status_reason = re.split(br"\s+", first_line, maxsplit=2) + [None]
577+
self.setfieldval('Http_Version', version_status_reason[0])
578+
self.setfieldval('Status_Code', version_status_reason[1])
579+
self.setfieldval('Reason_Phrase', version_status_reason[2])
580580
except ValueError:
581581
pass
582582
if body:

0 commit comments

Comments
 (0)