diff --git a/expr.go b/expr.go index a8749f10..a415d40b 100644 --- a/expr.go +++ b/expr.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "reflect" + "sort" "strings" ) @@ -66,6 +67,18 @@ func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) { return } +// GenerateOrderPredicateIndex provides a slice of keys useful for ordering predicates. +func GenerateOrderPredicateIndex(predicates map[string]interface{}) []string { + keys := make([]string, len(predicates)) + counter := 0 + for key := range predicates { + keys[counter] = key + counter++ + } + sort.Strings(keys) + return keys +} + // Eq is syntactic sugar for use with Where/Having/Set methods. // Ex: // .Where(Eq{"id": 1}) @@ -74,9 +87,9 @@ type Eq map[string]interface{} func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) { var ( exprs []string - equalOpr string = "=" - inOpr string = "IN" - nullOpr string = "IS" + equalOpr = "=" + inOpr = "IN" + nullOpr = "IS" ) if useNotOpr { @@ -85,8 +98,12 @@ func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) { nullOpr = "IS NOT" } - for key, val := range eq { - expr := "" + predicateIndex := GenerateOrderPredicateIndex(eq) + + for _, key := range predicateIndex { + val := eq[key] + + var expr string switch v := val.(type) { case driver.Valuer: @@ -143,7 +160,7 @@ type Lt map[string]interface{} func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) { var ( exprs []string - opr string = "<" + opr = "<" ) if opposite { @@ -154,8 +171,10 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err opr = fmt.Sprintf("%s%s", opr, "=") } - for key, val := range lt { - expr := "" + predicateIndex := GenerateOrderPredicateIndex(lt) + + for _, key := range predicateIndex { + val := lt[key] switch v := val.(type) { case driver.Valuer: @@ -167,16 +186,16 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err if val == nil { err = fmt.Errorf("cannot use null with less than or greater than operators") return - } else { - valVal := reflect.ValueOf(val) - if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice { - err = fmt.Errorf("cannot use array or slice with less than or greater than operators") - return - } else { - expr = fmt.Sprintf("%s %s ?", key, opr) - args = append(args, val) - } } + + valVal := reflect.ValueOf(val) + if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice { + err = fmt.Errorf("cannot use array or slice with less than or greater than operators") + return + } + + expr := fmt.Sprintf("%s %s ?", key, opr) + args = append(args, val) exprs = append(exprs, expr) } sql = strings.Join(exprs, " AND ") diff --git a/expr_test.go b/expr_test.go index f8565f4b..3c811e05 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2,10 +2,54 @@ package squirrel import ( "database/sql" - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) +func TestGenerateOrderPredicateIndex(t *testing.T) { + output := []string{"one", "two"} + + type args struct { + predicates map[string]interface{} + } + + tests := []struct { + args args + want []string + }{ + {args{Eq{}}, []string{}}, + {args{Eq{"one": 1}}, []string{"one"}}, + {args{Eq{"one": 1, "two": 2}}, output}, + {args{Eq{"two": 2, "one": 1}}, output}, + + {args{Lt{}}, []string{}}, + {args{Lt{"one": 1}}, []string{"one"}}, + {args{Lt{"one": 1, "two": 2}}, output}, + {args{Lt{"two": 2, "one": 1}}, output}, + + {args{Gt{}}, []string{}}, + {args{Gt{"one": 1}}, []string{"one"}}, + {args{Gt{"one": 1, "two": 2}}, output}, + {args{Gt{"two": 2, "one": 1}}, output}, + + {args{GtOrEq{}}, []string{}}, + {args{GtOrEq{"one": 1}}, []string{"one"}}, + {args{GtOrEq{"one": 1, "two": 2}}, output}, + {args{GtOrEq{"two": 2, "one": 1}}, output}, + + {args{LtOrEq{}}, []string{}}, + {args{LtOrEq{"one": 1}}, []string{"one"}}, + {args{LtOrEq{"one": 1, "two": 2}}, output}, + {args{LtOrEq{"two": 2, "one": 1}}, output}, + } + + for _, tt := range tests { + got := GenerateOrderPredicateIndex(tt.args.predicates) + assert.Equal(t, tt.want, got) + } +} + func TestEqToSql(t *testing.T) { b := Eq{"id": 1} sql, args, err := b.ToSql()