14
14
import threading
15
15
import time
16
16
from contextlib import contextmanager
17
- from typing import Dict , Optional
17
+ from typing import Dict , Optional , List
18
+ from collections import deque
18
19
19
20
from ..page import make_applications , render_page
20
21
from ..utils import deserialize_binary_event
21
22
from ...session import CoroutineBasedSession , ThreadBasedSession , register_session_implement_for_target
22
- from ...session .base import get_session_info_from_headers
23
+ from ...session .base import get_session_info_from_headers , Session
23
24
from ...utils import random_str , LRUDict , isgeneratorfunction , iscoroutinefunction , check_webio_js
24
25
25
26
@@ -35,7 +36,7 @@ def request_obj(self):
35
36
Return the current request object"""
36
37
pass
37
38
38
- def request_method (self ):
39
+ def request_method (self ) -> str :
39
40
"""返回当前请求的方法,大写
40
41
Return the HTTP method of the current request, uppercase"""
41
42
pass
@@ -45,29 +46,19 @@ def request_headers(self) -> Dict:
45
46
Return the header dictionary of the current request"""
46
47
pass
47
48
48
- def request_url_parameter (self , name , default = None ):
49
+ def request_url_parameter (self , name , default = None ) -> str :
49
50
"""返回当前请求的URL参数
50
51
Returns the value of the given URL parameter of the current request"""
51
52
pass
52
53
53
- def request_body (self ):
54
+ def request_body (self ) -> bytes :
54
55
"""返回当前请求的body数据
55
56
Returns the data of the current request body
56
57
57
58
:return: bytes/bytearray
58
59
"""
59
60
return b''
60
61
61
- def request_json (self ) -> Optional [Dict ]:
62
- """返回当前请求的json反序列化后的内容,若请求数据不为json格式,返回None
63
- Return the data (json deserialization) of the currently requested, if the data is not in json format, return None"""
64
- try :
65
- if self .request_headers ().get ('content-type' ) == 'application/octet-stream' :
66
- return deserialize_binary_event (self .request_body ())
67
- return json .loads (self .request_body ())
68
- except Exception :
69
- return None
70
-
71
62
def set_header (self , name , value ):
72
63
"""为当前响应设置header
73
64
Set a header for the current response"""
@@ -92,7 +83,7 @@ def get_response(self):
92
83
Get the current response object"""
93
84
pass
94
85
95
- def get_client_ip (self ):
86
+ def get_client_ip (self ) -> str :
96
87
"""获取用户的ip
97
88
Get the user's ip"""
98
89
pass
@@ -102,6 +93,56 @@ def get_client_ip(self):
102
93
_event_loop = None
103
94
104
95
96
+ class ReliableTransport :
97
+ def __init__ (self , session : Session , message_window : int = 4 ):
98
+ self .session = session
99
+ self .messages = deque ()
100
+ self .window_size = message_window
101
+ self .min_msg_id = 0 # the id of the first message in the window
102
+ self .finished_event_id = - 1 # the id of the last finished event
103
+
104
+ @staticmethod
105
+ def close_message (ack ):
106
+ return dict (
107
+ commands = [[dict (command = 'close_session' )]],
108
+ seq = ack + 1
109
+ )
110
+
111
+ def push_event (self , events : List [Dict ], seq : int ) -> int :
112
+ """Send client events to the session and return the success message count"""
113
+ if not events :
114
+ return 0
115
+
116
+ submit_cnt = 0
117
+ for eid , event in enumerate (events , start = seq ):
118
+ if eid > self .finished_event_id :
119
+ self .finished_event_id = eid # todo: use lock for check and set operation
120
+ self .session .send_client_event (event )
121
+ submit_cnt += 1
122
+
123
+ return submit_cnt
124
+
125
+ def get_response (self , ack = 0 ):
126
+ """
127
+ ack num is the number of messages that the client has received.
128
+ response is a list of messages that the client should receive, along with their min id `seq`.
129
+ """
130
+ while ack >= self .min_msg_id and self .messages :
131
+ self .messages .popleft ()
132
+ self .min_msg_id += 1
133
+
134
+ if len (self .messages ) < self .window_size :
135
+ msgs = self .session .get_task_commands ()
136
+ if msgs :
137
+ self .messages .append (msgs )
138
+
139
+ return dict (
140
+ commands = list (self .messages ),
141
+ seq = self .min_msg_id ,
142
+ ack = self .finished_event_id
143
+ )
144
+
145
+
105
146
# todo: use lock to avoid thread race condition
106
147
class HttpHandler :
107
148
"""基于HTTP的后端Handler实现
@@ -112,7 +153,7 @@ class HttpHandler:
112
153
113
154
"""
114
155
_webio_sessions = {} # WebIOSessionID -> WebIOSession()
115
- _webio_last_commands = {} # WebIOSessionID -> (last commands, commands sequence id)
156
+ _webio_transports = {} # WebIOSessionID -> ReliableTransport(), type: Dict[str, ReliableTransport]
116
157
_webio_expire = LRUDict () # WebIOSessionID -> last active timestamp. In increasing order of last active time
117
158
_webio_expire_lock = threading .Lock ()
118
159
@@ -143,23 +184,13 @@ def _remove_expired_sessions(cls, session_expire_seconds):
143
184
if session :
144
185
session .close (nonblock = True )
145
186
del cls ._webio_sessions [sid ]
187
+ del cls ._webio_transports [sid ]
146
188
147
189
@classmethod
148
190
def _remove_webio_session (cls , sid ):
149
191
cls ._webio_sessions .pop (sid , None )
150
192
cls ._webio_expire .pop (sid , None )
151
193
152
- @classmethod
153
- def get_response (cls , sid , ack = 0 ):
154
- commands , seq = cls ._webio_last_commands .get (sid , ([], 0 ))
155
- if ack == seq :
156
- webio_session = cls ._webio_sessions [sid ]
157
- commands = webio_session .get_task_commands ()
158
- seq += 1
159
- cls ._webio_last_commands [sid ] = (commands , seq )
160
-
161
- return {'commands' : commands , 'seq' : seq }
162
-
163
194
def _process_cors (self , context : HttpContext ):
164
195
"""Handling cross-domain requests: check the source of the request and set headers"""
165
196
origin = context .request_headers ().get ('Origin' , '' )
@@ -209,6 +240,14 @@ def get_cdn(self, context):
209
240
return False
210
241
return self .cdn
211
242
243
+ def read_event_data (self , context : HttpContext ) -> List [Dict ]:
244
+ try :
245
+ if context .request_headers ().get ('content-type' ) == 'application/octet-stream' :
246
+ return [deserialize_binary_event (context .request_body ())]
247
+ return json .loads (context .request_body ())
248
+ except Exception :
249
+ return []
250
+
212
251
@contextmanager
213
252
def handle_request_context (self , context : HttpContext ):
214
253
"""called when every http request"""
@@ -240,16 +279,18 @@ def handle_request_context(self, context: HttpContext):
240
279
context .set_content (html )
241
280
return context .get_response ()
242
281
243
- webio_session_id = None
282
+ ack = int (context .request_url_parameter ('ack' , 0 ))
283
+ webio_session_id = request_headers ['webio-session-id' ]
284
+ new_request = False
285
+ if webio_session_id .startswith ('NEW-' ):
286
+ new_request = True
287
+ webio_session_id = webio_session_id [4 :]
244
288
245
- # 初始请求,创建新 Session
246
- if not request_headers ['webio-session-id' ] or request_headers ['webio-session-id' ] == 'NEW' :
289
+ if new_request and webio_session_id not in cls ._webio_sessions : # 初始请求,创建新 Session
247
290
if context .request_method () == 'POST' : # 不能在POST请求中创建Session,防止CSRF攻击
248
291
context .set_status (403 )
249
292
return context .get_response ()
250
293
251
- webio_session_id = random_str (24 )
252
- context .set_header ('webio-session-id' , webio_session_id )
253
294
session_info = get_session_info_from_headers (context .request_headers ())
254
295
session_info ['user_ip' ] = context .get_client_ip ()
255
296
session_info ['request' ] = context .request_obj ()
@@ -264,17 +305,23 @@ def handle_request_context(self, context: HttpContext):
264
305
session_cls = ThreadBasedSession
265
306
webio_session = session_cls (application , session_info = session_info )
266
307
cls ._webio_sessions [webio_session_id ] = webio_session
267
- yield type (self ).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
268
- elif request_headers ['webio-session-id' ] not in cls ._webio_sessions : # WebIOSession deleted
269
- context .set_content ([dict (command = 'close_session' )], json_type = True )
308
+ cls ._webio_transports [webio_session_id ] = ReliableTransport (webio_session )
309
+ yield cls .WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
310
+ elif webio_session_id not in cls ._webio_sessions : # WebIOSession deleted
311
+ close_msg = ReliableTransport .close_message (ack )
312
+ context .set_content (close_msg , json_type = True )
270
313
return context .get_response ()
271
314
else :
272
- webio_session_id = request_headers ['webio-session-id' ]
315
+ # in this case, the request_headers['webio-session-id'] may also startswith NEW,
316
+ # this is because the response for the previous new session request has not been received by the client,
317
+ # and the client has sent a new request with the same session id.
273
318
webio_session = cls ._webio_sessions [webio_session_id ]
274
319
275
320
if context .request_method () == 'POST' : # client push event
276
- if context .request_json () is not None :
277
- webio_session .send_client_event (context .request_json ())
321
+ seq = int (context .request_url_parameter ('seq' , 0 ))
322
+ event_data = self .read_event_data (context )
323
+ submit_cnt = cls ._webio_transports [webio_session_id ].push_event (event_data , seq )
324
+ if submit_cnt > 0 :
278
325
yield type (self ).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
279
326
elif context .request_method () == 'GET' : # client pull messages
280
327
pass
@@ -283,8 +330,8 @@ def handle_request_context(self, context: HttpContext):
283
330
284
331
self .interval_cleaning ()
285
332
286
- ack = int ( context . request_url_parameter ( ' ack' , 0 ) )
287
- context .set_content (type ( self ). get_response ( webio_session_id , ack = ack ) , json_type = True )
333
+ resp = cls . _webio_transports [ webio_session_id ]. get_response ( ack )
334
+ context .set_content (resp , json_type = True )
288
335
289
336
if webio_session .closed ():
290
337
self ._remove_webio_session (webio_session_id )
0 commit comments