diff --git a/README.md b/README.md index 6f70a47..f1988d7 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,9 @@ An example consists of two pipelines: * Writing into Snowflake 1. Reading files from provided by `inputFile` argument. 2. Counting words - 3. Writing counts into Snowflake table provided by `tableName` argument. + 3. Writing counts into Snowflake table provided by `table` argument. * Reading from Snowflake - 1. Reading counts from Snowflake table provided by `tableName` argument. + 1. Reading counts from Snowflake table provided by `table` argument. 2. Writing counts into provided by `output` argument. #### Executing: @@ -70,7 +70,7 @@ An example consists of two pipelines: --password= \ --database= \ --schema= \ - --tableName= \ + --table= \ --storageIntegrationName= \ --stagingBucketName= \ --runner= \ diff --git a/beam-sdks-java-io-snowflake-2.22.0-SNAPSHOT.jar b/beam-sdks-java-io-snowflake-2.23.0-SNAPSHOT.jar similarity index 95% rename from beam-sdks-java-io-snowflake-2.22.0-SNAPSHOT.jar rename to beam-sdks-java-io-snowflake-2.23.0-SNAPSHOT.jar index 64e97db..340b385 100644 Binary files a/beam-sdks-java-io-snowflake-2.22.0-SNAPSHOT.jar and b/beam-sdks-java-io-snowflake-2.23.0-SNAPSHOT.jar differ diff --git a/build.gradle b/build.gradle index 02baa77..686c36b 100644 --- a/build.gradle +++ b/build.gradle @@ -22,10 +22,11 @@ repositories { dependencies { testCompile group: 'junit', name: 'junit', version: '4.12' - compile files('beam-sdks-java-io-snowflake-2.22.0-SNAPSHOT.jar') + compile files('beam-sdks-java-io-snowflake-2.23.0-SNAPSHOT.jar') compile group: 'org.apache.beam', name: 'beam-sdks-java-core', version: '2.22.0' compile group: 'org.apache.beam', name: 'beam-runners-direct-java', version: '2.22.0' compile group: 'org.apache.beam', name: 'beam-runners-google-cloud-dataflow-java', version: '2.22.0' + compile 'com.google.cloud:google-cloud-kms:1.20.0' } task execute (type:JavaExec) { diff --git a/src/main/java/batching/SnowflakeWordCountOptions.java b/src/main/java/batching/SnowflakeWordCountOptions.java index 6f98425..5aa0099 100644 --- a/src/main/java/batching/SnowflakeWordCountOptions.java +++ b/src/main/java/batching/SnowflakeWordCountOptions.java @@ -4,6 +4,7 @@ import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.Validation.Required; +import org.apache.beam.sdk.options.ValueProvider; /** * Supported PipelineOptions used in provided examples. @@ -21,4 +22,10 @@ public interface SnowflakeWordCountOptions extends SnowflakePipelineOptions { String getOutput(); void setOutput(String value); + + @Description( + "KMS Encryption Key should be in the format projects/{gcp_project}/locations/{key_region}/keyRings/{key_ring}/cryptoKeys/{kms_key_name}") + ValueProvider getKMSEncryptionKey(); + + void setKMSEncryptionKey(ValueProvider keyName); } \ No newline at end of file diff --git a/src/main/java/batching/WordCountExample.java b/src/main/java/batching/WordCountExample.java index 42d3b6e..e813c8c 100644 --- a/src/main/java/batching/WordCountExample.java +++ b/src/main/java/batching/WordCountExample.java @@ -13,6 +13,7 @@ import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; @@ -23,10 +24,11 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import util.KMSEncryptedNestedValueProvider; /** * An example that contains batch writing and reading from Snowflake. Inspired by Apache Beam/WordCount-example(https://github.com/apache/beam/blob/master/examples/java/src/main/java/org/apache/beam/examples/WordCount.java) - * + *

* Check main README for more information. */ public class WordCountExample { @@ -98,9 +100,16 @@ private static PTransform> createSnowflakeRead public static SnowflakeIO.DataSourceConfiguration createSnowflakeConfiguration(SnowflakeWordCountOptions options) { return SnowflakeIO.DataSourceConfiguration.create() - .withUsernamePasswordAuth(options.getUsername(), options.getPassword()) + .withUsernamePasswordAuth( + maybeDecrypt(options.getUsername(), options.getKMSEncryptionKey()), + maybeDecrypt(options.getPassword(), options.getKMSEncryptionKey()) + ) .withOAuth(options.getOauthToken()) - .withKeyPairRawAuth(options.getUsername(), options.getRawPrivateKey(), options.getPrivateKeyPassphrase()) + .withKeyPairRawAuth( + maybeDecrypt(options.getUsername(), options.getKMSEncryptionKey()), + maybeDecrypt(options.getRawPrivateKey(), options.getKMSEncryptionKey()), + maybeDecrypt(options.getPrivateKeyPassphrase(), options.getKMSEncryptionKey()) + ) .withDatabase(options.getDatabase()) .withServerName(options.getServerName()) .withSchema(options.getSchema()); @@ -154,4 +163,10 @@ public void processElement(@Element String element, OutputReceiver recei } } + private static ValueProvider maybeDecrypt( + ValueProvider unencryptedValue, ValueProvider kmsKey) { + + return new KMSEncryptedNestedValueProvider(unencryptedValue, kmsKey); + } + } diff --git a/src/main/java/util/DualInputNestedValueProvider.java b/src/main/java/util/DualInputNestedValueProvider.java new file mode 100644 index 0000000..549e67a --- /dev/null +++ b/src/main/java/util/DualInputNestedValueProvider.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. SecondTou may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANSecondT KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package util; + +import com.google.common.base.MoreObjects; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.SerializableFunction; + +import java.io.Serializable; + +/** + * {@link DualInputNestedValueProvider} is an implementation of {@link ValueProvider} that allows + * for wrapping two {@link ValueProvider} objects. It's inspired by {@link + * org.apache.beam.sdk.options.ValueProvider.NestedValueProvider} but it can accept two inputs + * rather than one. + */ +public class DualInputNestedValueProvider + implements ValueProvider, Serializable { + + /** Pair like struct holding two values. */ + public static class TranslatorInput { + private final FirstT x; + private final SecondT y; + + public TranslatorInput(FirstT x, SecondT y) { + this.x = x; + this.y = y; + } + + public FirstT getX() { + return x; + } + + public SecondT getY() { + return y; + } + } + + private final ValueProvider valueX; + private final ValueProvider valueY; + private final SerializableFunction, T> translator; + private transient volatile T cachedValue; + + public DualInputNestedValueProvider( + ValueProvider valueX, + ValueProvider valueY, + SerializableFunction, T> translator) { + this.valueX = valueX; + this.valueY = valueY; + this.translator = translator; + } + + /** Creates a {@link NestedValueProvider} that wraps two provided values. */ + public static DualInputNestedValueProvider of( + ValueProvider valueX, + ValueProvider valueY, + SerializableFunction, T> translator) { + DualInputNestedValueProvider factory = + new DualInputNestedValueProvider<>(valueX, valueY, translator); + return factory; + } + + @Override + public T get() { + if (cachedValue == null) { + cachedValue = translator.apply(new TranslatorInput<>(valueX.get(), valueY.get())); + } + return cachedValue; + } + + @Override + public boolean isAccessible() { + return valueX.isAccessible() && valueY.isAccessible(); + } + + @Override + public String toString() { + if (isAccessible()) { + return String.valueOf(get()); + } + return MoreObjects.toStringHelper(this) + .add("valueX", valueX) + .add("valueY", valueY) + .add("translator", translator.getClass().getSimpleName()) + .toString(); + } +} diff --git a/src/main/java/util/KMSEncryptedNestedValueProvider.java b/src/main/java/util/KMSEncryptedNestedValueProvider.java new file mode 100644 index 0000000..a80a984 --- /dev/null +++ b/src/main/java/util/KMSEncryptedNestedValueProvider.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. SecondTou may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANSecondT KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package util; + +import com.google.cloud.kms.v1.DecryptResponse; +import com.google.cloud.kms.v1.KeyManagementServiceClient; +import com.google.protobuf.ByteString; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Base64; +import java.util.regex.Pattern; + + +/** + * {@link KMSEncryptedNestedValueProvider} is a subclass of {@link DualInputNestedValueProvider} + * that allows for taking two {@link ValueProvider} objects - one as an encrypted string and the + * other as a KMS encryption key. If no encryption key is passed, the string is returned, else + * the encryption key is used to decrypt the encrypted string. + */ +public class KMSEncryptedNestedValueProvider + extends DualInputNestedValueProvider { + private static final Pattern KEYNAME_PATTERN = + Pattern.compile( + "projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/" + + "[a-zA-Z0-9_-]{1,63}/cryptoKeys/[a-zA-Z0-9_-]{1,63}"); + + /** The log to output status messages to. */ + private static final Logger LOG = LoggerFactory.getLogger(KMSEncryptedNestedValueProvider.class); + + private static class KmsTranslatorInput + implements SerializableFunction, String> { + private KmsTranslatorInput() {} + + public static KmsTranslatorInput of() { + return new KmsTranslatorInput(); + } + + @Override + public String apply(TranslatorInput input) { + String decrypted; + String unencrypted; + String kmsKey; + + unencrypted = input.getX(); + kmsKey = input.getY(); + + if (kmsKey == null || unencrypted.isEmpty()) { + LOG.info("KMS Key is not specified. Using: " + unencrypted); + return unencrypted; + } else if (!testkmsKey(kmsKey)) { + IllegalArgumentException exception = + new IllegalArgumentException("Provided KMS Key %s is invalid"); + throw new RuntimeException(exception); + } else { + try { + decrypted = decryptWithKMS(unencrypted /*value*/, kmsKey /*key*/); + } catch (IOException e) { + throw new RuntimeException(e); + } + return decrypted; + } + } + } + + /** Creates a {@link KMSEncryptedNestedValueProvider} that wraps + * the key and the encrypted value. + */ + public KMSEncryptedNestedValueProvider(ValueProvider value, ValueProvider key) { + super(value, key, KmsTranslatorInput.of()); + } + + private static boolean testkmsKey(String kmsKey) { + return KEYNAME_PATTERN.matcher(kmsKey).matches(); + } + + /** Uses the GCP KMS client to decrypt an encrypted value using a KMS key of the form + * projects/{gcp_project}/locations/{key_region}/keyRings/{key_ring}/cryptoKeys/{kms_key_name} + * The encrypted value should be a base64 encrypted string which has been encrypted using + * the KMS encrypt API call. + * See + * this KMS API Encrypt Link. + */ + private static String decryptWithKMS(String encryptedValue, String kmsKey) throws IOException { + /* + kmsKey should be in the following format: + projects/{gcp_project}/locations/{key_region}/keyRings/{key_ring}/cryptoKeys/{kms_key_name} + */ + + byte[] cipherText = Base64.getDecoder().decode(encryptedValue.getBytes("UTF-8")); + + + try (KeyManagementServiceClient client = KeyManagementServiceClient.create()) { + + // Decrypt the ciphertext with Cloud KMS. + DecryptResponse response = client.decrypt(kmsKey, ByteString.copyFrom(cipherText)); + + // Extract the plaintext from the response. + return new String(response.getPlaintext().toByteArray()); + } + } +}