Skip to content

Commit ce8f070

Browse files
authored
Merge branch 'main' into redis-trim
2 parents 57ba387 + 397766a commit ce8f070

File tree

2 files changed

+222
-15
lines changed

2 files changed

+222
-15
lines changed

state/oracledatabase/oracledatabaseaccess.go

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ import (
2222
"fmt"
2323
"net/url"
2424
"strconv"
25+
"strings"
2526
"time"
2627

27-
"github.com/google/uuid"
28-
2928
"github.com/dapr/components-contrib/state"
3029
stateutils "github.com/dapr/components-contrib/state/utils"
3130
"github.com/dapr/kit/logger"
3231
"github.com/dapr/kit/metadata"
3332

34-
// Blank import for the underlying Oracle Database driver.
35-
_ "github.com/sijms/go-ora/v2"
33+
"github.com/google/uuid"
34+
goora "github.com/sijms/go-ora/v2"
3635
)
3736

3837
const (
@@ -78,23 +77,26 @@ func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
7877
// Init sets up OracleDatabase connection and ensures that the state table exists.
7978
func (o *oracleDatabaseAccess) Init(ctx context.Context, metadata state.Metadata) error {
8079
meta, err := parseMetadata(metadata.Properties)
81-
o.metadata = meta
8280
if err != nil {
8381
return err
8482
}
85-
if o.metadata.ConnectionString != "" {
86-
o.connectionString = meta.ConnectionString
87-
} else {
83+
84+
o.metadata = meta
85+
86+
if o.metadata.ConnectionString == "" {
8887
o.logger.Error("Missing Oracle Database connection string")
8988
return errors.New(errMissingConnectionString)
9089
}
91-
if o.metadata.OracleWalletLocation != "" {
92-
o.connectionString += "?TRACE FILE=trace.log&SSL=enable&SSL Verify=false&WALLET=" + url.QueryEscape(o.metadata.OracleWalletLocation)
90+
91+
o.connectionString, err = parseConnectionString(meta)
92+
if err != nil {
93+
o.logger.Error(err)
94+
return err
9395
}
96+
9497
db, err := sql.Open("oracle", o.connectionString)
9598
if err != nil {
9699
o.logger.Error(err)
97-
98100
return err
99101
}
100102

@@ -105,12 +107,62 @@ func (o *oracleDatabaseAccess) Init(ctx context.Context, metadata state.Metadata
105107
return err
106108
}
107109

108-
err = o.ensureStateTable(o.metadata.TableName)
110+
return o.ensureStateTable(o.metadata.TableName)
111+
}
112+
113+
func parseConnectionString(meta oracleDatabaseMetadata) (string, error) {
114+
username := ""
115+
password := ""
116+
host := ""
117+
port := 0
118+
serviceName := ""
119+
query := url.Values{}
120+
options := make(map[string]string)
121+
122+
connectionStringURL, err := url.Parse(meta.ConnectionString)
109123
if err != nil {
110-
return err
124+
return "", err
111125
}
112126

113-
return nil
127+
isURL := connectionStringURL.Scheme != "" && connectionStringURL.Host != ""
128+
if isURL {
129+
username = connectionStringURL.User.Username()
130+
password, _ = connectionStringURL.User.Password()
131+
query = connectionStringURL.Query()
132+
serviceName = strings.TrimPrefix(connectionStringURL.Path, "/")
133+
if strings.Contains(connectionStringURL.Host, ":") {
134+
host = strings.Split(connectionStringURL.Host, ":")[0]
135+
} else {
136+
host = connectionStringURL.Host
137+
}
138+
} else {
139+
host = connectionStringURL.Path
140+
}
141+
142+
if connectionStringURL.Port() != "" {
143+
port, err = strconv.Atoi(connectionStringURL.Port())
144+
if err != nil {
145+
return "", err
146+
}
147+
}
148+
149+
for k, v := range query {
150+
options[k] = v[0]
151+
}
152+
153+
if meta.OracleWalletLocation != "" {
154+
options["WALLET"] = meta.OracleWalletLocation
155+
options["TRACE FILE"] = "trace.log"
156+
options["SSL"] = "enable"
157+
options["SSL Verify"] = "false"
158+
}
159+
160+
if strings.Contains(host, "(DESCRIPTION") {
161+
// the connection string is a URL that contains the descriptor and authentication info
162+
return goora.BuildJDBC(username, password, host, options), nil
163+
} else {
164+
return goora.BuildUrl(host, port, serviceName, username, password, options), nil
165+
}
114166
}
115167

116168
// Set makes an insert or update to the database.
@@ -170,7 +222,7 @@ func (o *oracleDatabaseAccess) doSet(ctx context.Context, db querier, req *state
170222
if req.Options.Concurrency == state.FirstWrite {
171223
stmt = `INSERT INTO ` + o.metadata.TableName + `
172224
(key, value, binary_yn, etag, expiration_time)
173-
VALUES
225+
VALUES
174226
(:key, :value, :binary_yn, :etag, ` + ttlStatement + `) `
175227
} else {
176228
// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update.
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
16+
package oracledatabase
17+
18+
import (
19+
"net/url"
20+
"testing"
21+
22+
"github.com/dapr/components-contrib/metadata"
23+
"github.com/dapr/components-contrib/state"
24+
25+
"github.com/stretchr/testify/assert"
26+
"github.com/stretchr/testify/require"
27+
)
28+
29+
func TestConnectionString(t *testing.T) {
30+
tests := []struct {
31+
name string
32+
metadata map[string]string
33+
expectedConn string
34+
expectedError string
35+
withWallet bool
36+
walletLocation string
37+
}{
38+
{
39+
name: "Simple URL format",
40+
metadata: map[string]string{
41+
"connectionString": "oracle://system:pass@localhost:1521/FREEPDB1",
42+
},
43+
expectedConn: "oracle://system:pass@localhost:1521/FREEPDB1?",
44+
},
45+
{
46+
name: "Pure descriptor format",
47+
metadata: map[string]string{
48+
"connectionString": "(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=localhost)(PORT=1521))(CONNECT_DATA=(SERVICE_NAME=FREEPDB1)))",
49+
},
50+
expectedConn: "oracle://:@:0/?connStr=%28DESCRIPTION%3D%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Dlocalhost%29%28PORT%3D1521%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DFREEPDB1%29%29%29",
51+
},
52+
{
53+
name: "URL with descriptor format",
54+
metadata: map[string]string{
55+
"connectionString": "oracle://system:pass@(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=localhost)(PORT=1521))(CONNECT_DATA=(SERVICE_NAME=FREEPDB1)))",
56+
},
57+
expectedConn: "oracle://system:pass@:0/?connStr=%28DESCRIPTION%3D%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Dlocalhost%29%28PORT%3D1521%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DFREEPDB1%29%29%29",
58+
},
59+
{
60+
name: "Complex descriptor with load balancing and failover",
61+
metadata: map[string]string{
62+
"connectionString": "(DESCRIPTION=(CONNECT_TIMEOUT=30)(RETRY_COUNT=20)(RETRY_DELAY=3)(FAILOVER=ON)(LOAD_BALANCE=OFF)(ADDRESS_LIST=(LOAD_BALANCE=ON)(ADDRESS=(PROTOCOL=TCP)(HOST=db1.example.com)(PORT=1521)))(ADDRESS_LIST=(LOAD_BALANCE=ON)(ADDRESS=(PROTOCOL=TCP)(HOST=db2.example.com)(PORT=1521)))(CONNECT_DATA=(SERVICE_NAME=FREEPDB1_service)))",
63+
},
64+
expectedConn: "oracle://:@:0/?connStr=%28DESCRIPTION%3D%28CONNECT_TIMEOUT%3D30%29%28RETRY_COUNT%3D20%29%28RETRY_DELAY%3D3%29%28FAILOVER%3DON%29%28LOAD_BALANCE%3DOFF%29%28ADDRESS_LIST%3D%28LOAD_BALANCE%3DON%29%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Ddb1.example.com%29%28PORT%3D1521%29%29%29%28ADDRESS_LIST%3D%28LOAD_BALANCE%3DON%29%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Ddb2.example.com%29%28PORT%3D1521%29%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DFREEPDB1_service%29%29%29",
65+
},
66+
{
67+
name: "Simple URL with wallet",
68+
metadata: map[string]string{
69+
"connectionString": "oracle://system:pass@localhost:1521/service",
70+
"oracleWalletLocation": "/path/to/wallet",
71+
},
72+
withWallet: true,
73+
walletLocation: "/path/to/wallet",
74+
expectedConn: "oracle://system:pass@localhost:1521/service?WALLET=%2Fpath%2Fto%2Fwallet&TRACE FILE=trace.log&SSL=enable&SSL Verify=false",
75+
},
76+
{
77+
name: "Descriptor with wallet",
78+
metadata: map[string]string{
79+
"connectionString": "(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=test.example.com)(PORT=1521))(CONNECT_DATA=(SERVICE_NAME=FREEPDB1)))",
80+
"oracleWalletLocation": "/path/to/wallet",
81+
},
82+
withWallet: true,
83+
walletLocation: "/path/to/wallet",
84+
expectedConn: "oracle://:@:0/?WALLET=%2Fpath%2Fto%2Fwallet&TRACE FILE=trace.log&SSL=enable&SSL Verify=false&connStr=%28DESCRIPTION%3D%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Dtest.example.com%29%28PORT%3D1521%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DFREEPDB1%29%29%29",
85+
},
86+
{
87+
name: "URL with descriptor and existing parameters",
88+
metadata: map[string]string{
89+
"connectionString": "oracle://system:pass@(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=test.example.com)(PORT=1521))(CONNECT_DATA=(SERVICE_NAME=FREEPDB1)))?param1=value1",
90+
"oracleWalletLocation": "/path/to/wallet",
91+
},
92+
withWallet: true,
93+
walletLocation: "/path/to/wallet",
94+
expectedConn: "oracle://system:pass@:0/?param1=value1&TRACE FILE=trace.log&SSL=enable&SSL Verify=false&WALLET=%2Fpath%2Fto%2Fwallet&connStr=%28DESCRIPTION%3D%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Dtest.example.com%29%28PORT%3D1521%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DFREEPDB1%29%29%29",
95+
},
96+
{
97+
name: "Compressed descriptor format",
98+
metadata: map[string]string{
99+
"connectionString": "(DESCRIPTION=(CONNECT_TIMEOUT=90)(RETRY_COUNT=20)(RETRY_DELAY=3)(TRANSPORT_CONNECT_TIMEOUT=3)(ADDRESS=(PROTOCOL=TCP)(HOST=db.example.com)(PORT=1521))(CONNECT_DATA=(SERVICE_NAME=MYSERVICE)))",
100+
},
101+
expectedConn: "oracle://@:0/?connStr=%28DESCRIPTION%3D%28CONNECT_TIMEOUT%3D90%29%28RETRY_COUNT%3D20%29%28RETRY_DELAY%3D3%29%28TRANSPORT_CONNECT_TIMEOUT%3D3%29%28ADDRESS%3D%28PROTOCOL%3DTCP%29%28HOST%3Ddb.example.com%29%28PORT%3D1521%29%29%28CONNECT_DATA%3D%28SERVICE_NAME%3DMYSERVICE%29%29%29",
102+
},
103+
}
104+
105+
for _, tt := range tests {
106+
t.Run(tt.name, func(t *testing.T) {
107+
// Create metadata
108+
metadata := state.Metadata{
109+
Base: metadata.Base{
110+
Properties: tt.metadata,
111+
},
112+
}
113+
114+
meta, err := parseMetadata(metadata.Properties)
115+
require.NoError(t, err)
116+
117+
actualConnectionString, err := parseConnectionString(meta)
118+
require.NoError(t, err)
119+
120+
if tt.expectedError != "" {
121+
require.Error(t, err)
122+
assert.Contains(t, err.Error(), tt.expectedError)
123+
return
124+
} else {
125+
require.NoError(t, err)
126+
}
127+
128+
expectedURL, err := url.Parse(tt.expectedConn)
129+
require.NoError(t, err)
130+
actualURL, err := url.Parse(actualConnectionString)
131+
require.NoError(t, err)
132+
133+
assert.Equal(t, expectedURL.Scheme, actualURL.Scheme)
134+
assert.Equal(t, expectedURL.Host, actualURL.Host)
135+
assert.Equal(t, expectedURL.Path, actualURL.Path)
136+
assert.Equal(t, expectedURL.User.Username(), actualURL.User.Username())
137+
ep, _ := expectedURL.User.Password()
138+
ap, _ := actualURL.User.Password()
139+
assert.Equal(t, ep, ap)
140+
141+
query, err := url.ParseQuery(expectedURL.RawQuery)
142+
require.NoError(t, err)
143+
144+
for k, v := range query {
145+
assert.Equal(t, v, actualURL.Query()[k])
146+
}
147+
148+
if tt.withWallet {
149+
assert.Equal(t, tt.walletLocation, meta.OracleWalletLocation)
150+
assert.Contains(t, actualConnectionString, "WALLET=")
151+
assert.Contains(t, actualConnectionString, "SSL=enable")
152+
}
153+
})
154+
}
155+
}

0 commit comments

Comments
 (0)