Skip to content

Commit c29b043

Browse files
朱斌Rennbon
authored andcommitted
no message
Author: Rennbon <[email protected]>
1 parent 699dcdf commit c29b043

File tree

4 files changed

+195
-1
lines changed

4 files changed

+195
-1
lines changed

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ module github.com/DATA-DOG/go-sqlmock
22

33
go 1.15
44

5-
require github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46
5+
require (
6+
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46
7+
github.com/stretchr/testify v1.7.0 // indirect
8+
)

go.sum

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
1+
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
13
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46 h1:veS9QfglfvqAw2e+eeNT/SbGySq8ajECXJ9e4fPoLhY=
24
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
5+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
6+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
7+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
8+
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
9+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
10+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
11+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
12+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

rows.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ import (
44
"bytes"
55
"database/sql/driver"
66
"encoding/csv"
7+
"errors"
78
"fmt"
89
"io"
10+
"reflect"
911
"strings"
12+
"time"
1013
)
1114

1215
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
@@ -127,6 +130,133 @@ type Rows struct {
127130
closeErr error
128131
}
129132

133+
// NewRowsFromStruct new Rows from struct reflect with tagName
134+
// tagName default "json"
135+
func NewRowsFromStruct(m interface{}, tagName ...string) (*Rows, error) {
136+
if m == nil {
137+
return nil, errors.New("param m is nil")
138+
}
139+
val := reflect.ValueOf(m).Elem()
140+
if val.Kind() != reflect.Struct {
141+
return nil, errors.New("param type must be struct")
142+
}
143+
num := val.NumField()
144+
if num == 0 {
145+
return nil, errors.New("no properties available")
146+
}
147+
columns := make([]string, 0, num)
148+
var values []driver.Value
149+
tag := "json"
150+
if len(tagName) > 0 {
151+
tag = tagName[0]
152+
}
153+
for i := 0; i < num; i++ {
154+
f := val.Type().Field(i)
155+
column := f.Tag.Get(tag)
156+
if len(column) > 0 {
157+
columns = append(columns, column)
158+
values = append(values, val.Field(i))
159+
}
160+
}
161+
if len(columns) == 0 {
162+
return nil, errors.New("tag not match")
163+
}
164+
rows := &Rows{
165+
cols: columns,
166+
nextErr: make(map[int]error),
167+
converter: reflectTypeConverter{},
168+
}
169+
return rows.AddRow(values...), nil
170+
}
171+
172+
var timeKind = reflect.TypeOf(time.Time{}).Kind()
173+
174+
type reflectTypeConverter struct{}
175+
176+
func (reflectTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
177+
rv := v.(reflect.Value)
178+
switch rv.Kind() {
179+
case reflect.Ptr:
180+
// indirect pointers
181+
if rv.IsNil() {
182+
return nil, nil
183+
} else {
184+
return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
185+
}
186+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
187+
return rv.Int(), nil
188+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
189+
return int64(rv.Uint()), nil
190+
case reflect.Uint64:
191+
u64 := rv.Uint()
192+
if u64 >= 1<<63 {
193+
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
194+
}
195+
return int64(u64), nil
196+
case reflect.Float32, reflect.Float64:
197+
return rv.Float(), nil
198+
case reflect.Bool:
199+
return rv.Bool(), nil
200+
case reflect.Slice:
201+
ek := rv.Type().Elem().Kind()
202+
if ek == reflect.Uint8 {
203+
return rv.Bytes(), nil
204+
}
205+
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
206+
case reflect.String:
207+
return rv.String(), nil
208+
case timeKind:
209+
return rv.Interface().(time.Time), nil
210+
}
211+
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
212+
}
213+
214+
// NewRowsFromStructs new Rows from struct slice reflect with tagName
215+
// NOTE: arr must be of the same type
216+
// tagName default "json"
217+
func NewRowsFromStructs(tagName string, arr ...interface{}) (*Rows, error) {
218+
if len(arr) == 0 {
219+
return nil, errors.New("param arr is nil")
220+
}
221+
typ := reflect.TypeOf(arr[0]).Elem()
222+
if typ.Kind() != reflect.Struct {
223+
return nil, errors.New("param type must be struct")
224+
}
225+
if typ.NumField() == 0 {
226+
return nil, errors.New("no properties available")
227+
}
228+
var columns []string
229+
tag := "json"
230+
if len(tagName) > 0 {
231+
tag = tagName
232+
}
233+
for i := 0; i < typ.NumField(); i++ {
234+
f := typ.Field(i)
235+
column := f.Tag.Get(tag)
236+
if len(column) > 0 {
237+
columns = append(columns, column)
238+
}
239+
}
240+
if len(columns) == 0 {
241+
return nil, errors.New("tag not match")
242+
}
243+
rows := &Rows{
244+
cols: columns,
245+
nextErr: make(map[int]error),
246+
converter: reflectTypeConverter{},
247+
}
248+
for _, m := range arr {
249+
v := m
250+
val := reflect.ValueOf(v).Elem()
251+
var values []driver.Value
252+
for _, column := range columns {
253+
values = append(values, val.FieldByName(column))
254+
}
255+
rows.AddRow(values...)
256+
}
257+
return rows, nil
258+
}
259+
130260
// NewRows allows Rows to be created from a
131261
// sql driver.Value slice or from the CSV string and
132262
// to be used as sql driver.Rows.

rows_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"database/sql"
66
"database/sql/driver"
77
"fmt"
8+
"github.com/stretchr/testify/assert"
89
"testing"
10+
"time"
911
)
1012

1113
const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ `
@@ -753,3 +755,52 @@ func ExampleRows_AddRows() {
753755
// Output: scanned id: 1 and title: one
754756
// scanned id: 2 and title: two
755757
}
758+
759+
type MockStruct struct {
760+
Type int `mock:"type"`
761+
Name string `mock:"name"`
762+
CreateTime time.Time `mock:"createTime"`
763+
}
764+
765+
func TestNewRowsFromStruct(t *testing.T) {
766+
m := &MockStruct{
767+
Type: 1,
768+
Name: "sqlMock",
769+
CreateTime: time.Now(),
770+
}
771+
excepted := NewRows([]string{"type", "name", "createTime"}).AddRow(m.Type, m.Name, m.CreateTime)
772+
773+
actual, err := NewRowsFromStruct(m, "mock")
774+
if err != nil {
775+
t.Fatal(err)
776+
}
777+
assert.EqualValues(t, excepted.cols, actual.cols)
778+
assert.EqualValues(t, excepted.rows, actual.rows)
779+
assert.EqualValues(t, excepted.def, actual.def)
780+
}
781+
782+
func TestNewRowsFromStructs(t *testing.T) {
783+
m1 := &MockStruct{
784+
Type: 1,
785+
Name: "sqlMock1",
786+
CreateTime: time.Now(),
787+
}
788+
m2 := &MockStruct{
789+
Type: 2,
790+
Name: "sqlMock2",
791+
CreateTime: time.Now(),
792+
}
793+
arr := []*MockStruct{m1, m2}
794+
795+
excepted := NewRows([]string{"type", "name", "createTime"})
796+
for _, v := range arr {
797+
excepted.AddRow(v.Type, v.Name, v.CreateTime)
798+
}
799+
actual, err := NewRowsFromStructs("mock", m1, m2)
800+
if err != nil {
801+
t.Fatal(err)
802+
}
803+
assert.EqualValues(t, excepted.cols, actual.cols)
804+
assert.EqualValues(t, excepted.rows, actual.rows)
805+
assert.EqualValues(t, excepted.def, actual.def)
806+
}

0 commit comments

Comments
 (0)