Skip to content

Commit 9af68f8

Browse files
fix: handle close code on ws client properly
1 parent a61a0ae commit 9af68f8

File tree

4 files changed

+218
-101
lines changed

4 files changed

+218
-101
lines changed

examples/websocket_client/main.mbt

Lines changed: 85 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,102 @@
1313
// limitations under the License.
1414

1515
///|
16-
/// WebSocket client example
16+
/// WebSocket client example - Autobahn Test Suite Client
1717
///
18-
/// This demonstrates how to connect to a WebSocket server,
19-
/// send messages, and receive responses.
18+
/// This implements the Autobahn test suite client logic:
19+
/// 1. Get case count
20+
/// 2. Run each test case (echo all messages)
21+
/// 3. Update reports
2022
async fn main {
21-
connect_to_echo_server()
23+
run_autobahn_tests()
2224
}
2325

2426
///|
25-
pub async fn connect_to_echo_server() -> Unit {
26-
println("Connecting to WebSocket echo server at localhost:8080")
27+
/// Run Autobahn WebSocket test suite
28+
pub async fn run_autobahn_tests() -> Unit {
29+
let host = @sys.get_env_vars()
30+
.get("WS_TEST_HOST")
31+
.unwrap_or("127.0.0.1".to_string())
32+
let port = 9001
33+
let agent = "moonbit-async-websocket"
2734

28-
// Connect to the server
29-
let client = @websocket.Client::connect("0.0.0.0", "/", port=8080)
30-
println("Connected successfully!")
35+
// Step 1: Get case count
36+
@stdio.stdout.write("Getting case count from \{host}:\{port}...\n")
37+
let case_count = get_case_count(host, port)
38+
@stdio.stdout.write("Ok, will run \{case_count} cases\n\n")
3139

32-
// Send some test messages
33-
let test_messages = [
34-
"Hello, WebSocket!", "This is a test message", "MoonBit WebSocket client works!",
35-
"Final message",
36-
]
37-
for message in test_messages {
38-
// Send text message
39-
println("Sending: \{message}")
40-
client.send_text(message)
40+
// Step 2: Run each test case
41+
for case_id = 1; case_id <= case_count; case_id = case_id + 1 {
42+
@stdio.stdout.write(
43+
"Running test case \{case_id}/\{case_count} as user agent \{agent}\n",
44+
)
45+
run_test_case(host, port, case_id, agent)
46+
}
47+
48+
// Step 3: Update reports
49+
@stdio.stdout.write("\nUpdating reports...\n")
50+
update_reports(host, port, agent)
51+
@stdio.stdout.write("All tests completed!\n")
52+
}
4153

42-
// Receive echo response
43-
let response = client.receive()
44-
match response {
45-
@websocket.Text(text) => println("Received: \{text}")
46-
@websocket.Binary(data) =>
47-
println("Received binary data (\{data.length()} bytes)")
54+
///|
55+
/// Get the total number of test cases
56+
async fn get_case_count(host : String, port : Int) -> Int {
57+
let client = @websocket.Client::connect(host, "/getCaseCount", port~)
58+
let message = client.receive()
59+
client.close()
60+
match message {
61+
@websocket.Message::Text(text) => {
62+
let count_str = text.to_string()
63+
@strconv.parse_int(count_str)
64+
}
65+
_ => {
66+
@stdio.stdout.write("Error: Expected text message with case count\n")
67+
0
4868
}
4969
}
70+
}
5071

51-
// Small delay between messages
52-
// Note: In a real implementation, you might want to add a sleep function
53-
// For now, we'll just continue immediately
54-
55-
// Test binary message
56-
println("Sending binary data...")
57-
let binary_data = @encoding/utf8.encode("Binary test data")
58-
client.send_binary(binary_data)
59-
let binary_response = client.receive()
60-
match binary_response {
61-
@websocket.Text(text) => println("Received text response: \{text}")
62-
@websocket.Binary(data) =>
63-
println("Received binary response (\{data.length()} bytes)")
72+
///|
73+
/// Run a single test case - echo all messages back to server
74+
async fn run_test_case(
75+
host : String,
76+
port : Int,
77+
case_id : Int,
78+
agent : String,
79+
) -> Unit {
80+
let path = "/runCase?case=\{case_id}&agent=\{agent}"
81+
let client = @websocket.Client::connect(host, path, port~)
82+
for {
83+
let message = client.receive() catch {
84+
@websocket.ConnectionClosed(_, _) =>
85+
// Test case completed
86+
break
87+
err => {
88+
@stdio.stdout.write("Error in case \{case_id}: \{err}\n")
89+
break
90+
}
91+
}
92+
// Echo the message back (core test logic)
93+
match message {
94+
@websocket.Message::Text(text) => client.send_text(text)
95+
@websocket.Message::Binary(data) => client.send_binary(data)
96+
}
6497
}
98+
client.close()
99+
}
65100

66-
// Close the connection
67-
println("Closing connection...")
68-
client.send_close()
69-
println("Client example completed")
101+
///|
102+
/// Update test reports on the server
103+
async fn update_reports(host : String, port : Int, agent : String) -> Unit {
104+
let path = "/updateReports?agent=\{agent}"
105+
let client = @websocket.Client::connect(host, path, port~)
106+
// Wait for server to close the connection
107+
ignore(
108+
client.receive() catch {
109+
@websocket.ConnectionClosed(_, _) => @websocket.Message::Text("")
110+
_ => @websocket.Message::Text("")
111+
},
112+
)
113+
client.close()
70114
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
{
22
"import": [
33
"moonbitlang/async/websocket",
4-
"moonbitlang/async"
4+
"moonbitlang/async",
5+
"moonbitlang/async/stdio",
6+
"moonbitlang/async/io",
7+
"moonbitlang/x/sys"
58
],
69
"is-main": true
710
}

src/websocket/client.mbt

Lines changed: 128 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ pub async fn Client::connect(
5858
let key = base64_encode(nonce.unsafe_reinterpret_as_bytes())
5959
let request = "GET \{path} HTTP/1.1\r\n"
6060
conn.write(request)
61-
conn.write("Host: \{host}\r\n")
61+
let host_header = if port == 80 || port == 443 {
62+
host
63+
} else {
64+
"\{host}:\{port}"
65+
}
66+
conn.write("Host: \{host_header}\r\n")
6267
conn.write("Upgrade: websocket\r\n")
6368
conn.write("Connection: Upgrade\r\n")
6469
conn.write("Sec-WebSocket-Key: \{key}\r\n")
@@ -149,17 +154,17 @@ pub async fn Client::send_close(
149154
code? : CloseCode = Normal,
150155
reason? : String,
151156
) -> Unit {
152-
if self.closed is Some(_) {
153-
return
157+
if self.closed is Some(e) {
158+
raise e
154159
}
155-
let mut payload = FixedArray::make(0, b'\x00')
156160
let code_int = code.to_int()
157-
let reason_bytes = if reason is Some(r) {
158-
@encoding/utf8.encode(r)
159-
} else {
160-
b""
161+
let reason_bytes = @encoding/utf8.encode(reason.unwrap_or(""))
162+
if reason_bytes.length() > 123 {
163+
// Close reason too long
164+
// TODO: should we close the connection anyway?
165+
fail("Close reason too long")
161166
}
162-
payload = FixedArray::make(2 + reason_bytes.length(), b'\x00')
167+
let payload = FixedArray::make(2 + reason_bytes.length(), b'\x00')
163168
payload.unsafe_write_uint16_be(0, code_int.to_uint16())
164169
payload.blit_from_bytesview(2, reason_bytes)
165170
write_frame(
@@ -169,12 +174,16 @@ pub async fn Client::send_close(
169174
payload.unsafe_reinterpret_as_bytes(),
170175
self.rand.int().to_be_bytes(),
171176
)
177+
// Wait until the server acknowledges the close
178+
ignore(read_frame(self.conn)) catch {
179+
_ => ()
180+
}
172181
self.closed = Some(ConnectionClosed(code, reason))
173182
}
174183
175184
///|
176185
/// Send a text message
177-
pub async fn Client::send_text(self : Client, text : String) -> Unit {
186+
pub async fn Client::send_text(self : Client, text : StringView) -> Unit {
178187
if self.closed is Some(code) {
179188
raise code
180189
}
@@ -190,7 +199,7 @@ pub async fn Client::send_text(self : Client, text : String) -> Unit {
190199
191200
///|
192201
/// Send a binary message
193-
pub async fn Client::send_binary(self : Client, data : Bytes) -> Unit {
202+
pub async fn Client::send_binary(self : Client, data : BytesView) -> Unit {
194203
if self.closed is Some(code) {
195204
raise code
196205
}
@@ -247,67 +256,127 @@ pub async fn Client::receive(self : Client) -> Message {
247256
}
248257
let frames : Array[Frame] = []
249258
let mut first_opcode : OpCode? = None
250-
for {
259+
while self.closed is None {
251260
let frame = read_frame(self.conn)
252261
253262
// Handle control frames immediately
254263
match frame.opcode {
255-
OpCode::Close => {
264+
Close => {
256265
// Parse close code and reason
266+
// Ref: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
257267
let mut close_code = Normal
258268
let mut reason : String? = None
259-
if frame.payload is [u16be(code), .. data] {
269+
if frame.payload is [u16be(code), .. rest] {
260270
close_code = CloseCode::from_int(code.reinterpret_as_int())
261-
reason = Some(
262-
@encoding/utf8.decode(data) catch {
263-
_ => {
264-
close_code = ProtocolError
265-
""
266-
}
267-
},
268-
)
271+
reason = Some(@encoding/utf8.decode(rest)) catch {
272+
_ => {
273+
// Invalid reason, fail fast
274+
close_code = ProtocolError
275+
None
276+
}
277+
}
278+
} else {
279+
guard frame.payload is [] else {
280+
// Invalid close payload
281+
close_code = ProtocolError
282+
}
269283
}
270-
self.closed = Some(ConnectionClosed(close_code, reason))
271-
raise ConnectionClosed(close_code, reason)
284+
// If we didn't send close first, respond with close
285+
if self.closed is None {
286+
// Echo the close frame back and close
287+
self.send_close(code=close_code, reason?) catch {
288+
_ => ()
289+
}
290+
}
291+
continue
272292
}
273-
OpCode::Ping => {
293+
Ping =>
274294
// Auto-respond to ping with pong
275295
self.pong(data=frame.payload)
276-
continue
277-
}
278-
OpCode::Pong =>
296+
Pong =>
279297
// Ignore pong frames
280-
continue
281-
_ => ()
282-
}
283-
284-
// Track the first opcode for message type
285-
if first_opcode is None {
286-
first_opcode = Some(frame.opcode)
287-
}
288-
frames.push(frame)
289-
290-
// If this is the final frame, assemble the message
291-
if frame.fin {
292-
break
293-
}
294-
}
295-
296-
// Assemble message from frames
297-
let total_size = frames.fold(init=0, fn(acc, f) { acc + f.payload.length() })
298-
let data = FixedArray::make(total_size, b'\x00')
299-
let mut offset = 0
300-
for frame in frames {
301-
let payload_arr = frame.payload.to_fixedarray()
302-
for i = 0; i < payload_arr.length(); i = i + 1 {
303-
data[offset + i] = payload_arr[i]
298+
()
299+
Text =>
300+
if first_opcode is Some(_) {
301+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
302+
// We don't have extensions, so fragments MUST NOT be interleaved
303+
self.send_close(code=ProtocolError) catch {
304+
_ => ()
305+
}
306+
} else if frame.fin {
307+
// Single-frame text message
308+
return Message::Text(@encoding/utf8.decode(frame.payload)) catch {
309+
_ => {
310+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-8.1
311+
// We MUST Fail the WebSocket Connection if the payload is not
312+
// valid UTF-8
313+
self.send_close(code=InvalidFramePayload) catch {
314+
_ => ()
315+
}
316+
continue
317+
}
318+
}
319+
} else {
320+
first_opcode = Some(Text)
321+
// Start of fragmented text message
322+
frames.push(frame)
323+
}
324+
Binary =>
325+
if first_opcode is Some(_) {
326+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
327+
// We don't have extensions, so fragments MUST NOT be interleaved
328+
self.send_close(code=ProtocolError) catch {
329+
_ => ()
330+
}
331+
} else if frame.fin {
332+
// Single-frame binary message
333+
return Message::Binary(frame.payload)
334+
} else {
335+
first_opcode = Some(Binary)
336+
frames.push(frame)
337+
}
338+
Continuation => {
339+
if first_opcode is None {
340+
// Continuation frame without a starting frame
341+
self.send_close(code=ProtocolError) catch {
342+
_ => ()
343+
}
344+
continue
345+
}
346+
frames.push(frame)
347+
if frame.fin {
348+
// Final fragment received, assemble message
349+
let total_size = frames.fold(init=0, fn(acc, f) {
350+
acc + f.payload.length()
351+
})
352+
let data = FixedArray::make(total_size, b'\x00')
353+
let mut offset = 0
354+
for f in frames {
355+
data.blit_from_bytes(offset, f.payload, 0, f.payload.length())
356+
offset += f.payload.length()
357+
}
358+
let message_data = data.unsafe_reinterpret_as_bytes()
359+
match first_opcode {
360+
Some(Text) => {
361+
let text = @encoding/utf8.decode(message_data) catch {
362+
_ => {
363+
self.send_close(code=InvalidFramePayload) catch {
364+
_ => ()
365+
}
366+
continue
367+
}
368+
}
369+
return Message::Text(text)
370+
}
371+
Some(Binary) => return Message::Binary(message_data)
372+
_ => panic()
373+
}
374+
// Reset for next message
375+
frames.clear()
376+
first_opcode = None
377+
}
378+
}
304379
}
305-
offset += payload_arr.length()
306-
}
307-
let message_data = data.unsafe_reinterpret_as_bytes()
308-
match first_opcode {
309-
Some(OpCode::Text) => Text(@encoding/utf8.decode_lossy(message_data))
310-
Some(OpCode::Binary) => Binary(message_data)
311-
_ => Binary(message_data)
312380
}
381+
raise self.closed.unwrap()
313382
}

0 commit comments

Comments
 (0)