1
+ #!/usr/bin/env python3
2
+
3
+ import requests
4
+ import json
5
+ import sys
6
+ import time
7
+
8
+ def create_long_prompt ():
9
+ """Create a very long prompt to ensure multiple batches are processed"""
10
+ # Create a much longer prompt that will definitely take multiple batches
11
+ # This will help us clearly see the progress effect
12
+ base_text = "This is a comprehensive test prompt designed to verify the progress functionality thoroughly. " * 200
13
+ return base_text
14
+
15
+ def test_completion_endpoint_progress (server_url ):
16
+ """Test progress functionality on /completion endpoint with long prompt"""
17
+ print ("\n === Testing /completion endpoint progress ===" )
18
+ print ("Using a very long prompt to clearly demonstrate progress..." )
19
+
20
+ prompt = create_long_prompt ()
21
+ print (f"Prompt length: { len (prompt )} characters" )
22
+
23
+ data = {
24
+ "prompt" : prompt ,
25
+ "stream" : True ,
26
+ "return_progress" : True ,
27
+ "max_tokens" : 10 , # Small number to focus on prompt processing
28
+ "temperature" : 0.7
29
+ }
30
+
31
+ progress_responses = []
32
+ content_responses = []
33
+
34
+ try :
35
+ print ("Sending request..." )
36
+ response = requests .post (f"{ server_url } /completion" , json = data , stream = True )
37
+ response .raise_for_status ()
38
+
39
+ print ("Receiving streaming response..." )
40
+ for line in response .iter_lines ():
41
+ if line :
42
+ line_str = line .decode ('utf-8' )
43
+ if line_str .startswith ('data: ' ):
44
+ data_str = line_str [6 :] # Remove 'data: ' prefix
45
+ if data_str .strip () == '[DONE]' :
46
+ break
47
+
48
+ try :
49
+ json_data = json .loads (data_str )
50
+ if 'prompt_processing' in json_data :
51
+ progress_responses .append (json_data ['prompt_processing' ])
52
+ progress = json_data ['prompt_processing' ]
53
+ percentage = progress .get ('progress' , 0 ) * 100
54
+ print (f"Progress: { percentage :.1f} % ({ progress .get ('n_prompt_tokens_processed' , 'N/A' )} /{ progress .get ('n_prompt_tokens' , 'N/A' )} )" )
55
+ elif 'content' in json_data and json_data .get ('content' , '' ):
56
+ content_responses .append (json_data )
57
+ except json .JSONDecodeError :
58
+ continue
59
+
60
+ print (f"\n Received { len (progress_responses )} progress responses" )
61
+ print (f"Received { len (content_responses )} content responses" )
62
+
63
+ # Detailed analysis
64
+ if progress_responses :
65
+ print ("\n === Progress Analysis ===" )
66
+ for i , progress in enumerate (progress_responses ):
67
+ percentage = progress .get ('progress' , 0 ) * 100
68
+ processed = progress .get ('n_prompt_tokens_processed' , 0 )
69
+ total = progress .get ('n_prompt_tokens' , 0 )
70
+ print (f" Progress { i + 1 } : { percentage :.1f} % ({ processed } /{ total } )" )
71
+
72
+ # Check if we reached 100%
73
+ last_progress = progress_responses [- 1 ].get ('progress' , 0 )
74
+ if last_progress >= 0.99 : # Allow for small floating point differences
75
+ print ("✅ Progress reached 100% as expected" )
76
+ return True
77
+ else :
78
+ print (f"❌ Progress did not reach 100% (last: { last_progress * 100 :.1f} %)" )
79
+ return False
80
+ else :
81
+ print ("❌ No progress responses received" )
82
+ return False
83
+
84
+ except Exception as e :
85
+ print (f"Error: { e } " )
86
+ return False
87
+
88
+ def test_progress_disabled (server_url ):
89
+ """Test that progress is not sent when return_progress is false"""
90
+ print ("\n === Testing progress disabled ===" )
91
+
92
+ prompt = create_long_prompt ()
93
+
94
+ data = {
95
+ "prompt" : prompt ,
96
+ "stream" : True ,
97
+ "return_progress" : False , # Disable progress
98
+ "max_tokens" : 10 ,
99
+ "temperature" : 0.7
100
+ }
101
+
102
+ progress_responses = []
103
+ content_responses = []
104
+
105
+ try :
106
+ print ("Sending request with progress disabled..." )
107
+ response = requests .post (f"{ server_url } /completion" , json = data , stream = True )
108
+ response .raise_for_status ()
109
+
110
+ for line in response .iter_lines ():
111
+ if line :
112
+ line_str = line .decode ('utf-8' )
113
+ if line_str .startswith ('data: ' ):
114
+ data_str = line_str [6 :] # Remove 'data: ' prefix
115
+ if data_str .strip () == '[DONE]' :
116
+ break
117
+
118
+ try :
119
+ json_data = json .loads (data_str )
120
+ if 'prompt_processing' in json_data :
121
+ progress_responses .append (json_data ['prompt_processing' ])
122
+ elif 'content' in json_data and json_data .get ('content' , '' ):
123
+ content_responses .append (json_data )
124
+ except json .JSONDecodeError :
125
+ continue
126
+
127
+ print (f"Received { len (progress_responses )} progress responses" )
128
+ print (f"Received { len (content_responses )} content responses" )
129
+
130
+ # Check that no progress responses were received
131
+ if len (progress_responses ) == 0 :
132
+ print ("✅ No progress responses received when disabled (correct)" )
133
+ return True
134
+ else :
135
+ print ("❌ Progress responses received when disabled (incorrect)" )
136
+ return False
137
+
138
+ except Exception as e :
139
+ print (f"Error: { e } " )
140
+ return False
141
+
142
+ def test_batch_size_effect (server_url ):
143
+ """Test the effect of different batch sizes on progress reporting"""
144
+ print ("\n === Testing batch size effect ===" )
145
+
146
+ prompt = create_long_prompt ()
147
+
148
+ # Test with different batch sizes
149
+ batch_sizes = [16 , 32 , 64 ]
150
+
151
+ for batch_size in batch_sizes :
152
+ print (f"\n Testing with batch size: { batch_size } " )
153
+
154
+ data = {
155
+ "prompt" : prompt ,
156
+ "stream" : True ,
157
+ "return_progress" : True ,
158
+ "max_tokens" : 10 ,
159
+ "temperature" : 0.7
160
+ }
161
+
162
+ progress_responses = []
163
+
164
+ try :
165
+ # Note: We can't directly set batch_size in the request, but we can observe the effect
166
+ # by counting progress responses - smaller batch sizes should result in more progress updates
167
+ response = requests .post (f"{ server_url } /completion" , json = data , stream = True )
168
+ response .raise_for_status ()
169
+
170
+ for line in response .iter_lines ():
171
+ if line :
172
+ line_str = line .decode ('utf-8' )
173
+ if line_str .startswith ('data: ' ):
174
+ data_str = line_str [6 :]
175
+ if data_str .strip () == '[DONE]' :
176
+ break
177
+
178
+ try :
179
+ json_data = json .loads (data_str )
180
+ if 'prompt_processing' in json_data :
181
+ progress_responses .append (json_data ['prompt_processing' ])
182
+ except json .JSONDecodeError :
183
+ continue
184
+
185
+ print (f" Progress responses: { len (progress_responses )} " )
186
+
187
+ except Exception as e :
188
+ print (f" Error: { e } " )
189
+ continue
190
+
191
+ print ("✅ Batch size effect test completed" )
192
+ return True
193
+
194
+ def main ():
195
+ if len (sys .argv ) != 2 :
196
+ print ("Usage: python3 test-progress-feature.py <server_url>" )
197
+ print ("Example: python3 test-progress-feature.py http://localhost:8081" )
198
+ sys .exit (1 )
199
+
200
+ server_url = sys .argv [1 ]
201
+
202
+ print ("Testing progress feature with comprehensive test cases..." )
203
+ print (f"Server URL: { server_url } " )
204
+ print ("This test uses a very long prompt to clearly demonstrate progress functionality." )
205
+
206
+ # Wait a moment for server to be ready
207
+ time .sleep (2 )
208
+
209
+ # Run tests
210
+ test1_passed = test_completion_endpoint_progress (server_url )
211
+ test2_passed = test_progress_disabled (server_url )
212
+ test3_passed = test_batch_size_effect (server_url )
213
+
214
+ # Summary
215
+ print ("\n === Test Summary ===" )
216
+ print (f"Completion endpoint progress: { '✅ PASS' if test1_passed else '❌ FAIL' } " )
217
+ print (f"Progress disabled: { '✅ PASS' if test2_passed else '❌ FAIL' } " )
218
+ print (f"Batch size effect: { '✅ PASS' if test3_passed else '❌ FAIL' } " )
219
+
220
+ if test1_passed and test2_passed and test3_passed :
221
+ print ("\n 🎉 All tests passed!" )
222
+ print ("The progress feature is working correctly with long prompts and small batch sizes." )
223
+ sys .exit (0 )
224
+ else :
225
+ print ("\n 💥 Some tests failed!" )
226
+ sys .exit (1 )
227
+
228
+ if __name__ == "__main__" :
229
+ main ()
0 commit comments