Skip to content
53 changes: 36 additions & 17 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"reflect"
"sort"
"strings"
)

Expand Down Expand Up @@ -66,6 +67,18 @@ func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) {
return
}

// GenerateOrderPredicateIndex provides a slice of keys useful for ordering predicates.
func GenerateOrderPredicateIndex(predicates map[string]interface{}) []string {
keys := make([]string, len(predicates))
counter := 0
for key := range predicates {
keys[counter] = key
counter++
}
sort.Strings(keys)
return keys
}

// Eq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Eq{"id": 1})
Expand All @@ -74,9 +87,9 @@ type Eq map[string]interface{}
func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
var (
exprs []string
equalOpr string = "="
inOpr string = "IN"
nullOpr string = "IS"
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
)

if useNotOpr {
Expand All @@ -85,8 +98,12 @@ func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
nullOpr = "IS NOT"
}

for key, val := range eq {
expr := ""
predicateIndex := GenerateOrderPredicateIndex(eq)

for _, key := range predicateIndex {
val := eq[key]

var expr string

switch v := val.(type) {
case driver.Valuer:
Expand Down Expand Up @@ -143,7 +160,7 @@ type Lt map[string]interface{}
func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) {
var (
exprs []string
opr string = "<"
opr = "<"
)

if opposite {
Expand All @@ -154,8 +171,10 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
opr = fmt.Sprintf("%s%s", opr, "=")
}

for key, val := range lt {
expr := ""
predicateIndex := GenerateOrderPredicateIndex(lt)

for _, key := range predicateIndex {
val := lt[key]

switch v := val.(type) {
case driver.Valuer:
Expand All @@ -167,16 +186,16 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return
} else {
valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
} else {
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}

valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
}

expr := fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
Expand Down
46 changes: 45 additions & 1 deletion expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,54 @@ package squirrel

import (
"database/sql"
"github.com/stretchr/testify/assert"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGenerateOrderPredicateIndex(t *testing.T) {
output := []string{"one", "two"}

type args struct {
predicates map[string]interface{}
}

tests := []struct {
args args
want []string
}{
{args{Eq{}}, []string{}},
{args{Eq{"one": 1}}, []string{"one"}},
{args{Eq{"one": 1, "two": 2}}, output},
{args{Eq{"two": 2, "one": 1}}, output},

{args{Lt{}}, []string{}},
{args{Lt{"one": 1}}, []string{"one"}},
{args{Lt{"one": 1, "two": 2}}, output},
{args{Lt{"two": 2, "one": 1}}, output},

{args{Gt{}}, []string{}},
{args{Gt{"one": 1}}, []string{"one"}},
{args{Gt{"one": 1, "two": 2}}, output},
{args{Gt{"two": 2, "one": 1}}, output},

{args{GtOrEq{}}, []string{}},
{args{GtOrEq{"one": 1}}, []string{"one"}},
{args{GtOrEq{"one": 1, "two": 2}}, output},
{args{GtOrEq{"two": 2, "one": 1}}, output},

{args{LtOrEq{}}, []string{}},
{args{LtOrEq{"one": 1}}, []string{"one"}},
{args{LtOrEq{"one": 1, "two": 2}}, output},
{args{LtOrEq{"two": 2, "one": 1}}, output},
}

for _, tt := range tests {
got := GenerateOrderPredicateIndex(tt.args.predicates)
assert.Equal(t, tt.want, got)
}
}

func TestEqToSql(t *testing.T) {
b := Eq{"id": 1}
sql, args, err := b.ToSql()
Expand Down