Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ public static ECPrivateKey privateKeyFromPem(String pemEncoding) {
parser.close();

JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider(BOUNCY_CASTLE_PROVIDER);
;
return (ECPrivateKey) converter.getPrivateKey(privateKeyInfo);
} catch (IOException e) {
throw new RuntimeException(e);
Expand Down
52 changes: 39 additions & 13 deletions sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;

import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import org.bouncycastle.jce.interfaces.ECPublicKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -65,6 +66,32 @@ public InvalidNanoTDFConfig(String errorMessage) {
}
}

private static Optional<Config.KASInfo> getBaseKey(WellKnownServiceClientInterface wellKnownService) {
var key = Planner.fetchBaseKey(wellKnownService);
key.ifPresent(k -> {
if (!KeyType.fromAlgorithm(k.getPublicKey().getAlgorithm()).isEc()) {
throw new SDKException(String.format("base key is not an EC key, cannot create NanoTDF using a key of type %s",
k.getPublicKey().getAlgorithm()));
}
});

return key.map(Config.KASInfo::fromSimpleKasKey);
}

private Optional<Config.KASInfo> getKasInfo(Config.NanoTDFConfig nanoTDFConfig) {
if (nanoTDFConfig.kasInfoList.isEmpty()) {
logger.debug("no kas info provided in NanoTDFConfig");
return Optional.empty();
}
Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0);
String url = kasInfo.URL;
if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) {
logger.info("no public key provided for KAS at {}, retrieving", url);
kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getCurve());
}
return Optional.of(kasInfo);
}

private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) throws InvalidNanoTDFConfig, UnsupportedNanoTDFFeature {
if (nanoTDFConfig.collectionConfig.useCollection) {
Config.HeaderInfo headerInfo = nanoTDFConfig.collectionConfig.getHeaderInfo();
Expand All @@ -74,19 +101,20 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro
}

Gson gson = new GsonBuilder().create();
if (nanoTDFConfig.kasInfoList.isEmpty()) {
throw new InvalidNanoTDFConfig("kas url is missing");
Optional<Config.KASInfo> maybeKas = getKasInfo(nanoTDFConfig).or(() -> NanoTDF.getBaseKey(services.wellknown()));
if (maybeKas.isEmpty()) {
throw new SDKException("no KAS info provided and couldn't get base key, cannot create NanoTDF");
}

Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0);
Config.KASInfo kasInfo = maybeKas.get();
String url = kasInfo.URL;
if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) {
logger.info("no public key provided for KAS at {}, retrieving", url);
kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getCurve());
}

// Kas url resource locator
ResourceLocator kasURL = new ResourceLocator(nanoTDFConfig.kasInfoList.get(0).URL, kasInfo.KID);
ResourceLocator kasURL = new ResourceLocator(kasInfo.URL, kasInfo.KID);
assert kasURL.getIdentifier() != null : "Identifier in ResourceLocator cannot be null";

NanoTDFType.ECCurve ecCurve = getEcCurve(nanoTDFConfig, kasInfo);
Expand Down Expand Up @@ -139,12 +167,10 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro
// Create header
byte[] compressedPubKey = keyPair.compressECPublickey();
Header header = new Header();
ECCMode mode;
if (nanoTDFConfig.eccMode.getCurve() != keyPair.getCurve()) {
mode = new ECCMode(nanoTDFConfig.eccMode.getECCModeAsByte());
mode.setEllipticCurve(keyPair.getCurve());
} else {
mode = nanoTDFConfig.eccMode;
ECCMode mode = new ECCMode();
mode.setEllipticCurve(keyPair.getCurve());
if (logger.isWarnEnabled() && !nanoTDFConfig.eccMode.equals(mode)) {
logger.warn("ECC mode provided in NanoTDFConfig: {}, ECC mode from key: {}", nanoTDFConfig.eccMode.getCurve(), mode.getCurve());
}
header.setECCMode(mode);
header.setPayloadConfig(nanoTDFConfig.config);
Expand All @@ -169,10 +195,10 @@ private static NanoTDFType.ECCurve getEcCurve(Config.NanoTDFConfig nanoTDFConfig
logger.info("no curve specified in KASInfo, using the curve from config [{}]", nanoTDFConfig.eccMode.getCurve());
ecCurve = nanoTDFConfig.eccMode.getCurve();
} else {
if (specifiedCurve.get() != nanoTDFConfig.eccMode.getCurve()) {
logger.warn("ECCurve in NanoTDFConfig [{}] does not match the curve in KASInfo, using KASInfo curve [{}]", nanoTDFConfig.eccMode.getCurve(), specifiedCurve);
}
ecCurve = specifiedCurve.get();
if (ecCurve != nanoTDFConfig.eccMode.getCurve()) {
logger.warn("ECCurve in NanoTDFConfig [{}] does not match the curve in KASInfo, using KASInfo curve [{}]", nanoTDFConfig.eccMode.getCurve(), ecCurve);
}
}
return ecCurve;
}
Expand Down
75 changes: 61 additions & 14 deletions sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.connectrpc.ResponseMessage;
import com.connectrpc.UnaryBlockingCall;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.opentdf.platform.policy.KeyAccessServer;
import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient;
import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest;
Expand All @@ -11,6 +13,8 @@

import java.nio.charset.StandardCharsets;

import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand All @@ -23,6 +27,7 @@
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Random;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
Expand All @@ -45,14 +50,27 @@ public class NanoTDFTest {
"oVP7Vpcx\n" +
"-----END PRIVATE KEY-----";

private static final String BASE_PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\n" +
"MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE/NawR/F7RJfX/odyOLPjl+5Ce1Br\n" +
"QZ/MBCIerHe26HzlBSbpa7HQHZx9PYVamHTw9+iJCY3dm8Uwp4Ab2uehnA==\n" +
"-----END PUBLIC KEY-----";

private static final String BASE_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" +
"MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgB3YtAvS7lctHlPsq\n" +
"bZI8OX1B9W1c4GAIxzwKzD6iPkqhRANCAAT81rBH8XtEl9f+h3I4s+OX7kJ7UGtB\n" +
"n8wEIh6sd7bofOUFJulrsdAdnH09hVqYdPD36IkJjd2bxTCngBva56Gc\n" +
"-----END PRIVATE KEY-----" ;

private static final String KID = "r1";
private static final String BASE_KID = "basekid";

protected static KeyAccessServerRegistryServiceClient kasRegistryService;
protected static List<String> registeredKases = List.of(
"https://api.example.com/kas",
"https://other.org/kas2",
"http://localhost:8181/kas",
"https://localhost:8383/kas"
"https://localhost:8383/kas",
"https://api.kaswithbasekey.example.com"
);
protected static String platformUrl = "http://localhost:8080";

Expand All @@ -70,10 +88,16 @@ public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) {

@Override
public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) {
var k2 = kasInfo.clone();
if (Objects.equals(kasInfo.KID, BASE_KID)) {
assertThat(kasInfo.URL).isEqualTo("https://api.kaswithbasekey.example.com");
assertThat(kasInfo.Algorithm).isEqualTo("ec:secp384r1");
k2.PublicKey = BASE_PUBLIC_KEY;
return k2;
}
if (kasInfo.Algorithm != null && !"ec:secp256r1".equals(kasInfo.Algorithm)) {
throw new IllegalArgumentException("Unexpected algorithm: " + kasInfo);
}
var k2 = kasInfo.clone();
k2.KID = KID;
k2.PublicKey = kasPublicKey;
k2.Algorithm = "ec:secp256r1";
Expand All @@ -82,19 +106,14 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve)

@Override
public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy, KeyType sessionKeyType) {
int index = Integer.parseInt(keyAccess.url);
var decryptor = new AsymDecryption(keypairs.get(index).getPrivate());
var bytes = Base64.getDecoder().decode(keyAccess.wrappedKey);
try {
return decryptor.decrypt(bytes);
} catch (Exception e) {
throw new RuntimeException(e);
}
throw new UnsupportedOperationException("no unwrapping ZTDFs here");
}

@Override
public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL) {

String key = Objects.equals(kasURL, "https://api.kaswithbasekey.example.com")
? BASE_PRIVATE_KEY
: kasPrivateKey;
byte[] headerAsBytes = Base64.getDecoder().decode(header);
Header nTDFHeader = new Header(ByteBuffer.wrap(headerAsBytes));
byte[] ephemeralKey = nTDFHeader.getEphemeralKey();
Expand All @@ -103,7 +122,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas

// Generate symmetric key
byte[] symmetricKey = ECKeyPair.computeECDHKey(ECKeyPair.publicKeyFromPem(publicKeyAsPem),
ECKeyPair.privateKeyFromPem(kasPrivateKey));
ECKeyPair.privateKeyFromPem(key));

// Generate HKDF key
MessageDigest digest;
Expand All @@ -113,8 +132,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas
throw new SDKException("error creating SHA-256 message digest", e);
}
byte[] hashOfSalt = digest.digest(NanoTDF.MAGIC_NUMBER_AND_VERSION);
byte[] key = ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey);
return key;
return ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey);
}

@Override
Expand Down Expand Up @@ -203,6 +221,35 @@ void encryptionAndDecryptionWithValidKey() throws Exception {
}
}

@Test
void encryptionAndDecryptWithBaseKey() throws Exception {
var baseKeyJson = "{\"kas_url\":\"https://api.kaswithbasekey.example.com\",\"public_key\":{\"algorithm\":\"ALGORITHM_EC_P256\",\"kid\":\"" + BASE_KID + "\",\"pem\": \"" + BASE_PUBLIC_KEY + "\"}}";
var val = Value.newBuilder().setStringValue(baseKeyJson).build();
var config = Struct.newBuilder().putFields("base_key", val).build();
WellKnownServiceClientInterface wellknown = mock(WellKnownServiceClientInterface.class);
GetWellKnownConfigurationResponse response = GetWellKnownConfigurationResponse.newBuilder().setConfiguration(config).build();
when(wellknown.getWellKnownConfigurationBlocking(any(), any())).thenReturn(TestUtil.successfulUnaryCall(response));
Config.NanoTDFConfig nanoConfig = Config.newNanoTDFConfig(
Config.witDataAttributes("https://example.com/attr/Classification/value/S",
"https://example.com/attr/Classification/value/X")
);

String plainText = "Virtru!!";
ByteBuffer byteBuffer = ByteBuffer.wrap(plainText.getBytes());
ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream();
NanoTDF nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).setWellknownService(wellknown).build());
nanoTDF.createNanoTDF(byteBuffer, tdfOutputStream, nanoConfig);

byte[] nanoTDFBytes = tdfOutputStream.toByteArray();
ByteArrayOutputStream plainTextStream = new ByteArrayOutputStream();
nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).build());
nanoTDF.readNanoTDF(ByteBuffer.wrap(nanoTDFBytes), plainTextStream, platformUrl);
String out = new String(plainTextStream.toByteArray(), StandardCharsets.UTF_8);
assertThat(out).isEqualTo(plainText);
// KAS KID
assertThat(new String(nanoTDFBytes, StandardCharsets.UTF_8)).contains(BASE_KID);
}

@Test
void testWithDifferentConfigAndKeyValues() throws Exception {
var kasInfos = new ArrayList<>();
Expand Down