Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.vscode
8 changes: 4 additions & 4 deletions oracle/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,19 +448,19 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) {
// - plsqlBuilder: The builder to write the PL/SQL code into.
// - dbNames: The slice containing the column names.
// - table: The table name
func writeTableRecordCollectionDecl(plsqlBuilder *strings.Builder, dbNames []string, table string) {
func writeTableRecordCollectionDecl(db *gorm.DB, plsqlBuilder *strings.Builder, dbNames []string, table string) {
// Declare a record where each element has the same structure as a row from the given table
plsqlBuilder.WriteString(" TYPE t_record IS RECORD (\n")
for i, field := range dbNames {
if i > 0 {
plsqlBuilder.WriteString(",\n")
}
plsqlBuilder.WriteString(" ")
writeQuotedIdentifier(plsqlBuilder, field)
db.QuoteTo(plsqlBuilder, field)
plsqlBuilder.WriteString(" ")
writeQuotedIdentifier(plsqlBuilder, table)
db.QuoteTo(plsqlBuilder, table)
plsqlBuilder.WriteString(".")
writeQuotedIdentifier(plsqlBuilder, field)
db.QuoteTo(plsqlBuilder, field)
plsqlBuilder.WriteString("%TYPE")
}
plsqlBuilder.WriteString("\n")
Expand Down
46 changes: 23 additions & 23 deletions oracle/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_affected_records t_records;\n")

// Create array types and variables for each column
Expand Down Expand Up @@ -323,7 +323,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
// FORALL with MERGE and RETURNING BULK COLLECT INTO
plsqlBuilder.WriteString(fmt.Sprintf(" FORALL i IN 1..%d\n", len(createValues.Values)))
plsqlBuilder.WriteString(" MERGE INTO ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" t\n")
// Build USING clause
plsqlBuilder.WriteString(" USING (SELECT ")
Expand All @@ -332,7 +332,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString(fmt.Sprintf("l_col_%d_array(i) AS ", idx))
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
}
plsqlBuilder.WriteString(" FROM DUAL) s\n")

Expand All @@ -344,9 +344,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(" AND ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, conflictCol.Name)
db.QuoteTo(&plsqlBuilder, conflictCol.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, conflictCol.Name)
db.QuoteTo(&plsqlBuilder, conflictCol.Name)
}
plsqlBuilder.WriteString(")\n")

Expand All @@ -371,9 +371,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
updateCount++
}
}
Expand Down Expand Up @@ -405,9 +405,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
updateCount++
}
}
Expand All @@ -427,9 +427,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
}
}
plsqlBuilder.WriteString(" WHEN MATCHED THEN UPDATE SET t.")
writeQuotedIdentifier(&plsqlBuilder, noopCol)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, noopCol)
db.QuoteTo(&plsqlBuilder, noopCol)
plsqlBuilder.WriteString(" = t.")
db.QuoteTo(&plsqlBuilder, noopCol)
plsqlBuilder.WriteString("\n")
}

Expand All @@ -444,7 +444,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if insertCount > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -459,7 +459,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -475,7 +475,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if insertCount > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -489,7 +489,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -503,7 +503,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_affected_records;\n")

Expand All @@ -514,7 +514,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if field := findFieldByDBName(schema, column); field != nil {
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString("; END IF;\n")
outParamIndex++
}
Expand Down Expand Up @@ -548,7 +548,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_inserted_records t_records;\n")

// Create array types and variables for each column
Expand Down Expand Up @@ -582,14 +582,14 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
// FORALL with RETURNING BULK COLLECT INTO
plsqlBuilder.WriteString(fmt.Sprintf(" FORALL i IN 1..%d\n", len(createValues.Values)))
plsqlBuilder.WriteString(" INSERT INTO ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" (")
// Add column names
for i, column := range createValues.Columns {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
}
plsqlBuilder.WriteString(") VALUES (")

Expand All @@ -609,7 +609,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_inserted_records;\n")

Expand All @@ -618,7 +618,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
for _, column := range allColumns {
var columnBuilder strings.Builder
writeQuotedIdentifier(&columnBuilder, column)
db.QuoteTo(&columnBuilder, column)
quotedColumn := columnBuilder.String()

if field := findFieldByDBName(schema, column); field != nil {
Expand Down
16 changes: 8 additions & 8 deletions oracle/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,13 @@ func buildBulkDeletePLSQL(db *gorm.DB) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_deleted_records t_records;\n")
plsqlBuilder.WriteString("BEGIN\n")

// Build DELETE statement
plsqlBuilder.WriteString(" DELETE FROM ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)

// Add WHERE clause if it exists
if whereClause, hasWhere := stmt.Clauses["WHERE"]; hasWhere {
Expand All @@ -278,7 +278,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)

}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_deleted_records;\n")
Expand All @@ -297,7 +297,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {

plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))
plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_deleted_records(%d).", outParamIndex+1, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString(";\n")
plsqlBuilder.WriteString(" END IF;\n")
outParamIndex++
Expand All @@ -324,9 +324,9 @@ func buildWhereClause(db *gorm.DB, plsqlBuilder *strings.Builder, expressions []
case clause.Eq:
// Write the column name
if columnName, ok := e.Column.(string); ok {
writeQuotedIdentifier(plsqlBuilder, columnName)
db.QuoteTo(plsqlBuilder, columnName)
} else if columnExpr, ok := e.Column.(clause.Column); ok {
writeQuotedIdentifier(plsqlBuilder, columnExpr.Name)
db.QuoteTo(plsqlBuilder, columnExpr.Name)
} else {
plsqlBuilder.WriteString(fmt.Sprintf("%v", e.Column))
}
Expand All @@ -342,9 +342,9 @@ func buildWhereClause(db *gorm.DB, plsqlBuilder *strings.Builder, expressions []

case clause.IN:
if columnName, ok := e.Column.(string); ok {
writeQuotedIdentifier(plsqlBuilder, columnName)
db.QuoteTo(plsqlBuilder, columnName)
} else if columnExpr, ok := e.Column.(clause.Column); ok {
writeQuotedIdentifier(plsqlBuilder, columnExpr.Name)
db.QuoteTo(plsqlBuilder, columnExpr.Name)
} else {
plsqlBuilder.WriteString(fmt.Sprintf("%v", e.Column))
}
Expand Down
27 changes: 19 additions & 8 deletions oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ import (
_ "github.com/godror/godror"
)

const DefaultDriverName string = "godror"

type Config struct {
DriverName string
DataSourceName string
Conn *sql.DB
DefaultStringSize uint
DriverName string
DataSourceName string
Conn *sql.DB
DefaultStringSize uint
SkipQuoteIdentifiers bool
}

type Dialector struct {
Expand All @@ -79,7 +82,7 @@ func (d Dialector) Name() string {

// Open creates a new godror Dialector with the given DSN
func Open(dsn string) gorm.Dialector {
return &Dialector{Config: &Config{DriverName: "godror", DataSourceName: dsn}}
return &Dialector{Config: &Config{DataSourceName: dsn}}
}

// New creates a new Dialector with the given config
Expand All @@ -89,6 +92,10 @@ func New(config Config) gorm.Dialector {

// Initializes the database connection
func (d Dialector) Initialize(db *gorm.DB) (err error) {
if d.DriverName == "" {
d.DriverName = DefaultDriverName
}

d.DefaultStringSize = 4000

config := &callbacks.Config{
Expand Down Expand Up @@ -237,9 +244,13 @@ func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v inter

// Manages quoting of identifiers
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
var builder strings.Builder
writeQuotedIdentifier(&builder, str)
writer.WriteString(builder.String())
out := str
if !d.SkipQuoteIdentifiers {
var builder strings.Builder
writeQuotedIdentifier(&builder, str)
out = builder.String()
}
_, _ = writer.WriteString(out)
}

var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
Expand Down
10 changes: 5 additions & 5 deletions oracle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,21 +476,21 @@ func buildUpdatePLSQL(db *gorm.DB) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_updated_records t_records;\n")
plsqlBuilder.WriteString("BEGIN\n")

// Build UPDATE statement
plsqlBuilder.WriteString(" UPDATE ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" SET ")

// Add SET assignments - handle both regular values and expressions
for i, assignment := range set {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, assignment.Column.Name)
db.QuoteTo(&plsqlBuilder, assignment.Column.Name)
plsqlBuilder.WriteString(" = ")

// Check if the value is a clause.Expr (like gorm.Expr)
Expand Down Expand Up @@ -528,7 +528,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_updated_records;\n")

Expand Down Expand Up @@ -559,7 +559,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
// Add the assignment to PL/SQL with correct parameter reference
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_updated_records.COUNT > %d THEN\n", rowIdx))
plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_updated_records(%d).", paramIndex, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString(";\n")
plsqlBuilder.WriteString(" END IF;\n")
}
Expand Down
1 change: 0 additions & 1 deletion tests/.gitignore

This file was deleted.

Loading
Loading