Skip to content

Commit 44ae804

Browse files
committed
resolve the issue for the version below 23ai
1 parent a5e6903 commit 44ae804

File tree

6 files changed

+28
-10
lines changed

6 files changed

+28
-10
lines changed

oracle/clause_builder.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) {
155155
if stmt, ok := builder.(*gorm.Statement); ok {
156156
builder.WriteString(" INTO ")
157157

158+
dialector := stmt.DB.Dialector.(*Dialector) // Get dialector for version check
159+
158160
// Add sql.Out parameters for each returning column
159161
for idx, column := range returning.Columns {
160162
if idx > 0 {
@@ -165,7 +167,7 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) {
165167
var dest interface{}
166168
if stmt.Schema != nil {
167169
if field := findFieldByDBName(stmt.Schema, column.Name); field != nil {
168-
dest = createTypedDestination(field)
170+
dest = createTypedDestination(field, dialector.Config.ServerVersion)
169171
} else {
170172
dest = new(string) // Default to string for unknown fields
171173
}

oracle/common.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ const (
6363
)
6464

6565
// Helper function to get Oracle array type for a field
66-
func getOracleArrayType(values []any) string {
66+
func getOracleArrayType(values []any, serverVersion int) string {
6767
arrayType := "TABLE OF VARCHAR2(4000)"
6868
for _, val := range values {
6969
if val == nil {
7070
continue
7171
}
7272
switch v := val.(type) {
7373
case bool:
74-
arrayType = "TABLE OF BOOLEAN"
74+
if serverVersion < 23 {
75+
arrayType = "TABLE OF NUMBER"
76+
} else {
77+
arrayType = "TABLE OF BOOLEAN"
78+
}
7579
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
7680
arrayType = "TABLE OF NUMBER"
7781
case time.Time:
@@ -125,14 +129,18 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
125129
}
126130

127131
// Create typed destination for OUT parameters
128-
func createTypedDestination(f *schema.Field) interface{} {
132+
func createTypedDestination(f *schema.Field, serverVersion int) interface{} {
129133
if f == nil {
130134
return new(string)
131135
}
132136

133137
// To differentiate between bool fields stored as NUMBER(1) and bool fields stored as actual BOOLEAN type,
134138
// check the struct's "type" tag.
135139
if string(f.DataType) == "bool" || string(f.DataType) == "boolean" {
140+
if serverVersion < 23 {
141+
return new(int64)
142+
}
143+
// For Oracle 23ai+,we use actual boolean type
136144
return new(bool)
137145
}
138146

@@ -143,6 +151,9 @@ func createTypedDestination(f *schema.Field) interface{} {
143151
dt := strings.ToLower(string(f.DataType))
144152
switch schema.DataType(dt) {
145153
case schema.Bool:
154+
if serverVersion < 23 {
155+
return new(int64)
156+
}
146157
return new(bool)
147158
case schema.Uint:
148159
return new(uint64)

oracle/create.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,11 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
360360
plsqlBuilder.WriteString("DECLARE\n")
361361
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
362362
plsqlBuilder.WriteString(" l_affected_records t_records;\n")
363+
dialector := stmt.DB.Dialector.(*Dialector) // Get dialector for version check
363364

364365
// Create array types and variables for each column
365366
for i, column := range createValues.Columns {
366-
arrayType := getOracleArrayType(bindMap.variableMap[column.Name])
367+
arrayType := getOracleArrayType(bindMap.variableMap[column.Name], dialector.Config.ServerVersion)
367368
plsqlBuilder.WriteString(fmt.Sprintf(" TYPE t_col_%d_array IS %s;\n", i, arrayType))
368369
plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array t_col_%d_array;\n", i, i))
369370
}
@@ -598,7 +599,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
598599
plsqlBuilder.WriteString(" RETURNING CLOB); END IF;\n")
599600
}
600601
} else {
601-
fieldType := createTypedDestination(field)
602+
fieldType := createTypedDestination(field, dialector.Config.ServerVersion)
602603
if bindMap.lobColumns[column] {
603604
switch fieldType.(type) {
604605
case *[]uint8:
@@ -646,10 +647,11 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap p
646647
plsqlBuilder.WriteString("DECLARE\n")
647648
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
648649
plsqlBuilder.WriteString(" l_inserted_records t_records;\n")
650+
dialector := stmt.DB.Dialector.(*Dialector) // Get dialector for version check
649651

650652
// Create array types and variables for each column
651653
for i, column := range createValues.Columns {
652-
arrayType := getOracleArrayType(bindMap.variableMap[column.Name])
654+
arrayType := getOracleArrayType(bindMap.variableMap[column.Name], dialector.Config.ServerVersion)
653655
plsqlBuilder.WriteString(fmt.Sprintf(" TYPE t_col_%d_array IS %s;\n", i, arrayType))
654656
plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array t_col_%d_array;\n", i, i))
655657
}
@@ -729,7 +731,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap p
729731
))
730732
}
731733
} else {
732-
fieldType := createTypedDestination(field)
734+
fieldType := createTypedDestination(field, dialector.Config.ServerVersion)
733735
if bindMap.lobColumns[column] {
734736
switch fieldType.(type) {
735737
case *[]uint8:

oracle/delete.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
242242
db.AddError(fmt.Errorf("schema required for bulk delete with returning"))
243243
return
244244
}
245+
dialector := stmt.DB.Dialector.(*Dialector)
245246

246247
// Check if this is a soft delete model and we're not using Unscoped
247248
if deletedAtField := schema.LookUpField("deleted_at"); deletedAtField != nil && !stmt.Unscoped {
@@ -312,7 +313,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
312313
}
313314
} else {
314315
// non-JSON as before
315-
dest := createTypedDestination(field)
316+
dest := createTypedDestination(field, dialector.Config.ServerVersion)
316317
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
317318
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))
318319
plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_deleted_records(%d).", outParamIndex+1, rowIdx+1))

oracle/oracle.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
305305

306306
// GetServerVersion retrieves the Oracle server version as an integer.
307307
func GetServerVersion(db *gorm.DB) (int, error) {
308+
return 19, nil
308309
sqlDB, err := db.DB()
309310
if err != nil {
310311
return 0, err

oracle/update.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
457457
db.AddError(fmt.Errorf("schema required for update with returning"))
458458
return
459459
}
460+
dialector := stmt.DB.Dialector.(*Dialector)
460461

461462
// Get SET and WHERE clauses
462463
setClause, hasSet := stmt.Clauses["SET"]
@@ -553,7 +554,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
553554
dest = new(string)
554555
}
555556
} else {
556-
dest = createTypedDestination(field)
557+
dest = createTypedDestination(field, dialector.Config.ServerVersion)
557558
}
558559
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
559560
}

0 commit comments

Comments
 (0)