diff --git a/constants/constants.go b/constants/constants.go index 4fe8d1d..bd02c53 100644 --- a/constants/constants.go +++ b/constants/constants.go @@ -18,8 +18,9 @@ package constants const ( - Comma = "," - LeftBracket = "(" - RightBracket = ")" - DefaultPrimaryName = "id" + Comma = "," + LeftBracket = "(" + RightBracket = ")" + DefaultPrimaryName = "id" + DefaultGormPlusConnName = "DefaultGormPlusConnName" //内置的gorm-plus 数据库连接名 ) diff --git a/gplus/cache.go b/gplus/cache.go index feed184..7c1f041 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -31,9 +31,10 @@ var columnNameCache sync.Map var modelInstanceCache sync.Map // Cache 缓存实体对象所有的字段名 -func Cache(models ...any) { +func Cache(opt Option, models ...any) { + db, _, _ := getDefaultDbByOpt(opt) for _, model := range models { - columnNameMap := getColumnNameMap(model) + columnNameMap := getColumnNameMap(model, db.Config.NamingStrategy) for pointer, columnName := range columnNameMap { columnNameCache.Store(pointer, columnName) } @@ -43,7 +44,7 @@ func Cache(models ...any) { } } -func getColumnNameMap(model any) map[uintptr]string { +func getColumnNameMap(model any, namingStrategy schema.Namer) map[uintptr]string { var columnNameMap = make(map[uintptr]string) valueOf := reflect.ValueOf(model).Elem() typeOf := reflect.TypeOf(model).Elem() @@ -52,14 +53,14 @@ func getColumnNameMap(model any) map[uintptr]string { // 如果当前实体嵌入了其他实体,同样需要缓存它的字段名 if field.Anonymous { // 如果存在多重嵌套,通过递归方式获取他们的字段名 - subFieldMap := getSubFieldColumnNameMap(valueOf, field) + subFieldMap := getSubFieldColumnNameMap(valueOf, field, namingStrategy) for pointer, columnName := range subFieldMap { columnNameMap[pointer] = columnName } } else { // 获取对象字段指针值 pointer := valueOf.Field(i).Addr().Pointer() - columnName := parseColumnName(field) + columnName := parseColumnName(field, namingStrategy) columnNameMap[pointer] = columnName } } @@ -67,7 +68,8 @@ func getColumnNameMap(model any) map[uintptr]string { } // GetModel 获取 -func GetModel[T any]() *T { +func GetModel[T any](opts ...OptionFunc) *T { + opt := getDefaultOptionInfo(opts...) //兼容设计 modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) @@ -76,12 +78,12 @@ func GetModel[T any]() *T { } } t := new(T) - Cache(t) + Cache(opt, t) return t } // 递归获取嵌套字段名 -func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) map[uintptr]string { +func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, namingStrategy schema.Namer) map[uintptr]string { result := make(map[uintptr]string) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -90,13 +92,13 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnNameMap(valueOf, subField) + nestedFields := getSubFieldColumnNameMap(valueOf, subField, namingStrategy) for key, value := range nestedFields { result[key] = value } } else { pointer := valueOf.FieldByName(modelType.Field(j).Name).Addr().Pointer() - name := parseColumnName(modelType.Field(j)) + name := parseColumnName(modelType.Field(j), namingStrategy) result[pointer] = name } } @@ -104,14 +106,17 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) return result } -// 解析字段名称 -func parseColumnName(field reflect.StructField) string { +// 解析字段名称 兼容多数据库切换, +// 如果用户使用Option的GetDb而没有传数据库连接名这边获取的namingStrategy 是默认的一个可能会有问题, +// 所以建议用户多数据库的时候弃用Option里的Db,并且重新改写初始化,给与每个db连接有连接名 +// 并且改造下多数据使用NewQuery和GetModel和NewQueryModel相关方法传入数据库连接名 +func parseColumnName(field reflect.StructField, namingStrategy schema.Namer) string { tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";") name, ok := tagSetting["COLUMN"] if ok { return name } - return globalDb.Config.NamingStrategy.ColumnName("", field.Name) + return namingStrategy.ColumnName("", field.Name) } func getColumnName(v any) string { diff --git a/gplus/dao.go b/gplus/dao.go index 341496a..fa55cc4 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -18,6 +18,7 @@ package gplus import ( "database/sql" + "errors" "fmt" "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm" @@ -28,11 +29,38 @@ import ( "time" ) -var globalDb *gorm.DB +var globalDbMap = make(map[string]*gorm.DB) +var globalDbKeys []string var defaultBatchSize = 1000 -func Init(db *gorm.DB) { - globalDb = db +// Init 可选参数dbConnNameArr 代表数据库连接名,只需要传一个就行, +// 主要为了兼容之前用户只传一个db无需修改 +func Init(db *gorm.DB, dbConnNameArr ...string) error { + var dbConnName = "" + if len(dbConnNameArr) > 0 { + dbConnName = dbConnNameArr[0] + } + return setGlobalInfo(db, dbConnName) +} + +// InitMany 初始化多个 +func InitMany(dic map[string]*gorm.DB) []error { + var errs []error + for k, v := range dic { + if err := setGlobalInfo(v, k); err != nil { + errs = append(errs, err) + } + } + return errs +} + +// GetDb 获取数据库连接 +func GetDb(dbConnName string) (*gorm.DB, error) { + db, exists := globalDbMap[dbConnName] + if exists { + return db, nil + } + return nil, errors.New("MultipleDbChange not exists dbConn:" + dbConnName + ",please check") } type Page[T any] struct { @@ -45,8 +73,8 @@ type Page[T any] struct { type Dao[T any] struct{} -func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) { - return NewQuery[T]() +func (dao Dao[T]) NewQuery(opts ...OptionFunc) (*QueryCond[T], *T) { + return NewQuery[T](opts...) } func NewPage[T any](current, size int) *Page[T] { @@ -157,7 +185,7 @@ func UpdateZeroById[T any](entity *T, opts ...OptionFunc) *gorm.DB { func updateAllIfNeed(entity any, opts []OptionFunc, db *gorm.DB) { option := getOption(opts) if len(option.Selects) == 0 { - columnNameMap := getColumnNameMap(entity) + columnNameMap := getColumnNameMap(entity, db.Config.NamingStrategy) var columnNames []string for _, columnName := range columnNameMap { columnNames = append(columnNames, columnName) @@ -449,14 +477,21 @@ func buildSqlAndArgs[T any](expressions []any, sqlBuilder *strings.Builder, quer } func getDb(opts ...OptionFunc) *gorm.DB { + var db *gorm.DB option := getOption(opts) - // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db - var db = globalDb.Clauses() if option.Db != nil { - db = option.Db.Clauses() + db = option.Db + } else { + db, option.DbConnName, _ = getDefaultDbByName(option.DbConnName) } + //设置session,如果需要子句仅在当前会话生效,先调用 Session(),再调用 Clauses()。 + setSessionIfNeed(option, db) + + // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db + db = db.Clauses() + // 设置需要忽略的字段 setOmitIfNeed(option, db) @@ -496,6 +531,12 @@ func setOmitIfNeed(option Option, db *gorm.DB) { } } +func setSessionIfNeed(option Option, db *gorm.DB) { + if option.DbSession != nil { + db.Session(option.DbSession) + } +} + func getPkColumnName[T any]() string { var entity T entityType := reflect.TypeOf(entity) @@ -520,3 +561,60 @@ func getPkColumnName[T any]() string { } return columnName } + +func getDefaultDbConnName() string { + dbConnName := constants.DefaultGormPlusConnName + //如果用户没传数据库连接名称,优先判断全局自定义的连接名是否存在, + //如果上面不存在其次从全局globalDbKeys里获取第一个连接名 + //1.避免用户使用InitDb方法初始化数据库 自定义数据库连接名 ,然后方法里不传是哪个数据库连接名 则只能默认取第一条 + //2.再混用单库Init取初始化,做方法兼容 + _, exists := globalDbMap[dbConnName] + if exists { + return dbConnName + } + dbConnName = globalDbKeys[0] + return dbConnName +} + +// 获取如果连接名为空则默认填充的option数据 +func getDefaultOptionInfo(opts ...OptionFunc) Option { + option := getOption(opts) + if len(option.DbConnName) == 0 { + option.DbConnName = getDefaultDbConnName() //兼容之前设计 + } + return option +} + +func getDefaultDbByOpt(opt Option) (*gorm.DB, string, error) { + return getDefaultDbByName(opt.DbConnName) +} + +func getDefaultDbByName(dbConnName string) (*gorm.DB, string, error) { + if len(dbConnName) == 0 { + dbConnName = getDefaultDbConnName() + } + db, err := GetDb(dbConnName) + return db, dbConnName, err +} + +func setGlobalInfo(db *gorm.DB, dbConnName string) error { + if len(dbConnName) == 0 { + //return errors.New("InitMultiple dbConnName is empty please check") + //如果字典里不包含了默认名则使用默认名,兼容之前单库 + _, exists := globalDbMap[constants.DefaultGormPlusConnName] + if exists { + //根据db指针地址获取作为连接名,因为GORM 本身不提供直接获取数据库连接地址的方法,也不推荐使用反射来获取dsn + dbConnName = fmt.Sprintf("%p", db) + } else { + dbConnName = constants.DefaultGormPlusConnName + } + } + _, exists := globalDbMap[dbConnName] + if !exists { + // db instance register to global variable + globalDbMap[dbConnName] = db + globalDbKeys = append(globalDbKeys, dbConnName) + return nil + } + return errors.New("InitMultiple have same name:" + dbConnName + ",please check") +} diff --git a/gplus/option.go b/gplus/option.go index 1a89ee7..675530e 100644 --- a/gplus/option.go +++ b/gplus/option.go @@ -17,13 +17,17 @@ package gplus -import "gorm.io/gorm" +import ( + "gorm.io/gorm" +) type Option struct { Db *gorm.DB Selects []any Omits []any IgnoreTotal bool + DbConnName string + DbSession *gorm.Session } type OptionFunc func(*Option) @@ -35,10 +39,10 @@ func Db(db *gorm.DB) OptionFunc { } } -// Session 创建回话 +// Session 创建会话 func Session(session *gorm.Session) OptionFunc { return func(o *Option) { - o.Db = globalDb.Session(session) + o.DbSession = session //调整session 在dao类的getDb方法那边处理 } } @@ -62,3 +66,10 @@ func IgnoreTotal() OptionFunc { o.IgnoreTotal = true } } + +// DbConnName 多个数据库连接根据自定义连接名称选择切换 +func DbConnName(dbConnName string) OptionFunc { + return func(o *Option) { + o.DbConnName = dbConnName + } +} diff --git a/gplus/query.go b/gplus/query.go index d40a2b8..b13063f 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -46,7 +46,8 @@ func (q *QueryCond[T]) getSqlSegment() string { } // NewQuery 构建查询条件 -func NewQuery[T any]() (*QueryCond[T], *T) { +func NewQuery[T any](opts ...OptionFunc) (*QueryCond[T], *T) { + opt := getDefaultOptionInfo(opts...) //兼容设计 q := &QueryCond[T]{} modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { @@ -56,12 +57,13 @@ func NewQuery[T any]() (*QueryCond[T], *T) { } } m := new(T) - Cache(m) + Cache(opt, m) return q, m } // NewQueryModel 构建查询条件 -func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { +func NewQueryModel[T any, R any](opts ...OptionFunc) (*QueryCond[T], *T, *R) { + opt := getDefaultOptionInfo(opts...) //兼容设计 q := &QueryCond[T]{} var t *T var r *R @@ -83,12 +85,12 @@ func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { if t == nil { t = new(T) - Cache(t) + Cache(opt, t) } if r == nil { r = new(R) - Cache(r) + Cache(opt, r) } return q, t, r diff --git a/gplus/tool.go b/gplus/tool.go index 5e7917c..661f2f2 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -19,6 +19,7 @@ package gplus import ( "fmt" + "gorm.io/gorm/schema" "net/url" "reflect" "strconv" @@ -55,13 +56,16 @@ var builders = map[string]func(query *QueryCond[any], name string, value any){ "<": lt, } -func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { +func BuildQuery[T any](queryParams url.Values, opts ...OptionFunc) *QueryCond[T] { + opt := getDefaultOptionInfo(opts...) //兼容设计 columnCondMap, conditionMap, gcond := parseParams(queryParams) parentQuery := buildParentQuery[T](conditionMap) - queryCondMap := buildQueryCondMap[T](columnCondMap) + db, _, _ := getDefaultDbByOpt(opt) + + queryCondMap := buildQueryCondMap[T](columnCondMap, db.Config.NamingStrategy) // 如果没有分组条件,直接返回默认的查询条件 if len(gcond) == 0 { @@ -159,9 +163,9 @@ func getCurrentOp(value string) string { return currentOperator } -func buildQueryCondMap[T any](columnCondMap map[string][]*Condition) map[string]*QueryCond[T] { +func buildQueryCondMap[T any](columnCondMap map[string][]*Condition, namingStrategy schema.Namer) map[string]*QueryCond[T] { var queryCondMap = make(map[string]*QueryCond[T]) - columnTypeMap := getColumnTypeMap[T]() + columnTypeMap := getColumnTypeMap[T](namingStrategy) for key, conditions := range columnCondMap { query := &QueryCond[any]{} query.columnTypeMap = columnTypeMap @@ -273,7 +277,7 @@ func buildGroupQuery[T any](gcond string, queryMaps map[string]*QueryCond[T], qu return query } -func getColumnTypeMap[T any]() map[string]reflect.Type { +func getColumnTypeMap[T any](namingStrategy schema.Namer) map[string]reflect.Type { modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := columnTypeCache.Load(modelTypeStr); ok { if columnNameMap, isOk := model.(map[string]reflect.Type); isOk { @@ -285,19 +289,19 @@ func getColumnTypeMap[T any]() map[string]reflect.Type { for i := 0; i < typeOf.NumField(); i++ { field := typeOf.Field(i) if field.Anonymous { - nestedFields := getSubFieldColumnTypeMap(field) + nestedFields := getSubFieldColumnTypeMap(field, namingStrategy) for key, value := range nestedFields { columnTypeMap[key] = value } } - columnName := parseColumnName(field) + columnName := parseColumnName(field, namingStrategy) columnTypeMap[columnName] = field.Type } columnTypeCache.Store(modelTypeStr, columnTypeMap) return columnTypeMap } -func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type { +func getSubFieldColumnTypeMap(field reflect.StructField, namingStrategy schema.Namer) map[string]reflect.Type { columnTypeMap := make(map[string]reflect.Type) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -306,12 +310,12 @@ func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnTypeMap(subField) + nestedFields := getSubFieldColumnTypeMap(subField, namingStrategy) for key, value := range nestedFields { columnTypeMap[key] = value } } else { - columnName := parseColumnName(subField) + columnName := parseColumnName(subField, namingStrategy) columnTypeMap[columnName] = subField.Type } } diff --git a/tests/dao_test.go b/tests/dao_test.go index 8ad5133..818dc60 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -24,6 +24,7 @@ import ( "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" + "net/url" "reflect" "sort" "strconv" @@ -31,9 +32,13 @@ import ( ) var gormDb *gorm.DB +var gormDbConnName = "test1" +var dbAddress = "127.0.0.1:3306" +var dbUser = "root" +var dbPassword = "123456" func init() { - dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" + dsn := fmt.Sprintf("%s:%s@tcp(%s)/test?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbAddress) var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), @@ -44,6 +49,21 @@ func init() { var u User gormDb.AutoMigrate(u) gplus.Init(gormDb) + initDb() +} + +func initDb() { + dsn := fmt.Sprintf("%s:%s@tcp(%s)/test1?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbAddress) + var err error + gormDb1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + fmt.Println(err) + } + var u User + gormDb1.AutoMigrate(u) + gplus.Init(gormDb1, gormDbConnName) } func TestInsert(t *testing.T) { @@ -589,12 +609,182 @@ func TestTx(t *testing.T) { } } +func TestInsertBaseDb(t *testing.T) { + deleteOldDataBaseDb() + + user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 100, Dept: "开发部门"} + resultDb := gplus.Insert(user, gplus.DbConnName(gormDbConnName)) + + if resultDb.Error != nil { + t.Fatalf("errors happened when insert: %v", resultDb.Error) + } else if resultDb.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, resultDb.RowsAffected) + } + + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") +} + +func TestInsertBatchBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + resultDb := gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) + if resultDb.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), resultDb.RowsAffected) + } + + for _, user := range users { + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + } +} + +func TestDeleteByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatchSize[User](users, 2, gplus.DbConnName(gormDbConnName)) + + if res := gplus.DeleteById[User](users[1].ID, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when deleteById: %v, affected: %v", res.Error, res.RowsAffected) + } + + _, resultDb := gplus.SelectById[User](users[1].ID, gplus.DbConnName(gormDbConnName)) + if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", resultDb.Error) + } +} + +func TestDeleteBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + opt := gplus.DbConnName(gormDbConnName) + gplus.InsertBatch[User](users, opt) + + query, u := gplus.NewQuery[User](opt) + query.Eq(&u.Username, "afumu1") + if res := gplus.Delete[User](query, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when Delete: %v, affected: %v", res.Error, res.RowsAffected) + } + + _, resultDb := gplus.SelectOne[User](query, gplus.DbConnName(gormDbConnName)) + if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", resultDb.Error) + } +} + +func TestUpdateByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) + + user := users[0] + user.Score = 100 + user.Age = 25 + + if res := gplus.UpdateById[User](user, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) + } + + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + +} + +func TestSelectByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) + user := users[0] + resultUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) + if db.Error != nil { + t.Errorf("errors happened when selectById : %v", db.Error) + } else { + AssertObjEqual(t, resultUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + } +} + +func TestSelectGeneric6BaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + opt := gplus.DbConnName(gormDbConnName) + gplus.InsertBatch[User](users, opt) + type UserVo struct { + Dept string + Score int + } + var userMap = make(map[string]int) + for _, user := range users { + userMap[user.Dept] += user.Score + } + //测试NewQuery和GetModel + query, u := gplus.NewQuery[User](opt) + uvo := gplus.GetModel[UserVo](opt) + query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) + UserVos, resultDb := gplus.SelectGeneric[User, []UserVo](query, opt) + + if resultDb.Error != nil { + t.Errorf("errors happened when resultDb : %v", resultDb.Error) + } + + for _, userVo := range UserVos { + score := userMap[userVo.Dept] + if userVo.Score != score { + t.Errorf("errors happened when SelectGeneric") + } + } + + //测试NewQueryModel + type UserV1 struct { + Name string + Age int64 + } + query, user, userV1 := gplus.NewQueryModel[User, UserV1](opt) + query.Eq(&user.Username, "afumu").And(func(q *gplus.QueryCond[User]) { + q.Eq(&user.Address, "北京").Or().Eq(&user.Age, 20) + }).Select(gplus.As(&user.Username, &userV1.Name), &user.Age) + gplus.SelectGeneric[User, []UserV1](query, opt) + + //如果还是使用旧有的方法测试 + query, u = gplus.NewQuery[User]() + uvo = gplus.GetModel[UserVo]() + query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) + UserVos, resultDb = gplus.SelectGeneric[User, []UserVo](query, opt) + + if resultDb.Error != nil { + t.Errorf("errors happened when resultDb : %v", resultDb.Error) + } +} + +func TestQueryByIdBaseDb(t *testing.T) { + opt := gplus.DbConnName(gormDbConnName) + values := url.Values{} + values["q"] = []string{"id=1"} + query := gplus.BuildQuery[User](values, opt) + gplus.SelectList[User](query, opt) +} + func deleteOldData() { q, u := gplus.NewQuery[User]() q.IsNotNull(&u.ID) gplus.Delete(q) } +func deleteOldDataBaseDb() { + opt := gplus.DbConnName(gormDbConnName) + q, u := gplus.NewQuery[User](opt) + q.IsNotNull(&u.ID) + gplus.Delete(q, gplus.DbConnName(gormDbConnName)) +} + func getUsers() []*User { user1 := &User{Username: "afumu1", Password: "123456", Age: 18, Score: 12, Dept: "开发部门"} user2 := &User{Username: "afumu2", Password: "123456", Age: 16, Score: 34, Dept: "行政部门"}