diff --git a/java/connector-node/risingwave-jdbc-runner/pom.xml b/java/connector-node/risingwave-jdbc-runner/pom.xml index 1e421957b1e69..a48e596462198 100644 --- a/java/connector-node/risingwave-jdbc-runner/pom.xml +++ b/java/connector-node/risingwave-jdbc-runner/pom.xml @@ -43,11 +43,21 @@ redshift-jdbc42 2.1.0.33 + + com.risingwave + risingwave-sink-jdbc + 0.1.0-SNAPSHOT + org.apache.logging.log4j log4j-slf4j2-impl test + + junit + junit + test + diff --git a/java/connector-node/risingwave-jdbc-runner/src/main/java/com/risingwave/runner/JDBCSqlRunner.java b/java/connector-node/risingwave-jdbc-runner/src/main/java/com/risingwave/runner/JDBCSqlRunner.java index b4337d44ba8d7..4848130efcaa5 100644 --- a/java/connector-node/risingwave-jdbc-runner/src/main/java/com/risingwave/runner/JDBCSqlRunner.java +++ b/java/connector-node/risingwave-jdbc-runner/src/main/java/com/risingwave/runner/JDBCSqlRunner.java @@ -16,18 +16,40 @@ package com.risingwave.runner; -import java.sql.*; +import com.risingwave.connector.SnowflakeJDBCSinkConfig; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class JDBCSqlRunner { private static final Logger LOG = LoggerFactory.getLogger(JDBCSqlRunner.class); - public static void executeSql(String fullUrl, String[] sqls) throws Exception { + public static void executeSqlWithProps( + String fullUrl, String[] sqls, String[] propKeys, String[] propValues) + throws Exception { Connection connection = null; try { Class.forName("net.snowflake.client.jdbc.SnowflakeDriver"); - connection = DriverManager.getConnection(fullUrl); + Properties props = new Properties(); + if (propKeys != null && propValues != null) { + if (propKeys.length != propValues.length) { + throw new IllegalArgumentException( + "Property keys and values arrays must have the same length"); + } + for (int i = 0; i < propKeys.length; i++) { + if (propKeys[i] != null && propValues[i] != null) { + props.put(propKeys[i], propValues[i]); + } + } + } + + SnowflakeJDBCSinkConfig.handleSnowflakeAuth(props); + + connection = DriverManager.getConnection(fullUrl, props); connection.setAutoCommit(false); LOG.info("[JDBCRunner] Transaction started, auto-commit disabled"); Statement stmt = connection.createStatement(); diff --git a/java/connector-node/risingwave-jdbc-runner/src/test/java/com/risingwave/runner/JDBCSqlRunnerTest.java b/java/connector-node/risingwave-jdbc-runner/src/test/java/com/risingwave/runner/JDBCSqlRunnerTest.java new file mode 100644 index 0000000000000..439b156318bb8 --- /dev/null +++ b/java/connector-node/risingwave-jdbc-runner/src/test/java/com/risingwave/runner/JDBCSqlRunnerTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025 RisingWave Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.risingwave.runner; + +import static org.junit.Assert.assertNotNull; + +import com.risingwave.connector.SnowflakeJDBCSinkConfig; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import org.junit.Test; + +public class JDBCSqlRunnerTest { + + private String loadTestPem() throws IOException { + try (InputStream is = + getClass().getClassLoader().getResourceAsStream("test-private-key.pem")) { + if (is == null) { + throw new IOException("Test PEM file not found in resources"); + } + return new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + } + + @Test + public void loadPrivateKeyFromPem_unencrypted() throws Exception { + String testPem = loadTestPem(); + PrivateKey key = SnowflakeJDBCSinkConfig.loadPrivateKeyFromPem(testPem, null); + assertNotNull(key); + } +} diff --git a/java/connector-node/risingwave-jdbc-runner/src/test/resources/test-private-key.pem b/java/connector-node/risingwave-jdbc-runner/src/test/resources/test-private-key.pem new file mode 100644 index 0000000000000..f3f0754de2553 --- /dev/null +++ b/java/connector-node/risingwave-jdbc-runner/src/test/resources/test-private-key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDh6PSnttDsv+vi +tUZTP1E3hVBah6PUGDWZhYgNiyW8quTWCmPvBmCR2YzuhUrY5+CtKP8UJOQico+p +oJHSAPsrzSr6YsGs3c9SQOslBmm9Fkh9/f/GZVTVZ6u5AsUmOcVvZ2q7Sz8Vj/aR +aIm0EJqRe9cQ5vvN9sg25rIv4xKwIZJ1VixKWJLmpCmDINqn7xvl+ldlUmSr3aGt +w21uSDuEJhQlzO3yf2FwJMkJ9SkCm9oVDXyl77OnKXj5bOQ/rojbyGeIxDJSUDWE +GKyRPuqKi6rSbwg6h2G/Z9qBJkqM5NNTbGRIFz/9/LdmmwvtaqCxlLtD7RVEryAp ++qTGDk5hAgMBAAECggEBAMYYfNDEYpf4A2SdCLne/9zrrfZ0kphdUkL48MDPj5vN +TzTRj6f9s5ixZ/+QKn3hdwbguCx13QbH5mocP0IjUhyqoFFHYAWxyyaZfpjM8tO4 +QoEYxby3BpjLe62UXESUzChQSytJZFwIDXKcdIPNO3zvVzufEJcfG5no2b9cIvsG +Dy6J1FNILWxCtDIqBM+G1B1is9DhZnUDgn0iKzINiZmh1I1l7k/4tMnozVIKAfwo +f1kYjG/d2IzDM02mTeTElz3IKeNriaOIYTZgI26xLJxTkiFnBV4JOWFAZw15X+yR ++DrjGSIkTfhzbLa20Vt3AFM+LFK0ZoXT2dRnjbYPjQECgYEA+9XJFGwLcEX6pl1p +IwXAjXKJdju9DDn4lmHTW0Pbw25h1EXONwm/NPafwsWmPll9kW9IwsxUQVUyBC9a +c3Q7rF1e8ai/qqVFRIZof275MI82ciV2Mw8Hz7FPAUyoju5CvnjAEH4+irt1VE/7 +SgdvQ1gDBQFegS69ijdz+cOhFxkCgYEA5aVoseMy/gIlsCvNPyw9+Jz/zBpKItX0 +jGzdF7lhERRO2cursujKaoHntRckHcE3P/Z4K565bvVq+VaVG0T/BcBKPmPHrLmY +iuVXidltW7Jh9/RCVwb5+BvqlwlC470PEwhqoUatY/fPJ74srztrqJHvp1L29FT5 +sdmlJW8YwokCgYAUa3dMgp5C0knKp5RY1KSSU5E11w4zKZgwiWob4lq1dAPWtHpO +GCo63yyBHImoUJVP75gUw4Cpc4EEudo5tlkIVuHV8nroGVKOhd9/Rb5K47Hke4kk +Brn5a0Use9qPDF65Fw1ryPDFSwHufjXAAO5SpZZJF51UGDgiNvDedbBgMQKBgHSk +t7DjPhtW69234eCckD2fQS5ijBV1p2lMQmCygGM0dXiawvN02puOsCqDPoz+fxm2 +DwPY80cw0M0k9UeMnBxHt25JMDrDan/iTbxu++T/jlNrdebOXFlxlI5y3c7fULDS +LZcNVzTXwhjlt7yp6d0NgzTyJw2ju9BiREfnTiRBAoGBAOPHrTOnPyjO+bVcCPTB +WGLsbBd77mVPGIuL0XGrvbVYPE8yIcNbZcthd8VXL/38Ygy8SIZh2ZqsrU1b5WFa +XUMLnGEODSS8x/GmW3i3KeirW5OxBNjfUzEF4XkJP8m41iTdsQEXQf9DdUY7X+CB +VL5h7N0VstYhGgycuPpcIUQa +-----END PRIVATE KEY----- diff --git a/java/connector-node/risingwave-sink-jdbc/pom.xml b/java/connector-node/risingwave-sink-jdbc/pom.xml index 222c792dfae8c..5ebddb5b6d161 100644 --- a/java/connector-node/risingwave-sink-jdbc/pom.xml +++ b/java/connector-node/risingwave-sink-jdbc/pom.xml @@ -53,6 +53,16 @@ 3.23.0 + + org.bouncycastle + bcprov-jdk18on + 1.78 + + + org.bouncycastle + bcpkix-jdk18on + 1.78 + diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/BatchAppendOnlyJDBCSink.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/BatchAppendOnlyJDBCSink.java index f4b8519efaf55..cbfa2beb8bcd9 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/BatchAppendOnlyJDBCSink.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/BatchAppendOnlyJDBCSink.java @@ -50,13 +50,7 @@ public BatchAppendOnlyJDBCSink(JDBCSinkConfig config, TableSchema tableSchema) { var factory = JdbcUtils.getDialectFactory(jdbcUrl); this.config = config; try { - conn = - JdbcUtils.getConnection( - config.getJdbcUrl(), - config.getUser(), - config.getPassword(), - config.isAutoCommit(), - config.getBatchInsertRows()); + conn = config.getConnection(); // column name -> java.sql.Types Map columnTypeMapping = getColumnTypeMapping(conn, config.getTableName(), config.getSchemaName()); diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java index 299f20737351b..26759ff3a066e 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java @@ -57,13 +57,7 @@ public JDBCSink(JDBCSinkConfig config, TableSchema tableSchema) { this.config = config; try { - conn = - JdbcUtils.getConnection( - config.getJdbcUrl(), - config.getUser(), - config.getPassword(), - config.isAutoCommit(), - DUMMY_BATCH_INSERT_ROWS); + conn = config.getConnection(); // Table schema has been validated before, so we get the PK from it directly this.pkColumnNames = tableSchema.getPrimaryKeys(); // column name -> java.sql.Types @@ -195,13 +189,7 @@ public boolean write(Iterable rows) { conn.close(); // create a new connection if the current connection is invalid - conn = - JdbcUtils.getConnection( - config.getJdbcUrl(), - config.getUser(), - config.getPassword(), - config.isAutoCommit(), - DUMMY_BATCH_INSERT_ROWS); + conn = config.getConnection(); // reset the flag since we will retry to prepare the batch again updateFlag = false; jdbcStatements = new JdbcStatements(conn, config.getQueryTimeout()); diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkConfig.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkConfig.java index 405a4778e49d8..f4a97223b8fbc 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkConfig.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkConfig.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.risingwave.connector.api.sink.CommonSinkConfig; +import java.sql.Connection; +import java.sql.SQLException; public class JDBCSinkConfig extends CommonSinkConfig { private String jdbcUrl; @@ -103,4 +105,16 @@ public String getDatabaseName() { public int getBatchInsertRows() { return batchInsertRows; } + + /** + * Creates a JDBC connection based on this configuration. Subclasses can override this method to + * provide specialized connection logic. The connection returned by this method is *not* + * autoCommit by default. + * + * @return JDBC connection + * @throws SQLException if connection fails + */ + public Connection getConnection() throws SQLException { + return JdbcUtils.getConnectionDefault(this); + } } diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkFactory.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkFactory.java index 71ec28e19b2e1..62dbf264ccaa5 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkFactory.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSinkFactory.java @@ -35,10 +35,26 @@ public class JDBCSinkFactory implements SinkFactory { public static final String JDBC_URL_PROP = "jdbc.url"; public static final String TABLE_NAME_PROP = "table.name"; + /** + * Creates the appropriate config class based on the JDBC URL. Returns SnowflakeJDBCSinkConfig + * for Snowflake, otherwise JDBCSinkConfig. + * + * @param mapper ObjectMapper for deserialization + * @param tableProperties properties to deserialize + * @return appropriate config instance + */ + private JDBCSinkConfig createConfig(ObjectMapper mapper, Map tableProperties) { + String jdbcUrl = tableProperties.get(JDBC_URL_PROP); + if (jdbcUrl != null && jdbcUrl.startsWith("jdbc:snowflake")) { + return mapper.convertValue(tableProperties, SnowflakeJDBCSinkConfig.class); + } + return mapper.convertValue(tableProperties, JDBCSinkConfig.class); + } + @Override public SinkWriter createWriter(TableSchema tableSchema, Map tableProperties) { ObjectMapper mapper = new ObjectMapper(); - JDBCSinkConfig config = mapper.convertValue(tableProperties, JDBCSinkConfig.class); + JDBCSinkConfig config = createConfig(mapper, tableProperties); if ((config.getJdbcUrl().startsWith("jdbc:snowflake") || config.getJdbcUrl().startsWith("jdbc:redshift"))) { return new BatchAppendOnlyJDBCSink(config, tableSchema); @@ -51,7 +67,7 @@ public void validate( TableSchema tableSchema, Map tableProperties, SinkType sinkType) { ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_MISSING_CREATOR_PROPERTIES, true); - JDBCSinkConfig config = mapper.convertValue(tableProperties, JDBCSinkConfig.class); + JDBCSinkConfig config = createConfig(mapper, tableProperties); String jdbcUrl = config.getJdbcUrl(); String tableName = config.getTableName(); @@ -60,9 +76,7 @@ public void validate( Set jdbcPks = new HashSet<>(); Set jdbcTableNames = new HashSet<>(); - try (Connection conn = - DriverManager.getConnection( - jdbcUrl, config.getUser(), config.getPassword()); + try (Connection conn = config.getConnection(); ResultSet tableNamesResultSet = conn.getMetaData().getTables(null, schemaName, "%", null); ResultSet columnResultSet = diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JdbcUtils.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JdbcUtils.java index 71e37d5f93b96..6234b44587648 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JdbcUtils.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JdbcUtils.java @@ -25,8 +25,11 @@ import java.sql.SQLException; import java.util.Optional; import java.util.Properties; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public abstract class JdbcUtils { + private static final Logger LOG = LoggerFactory.getLogger(JdbcUtils.class); static final int CONNECTION_TIMEOUT = 30; static final int SOCKET_TIMEOUT = 300; @@ -47,15 +50,18 @@ public static Optional getDialectFactory(String jdbcUrl) { } } - /** The connection returned by this method is *not* autoCommit */ - public static Connection getConnection( - String jdbcUrl, String user, String password, boolean autoCommit, int batchInsertRows) - throws SQLException { + /** + * Creates base JDBC connection properties with common settings. This is a helper method that + * can be used by default and specialized connection logic. + * + * @param jdbcUrl JDBC URL to determine database-specific settings + * @param user Username for authentication + * @return Properties object with base connection settings + */ + static Properties createBaseProperties(String jdbcUrl, String user) { var props = new Properties(); + // enable TCP keep alive to avoid connection closed by server - // both MySQL and PG support this property - // https://jdbc.postgresql.org/documentation/use/ - // https://dev.mysql.com/doc/connectors/en/connector-j-connp-props-networking.html#cj-conn-prop_tcpKeepAlive props.setProperty("tcpKeepAlive", "true"); // default timeout in seconds @@ -66,20 +72,40 @@ public static Connection getConnection( int socketTimeout = isPg ? SOCKET_TIMEOUT : SOCKET_TIMEOUT * 1000; props.setProperty("connectTimeout", String.valueOf(connectTimeout)); props.setProperty("socketTimeout", String.valueOf(socketTimeout)); + if (user != null) { props.put("user", user); } - if (password != null) { - props.put("password", password); + + return props; + } + + /** + * Creates a JDBC connection for the default configuration (password authentication). The + * connection returned by this method is *not* autoCommit by default. + * + * @param config JDBC sink configuration + * @return JDBC connection + * @throws SQLException if connection fails + */ + static Connection getConnectionDefault(JDBCSinkConfig config) throws SQLException { + String jdbcUrl = config.getJdbcUrl(); + var props = createBaseProperties(jdbcUrl, config.getUser()); + + // Default password authentication + if (config.getPassword() != null) { + props.put("password", config.getPassword()); } - if (jdbcUrl.startsWith("jdbc:redshift") && batchInsertRows > 0) { + + if (jdbcUrl.startsWith("jdbc:redshift") && config.getBatchInsertRows() > 0) { props.setProperty("reWriteBatchedInserts", "true"); - props.setProperty("reWriteBatchedInsertsSize", String.valueOf(batchInsertRows)); + props.setProperty( + "reWriteBatchedInsertsSize", String.valueOf(config.getBatchInsertRows())); } var conn = DriverManager.getConnection(jdbcUrl, props); // disable auto commit can improve performance - conn.setAutoCommit(autoCommit); + conn.setAutoCommit(config.isAutoCommit()); // explicitly set isolation level to RC conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); return conn; diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/SnowflakeJDBCSinkConfig.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/SnowflakeJDBCSinkConfig.java new file mode 100644 index 0000000000000..dfce00ab369bc --- /dev/null +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/SnowflakeJDBCSinkConfig.java @@ -0,0 +1,242 @@ +/* + * Copyright 2025 RisingWave Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.risingwave.connector; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.io.IOException; +import java.io.StringReader; +import java.security.GeneralSecurityException; +import java.security.PrivateKey; +import java.security.Security; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; +import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; +import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder; +import org.bouncycastle.operator.InputDecryptorProvider; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo; +import org.bouncycastle.pkcs.PKCSException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Snowflake-specific JDBC sink configuration with support for key-pair authentication. Extends + * JDBCSinkConfig with Snowflake-specific authentication fields. + */ +public class SnowflakeJDBCSinkConfig extends JDBCSinkConfig { + private static final Logger LOG = LoggerFactory.getLogger(SnowflakeJDBCSinkConfig.class); + + private static final String PROP_PRIVATE_KEY = "privateKey"; + private static final String PROP_PRIVATE_KEY_PEM = "private_key_pem"; + private static final String PROP_PRIVATE_KEY_PWD = "private_key_file_pwd"; + private static final String PROP_AUTH_METHOD = "auth.method"; + private static final String AUTH_METHOD_KEY_PAIR_OBJECT = "key_pair_object"; + + static { + if (Security.getProvider("BC") == null) { + Security.addProvider(new BouncyCastleProvider()); + } + } + + // Authentication method control (password | key_pair_file | key_pair_object) + @JsonProperty(value = "auth.method") + private String authMethod; + + // Key-pair authentication via connection Properties (file-based) + @JsonProperty(value = "private_key_file") + private String privateKeyFile; + + @JsonProperty(value = "private_key_file_pwd") + private String privateKeyFilePwd; + + // Key-pair authentication via connection Properties (object-based, PEM content) + @JsonProperty(value = "private_key_pem") + private String privateKeyPem; + + @JsonCreator + public SnowflakeJDBCSinkConfig( + @JsonProperty(value = "jdbc.url") String jdbcUrl, + @JsonProperty(value = "table.name") String tableName, + @JsonProperty(value = "type") String sinkType) { + super(jdbcUrl, tableName, sinkType); + } + + public String getAuthMethod() { + return authMethod; + } + + public String getPrivateKeyFile() { + return privateKeyFile; + } + + public String getPrivateKeyFilePwd() { + return privateKeyFilePwd; + } + + public String getPrivateKeyPem() { + return privateKeyPem; + } + + /** + * Creates a Snowflake JDBC connection with support for key-pair authentication. Overrides the + * base implementation to handle Snowflake-specific auth methods. + * + * @return JDBC connection configured for Snowflake + * @throws SQLException if connection fails + */ + @Override + public Connection getConnection() throws SQLException { + String jdbcUrl = getJdbcUrl(); + // Use shared helper to create base properties + var props = JdbcUtils.createBaseProperties(jdbcUrl, getUser()); + + // Set authentication-related properties + if (authMethod != null) { + props.put(PROP_AUTH_METHOD, authMethod); + } + if (getPassword() != null) { + props.put("password", getPassword()); + } + if (privateKeyFile != null) { + props.put("private_key_file", privateKeyFile); + } + if (privateKeyPem != null) { + props.put(PROP_PRIVATE_KEY_PEM, privateKeyPem); + } + if (privateKeyFilePwd != null && !privateKeyFilePwd.isEmpty()) { + props.put(PROP_PRIVATE_KEY_PWD, privateKeyFilePwd); + } + + // Handle Snowflake-specific authentication + try { + handleSnowflakeAuth(props); + } catch (SQLException e) { + throw e; + } catch (Exception e) { + LOG.error("Failed to configure Snowflake authentication", e); + throw new SQLException("Failed to configure authentication: " + e.getMessage(), e); + } + + var conn = DriverManager.getConnection(jdbcUrl, props); + conn.setAutoCommit(isAutoCommit()); + conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); + return conn; + } + + /** + * Handles Snowflake-specific authentication by processing the auth.method property and setting + * up appropriate authentication properties. This is the main entry point for Snowflake + * authentication configuration. + * + *

Supported authentication methods: - password: Standard username/password authentication - + * key_pair_file: Key-pair authentication using a private key file path - key_pair_object: + * Key-pair authentication using PEM content (converted to PrivateKey object) + * + * @param props Properties to configure (will be modified in-place). Expected properties: + * auth.method, password, private_key_file, private_key_pem, private_key_file_pwd + * @throws IOException if PEM parsing fails + * @throws GeneralSecurityException if key conversion fails + * @throws OperatorCreationException if decryption fails + * @throws PKCSException if PKCS decryption fails + * @throws SQLException if authentication setup fails + */ + public static void handleSnowflakeAuth(Properties props) + throws IOException, + GeneralSecurityException, + OperatorCreationException, + PKCSException, + SQLException { + String authMethod = props.getProperty(PROP_AUTH_METHOD); + if (authMethod == null) { + // No auth method specified, use default password authentication + return; + } + + props.remove(PROP_AUTH_METHOD); + + if ("password".equalsIgnoreCase(authMethod)) { + // Password authentication - no additional processing needed + // The password property is already set + } else if ("key_pair_file".equalsIgnoreCase(authMethod)) { + // File-based key-pair authentication + // The private_key_file and optional private_key_file_pwd properties are already set + props.put("authenticator", "snowflake_jwt"); + } else if (AUTH_METHOD_KEY_PAIR_OBJECT.equalsIgnoreCase(authMethod)) { + // Object-based key-pair authentication - convert PEM to PrivateKey object + String pem = props.getProperty(PROP_PRIVATE_KEY_PEM); + String passphrase = props.getProperty(PROP_PRIVATE_KEY_PWD); + + PrivateKey privateKey = loadPrivateKeyFromPem(pem, passphrase); + props.put(PROP_PRIVATE_KEY, privateKey); + props.remove(PROP_PRIVATE_KEY_PEM); + props.remove(PROP_PRIVATE_KEY_PWD); + props.put("authenticator", "snowflake_jwt"); + LOG.debug("Loaded private key for Snowflake authentication"); + } + } + + public static PrivateKey loadPrivateKeyFromPem(String pemContent, String passphrase) + throws IOException, GeneralSecurityException, OperatorCreationException, PKCSException { + try (PEMParser parser = new PEMParser(new StringReader(pemContent))) { + Object parsed = parser.readObject(); + if (parsed == null) { + throw new GeneralSecurityException("No key found in privateKeyPem content"); + } + return convertToPrivateKey(parsed, passphrase); + } + } + + public static PrivateKey convertToPrivateKey(Object parsed, String passphrase) + throws GeneralSecurityException, IOException, OperatorCreationException, PKCSException { + JcaPEMKeyConverter converter = new JcaPEMKeyConverter(); + converter.setProvider("BC"); + + if (parsed instanceof PrivateKeyInfo) { + PrivateKeyInfo info = (PrivateKeyInfo) parsed; + return converter.getPrivateKey(info); + } + if (parsed instanceof PKCS8EncryptedPrivateKeyInfo) { + PKCS8EncryptedPrivateKeyInfo encryptedInfo = (PKCS8EncryptedPrivateKeyInfo) parsed; + if (passphrase == null || passphrase.isEmpty()) { + throw new GeneralSecurityException( + "Encrypted private key provided but 'private_key_file_pwd' is missing"); + } + InputDecryptorProvider decryptorProvider = + new JceOpenSSLPKCS8DecryptorProviderBuilder().build(passphrase.toCharArray()); + return converter.getPrivateKey(encryptedInfo.decryptPrivateKeyInfo(decryptorProvider)); + } + + // Some PEMs may be base64 without headers; attempt to decode as DER PKCS#8 + if (parsed instanceof byte[]) { + try { + PrivateKeyInfo info = PrivateKeyInfo.getInstance((byte[]) parsed); + return converter.getPrivateKey(info); + } catch (Exception e) { + throw new GeneralSecurityException("Unsupported private key format", e); + } + } + + throw new GeneralSecurityException( + "Unsupported private key object type: " + parsed.getClass().getName()); + } +} diff --git a/src/connector/src/sink/jdbc_jni_client.rs b/src/connector/src/sink/jdbc_jni_client.rs index 4f42cff584095..04df56c216d77 100644 --- a/src/connector/src/sink/jdbc_jni_client.rs +++ b/src/connector/src/sink/jdbc_jni_client.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt; + use anyhow::Context; use jni::objects::JObject; use risingwave_common::global_jvm::Jvm; @@ -20,10 +22,10 @@ use risingwave_jni_core::jvm_runtime::execute_with_jni_env; use crate::sink::Result; -#[derive(Debug)] pub struct JdbcJniClient { jvm: Jvm, jdbc_url: String, + driver_props: Option>, } impl Clone for JdbcJniClient { @@ -31,50 +33,94 @@ impl Clone for JdbcJniClient { Self { jvm: self.jvm, jdbc_url: self.jdbc_url.clone(), + driver_props: self.driver_props.clone(), } } } +impl fmt::Debug for JdbcJniClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JdbcJniClient") + .field("jdbc_url", &self.jdbc_url) + .field("driver_props", &"") + .finish() + } +} + impl JdbcJniClient { pub fn new(jdbc_url: String) -> Result { let jvm = Jvm::get_or_init()?; - Ok(Self { jvm, jdbc_url }) + Ok(Self { + jvm, + jdbc_url, + driver_props: None, + }) + } + + pub fn new_with_props(jdbc_url: String, driver_props: Vec<(String, String)>) -> Result { + let jvm = Jvm::get_or_init()?; + Ok(Self { + jvm, + jdbc_url, + driver_props: Some(driver_props), + }) } pub async fn execute_sql_sync(&self, sql: Vec) -> anyhow::Result<()> { + self.execute_sql_sync_with_props(sql).await + } + + pub async fn execute_sql_sync_with_props(&self, sql: Vec) -> anyhow::Result<()> { let jvm = self.jvm; let jdbc_url = self.jdbc_url.clone(); + let driver_props = self.driver_props.clone().unwrap_or_default(); tokio::task::spawn_blocking(move || -> anyhow::Result<()> { execute_with_jni_env(jvm, |env| { - // get source handler by source id - let full_url = env.new_string(&jdbc_url).with_context(|| { - format!( - "Failed to create jni string from source offset: {}.", - jdbc_url - ) + let j_url = env.new_string(&jdbc_url).with_context(|| { + format!("Failed to create jni string from jdbc url: {}.", jdbc_url) })?; - let props = - env.new_object_array((sql.len()) as i32, "java/lang/String", JObject::null())?; + // SQL array + let sql_arr = + env.new_object_array(sql.len() as i32, "java/lang/String", JObject::null())?; + for (i, s) in sql.iter().enumerate() { + let s_j = env.new_string(s)?; + env.set_object_array_element(&sql_arr, i as i32, s_j)?; + } - for (i, sql) in sql.iter().enumerate() { - let sql_j_str = env.new_string(sql)?; - env.set_object_array_element(&props, i as i32, sql_j_str)?; + // Driver properties as separate key and value arrays + let keys_arr = env.new_object_array( + driver_props.len() as i32, + "java/lang/String", + JObject::null(), + )?; + let values_arr = env.new_object_array( + driver_props.len() as i32, + "java/lang/String", + JObject::null(), + )?; + for (i, (k, v)) in driver_props.iter().enumerate() { + let k_j = env.new_string(k)?; + let v_j = env.new_string(v)?; + env.set_object_array_element(&keys_arr, i as i32, k_j)?; + env.set_object_array_element(&values_arr, i as i32, v_j)?; } call_static_method!( env, { com.risingwave.runner.JDBCSqlRunner }, - { void executeSql(String, String[]) }, - &full_url, - &props + { void executeSqlWithProps(String, String[], String[], String[]) }, + &j_url, + &sql_arr, + &keys_arr, + &values_arr )?; Ok(()) })?; Ok(()) }) .await - .context("Failed to execute SQL via JDBC JNI client")? + .context("Failed to execute SQL via JDBC JNI client with properties")? } } diff --git a/src/connector/src/sink/snowflake_redshift/snowflake.rs b/src/connector/src/sink/snowflake_redshift/snowflake.rs index cf99cefddef11..5f8ccc2760bc2 100644 --- a/src/connector/src/sink/snowflake_redshift/snowflake.rs +++ b/src/connector/src/sink/snowflake_redshift/snowflake.rs @@ -47,6 +47,11 @@ pub const SNOWFLAKE_SINK_V2: &str = "snowflake_v2"; pub const SNOWFLAKE_SINK_ROW_ID: &str = "__row_id"; pub const SNOWFLAKE_SINK_OP: &str = "__op"; +const AUTH_METHOD_PASSWORD: &str = "password"; +const AUTH_METHOD_KEY_PAIR_FILE: &str = "key_pair_file"; +const AUTH_METHOD_KEY_PAIR_OBJECT: &str = "key_pair_object"; +const PROP_AUTH_METHOD: &str = "auth.method"; + #[serde_as] #[derive(Debug, Clone, Deserialize, WithOptions)] pub struct SnowflakeV2Config { @@ -82,6 +87,21 @@ pub struct SnowflakeV2Config { #[serde(rename = "password")] pub password: Option, + // Authentication method control (password | key_pair_file | key_pair_object) + #[serde(rename = "auth.method")] + pub auth_method: Option, + + // Key-pair authentication via connection Properties (Option 2: file-based) + #[serde(rename = "private_key_file")] + pub private_key_file: Option, + + #[serde(rename = "private_key_file_pwd")] + pub private_key_file_pwd: Option, + + // Key-pair authentication via connection Properties (Option 1: object-based, PEM content) + #[serde(rename = "private_key_pem")] + pub private_key_pem: Option, + /// Commit every n(>0) checkpoints, default is 10. #[serde(default = "default_commit_checkpoint_interval")] #[serde_as(as = "DisplayFromStr")] @@ -121,8 +141,66 @@ fn default_with_s3() -> bool { } impl SnowflakeV2Config { + /// Build JDBC Properties for the Snowflake JDBC connection (no URL parameters). + /// Returns (`jdbc_url`, `driver_properties`). + /// - `driver_properties` are transformed/used by the Java runner and passed to `DriverManager::getConnection(url, props)` + /// + /// Note: This method assumes the config has been validated by `from_btreemap`. + pub fn build_jdbc_connection_properties(&self) -> Result<(String, Vec<(String, String)>)> { + let jdbc_url = self + .jdbc_url + .clone() + .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?; + let username = self + .username + .clone() + .ok_or(SinkError::Config(anyhow!("username is required")))?; + + let mut connection_properties: Vec<(String, String)> = vec![("user".to_owned(), username)]; + + // auth_method is guaranteed to be Some after validation in from_btreemap + match self.auth_method.as_deref().unwrap() { + AUTH_METHOD_PASSWORD => { + // password is guaranteed to exist by from_btreemap validation + connection_properties.push(("password".to_owned(), self.password.clone().unwrap())); + } + AUTH_METHOD_KEY_PAIR_FILE => { + // private_key_file is guaranteed to exist by from_btreemap validation + connection_properties.push(( + "private_key_file".to_owned(), + self.private_key_file.clone().unwrap(), + )); + if let Some(pwd) = self.private_key_file_pwd.clone() { + connection_properties.push(("private_key_file_pwd".to_owned(), pwd)); + } + } + AUTH_METHOD_KEY_PAIR_OBJECT => { + connection_properties.push(( + PROP_AUTH_METHOD.to_owned(), + AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(), + )); + // private_key_pem is guaranteed to exist by from_btreemap validation + connection_properties.push(( + "private_key_pem".to_owned(), + self.private_key_pem.clone().unwrap(), + )); + if let Some(pwd) = self.private_key_file_pwd.clone() { + connection_properties.push(("private_key_file_pwd".to_owned(), pwd)); + } + } + _ => { + // This should never happen since from_btreemap validates auth_method + unreachable!( + "Invalid auth_method - should have been caught during config validation" + ) + } + } + + Ok((jdbc_url, connection_properties)) + } + pub fn from_btreemap(properties: &BTreeMap) -> Result { - let config = + let mut config = serde_json::from_value::(serde_json::to_value(properties).unwrap()) .map_err(|e| SinkError::Config(anyhow!(e)))?; if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT { @@ -133,6 +211,87 @@ impl SnowflakeV2Config { SINK_TYPE_UPSERT ))); } + + // Normalize and validate authentication method + let has_password = config.password.is_some(); + let has_file = config.private_key_file.is_some(); + let has_pem = config.private_key_pem.as_deref().is_some(); + + let normalized_auth_method = match config + .auth_method + .as_deref() + .map(|s| s.trim().to_ascii_lowercase()) + { + Some(method) if method == AUTH_METHOD_PASSWORD => { + if !has_password { + return Err(SinkError::Config(anyhow!( + "auth.method=password requires `password`" + ))); + } + if has_file || has_pem { + return Err(SinkError::Config(anyhow!( + "auth.method=password must not set `private_key_file`/`private_key_pem`" + ))); + } + AUTH_METHOD_PASSWORD.to_owned() + } + Some(method) if method == AUTH_METHOD_KEY_PAIR_FILE => { + if !has_file { + return Err(SinkError::Config(anyhow!( + "auth.method=key_pair_file requires `private_key_file`" + ))); + } + if has_password { + return Err(SinkError::Config(anyhow!( + "auth.method=key_pair_file must not set `password`" + ))); + } + if has_pem { + return Err(SinkError::Config(anyhow!( + "auth.method=key_pair_file must not set `private_key_pem`" + ))); + } + AUTH_METHOD_KEY_PAIR_FILE.to_owned() + } + Some(method) if method == AUTH_METHOD_KEY_PAIR_OBJECT => { + if !has_pem { + return Err(SinkError::Config(anyhow!( + "auth.method=key_pair_object requires `private_key_pem`" + ))); + } + if has_password { + return Err(SinkError::Config(anyhow!( + "auth.method=key_pair_object must not set `password`" + ))); + } + AUTH_METHOD_KEY_PAIR_OBJECT.to_owned() + } + Some(other) => { + return Err(SinkError::Config(anyhow!( + "invalid auth.method: {} (allowed: password | key_pair_file | key_pair_object)", + other + ))); + } + None => { + // Infer auth method from supplied fields + match (has_password, has_file, has_pem) { + (true, false, false) => AUTH_METHOD_PASSWORD.to_owned(), + (false, true, false) => AUTH_METHOD_KEY_PAIR_FILE.to_owned(), + (false, false, true) => AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(), + (true, true, _) | (true, _, true) | (false, true, true) => { + return Err(SinkError::Config(anyhow!( + "ambiguous auth: multiple auth options provided; remove one or set `auth.method`" + ))); + } + _ => { + return Err(SinkError::Config(anyhow!( + "no authentication configured: set either `password`, or `private_key_file`, or `private_key_pem` (or provide `auth.method`)" + ))); + } + } + } + }; + config.auth_method = Some(normalized_auth_method); Ok(config) } @@ -166,20 +325,8 @@ impl SnowflakeV2Config { ..Default::default() }; - let jdbc_url = self - .jdbc_url - .clone() - .ok_or(SinkError::Config(anyhow!("jdbc.url is required")))?; - let username = self - .username - .clone() - .ok_or(SinkError::Config(anyhow!("username is required")))?; - let password = self - .password - .clone() - .ok_or(SinkError::Config(anyhow!("password is required")))?; - let jdbc_url = format!("{}?user={}&password={}", jdbc_url, username, password); - let client = JdbcJniClient::new(jdbc_url)?; + let (jdbc_url, connection_properties) = self.build_jdbc_connection_properties()?; + let client = JdbcJniClient::new_with_props(jdbc_url, connection_properties)?; if self.with_s3 { let stage = self @@ -236,6 +383,9 @@ impl EnforceSecret for SnowflakeV2Config { "username", "password", "jdbc.url", + // Key-pair authentication secrets + "private_key_file_pwd", + "private_key_pem", }; } @@ -483,7 +633,7 @@ impl SnowflakeSinkJdbcWriter { config.snowflake_cdc_table_name.clone().unwrap_or_default() ) }; - let new_properties = BTreeMap::from([ + let mut new_properties = BTreeMap::from([ ("table.name".to_owned(), full_table_name), ("connector".to_owned(), "snowflake_v2".to_owned()), ( @@ -491,14 +641,6 @@ impl SnowflakeSinkJdbcWriter { config.jdbc_url.clone().unwrap_or_default(), ), ("type".to_owned(), "append-only".to_owned()), - ( - "user".to_owned(), - config.username.clone().unwrap_or_default(), - ), - ( - "password".to_owned(), - config.password.clone().unwrap_or_default(), - ), ( "primary_key".to_owned(), properties.get("primary_key").cloned().unwrap_or_default(), @@ -512,6 +654,13 @@ impl SnowflakeSinkJdbcWriter { config.snowflake_database.clone().unwrap_or_default(), ), ]); + + // Reuse build_jdbc_connection_properties to get driver properties (auth, user, etc.) + let (_jdbc_url, connection_properties) = config.build_jdbc_connection_properties()?; + for (key, value) in connection_properties { + new_properties.insert(key, value); + } + param.properties = new_properties; let jdbc_sink_writer = @@ -1015,9 +1164,83 @@ END;"#, #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use super::*; use crate::sink::jdbc_jni_client::normalize_sql; + fn base_properties() -> BTreeMap { + BTreeMap::from([ + ("type".to_owned(), "append-only".to_owned()), + ("jdbc.url".to_owned(), "jdbc:snowflake://account".to_owned()), + ("username".to_owned(), "RW_USER".to_owned()), + ]) + } + + #[test] + fn test_build_jdbc_props_password() { + let mut props = base_properties(); + props.insert("password".to_owned(), "secret".to_owned()); + let config = SnowflakeV2Config::from_btreemap(&props).unwrap(); + let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap(); + assert_eq!(url, "jdbc:snowflake://account"); + let map: BTreeMap<_, _> = connection_properties.into_iter().collect(); + assert_eq!(map.get("user"), Some(&"RW_USER".to_owned())); + assert_eq!(map.get("password"), Some(&"secret".to_owned())); + assert!(!map.contains_key("authenticator")); + } + + #[test] + fn test_build_jdbc_props_key_pair_file() { + let mut props = base_properties(); + props.insert( + "auth.method".to_owned(), + AUTH_METHOD_KEY_PAIR_FILE.to_owned(), + ); + props.insert("private_key_file".to_owned(), "/tmp/rsa_key.p8".to_owned()); + props.insert("private_key_file_pwd".to_owned(), "dummy".to_owned()); + let config = SnowflakeV2Config::from_btreemap(&props).unwrap(); + let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap(); + assert_eq!(url, "jdbc:snowflake://account"); + let map: BTreeMap<_, _> = connection_properties.into_iter().collect(); + assert_eq!(map.get("user"), Some(&"RW_USER".to_owned())); + assert_eq!( + map.get("private_key_file"), + Some(&"/tmp/rsa_key.p8".to_owned()) + ); + assert_eq!(map.get("private_key_file_pwd"), Some(&"dummy".to_owned())); + } + + #[test] + fn test_build_jdbc_props_key_pair_object() { + let mut props = base_properties(); + props.insert( + "auth.method".to_owned(), + AUTH_METHOD_KEY_PAIR_OBJECT.to_owned(), + ); + props.insert( + "private_key_pem".to_owned(), + "-----BEGIN PRIVATE KEY----- +... +-----END PRIVATE KEY-----" + .to_owned(), + ); + let config = SnowflakeV2Config::from_btreemap(&props).unwrap(); + let (url, connection_properties) = config.build_jdbc_connection_properties().unwrap(); + assert_eq!(url, "jdbc:snowflake://account"); + let map: BTreeMap<_, _> = connection_properties.into_iter().collect(); + assert_eq!( + map.get("private_key_pem"), + Some( + &"-----BEGIN PRIVATE KEY----- +... +-----END PRIVATE KEY-----" + .to_owned() + ) + ); + assert!(!map.contains_key("private_key_file")); + } + #[test] fn test_snowflake_sink_commit_coordinator() { let snowflake_task_context = SnowflakeTaskContext { diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 32c1117b3ba8f..d1add665298f7 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -1481,6 +1481,18 @@ SnowflakeV2Config: - name: password field_type: String required: false + - name: auth.method + field_type: String + required: false + - name: private_key_file + field_type: String + required: false + - name: private_key_file_pwd + field_type: String + required: false + - name: private_key_pem + field_type: String + required: false - name: commit_checkpoint_interval field_type: u64 comments: Commit every n(>0) checkpoints, default is 10.