@@ -52,8 +52,8 @@ func TestFirewallDBMigration(t *testing.T) {
52
52
t .Skipf ("Skipping Firewall DB migration test for kvdb build" )
53
53
}
54
54
55
- makeSQLDB := func (t * testing.T , sessionsStore session. Store ) ( * SQLDB ,
56
- * SQLQueriesExecutor [SQLQueries ]) {
55
+ makeSQLDB := func (t * testing.T ,
56
+ sessionsStore session. Store ) * SQLQueriesExecutor [SQLQueries ] {
57
57
58
58
testDBStore := NewTestDBWithSessions (t , sessionsStore , clock )
59
59
@@ -64,13 +64,13 @@ func TestFirewallDBMigration(t *testing.T) {
64
64
65
65
queries := sqlc .NewForType (baseDB , baseDB .BackendType )
66
66
67
- return store , NewSQLQueriesExecutor (baseDB , queries )
67
+ return NewSQLQueriesExecutor (baseDB , queries )
68
68
}
69
69
70
70
// The assertMigrationResults function will currently assert that
71
71
// the migrated kv stores entries in the SQLDB match the original kv
72
72
// stores entries in the BoltDB.
73
- assertMigrationResults := func (t * testing.T , store * SQLDB ,
73
+ assertMigrationResults := func (t * testing.T , store SQLQueries ,
74
74
kvEntries []* kvEntry ) {
75
75
76
76
var (
@@ -83,9 +83,7 @@ func TestFirewallDBMigration(t *testing.T) {
83
83
getRuleID := func (ruleName string ) int64 {
84
84
ruleID , ok := ruleIDs [ruleName ]
85
85
if ! ok {
86
- ruleID , err = store .db .GetRuleID (
87
- ctx , ruleName ,
88
- )
86
+ ruleID , err = store .GetRuleID (ctx , ruleName )
89
87
require .NoError (t , err )
90
88
91
89
ruleIDs [ruleName ] = ruleID
@@ -97,7 +95,7 @@ func TestFirewallDBMigration(t *testing.T) {
97
95
getGroupID := func (groupAlias []byte ) int64 {
98
96
groupID , ok := groupIDs [string (groupAlias )]
99
97
if ! ok {
100
- groupID , err = store .db . GetSessionIDByAlias (
98
+ groupID , err = store .GetSessionIDByAlias (
101
99
ctx , groupAlias ,
102
100
)
103
101
require .NoError (t , err )
@@ -111,7 +109,7 @@ func TestFirewallDBMigration(t *testing.T) {
111
109
getFeatureID := func (featureName string ) int64 {
112
110
featureID , ok := featureIDs [featureName ]
113
111
if ! ok {
114
- featureID , err = store .db . GetFeatureID (
112
+ featureID , err = store .GetFeatureID (
115
113
ctx , featureName ,
116
114
)
117
115
require .NoError (t , err )
@@ -125,7 +123,7 @@ func TestFirewallDBMigration(t *testing.T) {
125
123
// First we extract all migrated kv entries from the SQLDB,
126
124
// in order to be able to compare them to the original kv
127
125
// entries, to ensure that the migration was successful.
128
- sqlKvEntries , err := store .db . ListAllKVStoresRecords (ctx )
126
+ sqlKvEntries , err := store .ListAllKVStoresRecords (ctx )
129
127
require .NoError (t , err )
130
128
require .Equal (t , len (kvEntries ), len (sqlKvEntries ))
131
129
@@ -141,7 +139,7 @@ func TestFirewallDBMigration(t *testing.T) {
141
139
ruleID := getRuleID (entry .ruleName )
142
140
143
141
if entry .groupAlias .IsNone () {
144
- sqlVal , err := store .db . GetGlobalKVStoreRecord (
142
+ sqlVal , err := store .GetGlobalKVStoreRecord (
145
143
ctx ,
146
144
sqlc.GetGlobalKVStoreRecordParams {
147
145
Key : entry .key ,
@@ -159,7 +157,7 @@ func TestFirewallDBMigration(t *testing.T) {
159
157
groupAlias := entry .groupAlias .UnwrapOrFail (t )
160
158
groupID := getGroupID (groupAlias [:])
161
159
162
- v , err := store .db . GetGroupKVStoreRecord (
160
+ v , err := store .GetGroupKVStoreRecord (
163
161
ctx ,
164
162
sqlc.GetGroupKVStoreRecordParams {
165
163
Key : entry .key ,
@@ -184,7 +182,7 @@ func TestFirewallDBMigration(t *testing.T) {
184
182
entry .featureName .UnwrapOrFail (t ),
185
183
)
186
184
187
- sqlVal , err := store .db . GetFeatureKVStoreRecord (
185
+ sqlVal , err := store .GetFeatureKVStoreRecord (
188
186
ctx ,
189
187
sqlc.GetFeatureKVStoreRecordParams {
190
188
Key : entry .key ,
@@ -290,7 +288,7 @@ func TestFirewallDBMigration(t *testing.T) {
290
288
291
289
// Create the SQL store that we will migrate the data
292
290
// to.
293
- sqlStore , txEx := makeSQLDB (t , sessionsStore )
291
+ txEx := makeSQLDB (t , sessionsStore )
294
292
295
293
// Perform the migration.
296
294
//
@@ -299,15 +297,20 @@ func TestFirewallDBMigration(t *testing.T) {
299
297
var opts sqldb.MigrationTxOptions
300
298
err = txEx .ExecTx (ctx , & opts ,
301
299
func (tx SQLQueries ) error {
302
- return MigrateFirewallDBToSQL (
300
+ err = MigrateFirewallDBToSQL (
303
301
ctx , firewallStore .DB , tx ,
304
302
)
303
+ if err != nil {
304
+ return err
305
+ }
306
+
307
+ // Assert migration results.
308
+ assertMigrationResults (t , tx , entries )
309
+
310
+ return nil
305
311
}, sqldb .NoOpReset ,
306
312
)
307
313
require .NoError (t , err )
308
-
309
- // Assert migration results.
310
- assertMigrationResults (t , sqlStore , entries )
311
314
})
312
315
}
313
316
}
0 commit comments