@@ -2,17 +2,19 @@ package redis
2
2
3
3
import (
4
4
"context"
5
- "encoding/json"
6
5
"sync"
7
6
"time"
8
7
9
8
"github.com/go-redis/redis"
9
+ "github.com/json-iterator/go"
10
10
"gopkg.in/session.v2"
11
11
)
12
12
13
13
var (
14
- _ session.ManagerStore = & managerStore {}
15
- _ session.Store = & store {}
14
+ _ session.ManagerStore = & managerStore {}
15
+ _ session.Store = & store {}
16
+ jsonMarshal = jsoniter .Marshal
17
+ jsonUnmarshal = jsoniter .Unmarshal
16
18
)
17
19
18
20
// NewRedisStore Create an instance of a redis store
@@ -23,6 +25,11 @@ func NewRedisStore(opt *Options) session.ManagerStore {
23
25
return & managerStore {cli : redis .NewClient (opt .redisOptions ())}
24
26
}
25
27
28
+ // NewRedisStoreWithCli Create an instance of a redis store
29
+ func NewRedisStoreWithCli (cli * redis.Client ) session.ManagerStore {
30
+ return & managerStore {cli : cli }
31
+ }
32
+
26
33
type managerStore struct {
27
34
cli * redis.Client
28
35
}
@@ -39,33 +46,43 @@ func (s *managerStore) getValue(sid string) (string, error) {
39
46
return cmd .Val (), nil
40
47
}
41
48
42
- func (s * managerStore ) parseValue (value string ) (map [string ]string , error ) {
43
- var values map [string ]string
44
-
49
+ func (s * managerStore ) parseValue (value string ) (map [string ]interface {}, error ) {
50
+ var values map [string ]interface {}
45
51
if len (value ) > 0 {
46
- err := json . Unmarshal ([]byte (value ), & values )
52
+ err := jsonUnmarshal ([]byte (value ), & values )
47
53
if err != nil {
48
54
return nil , err
49
55
}
50
56
}
51
57
52
58
if values == nil {
53
- values = make (map [string ]string )
59
+ values = make (map [string ]interface {} )
54
60
}
55
61
return values , nil
56
62
}
57
63
58
64
func (s * managerStore ) Create (ctx context.Context , sid string , expired int64 ) (session.Store , error ) {
59
- values := make (map [string ]string )
60
- return & store {ctx : ctx , sid : sid , cli : s .cli , expired : expired , values : values }, nil
65
+ return & store {
66
+ ctx : ctx ,
67
+ sid : sid ,
68
+ cli : s .cli ,
69
+ expired : expired ,
70
+ values : make (map [string ]interface {}),
71
+ }, nil
61
72
}
62
73
63
74
func (s * managerStore ) Update (ctx context.Context , sid string , expired int64 ) (session.Store , error ) {
64
75
value , err := s .getValue (sid )
65
76
if err != nil {
66
77
return nil , err
67
78
} else if value == "" {
68
- return s .Create (ctx , sid , expired )
79
+ return & store {
80
+ ctx : ctx ,
81
+ sid : sid ,
82
+ cli : s .cli ,
83
+ expired : expired ,
84
+ values : make (map [string ]interface {}),
85
+ }, nil
69
86
}
70
87
71
88
cmd := s .cli .Set (sid , value , time .Duration (expired )* time .Second )
@@ -77,37 +94,80 @@ func (s *managerStore) Update(ctx context.Context, sid string, expired int64) (s
77
94
if err != nil {
78
95
return nil , err
79
96
}
80
-
81
- return & store {ctx : ctx , sid : sid , cli : s .cli , expired : expired , values : values }, nil
97
+ return & store {
98
+ ctx : ctx ,
99
+ sid : sid ,
100
+ cli : s .cli ,
101
+ expired : expired ,
102
+ values : values ,
103
+ }, nil
82
104
}
83
105
84
106
func (s * managerStore ) Delete (_ context.Context , sid string ) error {
107
+ if ok , err := s .Check (nil , sid ); err != nil {
108
+ return err
109
+ } else if ! ok {
110
+ return nil
111
+ }
112
+
85
113
cmd := s .cli .Del (sid )
86
114
return cmd .Err ()
87
115
}
88
116
89
117
func (s * managerStore ) Check (_ context.Context , sid string ) (bool , error ) {
90
- cmd := s .cli .Get (sid )
118
+ cmd := s .cli .Exists (sid )
91
119
if err := cmd .Err (); err != nil {
92
- if err == redis .Nil {
93
- return false , nil
94
- }
95
120
return false , err
96
121
}
97
- return true , nil
122
+ return cmd .Val () > 0 , nil
123
+ }
124
+
125
+ func (s * managerStore ) Refresh (ctx context.Context , oldsid , sid string , expired int64 ) (session.Store , error ) {
126
+ value , err := s .getValue (oldsid )
127
+ if err != nil {
128
+ return nil , err
129
+ } else if value == "" {
130
+ return & store {
131
+ ctx : ctx ,
132
+ sid : sid ,
133
+ cli : s .cli ,
134
+ expired : expired ,
135
+ values : make (map [string ]interface {}),
136
+ }, nil
137
+ }
138
+
139
+ pipe := s .cli .TxPipeline ()
140
+ pipe .Set (sid , value , time .Duration (expired )* time .Second )
141
+ pipe .Del (oldsid )
142
+ _ , err = pipe .Exec ()
143
+ if err != nil {
144
+ return nil , err
145
+ }
146
+
147
+ values , err := s .parseValue (value )
148
+ if err != nil {
149
+ return nil , err
150
+ }
151
+ return & store {
152
+ ctx : ctx ,
153
+ sid : sid ,
154
+ cli : s .cli ,
155
+ expired : expired ,
156
+ values : values ,
157
+ }, nil
98
158
}
99
159
100
160
func (s * managerStore ) Close () error {
101
161
return s .cli .Close ()
102
162
}
103
163
104
164
type store struct {
165
+ sync.RWMutex
166
+ ctx context.Context
105
167
sid string
106
- cli * redis.Client
107
168
expired int64
108
- values map [string ]string
109
- sync.RWMutex
110
- ctx context.Context
169
+ values map [string ]interface {}
170
+ cli * redis.Client
111
171
}
112
172
113
173
func (s * store ) Context () context.Context {
@@ -118,20 +178,20 @@ func (s *store) SessionID() string {
118
178
return s .sid
119
179
}
120
180
121
- func (s * store ) Set (key , value string ) {
181
+ func (s * store ) Set (key string , value interface {} ) {
122
182
s .Lock ()
123
183
s .values [key ] = value
124
184
s .Unlock ()
125
185
}
126
186
127
- func (s * store ) Get (key string ) (string , bool ) {
187
+ func (s * store ) Get (key string ) (interface {} , bool ) {
128
188
s .RLock ()
129
189
defer s .RUnlock ()
130
190
val , ok := s .values [key ]
131
191
return val , ok
132
192
}
133
193
134
- func (s * store ) Delete (key string ) string {
194
+ func (s * store ) Delete (key string ) interface {} {
135
195
s .RLock ()
136
196
v , ok := s .values [key ]
137
197
s .RUnlock ()
@@ -145,7 +205,7 @@ func (s *store) Delete(key string) string {
145
205
146
206
func (s * store ) Flush () error {
147
207
s .Lock ()
148
- s .values = make (map [string ]string )
208
+ s .values = make (map [string ]interface {} )
149
209
s .Unlock ()
150
210
return s .Save ()
151
211
}
@@ -155,7 +215,11 @@ func (s *store) Save() error {
155
215
156
216
s .RLock ()
157
217
if len (s .values ) > 0 {
158
- buf , _ := json .Marshal (s .values )
218
+ buf , err := jsonMarshal (s .values )
219
+ if err != nil {
220
+ s .RUnlock ()
221
+ return err
222
+ }
159
223
value = string (buf )
160
224
}
161
225
s .RUnlock ()
0 commit comments