Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 109 additions & 11 deletions clickhouse_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error)
}
}

if o.opt.Addr == nil || len(o.opt.Addr) == 0 {
if len(o.opt.Addr) == 0 {
return nil, ErrAcquireConnNoAddress
}

Expand Down Expand Up @@ -342,19 +342,31 @@ func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.
return nil, driver.ErrBadConn
}

batch, err := std.conn.prepareBatch(ctx, func(nativeTransport, error) {}, func(context.Context) (nativeTransport, error) { return nil, nil }, query, chdriver.PrepareBatchOptions{})
if err != nil {
if isConnBrokenError(err) {
std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err)
return nil, driver.ErrBadConn
// Detect INSERT to decide between batch mode and read/exec prepared stmt
// We keep the heuristic simple and case-insensitive: leading non-space must start with INSERT
trimmed := strings.TrimLeft(query, " \t\n\r")
if len(trimmed) >= 6 && strings.EqualFold(trimmed[:6], "insert") {
batch, err := std.conn.prepareBatch(ctx, func(nativeTransport, error) {}, func(context.Context) (nativeTransport, error) { return nil, nil }, query, chdriver.PrepareBatchOptions{})
if err != nil {
if isConnBrokenError(err) {
std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err)
return nil, driver.ErrBadConn
}
std.debugf("PrepareContext error: %v\n", err)
return nil, err
}
std.debugf("PrepareContext error: %v\n", err)
return nil, err
std.commit = batch.Send
return &stdBatch{
batch: batch,
debugf: std.debugf,
}, nil
}
std.commit = batch.Send
return &stdBatch{
batch: batch,

// For non-INSERT, return stdStmt that supports QueryContext and ExecContext
return &stdStmt{
query: query,
debugf: std.debugf,
conn: std.conn,
}, nil
}

Expand Down Expand Up @@ -405,6 +417,92 @@ func (s *stdBatch) Query(args []driver.Value) (driver.Rows, error) {

func (s *stdBatch) Close() error { return nil }

// stdStmt supports prepared statements for non-INSERT queries using the same connection.
type stdStmt struct {
query string
debugf func(format string, v ...any)
conn stdConnect
}

func (s *stdStmt) NumInput() int { return -1 }

// Exec executes non-INSERT statements prepared via stdStmt.
func (s *stdStmt) Exec(args []driver.Value) (driver.Result, error) {
values := make([]any, 0, len(args))
for _, v := range args {
values = append(values, v)
}
if s.conn.isBad() {
s.debugf("[stmt][exec] connection is bad")
return nil, driver.ErrBadConn
}
if err := s.conn.exec(context.Background(), s.query, values...); err != nil {
if isConnBrokenError(err) {
s.debugf("[stmt][exec] fatal error, resetting connection: %v", err)
return nil, driver.ErrBadConn
}
s.debugf("[stmt][exec] error: %v", err)
return nil, err
}
return driver.RowsAffected(0), nil
}

func (s *stdStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s.conn.isBad() {
s.debugf("[stmt][execctx] connection is bad")
return nil, driver.ErrBadConn
}
if err := s.conn.exec(ctx, s.query, rebind(args)...); err != nil {
if isConnBrokenError(err) {
s.debugf("[stmt][execctx] fatal error, resetting connection: %v", err)
return nil, driver.ErrBadConn
}
s.debugf("[stmt][execctx] error: %v", err)
return nil, err
}
return driver.RowsAffected(0), nil
}

func (s *stdStmt) Query(args []driver.Value) (driver.Rows, error) {
values := make([]any, 0, len(args))
for _, v := range args {
values = append(values, v)
}
if s.conn.isBad() {
s.debugf("[stmt][query] connection is bad")
return nil, driver.ErrBadConn
}
r, err := s.conn.query(context.Background(), func(nativeTransport, error) {}, s.query, values...)
if isConnBrokenError(err) {
s.debugf("[stmt][query] fatal error, resetting connection: %v", err)
return nil, driver.ErrBadConn
}
if err != nil {
s.debugf("[stmt][query] error: %v", err)
return nil, err
}
return &stdRows{rows: r, debugf: s.debugf}, nil
}

func (s *stdStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s.conn.isBad() {
s.debugf("[stmt][queryctx] connection is bad")
return nil, driver.ErrBadConn
}
r, err := s.conn.query(ctx, func(nativeTransport, error) {}, s.query, rebind(args)...)
if isConnBrokenError(err) {
s.debugf("[stmt][queryctx] fatal error, resetting connection: %v", err)
return nil, driver.ErrBadConn
}
if err != nil {
s.debugf("[stmt][queryctx] error: %v", err)
return nil, err
}
return &stdRows{rows: r, debugf: s.debugf}, nil
}

func (s *stdStmt) Close() error { return nil }

type stdRows struct {
rows *rows
debugf func(format string, v ...any)
Expand Down
4 changes: 4 additions & 0 deletions examples/std/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,7 @@ func TestJSONStringExample(t *testing.T) {
clickhouse_tests.SkipOnCloud(t, "cannot modify JSON settings on cloud")
require.NoError(t, JSONStringExample())
}

func TestPreparedSelectExample(t *testing.T) {
require.NoError(t, PreparedSelect())
}
64 changes: 64 additions & 0 deletions examples/std/prepared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to ClickHouse, Inc. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. ClickHouse, Inc. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package std

import (
"context"
"database/sql"
"fmt"

"github.com/ClickHouse/clickhouse-go/v2"
)

// PreparedSelect demonstrates using database/sql prepared statements for read queries.
func PreparedSelect() error {
conn, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil)
if err != nil {
return err
}
defer func(db *sql.DB) { _ = db.Close() }(conn)

ctx := context.Background()
if err := conn.PingContext(ctx); err != nil {
return err
}

stmt, err := conn.PrepareContext(ctx, "SELECT ? + ?")
if err != nil {
return err
}
defer func() { _ = stmt.Close() }()

rows, err := stmt.QueryContext(ctx, 2, 3)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()

if !rows.Next() {
return fmt.Errorf("no rows returned from prepared SELECT")
}
var sum int64
if err := rows.Scan(&sum); err != nil {
return err
}
if sum != 5 {
return fmt.Errorf("unexpected result from prepared SELECT: got %d, want 5", sum)
}
return rows.Err()
}
97 changes: 97 additions & 0 deletions tests/std/prepared_stmt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed to ClickHouse, Inc. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. ClickHouse, Inc. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package std

import (
"context"
"testing"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/stretchr/testify/require"
)

func TestStdPreparedSelect(t *testing.T) {
db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })

ctx := context.Background()
require.NoError(t, db.PingContext(ctx))

stmt, err := db.PrepareContext(ctx, "SELECT ? + ?")
require.NoError(t, err)
t.Cleanup(func() { _ = stmt.Close() })

rows, err := stmt.QueryContext(ctx, 10, 5)
require.NoError(t, err)
t.Cleanup(func() { _ = rows.Close() })

require.True(t, rows.Next())
var sum int64
require.NoError(t, rows.Scan(&sum))
require.EqualValues(t, 15, sum)
require.NoError(t, rows.Err())
}

// Test for prepared selects using both positional and named params.
func TestStdPreparedFunds(t *testing.T) {
db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })

ctx := context.Background()
require.NoError(t, db.PingContext(ctx))

_, _ = db.ExecContext(ctx, "DROP TABLE IF EXISTS std_prepared_funds")
_, err = db.ExecContext(ctx, `
CREATE TABLE std_prepared_funds (
symbol String,
name String
) Engine = Memory`)
require.NoError(t, err)
t.Cleanup(func() { _, _ = db.ExecContext(ctx, "DROP TABLE IF EXISTS std_prepared_funds") })

_, err = db.ExecContext(ctx, `INSERT INTO std_prepared_funds (symbol, name) VALUES ('abc', 'ABC Fund')`)
require.NoError(t, err)

// q1: positional placeholder
stmt1, err := db.PrepareContext(ctx, `SELECT name FROM std_prepared_funds WHERE symbol=? LIMIT 1`)
require.NoError(t, err)
t.Cleanup(func() { _ = stmt1.Close() })
rows1, err := stmt1.QueryContext(ctx, "abc")
require.NoError(t, err)
t.Cleanup(func() { _ = rows1.Close() })
require.True(t, rows1.Next())
var name1 string
require.NoError(t, rows1.Scan(&name1))
require.Equal(t, "ABC Fund", name1)
require.NoError(t, rows1.Err())

// q2: named query parameter
stmt2, err := db.PrepareContext(ctx, `SELECT name FROM std_prepared_funds WHERE symbol={symbol: String} LIMIT 1`)
require.NoError(t, err)
t.Cleanup(func() { _ = stmt2.Close() })
rows2, err := stmt2.QueryContext(ctx, clickhouse.Named("symbol", "abc"))
require.NoError(t, err)
t.Cleanup(func() { _ = rows2.Close() })
require.True(t, rows2.Next())
var name2 string
require.NoError(t, rows2.Scan(&name2))
require.Equal(t, "ABC Fund", name2)
require.NoError(t, rows2.Err())
}
54 changes: 54 additions & 0 deletions tests/std/prepared_stmt_use_db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to ClickHouse, Inc. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. ClickHouse, Inc. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package std

import (
"context"
"testing"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/stretchr/testify/require"
)

// Ensures we can execute a USE <db>; followed by a prepared SELECT.
func TestStdPreparedSelectWithUseDatabase(t *testing.T) {
db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })

ctx := context.Background()
require.NoError(t, db.PingContext(ctx))

// Explicit USE should work as Exec on connection
_, err = db.ExecContext(ctx, "USE default")
require.NoError(t, err)

stmt, err := db.PrepareContext(ctx, "SELECT ? + ?")
require.NoError(t, err)
t.Cleanup(func() { _ = stmt.Close() })

rows, err := stmt.QueryContext(ctx, 7, 8)
require.NoError(t, err)
t.Cleanup(func() { _ = rows.Close() })

require.True(t, rows.Next())
var sum int64
require.NoError(t, rows.Scan(&sum))
require.EqualValues(t, 15, sum)
require.NoError(t, rows.Err())
}
Loading