diff --git a/cte.go b/cte.go new file mode 100644 index 00000000..dbb22bef --- /dev/null +++ b/cte.go @@ -0,0 +1,45 @@ +package squirrel + +import ( + "bytes" + "strings" +) + +// CTE represents a single common table expression. They are composed of an alias, a few optional components, and a data manipulation statement, though exactly what sort of statement depends on the database system you're using. MySQL, for example, only allows SELECT statements; others, like PostgreSQL, permit INSERTs, UPDATEs, and DELETEs. +// The optional components supported by this fork of Squirrel include: +// * a list of columns +// * the keyword RECURSIVE, the use of which may place additional constraints on the data manipulation statement +type CTE struct { + Alias string + ColumnList []string + Recursive bool + Expression Sqlizer +} + +// ToSql builds the SQL for a CTE +func (c CTE) ToSql() (string, []interface{}, error) { + + var buf bytes.Buffer + + if c.Recursive { + buf.WriteString("RECURSIVE ") + } + + buf.WriteString(c.Alias) + + if len(c.ColumnList) > 0 { + buf.WriteString("(") + buf.WriteString(strings.Join(c.ColumnList, ", ")) + buf.WriteString(")") + } + + buf.WriteString(" AS (") + sql, args, err := c.Expression.ToSql() + if err != nil { + return "", []interface{}{}, err + } + buf.WriteString(sql) + buf.WriteString(")") + + return buf.String(), args, nil +} diff --git a/cte_test.go b/cte_test.go new file mode 100644 index 00000000..597252f5 --- /dev/null +++ b/cte_test.go @@ -0,0 +1,42 @@ +package squirrel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalCTE(t *testing.T) { + + cte := CTE{ + Alias: "cte", + ColumnList: []string{"abc", "def"}, + Recursive: false, + Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}), + } + + sql, args, err := cte.ToSql() + + assert.Equal(t, "cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql) + assert.Equal(t, []interface{}{1}, args) + assert.Nil(t, err) + +} + +func TestRecursiveCTE(t *testing.T) { + + // this isn't usually valid SQL, but the point is to test the RECURSIVE part + cte := CTE{ + Alias: "cte", + ColumnList: []string{"abc", "def"}, + Recursive: true, + Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}), + } + + sql, args, err := cte.ToSql() + + assert.Equal(t, "RECURSIVE cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql) + assert.Equal(t, []interface{}{1}, args) + assert.Nil(t, err) + +} diff --git a/select.go b/select.go index d55ce4c7..f217d062 100644 --- a/select.go +++ b/select.go @@ -13,6 +13,9 @@ type selectData struct { PlaceholderFormat PlaceholderFormat RunWith BaseRunner Prefixes []Sqlizer + CTEs []Sqlizer + Union Sqlizer + UnionAll Sqlizer Options []string Columns []Sqlizer From Sqlizer @@ -78,6 +81,15 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { sql.WriteString(" ") } + if len(d.CTEs) > 0 { + sql.WriteString("WITH ") + args, err = appendToSql(d.CTEs, sql, ", ", args) + if err != nil { + return + } + sql.WriteString(" ") + } + sql.WriteString("SELECT ") if len(d.Options) > 0 { @@ -116,6 +128,22 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { } } + if d.Union != nil { + sql.WriteString(" UNION ") + args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args) + if err != nil { + return + } + } + + if d.UnionAll != nil { + sql.WriteString(" UNION ALL ") + args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args) + if err != nil { + return + } + } + if len(d.GroupBys) > 0 { sql.WriteString(" GROUP BY ") sql.WriteString(strings.Join(d.GroupBys, ", ")) @@ -253,6 +281,22 @@ func (b SelectBuilder) Options(options ...string) SelectBuilder { return builder.Extend(b, "Options", options).(SelectBuilder) } +// With adds a non-recursive CTE to the query. +func (b SelectBuilder) With(alias string, expr Sqlizer) SelectBuilder { + return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: false, Expression: expr}) +} + +// WithRecursive adds a recursive CTE to the query. +func (b SelectBuilder) WithRecursive(alias string, expr Sqlizer) SelectBuilder { + return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: true, Expression: expr}) +} + +// WithCTE adds an arbitrary Sqlizer to the query. +// The sqlizer will be sandwiched between the keyword WITH and, if there's more than one CTE, a comma. +func (b SelectBuilder) WithCTE(cte Sqlizer) SelectBuilder { + return builder.Append(b, "CTEs", cte).(SelectBuilder) +} + // Columns adds result columns to the query. func (b SelectBuilder) Columns(columns ...string) SelectBuilder { parts := make([]interface{}, 0, len(columns)) @@ -289,6 +333,20 @@ func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilde return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder) } +// UnionSelect sets a union SelectBuilder which removes duplicate rows +// --> UNION combines the result from multiple SELECT statements into a single result set +func (b SelectBuilder) UnionSelect(union SelectBuilder) SelectBuilder { + union = union.PlaceholderFormat(Question) + return builder.Set(b, "Union", union).(SelectBuilder) +} + +// UnionAllSelect sets a union SelectBuilder which includes all matching rows +// --> UNION combines the result from multiple SELECT statements into a single result set +func (b SelectBuilder) UnionAllSelect(union SelectBuilder) SelectBuilder { + union = union.PlaceholderFormat(Question) + return builder.Set(b, "UnionAll", union).(SelectBuilder) +} + // JoinClause adds a join clause to the query. func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder { return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder) @@ -319,6 +377,16 @@ func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder return b.JoinClause("CROSS JOIN "+join, rest...) } +// Union adds UNION to the query. (duplicate rows are removed) +func (b SelectBuilder) Union(join string, rest ...interface{}) SelectBuilder { + return b.JoinClause("UNION "+join, rest...) +} + +// UnionAll adds UNION ALL to the query. (includes all matching rows) +func (b SelectBuilder) UnionAll(join string, rest ...interface{}) SelectBuilder { + return b.JoinClause("UNION ALL "+join, rest...) +} + // Where adds an expression to the WHERE clause of the query. // // Expressions are ANDed together in the generated SQL. diff --git a/select_test.go b/select_test.go index 80161bf5..67504721 100644 --- a/select_test.go +++ b/select_test.go @@ -279,6 +279,30 @@ func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) { assert.Equal(t, []interface{}{1, 2}, args) } +func TestOneCTE(t *testing.T) { + sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).ToSql() + + assert.NoError(t, err) + + assert.Equal(t, "WITH cte AS (SELECT abc FROM def) SELECT * FROM cte", sql) +} + +func TestTwoCTEs(t *testing.T) { + sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).With("cte2", Select("ghi").From("jkl")).ToSql() + + assert.NoError(t, err) + + assert.Equal(t, "WITH cte AS (SELECT abc FROM def), cte2 AS (SELECT ghi FROM jkl) SELECT * FROM cte", sql) +} + +func TestCTEErrorBubblesUp(t *testing.T) { + + // a SELECT with no columns raises an error + _, _, err := Select("*").From("cte").With("cte", SelectBuilder{}.From("def")).ToSql() + + assert.Error(t, err) +} + func TestSelectJoinClausePlaceholderNumbering(t *testing.T) { subquery := Select("a").Where(Eq{"b": 2}).PlaceholderFormat(Dollar) @@ -452,6 +476,45 @@ func ExampleSelectBuilder_ToSql() { } } +func TestSelectBuilderUnionToSql(t *testing.T) { + multi := Select("column1", "column2"). + From("table1"). + Where(Eq{"column1": "test"}). + UnionSelect(Select("column3", "column4").From("table2").Where(Lt{"column4": 5}). + UnionSelect(Select("column5", "column6").From("table3").Where(LtOrEq{"column5": 6}))) + sql, args, err := multi.ToSql() + assert.NoError(t, err) + + expectedSql := `SELECT column1, column2 FROM table1 WHERE column1 = ? ` + + "UNION SELECT column3, column4 FROM table2 WHERE column4 < ? " + + "UNION SELECT column5, column6 FROM table3 WHERE column5 <= ?" + assert.Equal(t, expectedSql, sql) + + expectedArgs := []interface{}{"test", 5, 6} + assert.Equal(t, expectedArgs, args) + + sql, _, err = multi.PlaceholderFormat(Dollar).ToSql() + assert.NoError(t, err) + expectedSql = `SELECT column1, column2 FROM table1 WHERE column1 = $1 ` + + "UNION SELECT column3, column4 FROM table2 WHERE column4 < $2 " + + "UNION SELECT column5, column6 FROM table3 WHERE column5 <= $3" + assert.Equal(t, expectedSql, sql) + + unionAll := Select("count(true) as C"). + From("table1"). + Where(Eq{"column1": []string{"test", "tester"}}). + UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where("id=table2.column3"))) + sql, args, err = unionAll.ToSql() + assert.NoError(t, err) + + expectedSql = `SELECT count(true) as C FROM table1 WHERE column1 IN (?,?) ` + + "UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id=table2.column3 )" + assert.Equal(t, expectedSql, sql) + + expectedArgs = []interface{}{"test", "tester"} + assert.Equal(t, expectedArgs, args) +} + func TestRemoveColumns(t *testing.T) { query := Select("id"). From("users").