11package ccm
22
33import (
4+ "bytes"
45 "context"
56 "encoding/json"
67 "errors"
@@ -26,6 +27,7 @@ import (
2627 N "github.com/sagernet/sing/common/network"
2728 aTLS "github.com/sagernet/sing/common/tls"
2829
30+ "github.com/anthropics/anthropic-sdk-go"
2931 "github.com/go-chi/chi/v5"
3032 "golang.org/x/net/http2"
3133)
@@ -82,6 +84,7 @@ type Service struct {
8284 httpServer * http.Server
8385 userManager * UserManager
8486 access sync.RWMutex
87+ usageTracker * AggregatedUsage
8588}
8689
8790func NewService (ctx context.Context , logger log.ContextLogger , tag string , options option.CCMServiceOptions ) (adapter.Service , error ) {
@@ -107,6 +110,11 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
107110
108111 userManager := NewUserManager ()
109112
113+ var usageTracker * AggregatedUsage
114+ if options .UsagesPath != "" {
115+ usageTracker = NewAggregatedUsage (options .UsagesPath )
116+ }
117+
110118 service := & Service {
111119 Adapter : boxService .NewAdapter (C .TypeCCM , tag ),
112120 ctx : ctx ,
@@ -121,7 +129,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
121129 Network : []string {N .NetworkTCP },
122130 Listen : options .ListenOptions ,
123131 }),
124- userManager : userManager ,
132+ userManager : userManager ,
133+ usageTracker : usageTracker ,
125134 }
126135
127136 if options .TLS != nil {
@@ -148,6 +157,15 @@ func (s *Service) Start(stage adapter.StartStage) error {
148157 }
149158 s .credentials = credentials
150159
160+ if s .usageTracker != nil {
161+ err = s .usageTracker .Load ()
162+ if err != nil {
163+ s .logger .Warn ("load usage statistics: " , err )
164+ } else {
165+ s .logger .Info ("usage statistics loaded" )
166+ }
167+ }
168+
151169 router := chi .NewRouter ()
152170 router .Mount ("/" , s )
153171
@@ -230,6 +248,47 @@ func (s *Service) authenticateRequest(r *http.Request) bool {
230248 return ok
231249}
232250
251+ func (s * Service ) getUsernameFromRequest (r * http.Request ) string {
252+ if len (s .users ) == 0 {
253+ return ""
254+ }
255+ clientToken := r .Header .Get ("x-api-key" )
256+ if clientToken == "" {
257+ return ""
258+ }
259+ username , ok := s .userManager .Authenticate (clientToken )
260+ if ! ok {
261+ return ""
262+ }
263+ return username
264+ }
265+
266+ func countMessagesInRequest (body []byte ) (model string , messagesCount int ) {
267+ var req struct {
268+ Model string `json:"model"`
269+ Messages []anthropic.MessageParam `json:"messages"`
270+ }
271+ if err := json .Unmarshal (body , & req ); err != nil {
272+ return "" , 0
273+ }
274+ return req .Model , len (req .Messages )
275+ }
276+
277+ func extractUsageFromResponse (body []byte ) (model string , usage anthropic.Usage ) {
278+ var message anthropic.Message
279+ if err := json .Unmarshal (body , & message ); err != nil {
280+ return "" , anthropic.Usage {}
281+ }
282+ return string (message .Model ), message .Usage
283+ }
284+
285+ func detectContextWindow (betaHeader string , inputTokens int64 ) int {
286+ if strings .Contains (betaHeader , "context-1m" ) && inputTokens > 200000 {
287+ return 1000000
288+ }
289+ return 200000
290+ }
291+
233292func (s * Service ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
234293 if ! strings .HasPrefix (r .URL .Path , "/v1/" ) {
235294 writeJSONError (w , r , http .StatusNotFound , "not_found_error" , "Not found" )
@@ -242,6 +301,19 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
242301 return
243302 }
244303
304+ var requestModel string
305+ var messagesCount int
306+ var username string
307+
308+ if s .usageTracker != nil && r .Body != nil {
309+ username = s .getUsernameFromRequest (r )
310+ bodyBytes , err := io .ReadAll (r .Body )
311+ if err == nil {
312+ requestModel , messagesCount = countMessagesInRequest (bodyBytes )
313+ }
314+ r .Body = io .NopCloser (bytes .NewBuffer (bodyBytes ))
315+ }
316+
245317 accessToken , err := s .getAccessToken ()
246318 if err != nil {
247319 s .logger .Error ("get access token: " , err )
@@ -263,7 +335,8 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
263335 }
264336 }
265337
266- if betaHeader := proxyRequest .Header .Get ("anthropic-beta" ); betaHeader != "" {
338+ betaHeader := proxyRequest .Header .Get ("anthropic-beta" )
339+ if betaHeader != "" {
267340 proxyRequest .Header .Set ("anthropic-beta" , anthropicBetaOAuthValue + "," + betaHeader )
268341 } else {
269342 proxyRequest .Header .Set ("anthropic-beta" , anthropicBetaOAuthValue )
@@ -290,7 +363,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
290363 }
291364 }
292365 w .WriteHeader (response .StatusCode )
293- s .handleResponse (w , response )
366+
367+ if s .usageTracker != nil && response .StatusCode == http .StatusOK {
368+ s .handleResponseWithTracking (w , response , requestModel , betaHeader , messagesCount , username )
369+ } else {
370+ s .handleResponse (w , response )
371+ }
294372}
295373
296374func (s * Service ) handleResponse (writer http.ResponseWriter , response * http.Response ) {
@@ -320,7 +398,146 @@ func (s *Service) handleResponse(writer http.ResponseWriter, response *http.Resp
320398 }
321399}
322400
401+ func (s * Service ) handleResponseWithTracking (writer http.ResponseWriter , response * http.Response , requestModel string , betaHeader string , messagesCount int , username string ) {
402+ mediaType , _ , err := mime .ParseMediaType (response .Header .Get ("Content-Type" ))
403+ isStreaming := err == nil && mediaType == "text/event-stream"
404+
405+ if ! isStreaming {
406+ bodyBytes , err := io .ReadAll (response .Body )
407+ if err != nil {
408+ s .logger .Error ("read response body: " , err )
409+ return
410+ }
411+
412+ responseModel , usage := extractUsageFromResponse (bodyBytes )
413+ if responseModel == "" {
414+ responseModel = requestModel
415+ }
416+
417+ if usage .InputTokens > 0 || usage .OutputTokens > 0 {
418+ contextWindow := detectContextWindow (betaHeader , usage .InputTokens )
419+ err := s .usageTracker .AddUsage (
420+ responseModel ,
421+ contextWindow ,
422+ messagesCount ,
423+ usage .InputTokens ,
424+ usage .OutputTokens ,
425+ usage .CacheReadInputTokens ,
426+ usage .CacheCreationInputTokens ,
427+ username ,
428+ )
429+ if err != nil {
430+ s .logger .Warn ("track usage: " , err )
431+ }
432+ }
433+
434+ _ , _ = writer .Write (bodyBytes )
435+ return
436+ }
437+
438+ flusher , ok := writer .(http.Flusher )
439+ if ! ok {
440+ s .logger .Error ("streaming not supported" )
441+ return
442+ }
443+
444+ var accumulatedUsage anthropic.Usage
445+ var responseModel string
446+ buffer := make ([]byte , buf .BufferSize )
447+ var leftover []byte
448+
449+ for {
450+ n , err := response .Body .Read (buffer )
451+ if n > 0 {
452+ data := append (leftover , buffer [:n ]... )
453+ lines := bytes .Split (data , []byte ("\n " ))
454+
455+ if err == nil {
456+ leftover = lines [len (lines )- 1 ]
457+ lines = lines [:len (lines )- 1 ]
458+ } else {
459+ leftover = nil
460+ }
461+
462+ for _ , line := range lines {
463+ line = bytes .TrimSpace (line )
464+ if len (line ) == 0 {
465+ continue
466+ }
467+
468+ if bytes .HasPrefix (line , []byte ("data: " )) {
469+ eventData := bytes .TrimPrefix (line , []byte ("data: " ))
470+ if bytes .Equal (eventData , []byte ("[DONE]" )) {
471+ continue
472+ }
473+
474+ var event anthropic.MessageStreamEventUnion
475+ if err := json .Unmarshal (eventData , & event ); err == nil {
476+ switch event .Type {
477+ case "message_start" :
478+ messageStart := event .AsMessageStart ()
479+ if messageStart .Message .Model != "" {
480+ responseModel = string (messageStart .Message .Model )
481+ }
482+ if messageStart .Message .Usage .InputTokens > 0 {
483+ accumulatedUsage .InputTokens = messageStart .Message .Usage .InputTokens
484+ accumulatedUsage .CacheReadInputTokens = messageStart .Message .Usage .CacheReadInputTokens
485+ accumulatedUsage .CacheCreationInputTokens = messageStart .Message .Usage .CacheCreationInputTokens
486+ }
487+ case "message_delta" :
488+ messageDelta := event .AsMessageDelta ()
489+ if messageDelta .Usage .OutputTokens > 0 {
490+ accumulatedUsage .OutputTokens = messageDelta .Usage .OutputTokens
491+ }
492+ }
493+ }
494+ }
495+ }
496+
497+ _ , writeError := writer .Write (buffer [:n ])
498+ if writeError != nil {
499+ s .logger .Error ("write streaming response: " , writeError )
500+ return
501+ }
502+ flusher .Flush ()
503+ }
504+
505+ if err != nil {
506+ if responseModel == "" {
507+ responseModel = requestModel
508+ }
509+
510+ if accumulatedUsage .InputTokens > 0 || accumulatedUsage .OutputTokens > 0 {
511+ contextWindow := detectContextWindow (betaHeader , accumulatedUsage .InputTokens )
512+ err := s .usageTracker .AddUsage (
513+ responseModel ,
514+ contextWindow ,
515+ messagesCount ,
516+ accumulatedUsage .InputTokens ,
517+ accumulatedUsage .OutputTokens ,
518+ accumulatedUsage .CacheReadInputTokens ,
519+ accumulatedUsage .CacheCreationInputTokens ,
520+ username ,
521+ )
522+ if err != nil {
523+ s .logger .Warn ("track usage: " , err )
524+ }
525+ }
526+ return
527+ }
528+ }
529+ }
530+
323531func (s * Service ) Close () error {
532+ if s .usageTracker != nil {
533+ err := s .usageTracker .Save ()
534+ if err != nil {
535+ s .logger .Error ("save usage statistics: " , err )
536+ } else {
537+ s .logger .Info ("usage statistics saved" )
538+ }
539+ }
540+
324541 return common .Close (
325542 common .PtrOrNil (s .httpServer ),
326543 common .PtrOrNil (s .listener ),
0 commit comments