7
7
"errors"
8
8
"fmt"
9
9
"go/format"
10
+ "path/filepath"
10
11
"strings"
11
12
"text/template"
12
13
@@ -126,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
126
127
}
127
128
128
129
if options .OmitUnusedStructs {
129
- enums , structs = filterUnusedStructs (enums , structs , queries )
130
+ enums , structs = filterUnusedStructs (options , enums , structs , queries )
130
131
}
131
132
132
133
if err := validate (options , enums , structs , queries ); err != nil {
@@ -216,6 +217,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
216
217
"imports" : i .Imports ,
217
218
"hasImports" : i .HasImports ,
218
219
"hasPrefix" : strings .HasPrefix ,
220
+ "trimPrefix" : strings .TrimPrefix ,
219
221
220
222
// These methods are Go specific, they do not belong in the codegen package
221
223
// (as that is language independent)
@@ -237,14 +239,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
237
239
238
240
output := map [string ]string {}
239
241
240
- execute := func (name , templateName string ) error {
242
+ execute := func (name , packageName , templateName string ) error {
241
243
imports := i .Imports (name )
242
244
replacedQueries := replaceConflictedArg (imports , queries )
243
245
244
246
var b bytes.Buffer
245
247
w := bufio .NewWriter (& b )
246
248
tctx .SourceName = name
247
249
tctx .GoQueries = replacedQueries
250
+ tctx .Package = packageName
248
251
err := tmpl .ExecuteTemplate (w , templateName , & tctx )
249
252
w .Flush ()
250
253
if err != nil {
@@ -256,8 +259,13 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
256
259
return fmt .Errorf ("source error: %w" , err )
257
260
}
258
261
259
- if templateName == "queryFile" && options .OutputFilesSuffix != "" {
260
- name += options .OutputFilesSuffix
262
+ if templateName == "queryFile" {
263
+ if options .OutputQueryFilesDirectory != "" {
264
+ name = filepath .Join (options .OutputQueryFilesDirectory , name )
265
+ }
266
+ if options .OutputFilesSuffix != "" {
267
+ name += options .OutputFilesSuffix
268
+ }
261
269
}
262
270
263
271
if ! strings .HasSuffix (name , ".go" ) {
@@ -289,24 +297,29 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
289
297
batchFileName = options .OutputBatchFileName
290
298
}
291
299
292
- if err := execute (dbFileName , "dbFile" ); err != nil {
300
+ modelsPackageName := options .Package
301
+ if options .OutputModelsPackage != "" {
302
+ modelsPackageName = options .OutputModelsPackage
303
+ }
304
+
305
+ if err := execute (dbFileName , options .Package , "dbFile" ); err != nil {
293
306
return nil , err
294
307
}
295
- if err := execute (modelsFileName , "modelsFile" ); err != nil {
308
+ if err := execute (modelsFileName , modelsPackageName , "modelsFile" ); err != nil {
296
309
return nil , err
297
310
}
298
311
if options .EmitInterface {
299
- if err := execute (querierFileName , "interfaceFile" ); err != nil {
312
+ if err := execute (querierFileName , options . Package , "interfaceFile" ); err != nil {
300
313
return nil , err
301
314
}
302
315
}
303
316
if tctx .UsesCopyFrom {
304
- if err := execute (copyfromFileName , "copyfromFile" ); err != nil {
317
+ if err := execute (copyfromFileName , options . Package , "copyfromFile" ); err != nil {
305
318
return nil , err
306
319
}
307
320
}
308
321
if tctx .UsesBatch {
309
- if err := execute (batchFileName , "batchFile" ); err != nil {
322
+ if err := execute (batchFileName , options . Package , "batchFile" ); err != nil {
310
323
return nil , err
311
324
}
312
325
}
@@ -317,7 +330,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
317
330
}
318
331
319
332
for source := range files {
320
- if err := execute (source , "queryFile" ); err != nil {
333
+ if err := execute (source , options . Package , "queryFile" ); err != nil {
321
334
return nil , err
322
335
}
323
336
}
@@ -367,7 +380,7 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
367
380
return nil
368
381
}
369
382
370
- func filterUnusedStructs (enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
383
+ func filterUnusedStructs (options * opts. Options , enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
371
384
keepTypes := make (map [string ]struct {})
372
385
373
386
for _ , query := range queries {
@@ -394,16 +407,23 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
394
407
395
408
keepEnums := make ([]Enum , 0 , len (enums ))
396
409
for _ , enum := range enums {
397
- _ , keep := keepTypes [enum .Name ]
398
- _ , keepNull := keepTypes ["Null" + enum .Name ]
410
+ var enumType string
411
+ if options .ModelsPackageImportPath != "" {
412
+ enumType = options .OutputModelsPackage + "." + enum .Name
413
+ } else {
414
+ enumType = enum .Name
415
+ }
416
+
417
+ _ , keep := keepTypes [enumType ]
418
+ _ , keepNull := keepTypes ["Null" + enumType ]
399
419
if keep || keepNull {
400
420
keepEnums = append (keepEnums , enum )
401
421
}
402
422
}
403
423
404
424
keepStructs := make ([]Struct , 0 , len (structs ))
405
425
for _ , st := range structs {
406
- if _ , ok := keepTypes [st .Name ]; ok {
426
+ if _ , ok := keepTypes [st .Type () ]; ok {
407
427
keepStructs = append (keepStructs , st )
408
428
}
409
429
}
0 commit comments