Skip to content

Commit 71ce9db

Browse files
committed
table interface fix
1 parent 04c446c commit 71ce9db

File tree

9 files changed

+87
-53
lines changed

9 files changed

+87
-53
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,22 @@ import (
3030
var db, _ = orm.OpenMysql("user:password@tcp(127.0.0.1:3306)/mydb?parseTime=true&charset=utf8mb4&loc=Asia%2FShanghai")
3131

3232
//user table model
33-
var UserTable = orm.NewQuery(User{}, db)
33+
var UserTable = orm.NewQuery(User{})
3434

3535
type User struct {
3636
Id int `json:"id"`
3737
Name string `json:"name"`
3838
}
3939

40-
//Table interface: implements two methods below
41-
func (User) TableName() string {
42-
return "user"
40+
func (User) Connection() []*sql.DB {
41+
return []*sql.DB{db}
4342
}
4443
func (User) DatabaseName() string {
4544
return "mydb"
4645
}
46+
func (User) TableName() string {
47+
return "user"
48+
}
4749
```
4850

4951
## migration (create table from struct | create struct from table)

example.go

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,37 @@ import (
1010
//connect mysql db
1111
var db, _ = orm.OpenMysql("user:password@tcp(127.0.0.1:3306)/mydb?parseTime=true&charset=utf8mb4&loc=Asia%2FShanghai")
1212

13-
//user table model
14-
var UserTable = orm.NewQuery(User{}, db)
13+
//query user
14+
var UserQuery = orm.NewQuery(User{})
1515

1616
type User struct {
1717
Id int `json:"id"`
1818
Name string `json:"name"`
1919
}
2020

21-
func (User) TableName() string {
22-
return "user"
21+
func (User) Connection() []*sql.DB {
22+
return []*sql.DB{db}
2323
}
2424
func (User) DatabaseName() string {
2525
return "mydb"
2626
}
27+
func (User) TableName() string {
28+
return "user"
29+
}
2730

2831
//user table model
29-
var OrderTable = orm.NewQuery(Order{}, db)
32+
var OrderQuery = orm.NewQuery(Order{})
3033

3134
type Order struct {
3235
Id int `json:"id"`
3336
UserId int `json:"user_id"`
3437
OrderAmount int `json:"order_amount"`
3538
}
3639

40+
func (o Order) Connection() []*sql.DB {
41+
return []*sql.DB{db}
42+
}
43+
3744
func (Order) TableName() string {
3845
return "order"
3946
}
@@ -44,88 +51,88 @@ func (Order) DatabaseName() string {
4451
func main() {
4552
{
4653
//create db table from go struct
47-
_, _ = orm.CreateTableFromStruct(UserTable)
54+
_, _ = orm.CreateTableFromStruct(UserQuery)
4855
//create go struct from db table
49-
_ = orm.CreateStructFromTable(UserTable)
56+
_ = orm.CreateStructFromTable(UserQuery)
5057
}
5158

5259
//query select
5360
{
5461
//get first user (name='join') as struct
55-
user, query := UserTable.Where(&UserTable.T.Name, "john").Get()
62+
user, query := UserQuery.Where(&UserQuery.T.Name, "john").Get()
5663
fmt.Println(user, query.Sql(), query.Error())
5764

5865
//get users by primary ids
59-
users, query := UserTable.Gets(1, 2, 3)
66+
users, query := UserQuery.Gets(1, 2, 3)
6067
fmt.Println(users, query.Sql(), query.Error())
6168

6269
//get user rows as []map[string]interface
63-
rows, query := UserTable.Limit(5).GetRows()
70+
rows, query := UserQuery.Limit(5).GetRows()
6471
fmt.Println(rows, query.Sql(), query.Error())
6572

6673
//get users count(*)
67-
count, query := UserTable.GetCount()
74+
count, query := UserQuery.GetCount()
6875
fmt.Println(count, query.Sql(), query.Error())
6976

7077
//get user names map key by id
7178
var userNameKeyById map[int]string
72-
UserTable.Select(&UserTable.T.Id, &UserTable.T.Name).GetTo(&userNameKeyById)
79+
UserQuery.Select(&UserQuery.T.Id, &UserQuery.T.Name).GetTo(&userNameKeyById)
7380

7481
//get users map key by name
7582
var usersMapkeyByName map[string][]User
76-
UserTable.Select(&UserTable.T.Name, UserTable.AllCols()).GetTo(&usersMapkeyByName)
83+
UserQuery.Select(&UserQuery.T.Name, UserQuery.AllCols()).GetTo(&usersMapkeyByName)
7784

7885
//select rank by column
79-
OrderTable.Where(&OrderTable.T.UserId, 1).
80-
Select(OrderTable.AllCols()).
81-
SelectRank(&OrderTable.T.OrderAmount, "order_amount_rank").GetRows()
86+
OrderQuery.Where(&OrderQuery.T.UserId, 1).
87+
Select(OrderQuery.AllCols()).
88+
SelectRank(&OrderQuery.T.OrderAmount, "order_amount_rank").GetRows()
8289
}
8390

8491
//query update and delete and insert
8592
{
8693
//update user set name="hello" where id=1
87-
UserTable.WherePrimary(1).Update(&UserTable.T.Name, "hello")
94+
UserQuery.WherePrimary(1).Update(&UserQuery.T.Name, "hello")
8895

8996
//query delete
90-
UserTable.Delete(1, 2, 3)
97+
UserQuery.Delete(1, 2, 3)
9198

9299
//query insert
93-
_ = UserTable.Insert(User{Name: "han"}).LastInsertId //insert one row and get id
100+
_ = UserQuery.Insert(User{Name: "han"}).LastInsertId //insert one row and get id
94101

95102
//insert batch on duplicate key update name=values(name)
96-
_ = UserTable.InsertsIgnore([]User{{Id: 1, Name: "han"}, {Id: 2, Name: "jen"}},
97-
[]orm.UpdateColumn{{Column: &UserTable.T.Name, Val: &UserTable.T.Name}})
103+
_ = UserQuery.InsertsIgnore([]User{{Id: 1, Name: "han"}, {Id: 2, Name: "jen"}},
104+
[]orm.UpdateColumn{{Column: &UserQuery.T.Name, Val: &UserQuery.T.Name}})
98105
}
99106

100107
//query join
101108
{
102-
UserTable.Join(OrderTable.T, func(query orm.Query[User]) orm.Query[User] {
103-
return query.Where(&UserTable.T.Id, &OrderTable.T.UserId)
104-
}).Where(&OrderTable.T.OrderAmount, 100).
105-
Select(UserTable.AllCols()).Gets()
109+
UserQuery.Join(OrderQuery.T, func(query orm.Query[User]) orm.Query[User] {
110+
return query.Where(&UserQuery.T.Id, &OrderQuery.T.UserId)
111+
}).Where(&OrderQuery.T.OrderAmount, 100).
112+
Select(UserQuery.AllCols()).Gets()
106113
}
107114
{
108115
//transaction
109-
_ = UserTable.Transaction(func(tx *sql.Tx) error {
110-
newId := UserTable.UseTx(tx).Insert(User{Name: "john"}).LastInsertId //insert
116+
_ = UserQuery.Transaction(func(tx *sql.Tx) error {
117+
newId := UserQuery.UseTx(tx).Insert(User{Name: "john"}).LastInsertId //insert
111118
fmt.Println(newId)
112119
return errors.New("I want rollback") //rollback
113120
})
114121
}
115122

116123
{
117124
//subquery
118-
subquery := UserTable.Where(&UserTable.T.Id, 1).SubQuery()
125+
subquery := UserQuery.Where(&UserQuery.T.Id, 1).SubQuery()
119126

120127
//where in suquery
121-
UserTable.Where(&UserTable.T.Id, orm.WhereIn, subquery).Gets()
128+
UserQuery.Where(&UserQuery.T.Id, orm.WhereIn, subquery).Gets()
122129

123130
//insert subquery
124-
UserTable.InsertSubquery(subquery, nil)
131+
UserQuery.InsertSubquery(subquery, nil)
125132

126133
//join subquery
127-
UserTable.Join(subquery, func(query orm.Query[User]) orm.Query[User] {
128-
return query.Where(&UserTable.T.Id, orm.Raw("sub.id"))
134+
UserQuery.Join(subquery, func(query orm.Query[User]) orm.Query[User] {
135+
return query.Where(&UserQuery.T.Id, orm.Raw("sub.id"))
129136
}).Gets()
130137
}
131138
}

orm/query.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,24 @@ func (m Query[T]) UseTx(tx *sql.Tx) Query[T] {
8181
func (m Query[T]) DB() *sql.DB {
8282
return m.writeDB()
8383
}
84+
85+
func (m Query[T]) DBs() []*sql.DB {
86+
if len(m.writeAndReadDbs) == 0 && len(m.tables) > 0 {
87+
m.writeAndReadDbs = m.tables[0].table.Connection()
88+
}
89+
return m.writeAndReadDbs
90+
}
8491
func (m Query[T]) writeDB() *sql.DB {
85-
if len(m.writeAndReadDbs) > 0 {
86-
return m.writeAndReadDbs[0]
92+
dbs := m.DBs()
93+
if len(dbs) > 0 {
94+
return dbs[0]
8795
}
8896
return nil
8997
}
9098
func (m Query[T]) readDB() *sql.DB {
91-
if len(m.writeAndReadDbs) > 1 {
92-
return m.writeAndReadDbs[rand.Intn(len(m.writeAndReadDbs)-1)+1] //rand get db
99+
dbs := m.DBs()
100+
if len(dbs) > 1 {
101+
return dbs[rand.Intn(len(dbs)-1)+1] //rand get db
93102
} else {
94103
return m.writeDB()
95104
}

orm/query_select_gen.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func (m Query[T]) SubQuery() SubQuery {
1212
mt := cte.WithRecursiveCte(m.SubQuery(), cte.T.TableName())
1313
tempTable := mt.generateSelectQuery(mt.columns...)
1414

15-
tempTable.dbs = mt.writeAndReadDbs
15+
tempTable.dbs = mt.DBs()
1616
tempTable.tx = mt.tx
1717
tempTable.dbName = mt.tables[0].table.DatabaseName()
1818
if mt.result.Err != nil {
@@ -24,7 +24,7 @@ func (m Query[T]) SubQuery() SubQuery {
2424
mt := m
2525
tempTable := mt.generateSelectQuery(mt.columns...)
2626

27-
tempTable.dbs = mt.writeAndReadDbs
27+
tempTable.dbs = mt.DBs()
2828
tempTable.tx = mt.tx
2929
tempTable.dbName = mt.tables[0].table.DatabaseName()
3030
if mt.result.Err != nil {

orm/query_with.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ func (m Query[T]) WithParentsOnColumn(pidColumn interface{}) Query[T] {
1515
coln := strings.Split(col, ".")
1616
newcol := strings.Trim(coln[len(coln)-1], "`")
1717

18-
cte := NewQueryRaw(tempName, m.writeAndReadDbs...)
18+
cte := NewQueryRaw(tempName, m.DBs()...)
1919

20-
appendQuery := NewQuery(*m.T, m.writeAndReadDbs...)
20+
appendQuery := NewQuery(*m.T, m.DBs()...)
2121
appendQuery = appendQuery.Join(cte.T, func(query Query[T]) Query[T] {
2222
return query.Where(appendQuery.tables[0].tableStruct.Field(0).Addr().Interface(), Raw(tempName+"."+newcol))
2323
}).Select(appendQuery.AllCols())
@@ -43,9 +43,9 @@ func (m Query[T]) WithChildrenOnColumn(pidColumn interface{}) Query[T] {
4343
coln := strings.Split(col, ".")
4444
newcol := strings.Trim(coln[len(coln)-1], "`")
4545

46-
cte := NewQueryRaw(tempName, m.writeAndReadDbs...)
46+
cte := NewQueryRaw(tempName, m.DBs()...)
4747

48-
appendQuery := NewQuery(*m.T, m.writeAndReadDbs...)
48+
appendQuery := NewQuery(*m.T, m.DBs()...)
4949
appendQuery = appendQuery.Join(cte.T, func(query Query[T]) Query[T] {
5050
return query.Where(pcol, Raw(tempName+"."+newcol))
5151
}).Select(appendQuery.AllCols())

orm/subquery.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ func NewSubQuery(prepareSql string, bindings ...interface{}) SubQuery {
2424
return SubQuery{raw: prepareSql, bindings: bindings}
2525
}
2626

27+
func (m SubQuery) Connection() []*sql.DB {
28+
return m.dbs
29+
}
30+
31+
func (m SubQuery) DatabaseName() string {
32+
return m.dbName
33+
}
34+
2735
func (m SubQuery) TableName() string {
2836
if m.tableName != "" {
2937
return m.tableName
@@ -34,10 +42,6 @@ func (m SubQuery) TableName() string {
3442
return ""
3543
}
3644

37-
func (m SubQuery) DatabaseName() string {
38-
return m.dbName
39-
}
40-
4145
func (m SubQuery) Error() error {
4246
return m.err
4347
}

orm/table.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package orm
22

3+
import "database/sql"
4+
35
type Table interface {
4-
TableName() string
6+
Connection() []*sql.DB
57
DatabaseName() string
8+
TableName() string
69
}

orm_migrate_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"database/sql"
45
"github.com/folospace/go-mysql-orm/orm"
56
"testing"
67
"time"
@@ -44,6 +45,9 @@ type Family struct {
4445
Updated time.Time `json:"updated" orm:"updated,timestamp" default:"CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"`
4546
}
4647

48+
func (Family) Connection() []*sql.DB {
49+
return []*sql.DB{tdb}
50+
}
4751
func (Family) TableName() string {
4852
return "family"
4953
}
@@ -54,7 +58,7 @@ func (Family) DatabaseName() string {
5458

5559
func TestMigrate(t *testing.T) {
5660
t.Run("create_struct", func(t *testing.T) {
57-
FamilyTable := orm.NewQuery(Family{}, tdb)
61+
FamilyTable := orm.NewQuery(Family{})
5862

5963
err := orm.CreateStructFromTable(FamilyTable)
6064

orm_select_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package main
22

33
import (
4+
"database/sql"
45
"github.com/folospace/go-mysql-orm/orm"
56
"testing"
67
"time"
78
)
89

910
var tdb, _ = orm.OpenMysql("rfamro@tcp(mysql-rfam-public.ebi.ac.uk:4497)/Rfam?parseTime=true&charset=utf8mb4&loc=Asia%2FShanghai")
1011

11-
var FamilyTable2 = orm.NewQuery(Family2{}, tdb)
12+
var FamilyTable2 = orm.NewQuery(Family2{})
1213

1314
type Family2 struct {
1415
RfamAcc string `json:"rfam_acc" orm:"rfam_acc,varchar(7),primary,unique"`
@@ -48,6 +49,10 @@ type Family2 struct {
4849
Updated time.Time `json:"updated" orm:"updated,timestamp" default:"CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"`
4950
}
5051

52+
func (f Family2) Connection() []*sql.DB {
53+
return []*sql.DB{tdb}
54+
}
55+
5156
func (Family2) TableName() string {
5257
return "family"
5358
}

0 commit comments

Comments
 (0)