diff --git a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java
index 0f0ab735..f9479577 100644
--- a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java
+++ b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java
@@ -7,7 +7,6 @@
import io.opentdf.platform.sdk.*;
import java.util.Collections;
-import java.util.concurrent.ExecutionException;
import java.util.List;
diff --git a/sdk/pom.xml b/sdk/pom.xml
index 795d34ba..a17f99ef 100644
--- a/sdk/pom.xml
+++ b/sdk/pom.xml
@@ -15,7 +15,7 @@
2.1.0
0.7.2
4.12.0
- protocol/go/v0.3.0
+ protocol/go/v0.5.0
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java b/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java
index 71445f69..96067c07 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java
@@ -152,10 +152,9 @@ public byte[] encrypt(byte[] iv, int authTagLen, byte[] plaintext, int offset, i
System.arraycopy(iv, 0, cipherTextWithNonce, 0, iv.length);
System.arraycopy(cipherText, 0, cipherTextWithNonce, iv.length, cipherText.length);
return cipherTextWithNonce;
- } catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidAlgorithmParameterException e) {
- throw new RuntimeException("error gcm decrypt", e);
- } catch (InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
- throw new RuntimeException("error gcm decrypt", e);
+ } catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidAlgorithmParameterException |
+ InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
+ throw new RuntimeException("error gcm encrypt", e);
}
}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java
index b60d0725..8a50cc03 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java
@@ -1,12 +1,13 @@
package io.opentdf.platform.sdk;
import com.connectrpc.ResponseMessageKt;
+import io.opentdf.platform.policy.Algorithm;
import io.opentdf.platform.policy.Attribute;
import io.opentdf.platform.policy.AttributeRuleTypeEnum;
import io.opentdf.platform.policy.AttributeValueSelector;
-import io.opentdf.platform.policy.KasPublicKey;
import io.opentdf.platform.policy.KasPublicKeyAlgEnum;
import io.opentdf.platform.policy.KeyAccessServer;
+import io.opentdf.platform.policy.SimpleKasKey;
import io.opentdf.platform.policy.Value;
import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest;
@@ -14,6 +15,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
@@ -27,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;
import java.util.function.Supplier;
@@ -53,47 +56,53 @@ class RuleType {
* This class includes functionality to create granter instances based on
* attributes either from a list of attribute values or from a service.
*/
-public class Autoconfigure {
+class Autoconfigure {
- public static Logger logger = LoggerFactory.getLogger(Autoconfigure.class);
+ private static Logger logger = LoggerFactory.getLogger(Autoconfigure.class);
- public static class KeySplitStep {
- public String kas;
- public String splitID;
+ private Autoconfigure() {
+ // Prevent instantiation, this class is a utility class that is only used statically
+ }
+
+ static class KeySplitStep {
+ final String kas;
+ final String splitID;
+ final String kid;
- public KeySplitStep(String kas, String splitId) {
- this.kas = kas;
- this.splitID = splitId;
+ KeySplitStep(String kas, String splitId) {
+ this(kas, splitId, null);
}
- @Override
- public String toString() {
- return "KeySplitStep{kas=" + this.kas + ", splitID=" + this.splitID + "}";
+ KeySplitStep(String kas, String splitId, @Nullable String kid) {
+ this.kas = Objects.requireNonNull(kas);
+ this.splitID = Objects.requireNonNull(splitId);
+ this.kid = kid;
}
@Override
- public boolean equals(Object obj) {
- if (this == obj) {
- return true;
- }
- if (obj == null || !(obj instanceof KeySplitStep)) {
- return false;
- }
- KeySplitStep ss = (KeySplitStep) obj;
- if ((this.kas.equals(ss.kas)) && (this.splitID.equals(ss.splitID))) {
- return true;
- }
- return false;
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ KeySplitStep that = (KeySplitStep) o;
+ return Objects.equals(kas, that.kas) && Objects.equals(splitID, that.splitID) && Objects.equals(kid, that.kid);
}
@Override
public int hashCode() {
- return Objects.hash(kas, splitID);
+ return Objects.hash(kas, splitID, kid);
+ }
+
+ @Override
+ public String toString() {
+ return "KeySplitStep{" +
+ "kas='" + kas + '\'' +
+ ", splitID='" + splitID + '\'' +
+ ", kid='" + kid + '\'' +
+ '}';
}
}
// Utility class for an attribute name FQN.
- public static class AttributeNameFQN {
+ static class AttributeNameFQN {
private final String url;
private final String key;
@@ -156,7 +165,7 @@ public String name() throws AutoConfigureException {
}
// Utility class for an attribute value FQN.
- public static class AttributeValueFQN {
+ static class AttributeValueFQN {
private final String url;
private final String key;
@@ -203,11 +212,11 @@ public int hashCode() {
return Objects.hash(key);
}
- public String getKey() {
+ String getKey() {
return key;
}
- public String authority() {
+ String authority() {
Pattern pattern = Pattern.compile("^(https?://[\\w./-]+)/attr/\\S*/value/\\S*$");
Matcher matcher = pattern.matcher(url);
if (!matcher.find()) {
@@ -216,7 +225,7 @@ public String authority() {
return matcher.group(1);
}
- public AttributeNameFQN prefix() throws AutoConfigureException {
+ AttributeNameFQN prefix() throws AutoConfigureException {
Pattern pattern = Pattern.compile("^(https?://[\\w./-]+/attr/\\S*)/value/\\S*$");
Matcher matcher = pattern.matcher(url);
if (!matcher.find()) {
@@ -225,7 +234,7 @@ public AttributeNameFQN prefix() throws AutoConfigureException {
return new AttributeNameFQN(matcher.group(1));
}
- public String value() {
+ String value() {
Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/\\S*/value/(\\S*)$");
Matcher matcher = pattern.matcher(url);
if (!matcher.find()) {
@@ -238,7 +247,7 @@ public String value() {
}
}
- public String name() {
+ String name() {
Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/(\\S*)/value/\\S*$");
Matcher matcher = pattern.matcher(url);
if (!matcher.find()) {
@@ -246,13 +255,13 @@ public String name() {
}
try {
return URLDecoder.decode(matcher.group(1), StandardCharsets.UTF_8.name());
- } catch (UnsupportedEncodingException | IllegalArgumentException e) {
- throw new RuntimeException("invalid attributeInstance", e);
+ } catch (UnsupportedEncodingException | IllegalArgumentException e) {
+ throw new RuntimeException("illegal attribute instance", e);
}
}
}
- public static class KeyAccessGrant {
+ static class KeyAccessGrant {
public Attribute attr;
public List kases;
@@ -266,40 +275,103 @@ public KeyAccessGrant(Attribute attr, List kases) {
static class Granter {
private final List policy;
private final Map grants = new HashMap<>();
+ private final Map> mappedKeys = new HashMap<>();
+ private boolean hasGrants = false;
+ private boolean hasMappedKeys = false;
- public Granter(List policy) {
+ Granter(List policy) {
this.policy = policy;
}
- public Map getGrants() {
- return new HashMap(grants);
+ Map getGrants() {
+ return new HashMap<>(grants);
}
- public List getPolicy() {
+ List getPolicy() {
return policy;
}
- public void addGrant(AttributeValueFQN fqn, String kas, Attribute attr) {
- grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(kas);
- }
+ boolean addAllGrants(AttributeValueFQN fqn, List granted, List mapped, Attribute attr) {
+ boolean foundMappedKey = false;
+ for (var mappedKey: mapped) {
+ foundMappedKey = true;
+ mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(mappedKey));
+ grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(mappedKey.getKasUri());
+ }
- public void addAllGrants(AttributeValueFQN fqn, List gs, Attribute attr) {
- if (gs.isEmpty()) {
- grants.putIfAbsent(fqn.key, new KeyAccessGrant(attr, new ArrayList<>()));
- } else {
- for (KeyAccessServer g : gs) {
- if (g != null) {
- addGrant(fqn, g.getUri(), attr);
+ if (foundMappedKey) {
+ hasMappedKeys = true;
+ return true;
+ }
+
+ boolean foundGrantedKey = false;
+ for (var grantedKey: granted) {
+ foundGrantedKey = true;
+ grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(grantedKey.getUri());
+ if (!grantedKey.getKasKeysList().isEmpty()) {
+ for (var kas : grantedKey.getKasKeysList()) {
+ mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(kas));
}
+ continue;
}
+ var cachedGrantKeys = grantedKey.getPublicKey().getCached().getKeysList();
+
+ if (logger.isDebugEnabled()) {
+ logger.debug("found {} keys cached in policy service", cachedGrantKeys.size());
+ }
+
+ for (var cachedGrantKey: cachedGrantKeys) {
+ var mappedKey = new Config.KASInfo();
+ mappedKey.URL = grantedKey.getUri();
+ mappedKey.KID = cachedGrantKey.getKid();
+ mappedKey.Algorithm = Autoconfigure.algProto2String(cachedGrantKey.getAlg());
+ mappedKey.PublicKey = cachedGrantKey.getPem();
+ mappedKey.Default = false;
+ mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(mappedKey);
+ }
+ }
+
+ if (!grants.containsKey(fqn.key)) {
+ grants.put(fqn.key, new KeyAccessGrant(attr, new ArrayList<>()));
+ }
+
+ if (foundGrantedKey) {
+ hasGrants = true;
}
+ return foundGrantedKey;
}
- public KeyAccessGrant byAttribute(AttributeValueFQN fqn) {
+ KeyAccessGrant byAttribute(AttributeValueFQN fqn) {
return grants.get(fqn.key);
}
- public List plan(List defaultKas, Supplier genSplitID)
+ List getSplits(List defaultKases, Supplier genSplitID, Supplier> baseKeySupplier) throws AutoConfigureException {
+ if (hasMappedKeys) {
+ logger.debug("generating plan from mapped keys");
+ return planFromAttributes(genSplitID);
+ }
+ if (hasGrants) {
+ logger.debug("generating plan from grants");
+ return plan(genSplitID);
+ }
+
+ var baseKey = baseKeySupplier.get();
+ if (baseKey.isPresent()) {
+ var key = baseKey.get();
+ String kas = key.getKasUri();
+ String splitID = "";
+ String kid = key.getPublicKey().getKid();
+ return Collections.singletonList(new KeySplitStep(kas, splitID, kid));
+ }
+
+ logger.warn("no grants or mapped keys found, generating plan from default KASes. this is deprecated");
+ // this is a little bit weird because we don't take into account the KIDs here. This is the way
+ // that it works in the go SDK but it seems a bit odd
+ return generatePlanFromDefaultKases(defaultKases, genSplitID);
+ }
+
+ @Nonnull
+ List plan(Supplier genSplitID)
throws AutoConfigureException {
AttributeBooleanExpression b = constructAttributeBoolean();
BooleanKeyExpression k = insertKeysForAttribute(b);
@@ -310,18 +382,7 @@ public List plan(List defaultKas, Supplier genSpli
k = k.reduce();
int l = k.size();
if (l == 0) {
- // default behavior: split key across all default KAS
- if (defaultKas.isEmpty()) {
- throw new AutoConfigureException("no default KAS specified; required for grantless plans");
- } else if (defaultKas.size() == 1) {
- return Collections.singletonList(new KeySplitStep(defaultKas.get(0), ""));
- } else {
- List result = new ArrayList<>();
- for (String kas : defaultKas) {
- result.add(new KeySplitStep(kas, genSplitID.get()));
- }
- return result;
- }
+ throw new AutoConfigureException("generated an empty plan");
}
List steps = new ArrayList<>();
@@ -334,6 +395,45 @@ public List plan(List defaultKas, Supplier genSpli
return steps;
}
+ @Nonnull
+ List planFromAttributes(Supplier genSplitID)
+ throws AutoConfigureException {
+ AttributeBooleanExpression b = constructAttributeBoolean();
+ BooleanKeyExpression k = assignKeysTo(b);
+ if (k == null) {
+ throw new AutoConfigureException("Error assigning keys to attribute");
+ }
+
+ k = k.reduce();
+ int l = k.size();
+ if (l == 0) {
+ return Collections.emptyList();
+ }
+
+ List steps = new ArrayList<>();
+ for (KeyClause v : k.values) {
+ String splitID = (l > 1) ? genSplitID.get() : "";
+ for (PublicKeyInfo o : v.values) {
+ steps.add(new KeySplitStep(o.kas, splitID, o.kid));
+ }
+ }
+ return steps;
+ }
+
+ static List generatePlanFromDefaultKases(List defaultKas, Supplier genSplitID) {
+ if (defaultKas.isEmpty()) {
+ throw new AutoConfigureException("no default KAS specified; required for grantless plans");
+ } else if (defaultKas.size() == 1) {
+ return Collections.singletonList(new KeySplitStep(defaultKas.get(0), ""));
+ } else {
+ List result = new ArrayList<>();
+ for (String kas : defaultKas) {
+ result.add(new KeySplitStep(kas, genSplitID.get()));
+ }
+ return result;
+ }
+ }
+
BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws AutoConfigureException {
List kcs = new ArrayList<>(e.must.size());
@@ -361,13 +461,53 @@ BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws
logger.warn("Unknown attribute rule type: " + clause);
}
- KeyClause kc = new KeyClause(op, kcv);
- kcs.add(kc);
+ kcs.add(new KeyClause(op, kcv));
}
return new BooleanKeyExpression(kcs);
}
+ BooleanKeyExpression assignKeysTo(AttributeBooleanExpression e) {
+ var keyClauses = new ArrayList();
+ for (var clause : e.must) {
+ ArrayList keys = new ArrayList<>();
+ if (clause.values.isEmpty()) {
+ logger.warn("No values found for attribute {}", clause.def.getFqn());
+ continue;
+ }
+ for (var value : clause.values) {
+ var mapped = mappedKeys.get(value.key);
+ if (mapped == null) {
+ logger.warn("No keys found for attribute value {}", value);
+ continue;
+ }
+ for (var kasInfo : mapped) {
+ if (kasInfo.URL == null || kasInfo.URL.isEmpty()) {
+ logger.warn("No KAS URL found for attribute value {}", value);
+ continue;
+ }
+ keys.add(new PublicKeyInfo(kasInfo.URL, kasInfo.KID));
+ }
+ }
+
+ String op = ruleToOperator(clause.def.getRule());
+ if (op.equals(RuleType.UNSPECIFIED)) {
+ logger.warn("Unknown attribute rule type {}", op);
+ }
+
+ keyClauses.add(new KeyClause(op, keys));
+ }
+
+ return new BooleanKeyExpression(keyClauses);
+ }
+
+ /**
+ * Constructs an AttributeBooleanExpression from the policy, splitting each attribute
+ * into its own clause. Each clause contains the attribute definition and a list of
+ * values.
+ * @return
+ * @throws AutoConfigureException
+ */
AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureException {
Map prefixes = new HashMap<>();
List sortedPrefixes = new ArrayList<>();
@@ -378,7 +518,7 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep
clause.values.add(aP);
} else if (byAttribute(aP) != null) {
var x = new SingleAttributeClause(byAttribute(aP).attr,
- new ArrayList(Arrays.asList(aP)));
+ new ArrayList<>(Arrays.asList(aP)));
prefixes.put(a.getKey(), x);
sortedPrefixes.add(a.getKey());
}
@@ -391,39 +531,6 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep
return new AttributeBooleanExpression(must);
}
- static class AttributeMapping {
-
- private Map dict;
-
- public AttributeMapping() {
- this.dict = new HashMap<>();
- }
-
- public void put(Attribute ad) throws AutoConfigureException {
- if (this.dict == null) {
- this.dict = new HashMap<>();
- }
-
- AttributeNameFQN prefix = new AttributeNameFQN(ad.getFqn());
-
- if (this.dict.containsKey(prefix)) {
- throw new AutoConfigureException("Attribute prefix already found: [" + prefix.toString() + "]");
- }
-
- this.dict.put(prefix, ad);
- }
-
- public Attribute get(AttributeNameFQN prefix) throws AutoConfigureException {
- Attribute ad = this.dict.get(prefix);
- if (ad == null) {
- throw new AutoConfigureException("Unknown attribute type: [" + prefix.toString() + "], not in ["
- + this.dict.keySet().toString() + "]");
- }
- return ad;
- }
-
- }
-
static class SingleAttributeClause {
private Attribute def;
@@ -435,9 +542,9 @@ public SingleAttributeClause(Attribute def, List values) {
}
}
- class AttributeBooleanExpression {
+ static class AttributeBooleanExpression {
- private List must;
+ private final List must;
public AttributeBooleanExpression(List must) {
this.must = must;
@@ -478,25 +585,65 @@ public String toString() {
}
- public class PublicKeyInfo {
- private String kas;
+ static class PublicKeyInfo implements Comparable {
+ final String kas;
+ final String kid;
+
+ PublicKeyInfo(String kas) {
+ this(kas, null);
+ }
- public PublicKeyInfo(String kas) {
+ PublicKeyInfo(String kas, String kid) {
this.kas = kas;
+ this.kid = kid;
}
- public String getKas() {
+ String getKas() {
return kas;
}
- public void setKas(String kas) {
- this.kas = kas;
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ PublicKeyInfo that = (PublicKeyInfo) o;
+ return Objects.equals(kas, that.kas) && Objects.equals(kid, that.kid);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(kas, kid);
+ }
+
+ @Override
+ public String toString() {
+ return "PublicKeyInfo{" +
+ "kas='" + kas + '\'' +
+ ", kid='" + kid + '\'' +
+ '}';
+ }
+
+ @Override
+ public int compareTo(PublicKeyInfo o) {
+ if (this.kas.equals(o.kas)) {
+ if (this.kid == null && o.kid == null) {
+ return 0;
+ }
+ if (this.kid == null) {
+ return -1;
+ }
+ if (o.kid == null) {
+ return 1;
+ }
+ return this.kid.compareTo(o.kid);
+ } else {
+ return this.kas.compareTo(o.kas);
+ }
}
}
- public class KeyClause {
- private String operator;
- private List values;
+ static class KeyClause {
+ private final String operator;
+ private final List values;
public KeyClause(String operator, List values) {
this.operator = operator;
@@ -531,8 +678,8 @@ public String toString() {
}
}
- public class BooleanKeyExpression {
- private List values;
+ static class BooleanKeyExpression {
+ private final List values;
public BooleanKeyExpression(List values) {
this.values = values;
@@ -572,7 +719,7 @@ public BooleanKeyExpression reduce() {
continue;
}
Disjunction terms = new Disjunction();
- terms.add(k.getKas());
+ terms.add(k);
if (!within(conjunction, terms)) {
conjunction.add(terms);
}
@@ -584,25 +731,22 @@ public BooleanKeyExpression reduce() {
}
List newValues = new ArrayList<>();
- for (List d : conjunction) {
+ for (List d : conjunction) {
List pki = new ArrayList<>();
- for (String k : d) {
- pki.add(new PublicKeyInfo(k));
- }
+ pki.addAll(d);
newValues.add(new KeyClause(RuleType.ANY_OF, pki));
}
return new BooleanKeyExpression(newValues);
}
public Disjunction sortedNoDupes(List l) {
- Set set = new HashSet<>();
+ Set set = new HashSet<>();
Disjunction list = new Disjunction();
for (PublicKeyInfo e : l) {
- String kas = e.getKas();
- if (!kas.equals(RuleType.EMPTY_TERM) && !set.contains(kas)) {
- set.add(kas);
- list.add(kas);
+ if (!Objects.equals(e.getKas(), RuleType.EMPTY_TERM) && !set.contains(e)) {
+ set.add(e);
+ list.add(e);
}
}
@@ -612,7 +756,7 @@ public Disjunction sortedNoDupes(List l) {
}
- class Disjunction extends ArrayList {
+ static class Disjunction extends ArrayList {
public boolean less(Disjunction r) {
int m = Math.min(this.size(), r.size());
@@ -683,7 +827,7 @@ public static String ruleToOperator(AttributeRuleTypeEnum e) {
// Given a policy (list of data attributes or tags),
// get a set of grants from attribute values to KASes.
// Unlike `NewGranterFromService`, this works offline.
- public static Granter newGranterFromAttributes(Value... attrValues) throws AutoConfigureException {
+ static Granter newGranterFromAttributes(KASKeyCache keyCache, Value... attrValues) throws AutoConfigureException {
var attrsAndValues = Arrays.stream(attrValues).map(v -> {
if (!v.hasAttribute()) {
throw new AutoConfigureException("tried to use an attribute that is not initialized");
@@ -694,11 +838,11 @@ public static Granter newGranterFromAttributes(Value... attrValues) throws AutoC
.build();
}).collect(Collectors.toList());
- return getGranter(null, attrsAndValues);
+ return getGranter(keyCache, attrsAndValues);
}
// Gets a list of directory of KAS grants for a list of attribute FQNs
- public static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException {
+ static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException {
GetAttributeValuesByFqnsRequest request = GetAttributeValuesByFqnsRequest.newBuilder()
.addAllFqns(Arrays.stream(fqns).map(AttributeValueFQN::toString).collect(Collectors.toList()))
.setWithValue(AttributeValueSelector.newBuilder().setWithKeyAccessGrants(true).build())
@@ -711,73 +855,65 @@ public static Granter newGranterFromService(AttributesServiceClientInterface as,
return getGranter(keyCache, new ArrayList<>(av.getFqnAttributeValuesMap().values()));
}
- private static List getGrants(GetAttributeValuesByFqnsResponse.AttributeAndValue attributeAndValue) {
- var val = attributeAndValue.getValue();
- var attribute = attributeAndValue.getAttribute();
- if (!val.getGrantsList().isEmpty()) {
- if (logger.isDebugEnabled()) {
- logger.debug("adding grants from attribute value [{}]: {}", val.getFqn(), val.getGrantsList().stream().map(KeyAccessServer::getUri).collect(Collectors.toList()));
- }
- return val.getGrantsList();
- } else if (!attribute.getGrantsList().isEmpty()) {
- var attributeGrants = attribute.getGrantsList();
- if (logger.isDebugEnabled()) {
- logger.debug("adding grants from attribute [{}]: {}", attribute.getFqn(), attributeGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList()));
- }
- return attributeGrants;
- } else if (!attribute.getNamespace().getGrantsList().isEmpty()) {
- var nsGrants = attribute.getNamespace().getGrantsList();
- if (logger.isDebugEnabled()) {
- logger.debug("adding grants from namespace [{}]: [{}]", attribute.getNamespace().getName(), nsGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList()));
- }
- return nsGrants;
- } else {
- // this is needed to mark the fact that we have an empty
- if (logger.isDebugEnabled()) {
- logger.debug("didn't find any grants on value, attribute, or namespace for attribute value [{}]", val.getFqn());
- }
- return Collections.emptyList();
+ static Autoconfigure.Granter createGranter(SDK.Services services, Config.TDFConfig tdfConfig) {
+ Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>());
+ if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) {
+ granter = Autoconfigure.newGranterFromAttributes(services.kas().getKeyCache(), tdfConfig.attributeValues.toArray(new Value[0]));
+ } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) {
+ granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(),
+ tdfConfig.attributes.toArray(new Autoconfigure.AttributeValueFQN[0]));
}
-
+ return granter;
}
private static Granter getGranter(@Nullable KASKeyCache keyCache, List values) {
- Granter grants = new Granter(values.stream().map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue).map(Value::getFqn).map(AttributeValueFQN::new).collect(Collectors.toList()));
+ List attributeValues = values.stream()
+ .map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue)
+ .map(Value::getFqn)
+ .map(AttributeValueFQN::new)
+ .collect(Collectors.toList());
+ Granter grants = new Granter(attributeValues);
for (var attributeAndValue: values) {
- var attributeGrants = getGrants(attributeAndValue);
String fqnstr = attributeAndValue.getValue().getFqn();
AttributeValueFQN fqn = new AttributeValueFQN(fqnstr);
- grants.addAllGrants(fqn, attributeGrants, attributeAndValue.getAttribute());
- if (keyCache != null) {
- storeKeysToCache(attributeGrants, keyCache);
+
+ var value = attributeAndValue.getValue();
+ var attribute = attributeAndValue.getAttribute();
+ var namespace = attribute.getNamespace();
+
+ if (grants.addAllGrants(fqn, value.getGrantsList(), value.getKasKeysList(), attribute)) {
+ storeKeysToCache(value.getGrantsList(), value.getKasKeysList(), keyCache);
+ continue;
+ }
+ if (grants.addAllGrants(fqn, attribute.getGrantsList(), attribute.getKasKeysList(), attribute)) {
+ storeKeysToCache(attribute.getGrantsList(), attribute.getKasKeysList(), keyCache);
+ continue;
+ }
+ if (grants.addAllGrants(fqn, namespace.getGrantsList(), namespace.getKasKeysList(), attribute)) {
+ storeKeysToCache(namespace.getGrantsList(), namespace.getKasKeysList(), keyCache);
}
}
return grants;
}
-
- static void storeKeysToCache(List kases, KASKeyCache keyCache) {
- for (KeyAccessServer kas : kases) {
- List keys = kas.getPublicKey().getCached().getKeysList();
- if (keys.isEmpty()) {
- logger.debug("No cached key in policy service for KAS: " + kas.getUri());
- continue;
- }
- for (KasPublicKey ki : keys) {
- Config.KASInfo kasInfo = new Config.KASInfo();
- kasInfo.URL = kas.getUri();
- kasInfo.KID = ki.getKid();
- kasInfo.Algorithm = algProto2String(ki.getAlg());
- kasInfo.PublicKey = ki.getPem();
- keyCache.store(kasInfo);
- }
+ static void storeKeysToCache(List kases, List kasKeys, KASKeyCache keyCache) {
+ if (keyCache == null) {
+ return;
+ }
+ for (var kas : kases) {
+ Config.KASInfo.fromKeyAccessServer(kas).forEach(keyCache::store);
}
+ kasKeys.stream().map(Config.KASInfo::fromSimpleKasKey).forEach(keyCache::store);
+ }
+
+ static String algProto2String(Algorithm e) {
+ return KeyType.fromAlgorithm(e).getCurveName();
}
- private static String algProto2String(KasPublicKeyAlgEnum e) {
+ static String algProto2String(KasPublicKeyAlgEnum e) {
switch (e) {
case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1:
return "ec:secp256r1";
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java
index 0a1f4a46..34e4fa27 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java
@@ -1,13 +1,20 @@
package io.opentdf.platform.sdk;
+import io.opentdf.platform.policy.KeyAccessServer;
+import io.opentdf.platform.policy.SimpleKasKey;
import io.opentdf.platform.policy.Value;
import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static io.opentdf.platform.sdk.Autoconfigure.algProto2String;
/**
* Configuration class for setting various configurations related to TDF.
@@ -22,6 +29,7 @@ public class Config {
public static final String KAS_PUBLIC_KEY_PATH = "/kas_public_key";
public static final String DEFAULT_MIME_TYPE = "application/octet-stream";
public static final int MAX_COLLECTION_ITERATION = (1 << 24) - 1;
+ private static Logger logger = LoggerFactory.getLogger(Config.class);
public enum TDFFormat {
JSONFormat,
@@ -33,8 +41,6 @@ public enum IntegrityAlgorithm {
GMAC
}
- public static final int K_HTTP_OK = 200;
-
public static class KASInfo implements Cloneable {
public String URL;
public String PublicKey;
@@ -71,6 +77,36 @@ public String toString() {
}
return sb.append("}").toString();
}
+
+ public static List fromKeyAccessServer(KeyAccessServer kas) {
+ var keys = kas.getPublicKey().getCached().getKeysList();
+ if (keys.isEmpty()) {
+ logger.warn("Invalid KAS key mapping for kas [{}]: publicKey is empty", kas.getUri());
+ return Collections.emptyList();
+ }
+ return keys.stream().flatMap(ki -> {
+ if (ki.getPem().isEmpty()) {
+ logger.warn("Invalid KAS key mapping for kas [{}]: publicKey PEM is empty", kas.getUri());
+ return Stream.empty();
+ }
+ Config.KASInfo kasInfo = new Config.KASInfo();
+ kasInfo.URL = kas.getUri();
+ kasInfo.KID = ki.getKid();
+ kasInfo.Algorithm = algProto2String(ki.getAlg());
+ kasInfo.PublicKey = ki.getPem();
+ return Stream.of(kasInfo);
+ }).collect(Collectors.toList());
+ }
+
+ public static KASInfo fromSimpleKasKey(SimpleKasKey ki) {
+ Config.KASInfo kasInfo = new Config.KASInfo();
+ kasInfo.URL = ki.getKasUri();
+ kasInfo.KID = ki.getPublicKey().getKid();
+ kasInfo.Algorithm = algProto2String(ki.getPublicKey().getAlgorithm());
+ kasInfo.PublicKey = ki.getPublicKey().getPem();
+
+ return kasInfo;
+ }
}
public static class AssertionVerificationKeys {
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java b/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java
index ce000d54..bcb8b208 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java
@@ -1,5 +1,7 @@
package io.opentdf.platform.sdk;
+import java.util.Objects;
+
public class ECCMode {
private ECCModeStruct data;
@@ -115,6 +117,18 @@ public static int getECCompressedPubKeySize(NanoTDFType.ECCurve curve) {
}
}
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ ECCMode eccMode = (ECCMode) o;
+ return Objects.equals(data, eccMode.data);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(data);
+ }
+
private class ECCModeStruct {
int curveMode;
int unused;
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java b/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java
index 53095bd7..88680efc 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java
@@ -197,18 +197,12 @@ public static byte[] compressECPublickey(String pemECPubKey) {
PemObject pemObject = pemReader.readPemObject();
PublicKey pubKey = ecKeyFac.generatePublic(new X509EncodedKeySpec(pemObject.getContent()));
return ((ECPublicKey) pubKey).getQ().getEncoded(true);
- } catch (NoSuchAlgorithmException e) {
- throw new RuntimeException(e);
- } catch (IOException e) {
- throw new RuntimeException(e);
- } catch (InvalidKeySpecException e) {
- throw new RuntimeException(e);
- } catch (NoSuchProviderException e) {
+ } catch (NoSuchAlgorithmException | IOException | InvalidKeySpecException | NoSuchProviderException e) {
throw new RuntimeException(e);
}
}
- public static String publicKeyFromECPoint(byte[] ecPoint, String curveName) {
+ public static String publicKeyFromECPoint(byte[] ecPoint, String curveName) throws RuntimeException {
try {
// Create EC Public key
ECNamedCurveParameterSpec ecSpec = ECNamedCurveTable.getParameterSpec(curveName);
@@ -274,7 +268,7 @@ public static byte[] computeECDHKey(ECPublicKey publicKey, ECPrivateKey privateK
}
public static byte[] calculateHKDF(byte[] salt, byte[] secret) {
- byte[] key = new byte[secret.length];
+ byte[] key = new byte[32];
HKDFParameters params = new HKDFParameters(secret, salt, null);
HKDFBytesGenerator hkdf = new HKDFBytesGenerator(SHA256Digest.newInstance());
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
index bafc3f2c..bfa1ffdd 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
@@ -85,7 +85,7 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve)
@Override
public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) {
- Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm);
+ Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID);
if (cachedValue != null) {
return cachedValue;
}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java
index 5879dd05..75bae93c 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java
@@ -7,6 +7,7 @@
import java.time.temporal.ChronoUnit;
import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
/**
* Class representing a cache for KAS (Key Access Server) information.
@@ -24,14 +25,14 @@ public void clear() {
this.cache = new HashMap<>();
}
- public Config.KASInfo get(String url, String algorithm) {
- log.debug("retrieving kasinfo for url = [{}], algorithm = [{}]", url, algorithm);
- KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm);
+ public Config.KASInfo get(String url, String algorithm, String kid) {
+ log.debug("retrieving kasinfo for url = [{}], algorithm = [{}], kid = [{}]", url, algorithm, kid);
+ KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm, kid);
LocalDateTime now = LocalDateTime.now();
TimeStampedKASInfo cachedValue = cache.get(cacheKey);
if (cachedValue == null) {
- log.debug("didn't find kasinfo for url = [{}], algorithm = [{}]", url, algorithm);
+ log.debug("didn't find kasinfo for key= [{}]", cacheKey);
return null;
}
@@ -49,7 +50,7 @@ public Config.KASInfo get(String url, String algorithm) {
public void store(Config.KASInfo kasInfo) {
log.debug("storing kasInfo into the cache {}", kasInfo);
- KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm);
+ KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID);
cache.put(cacheKey, new TimeStampedKASInfo(kasInfo, LocalDateTime.now()));
}
}
@@ -85,30 +86,34 @@ public TimeStampedKASInfo(Config.KASInfo kasInfo, LocalDateTime timestamp) {
class KASKeyRequest {
private String url;
private String algorithm;
+ private String kid;
- public KASKeyRequest(String url, String algorithm) {
+ public KASKeyRequest(String url, String algorithm, String kid) {
this.url = url;
this.algorithm = algorithm;
+ this.kid = kid;
}
- // Override equals and hashCode to ensure proper functioning of the HashMap
@Override
public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || !(o instanceof KASKeyRequest)) return false;
+ if (o == null || getClass() != o.getClass()) return false;
KASKeyRequest that = (KASKeyRequest) o;
- if (algorithm == null){
- return url.equals(that.url);
- }
- return url.equals(that.url) && algorithm.equals(that.algorithm);
+ return Objects.equals(url, that.url) && Objects.equals(algorithm, that.algorithm) && Objects.equals(kid, that.kid);
}
@Override
public int hashCode() {
- int result = 31 * url.hashCode();
- if (algorithm != null) {
- result = result + algorithm.hashCode();
- }
- return result;
+ return Objects.hash(url, algorithm, kid);
+ }
+
+ @Override
+ public String toString() {
+ return "KASKeyRequest{" +
+ "url='" + url + '\'' +
+ ", algorithm='" + algorithm + '\'' +
+ ", kid='" + kid + '\'' +
+ '}';
}
+
+ // Override equals and hashCode to ensure proper functioning of the HashMap
}
\ No newline at end of file
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
index 9c5cf010..3b49668b 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
@@ -1,5 +1,9 @@
package io.opentdf.platform.sdk;
+import io.opentdf.platform.policy.Algorithm;
+
+import java.util.Set;
+
public enum KeyType {
RSA2048Key("rsa:2048"),
EC256Key("ec:secp256r1"),
@@ -39,7 +43,24 @@ public static KeyType fromString(String keyType) {
throw new IllegalArgumentException("No enum constant for key type: " + keyType);
}
+ public static KeyType fromAlgorithm(Algorithm a) {
+ switch (a) {
+ case ALGORITHM_RSA_2048:
+ return RSA2048Key;
+ case ALGORITHM_EC_P256:
+ return EC256Key;
+ case ALGORITHM_EC_P384:
+ return EC384Key;
+ case ALGORITHM_EC_P521:
+ return EC521Key;
+ default:
+ throw new IllegalArgumentException("Unsupported algorithm: " + a);
+ }
+ }
+
+ private static final Set EC_KEY_TYPES = Set.of(EC256Key, EC384Key, EC521Key);
+
public boolean isEc() {
- return this != RSA2048Key;
+ return EC_KEY_TYPES.contains(this);
}
}
\ No newline at end of file
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java
index e0b41549..9b798962 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java
@@ -16,6 +16,8 @@
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
+import org.bouncycastle.jcajce.provider.asymmetric.ec.KeyFactorySpi;
import org.bouncycastle.jce.interfaces.ECPublicKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -65,6 +67,17 @@ public InvalidNanoTDFConfig(String errorMessage) {
}
}
+ private static Optional getBaseKey(WellKnownServiceClientInterface wellKnownService) {
+ var key = Planner.fetchBaseKey(wellKnownService);
+ key.ifPresent(k -> {
+ if (!KeyType.fromAlgorithm(k.getPublicKey().getAlgorithm()).isEc()) {
+ throw new SDKException(String.format("base key is not an EC key, cannot create NanoTDF using a key of type %s",
+ k.getPublicKey().getAlgorithm()));
+ }
+ });
+ return key.map(Config.KASInfo::fromSimpleKasKey);
+ }
+
private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) throws InvalidNanoTDFConfig, UnsupportedNanoTDFFeature {
if (nanoTDFConfig.collectionConfig.useCollection) {
Config.HeaderInfo headerInfo = nanoTDFConfig.collectionConfig.getHeaderInfo();
@@ -74,22 +87,19 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro
}
Gson gson = new GsonBuilder().create();
- if (nanoTDFConfig.kasInfoList.isEmpty()) {
- throw new InvalidNanoTDFConfig("kas url is missing");
- }
+ Optional maybeKas = getKasInfo(nanoTDFConfig).or(() -> NanoTDF.getBaseKey(services.wellknown()));
- Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0);
- String url = kasInfo.URL;
- if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) {
- logger.info("no public key provided for KAS at {}, retrieving", url);
- kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getEllipticCurveType());
+ if (maybeKas.isEmpty()) {
+ throw new SDKException("no KAS info provided and couldn't get base key, cannot create NanoTDF");
}
+ var kasInfo = maybeKas.get();
+
// Kas url resource locator
- ResourceLocator kasURL = new ResourceLocator(nanoTDFConfig.kasInfoList.get(0).URL, kasInfo.KID);
+ ResourceLocator kasURL = new ResourceLocator(kasInfo.URL, kasInfo.KID);
assert kasURL.getIdentifier() != null : "Identifier in ResourceLocator cannot be null";
- ECKeyPair keyPair = new ECKeyPair(nanoTDFConfig.eccMode.getCurveName(), ECKeyPair.ECAlgorithm.ECDSA);
+ ECKeyPair keyPair = new ECKeyPair(kasInfo.Algorithm, ECKeyPair.ECAlgorithm.ECDSA);
// Generate symmetric key
ECPublicKey kasPublicKey = ECKeyPair.publicKeyFromPem(kasInfo.PublicKey);
@@ -138,7 +148,12 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro
// Create header
byte[] compressedPubKey = keyPair.compressECPublickey();
Header header = new Header();
- header.setECCMode(nanoTDFConfig.eccMode);
+ var mode = new ECCMode();
+ mode.setEllipticCurve(Enum.valueOf(NanoTDFType.ECCurve.class, keyPair.curveName()));
+ if (logger.isWarnEnabled() && !nanoTDFConfig.eccMode.equals(mode)) {
+ logger.warn("ECC mode provided in NanoTDFConfig: {}, ECC mode from key: {}", nanoTDFConfig.eccMode, mode);
+ }
+ header.setECCMode(mode);
header.setPayloadConfig(nanoTDFConfig.config);
header.setEphemeralKey(compressedPubKey);
header.setKasLocator(kasURL);
@@ -152,6 +167,21 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro
return headerInfo;
}
+ private Optional getKasInfo(Config.NanoTDFConfig nanoTDFConfig) {
+ if (nanoTDFConfig.kasInfoList.isEmpty()) {
+ logger.debug("no kas info provided in NanoTDFConfig");
+ return Optional.empty();
+ }
+
+ Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0);
+ String url = kasInfo.URL;
+ if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) {
+ logger.info("no public key provided for KAS at {}, retrieving", url);
+ kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getEllipticCurveType());
+ }
+ return Optional.of(kasInfo);
+ }
+
public int createNanoTDF(ByteBuffer data, OutputStream outputStream,
Config.NanoTDFConfig nanoTDFConfig) throws SDKException, IOException {
int nanoTDFSize = 0;
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java
new file mode 100644
index 00000000..a40a1ddd
--- /dev/null
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java
@@ -0,0 +1,186 @@
+package io.opentdf.platform.sdk;
+
+import com.connectrpc.ConnectException;
+import com.google.gson.Gson;
+import com.google.gson.JsonSyntaxException;
+import com.google.gson.annotations.SerializedName;
+import io.opentdf.platform.policy.Algorithm;
+import io.opentdf.platform.policy.SimpleKasKey;
+import io.opentdf.platform.policy.SimpleKasPublicKey;
+import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
+import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.function.BiFunction;
+
+
+public class Planner {
+ private static final String BASE_KEY = "base_key";
+ private final Config.TDFConfig tdfConfig;
+ private final SDK.Services services;
+ private final BiFunction granterFactory;
+
+
+ private static final Logger logger = LoggerFactory.getLogger(Planner.class);
+
+ public Planner(Config.TDFConfig config, SDK.Services services, BiFunction granterFactory) {
+ this.tdfConfig = Objects.requireNonNull(config);
+ this.services = Objects.requireNonNull(services);
+ this.granterFactory = granterFactory;
+ }
+
+ private static String getUUID() {
+ return UUID.randomUUID().toString();
+ }
+
+ Map> getSplits(Config.TDFConfig tdfConfig) {
+ List splitPlan;
+ if (tdfConfig.autoconfigure) {
+ if (tdfConfig.splitPlan != null && !tdfConfig.splitPlan.isEmpty()) {
+ throw new IllegalArgumentException("cannot use autoconfigure with a split plan provided in the TDFConfig");
+ }
+ splitPlan = getAutoconfigurePlan(services, tdfConfig);
+ } else if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) {
+ splitPlan = generatePlanFromProvidedKases(tdfConfig.kasInfoList);
+ } else {
+ splitPlan = tdfConfig.splitPlan;
+ }
+
+ if (tdfConfig.kasInfoList.isEmpty() && splitPlan.isEmpty()) {
+ throw new SDK.KasInfoMissing("kas information is missing, no key access template specified or inferred");
+ }
+ return resolveKeys(splitPlan);
+ }
+
+ private List getAutoconfigurePlan(SDK.Services services, Config.TDFConfig tdfConfig) {
+ Autoconfigure.Granter granter = granterFactory.apply(services, tdfConfig);
+ return granter.getSplits(defaultKases(tdfConfig), Planner::getUUID, () -> Planner.fetchBaseKey(services.wellknown()));
+ }
+
+
+ List generatePlanFromProvidedKases(List kases) {
+ if (kases.size() == 1) {
+ var kasInfo = kases.get(0);
+ return Collections.singletonList(new Autoconfigure.KeySplitStep(kasInfo.URL, "", kasInfo.KID));
+ }
+ List splitPlan = new ArrayList<>();
+ for (var kasInfo : kases) {
+ splitPlan.add(new Autoconfigure.KeySplitStep(kasInfo.URL, getUUID(), kasInfo.KID));
+ }
+ return splitPlan;
+ }
+
+ static Optional fetchBaseKey(WellKnownServiceClientInterface wellknown) {
+ var responseMessage = wellknown
+ .getWellKnownConfigurationBlocking(GetWellKnownConfigurationRequest.getDefaultInstance(), Collections.emptyMap())
+ .execute();
+ GetWellKnownConfigurationResponse response;
+ try {
+ response = RequestHelper.getOrThrow(responseMessage);
+ } catch (ConnectException e) {
+ throw new SDKException("unable to retrieve base key from well known endpoint", e);
+ }
+
+ String baseKeyJson;
+ try {
+ baseKeyJson = response
+ .getConfiguration()
+ .getFieldsOrThrow(BASE_KEY)
+ .getStringValue();
+ } catch (IllegalArgumentException e) {
+ logger.info( "no `" + BASE_KEY + "` found in well known configuration.", e);
+ return Optional.empty();
+ }
+
+ BaseKey baseKey;
+ try {
+ baseKey = gson.fromJson(baseKeyJson, BaseKey.class);
+ } catch (JsonSyntaxException e) {
+ throw new SDKException("base key in well known configuration is malformed [" + baseKeyJson + "]", e);
+ }
+
+ if (baseKey == null || baseKey.kasUrl == null || baseKey.publicKey == null || baseKey.publicKey.kid == null || baseKey.publicKey.pem == null || baseKey.publicKey.algorithm == null) {
+ logger.error("base key in well known configuration is missing required fields [{}]. base key will not be used", baseKeyJson);
+ return Optional.empty();
+ }
+
+ return Optional.of(SimpleKasKey.newBuilder()
+ .setKasUri(baseKey.kasUrl)
+ .setPublicKey(
+ SimpleKasPublicKey.newBuilder()
+ .setKid(baseKey.publicKey.kid)
+ .setAlgorithm(baseKey.publicKey.algorithm)
+ .setPem(baseKey.publicKey.pem)
+ .build())
+ .build());
+ }
+
+ private static final Gson gson = new Gson();
+
+ private static class BaseKey {
+ @SerializedName("kas_url")
+ String kasUrl;
+
+ @SerializedName("public_key")
+ Key publicKey;
+
+ private static class Key {
+ String kid;
+ String pem;
+ Algorithm algorithm;
+ }
+ }
+
+ Map> resolveKeys(List splitPlan) {
+ Map> conjunction = new HashMap<>();
+ var latestKASInfo = new HashMap();
+ // Seed anything passed in manually
+ for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) {
+ if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) {
+ latestKASInfo.put(kasInfo.URL, kasInfo);
+ }
+ }
+
+ for (Autoconfigure.KeySplitStep splitInfo: splitPlan) {
+ // Public key was passed in with kasInfoList
+ // TODO First look up in attribute information / add to split plan?
+ Config.KASInfo ki = latestKASInfo.get(splitInfo.kas);
+ if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank() || (splitInfo.kid != null && !splitInfo.kid.equals(ki.KID))) {
+ logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas);
+ var getKI = new Config.KASInfo();
+ getKI.URL = splitInfo.kas;
+ getKI.Algorithm = tdfConfig.wrappingKeyType.toString();
+ getKI.KID = splitInfo.kid;
+ getKI = services.kas().getPublicKey(getKI);
+ latestKASInfo.put(splitInfo.kas, getKI);
+ ki = getKI;
+ }
+ conjunction.computeIfAbsent(splitInfo.splitID, s -> new ArrayList<>()).add(ki);
+ }
+ return conjunction;
+ }
+
+ static List defaultKases(Config.TDFConfig config) {
+ List allk = new ArrayList<>();
+ List defk = new ArrayList<>();
+
+ for (Config.KASInfo kasInfo : config.kasInfoList) {
+ if (kasInfo.Default != null && kasInfo.Default) {
+ defk.add(kasInfo.URL);
+ } else if (defk.isEmpty()) {
+ allk.add(kasInfo.URL);
+ }
+ }
+ return defk.isEmpty() ? allk : defk;
+ }
+}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
index 9195fca5..a0db542d 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
@@ -9,6 +9,7 @@
import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface;
import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface;
import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import javax.net.ssl.TrustManager;
import java.io.IOException;
@@ -75,6 +76,8 @@ public interface Services extends AutoCloseable {
KeyAccessServerRegistryServiceClientInterface kasRegistry();
+ WellKnownServiceClientInterface wellknown();
+
KAS kas();
}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
index dd83f75a..03b1943c 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
@@ -33,6 +33,7 @@
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClient;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import nl.altindag.ssl.SSLFactory;
import nl.altindag.ssl.pem.util.PemUtils;
import okhttp3.OkHttpClient;
@@ -251,6 +252,7 @@ ServicesAndInternals buildServices() {
var resourceMappingService = new ResourceMappingServiceClient(client);
var authorizationService = new AuthorizationServiceClient(client);
var kasRegistryService = new KeyAccessServerRegistryServiceClient(client);
+ var wellKnownService = new WellKnownServiceClient(client);
var services = new SDK.Services() {
@Override
@@ -290,6 +292,11 @@ public KeyAccessServerRegistryServiceClient kasRegistry() {
return kasRegistryService;
}
+ @Override
+ public WellKnownServiceClientInterface wellknown() {
+ return wellKnownService;
+ }
+
@Override
public SDK.KAS kas() {
return kasClient;
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
index fe547caa..cfe75d26 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
@@ -5,11 +5,8 @@
import com.google.gson.GsonBuilder;
import com.nimbusds.jose.*;
-import io.opentdf.platform.policy.Value;
import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest;
import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersResponse;
-import io.opentdf.platform.sdk.Config.TDFConfig;
-import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN;
import io.opentdf.platform.sdk.Config.KASInfo;
import org.apache.commons.codec.DecoderException;
@@ -142,7 +139,7 @@ private PolicyObject createPolicyObject(List at
private static final Base64.Encoder encoder = Base64.getEncoder();
- private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) {
+ private void prepareManifest(Config.TDFConfig tdfConfig, Map> splits) {
manifest.tdfVersion = tdfConfig.renderVersionInfoInManifest ? TDF_VERSION : null;
manifest.encryptionInformation.keyAccessType = kSplitKeyType;
manifest.encryptionInformation.keyAccessObj = new ArrayList<>();
@@ -150,60 +147,12 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) {
PolicyObject policyObject = createPolicyObject(tdfConfig.attributes);
String base64PolicyObject = encoder
.encodeToString(gson.toJson(policyObject).getBytes(StandardCharsets.UTF_8));
- Map latestKASInfo = new HashMap<>();
- if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) {
- // Default split plan: Split keys across all KASes
- List splitPlan = new ArrayList<>(tdfConfig.kasInfoList.size());
- int i = 0;
- for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) {
- Autoconfigure.KeySplitStep step = new Autoconfigure.KeySplitStep(kasInfo.URL, "");
- if (tdfConfig.kasInfoList.size() > 1) {
- step.splitID = String.format("s-%d", i++);
- }
- splitPlan.add(step);
- if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) {
- latestKASInfo.put(kasInfo.URL, kasInfo);
- }
- }
- tdfConfig.splitPlan = splitPlan;
- }
- // Seed anything passed in manually
- for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) {
- if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) {
- latestKASInfo.put(kasInfo.URL, kasInfo);
- }
- }
- // split plan: restructure by conjunctions
- Map> conjunction = new HashMap<>();
- List splitIDs = new ArrayList<>();
-
- for (Autoconfigure.KeySplitStep splitInfo : tdfConfig.splitPlan) {
- // Public key was passed in with kasInfoList
- // TODO First look up in attribute information / add to split plan?
- Config.KASInfo ki = latestKASInfo.get(splitInfo.kas);
- if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank()) {
- logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas);
- var getKI = new Config.KASInfo();
- getKI.URL = splitInfo.kas;
- getKI.Algorithm = tdfConfig.wrappingKeyType.toString();
- getKI = kas.getPublicKey(getKI);
- latestKASInfo.put(splitInfo.kas, getKI);
- ki = getKI;
- }
- if (conjunction.containsKey(splitInfo.splitID)) {
- conjunction.get(splitInfo.splitID).add(ki);
- } else {
- List newList = new ArrayList<>();
- newList.add(ki);
- conjunction.put(splitInfo.splitID, newList);
- splitIDs.add(splitInfo.splitID);
- }
- }
+ List symKeys = new ArrayList<>(splits.size());
+ for (var split : splits.entrySet()) {
+ String splitID = split.getKey();
- List symKeys = new ArrayList<>(splitIDs.size());
- for (String splitID : splitIDs) {
// Symmetric key
byte[] symKey = new byte[GCM_KEY_SIZE];
sRandom.nextBytes(symKey);
@@ -230,7 +179,8 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) {
encryptedMetadata = encoder.encodeToString(metadata.getBytes(StandardCharsets.UTF_8));
}
- for (Config.KASInfo kasInfo : conjunction.get(splitID)) {
+ List kasInfos = split.getValue();
+ for (Config.KASInfo kasInfo : kasInfos) {
if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) {
throw new SDK.KasPublicKeyMissing("Kas public key is missing in kas information list");
}
@@ -301,7 +251,6 @@ private String createRSAWrappedKey(Config.KASInfo kasInfo, byte[] symKey) {
}
}
-
private static final Base64.Decoder decoder = Base64.getDecoder();
public static class Reader {
@@ -326,7 +275,6 @@ public Manifest getManifest() {
this.aesGcm = new AesGcm(payloadKey);
this.payloadKey = payloadKey;
this.unencryptedMetadata = unencryptedMetadata;
-
}
public void readPayload(OutputStream outputStream) throws SDK.SegmentSignatureMismatch, IOException {
@@ -400,35 +348,11 @@ private static byte[] calculateSignature(byte[] data, byte[] secret, Config.Inte
}
TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFConfig tdfConfig) throws SDKException, IOException {
-
- if (tdfConfig.autoconfigure) {
- Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>());
- if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) {
- granter = Autoconfigure.newGranterFromAttributes(tdfConfig.attributeValues.toArray(new Value[0]));
- } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) {
- granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(),
- tdfConfig.attributes.toArray(new AttributeValueFQN[0]));
- }
-
- if (granter == null) {
- throw new AutoConfigureException("Failed to create Granter"); // Replace with appropriate error handling
- }
-
- List dk = defaultKases(tdfConfig);
- tdfConfig.splitPlan = granter.plan(dk, () -> UUID.randomUUID().toString());
-
- if (tdfConfig.splitPlan == null) {
- throw new AutoConfigureException("Failed to generate Split Plan"); // Replace with appropriate error
- // handling
- }
- }
-
- if (tdfConfig.kasInfoList.isEmpty() && (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty())) {
- throw new SDK.KasInfoMissing("kas information is missing, no key access template specified or inferred");
- }
+ Planner planner = new Planner(tdfConfig, services, Autoconfigure::createGranter);
+ Map> splits = planner.getSplits(tdfConfig);
TDFObject tdfObject = new TDFObject();
- tdfObject.prepareManifest(tdfConfig, services.kas());
+ tdfObject.prepareManifest(tdfConfig, splits);
long encryptedSegmentSize = tdfConfig.defaultSegmentSize + kGcmIvSize + kAesBlockSize;
TDFWriter tdfWriter = new TDFWriter(outputStream);
@@ -561,22 +485,6 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo
return tdfObject;
}
- static List defaultKases(TDFConfig config) {
- List allk = new ArrayList<>();
- List defk = new ArrayList<>();
-
- for (KASInfo kasInfo : config.kasInfoList) {
- if (kasInfo.Default != null && kasInfo.Default) {
- defk.add(kasInfo.URL);
- } else if (defk.isEmpty()) {
- allk.add(kasInfo.URL);
- }
- }
- if (defk.isEmpty()) {
- return allk;
- }
- return defk;
- }
Reader loadTDF(SeekableByteChannel tdf, String platformUrl) throws SDKException, IOException {
return loadTDF(tdf, Config.newTDFReaderConfig(), platformUrl);
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java
index 59cd0912..d37f2727 100644
--- a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java
@@ -1,18 +1,8 @@
package io.opentdf.platform.sdk;
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
import com.connectrpc.ResponseMessage;
import com.connectrpc.UnaryBlockingCall;
+import io.opentdf.platform.policy.Algorithm;
import io.opentdf.platform.policy.Attribute;
import io.opentdf.platform.policy.AttributeRuleTypeEnum;
import io.opentdf.platform.policy.KasPublicKey;
@@ -21,16 +11,17 @@
import io.opentdf.platform.policy.KeyAccessServer;
import io.opentdf.platform.policy.Namespace;
import io.opentdf.platform.policy.PublicKey;
+import io.opentdf.platform.policy.SimpleKasKey;
+import io.opentdf.platform.policy.SimpleKasPublicKey;
import io.opentdf.platform.policy.Value;
import io.opentdf.platform.policy.attributes.AttributesServiceClient;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse;
import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN;
+import io.opentdf.platform.sdk.Autoconfigure.Granter;
import io.opentdf.platform.sdk.Autoconfigure.Granter.AttributeBooleanExpression;
import io.opentdf.platform.sdk.Autoconfigure.Granter.BooleanKeyExpression;
import io.opentdf.platform.sdk.Autoconfigure.KeySplitStep;
-import io.opentdf.platform.sdk.Autoconfigure.Granter;
-
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
@@ -39,10 +30,26 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
+import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
public class AutoconfigureTest {
@@ -56,6 +63,16 @@ public class AutoconfigureTest {
public static final String SPECIFIED_KAS = "https://attr.kas.com/";
public static final String EVEN_MORE_SPECIFIC_KAS = "https://value.kas.com/";
private static final String NAMESPACE_KAS = "https://namespace.kas.com/";
+ private static final SimpleKasKey NAMESPACE_KAS_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey(
+ SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("namespacekey").setKid("namespacekeykid").build()
+ ).build();
+ private static final SimpleKasKey ATTRIBUTE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey(
+ SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("attrpem").setKid("attrkeykid").build()
+ ).build();
+ private static final SimpleKasKey VALUE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey(
+ SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("valuepem").setKid("valuekeykid").build()
+ ).build();
+ private static Autoconfigure.AttributeNameFQN UNMAPPED;
private static Autoconfigure.AttributeNameFQN SPKSPECKED;
private static Autoconfigure.AttributeNameFQN SPKUNSPECKED;
@@ -64,6 +81,8 @@ public class AutoconfigureTest {
private static Autoconfigure.AttributeNameFQN REL;
private static Autoconfigure.AttributeNameFQN UNSPECKED;
private static Autoconfigure.AttributeNameFQN SPECKED;
+ private static Autoconfigure.AttributeNameFQN MAPPED;
+ private static Autoconfigure.AttributeNameFQN SPKMAPPED;
private static Autoconfigure.AttributeValueFQN clsA;
private static Autoconfigure.AttributeValueFQN clsS;
@@ -84,6 +103,9 @@ public class AutoconfigureTest {
private static Autoconfigure.AttributeValueFQN spk2spk2uns;
private static Autoconfigure.AttributeValueFQN spk2spk2spk;
+ private static Autoconfigure.AttributeValueFQN mp2uns2uns;
+ private static Autoconfigure.AttributeValueFQN mp2uns2mp;
+
@BeforeAll
public static void setup() throws AutoConfigureException {
// Initialize the FQNs (Fully Qualified Names)
@@ -92,8 +114,11 @@ public static void setup() throws AutoConfigureException {
REL = new Autoconfigure.AttributeNameFQN("https://virtru.com/attr/Releasable%20To");
UNSPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/unspecified");
SPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/specified");
+ MAPPED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/mapped");
+ UNMAPPED = new Autoconfigure.AttributeNameFQN("https://mapped.com/attr/unspecified");
SPKUNSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/unspecified");
SPKSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/specified");
+ SPKMAPPED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/mapped");
clsA = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Allowed");
clsS = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Secret");
@@ -117,6 +142,9 @@ public static void setup() throws AutoConfigureException {
spk2uns2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/unspecified/value/specked");
spk2spk2uns = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/unspecked");
spk2spk2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/specked");
+
+ mp2uns2uns = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/unspecked");
+ mp2uns2mp = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/mapped");
}
private static String spongeCase(String s) {
@@ -181,6 +209,7 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) {
Namespace ns1 = Namespace.newBuilder().setId("v").setName("virtru.com").setFqn("https://virtru.com").build();
Namespace ns2 = Namespace.newBuilder().setId("o").setName("other.com").setFqn("https://other.com").build();
Namespace ns3 = Namespace.newBuilder().setId("h").setName("hasgrants.com").addGrants(KeyAccessServer.newBuilder().setUri(NAMESPACE_KAS).build()).setFqn("https://hasgrants.com").build();
+ Namespace ns4 = Namespace.newBuilder().setId("m").setName("mapped.com").addKasKeys(NAMESPACE_KAS_KEY).build();
String key = fqn.getKey();
if (key.equals(CLS.getKey())) {
@@ -216,6 +245,17 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) {
.setName("unspecified").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
.setName(fqn.toString())
.build();
+ } else if (key.equals(MAPPED.getKey())) {
+ return Attribute.newBuilder().setId(MAPPED.getKey()).setNamespace(ns4)
+ .setName("mapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
+ .setKasKeys(0, ATTRIBUTE_KEY)
+ .setName(fqn.toString())
+ .build();
+ } else if (key.equals(UNMAPPED.getKey())) {
+ return Attribute.newBuilder().setId(UNMAPPED.getKey()).setNamespace(ns4)
+ .setName("unmapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
+ .setName(fqn.toString())
+ .build();
}
throw new IllegalArgumentException("Key not recognized: " + key);
@@ -287,7 +327,13 @@ private Value mockValueFor(Autoconfigure.AttributeValueFQN fqn) throws AutoConfi
p = p.toBuilder().addGrants(KeyAccessServer.newBuilder().setUri(EVEN_MORE_SPECIFIC_KAS).build())
.build();
}
+ } else if (Objects.equals(UNMAPPED.getKey(), an.getKey())) {
+ if (fqn.value().equalsIgnoreCase("mapped")) {
+ p = p.toBuilder().addKasKeys(VALUE_KEY)
+ .build();
+ }
}
+
return p;
}
@@ -368,7 +414,7 @@ public void testConfigurationServicePutGet() {
for (ConfigurationTestCase tc : testCases) {
assertDoesNotThrow(() -> {
List v = valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0]));
- Granter grants = Autoconfigure.newGranterFromAttributes(v.toArray(new Value[0]));
+ Granter grants = Autoconfigure.newGranterFromAttributes(null, v.toArray(new Value[0]));
assertThat(grants).isNotNull();
assertThat(grants.getGrants()).hasSize(tc.getSize());
assertThat(policyToStringKeys(tc.getPolicy())).containsAll(grants.getGrants().keySet());
@@ -451,7 +497,7 @@ public void testReasonerConstructAttributeBoolean() {
new KeySplitStep(KAS_US_HCS, "2"), new KeySplitStep(KAS_US_SA, "3"))));
for (ReasonerTestCase tc : testCases) {
- Granter reasoner = Autoconfigure.newGranterFromAttributes(
+ Granter reasoner = Autoconfigure.newGranterFromAttributes(null,
valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0])).toArray(new Value[0]));
assertThat(reasoner).isNotNull();
@@ -467,15 +513,33 @@ public void testReasonerConstructAttributeBoolean() {
var wrapper = new Object() {
int i = 0;
};
- List plan = reasoner.plan(tc.getDefaults(), () -> {
- return String.valueOf(wrapper.i++ + 1);
- }
-
- );
- assertThat(plan).isEqualTo(tc.getPlan());
+ List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty);
+ assertThat(plan)
+ .as(tc.name)
+ .isEqualTo(tc.getPlan());
}
}
+ @Test
+ void testUsingAttributeMappedAtNamespace() {
+ Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), mockValueFor(mp2uns2uns));
+ var counter = new AtomicInteger(0);
+ var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty);
+ assertThat(splitPlan).isEqualTo(List.of(new KeySplitStep("https://mapped.example.com", "", NAMESPACE_KAS_KEY.getPublicKey().getKid())));
+ }
+
+ @Test
+ void testUsingAttributeMappedAtMultiplePlaces() {
+ var attributes = new Value[]{mockValueFor(mp2uns2uns), mockValueFor(mp2uns2mp)};
+ Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), attributes);
+ var counter = new AtomicInteger(0);
+ var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty);
+ assertThat(splitPlan).isEqualTo(List.of(
+ new KeySplitStep(NAMESPACE_KAS_KEY.getKasUri(), "0", NAMESPACE_KAS_KEY.getPublicKey().getKid()),
+ new KeySplitStep(VALUE_KEY.getKasUri(), "0", VALUE_KEY.getPublicKey().getKid())
+ ));
+ }
+
GetAttributeValuesByFqnsResponse getResponse(GetAttributeValuesByFqnsRequest req) {
GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder();
@@ -547,7 +611,7 @@ public void testReasonerSpecificity() {
List.of(KAS_US),
List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))),
new ReasonerTestCase(
- "uns.uns & uns.spk => spk",
+ "uns.uns & spk.spk => spk",
List.of(uns2uns, spk2spk),
List.of(KAS_US),
List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))),
@@ -598,12 +662,11 @@ public void cancel() {
var wrapper = new Object() {
int i = 0;
};
- List plan = reasoner.plan(tc.getDefaults(), () -> {
- return String.valueOf(wrapper.i++ + 1);
- }
+ List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty);
+ assertThat(plan)
+ .as(tc.name)
+ .hasSameElementsAs(tc.getPlan());
- );
- assertThat(plan).hasSameElementsAs(tc.getPlan());
}
}
@@ -691,7 +754,7 @@ private static class ReasonerTestCase {
private final List plan;
ReasonerTestCase(String name, List policy, List defaults, String ats, String keyed,
- String reduced, List plan) {
+ String reduced, List plan) {
this.name = name;
this.policy = policy;
this.defaults = defaults;
@@ -744,13 +807,11 @@ public List getPlan() {
void testStoreKeysToCache_NoKeys() {
KASKeyCache keyCache = Mockito.mock(KASKeyCache.class);
KeyAccessServer kas1 = KeyAccessServer.newBuilder().setPublicKey(
- PublicKey.newBuilder().setCached(
- KasPublicKeySet.newBuilder()))
+ PublicKey.newBuilder().setCached(
+ KasPublicKeySet.newBuilder()))
.build();
- List kases = List.of(kas1);
-
- Autoconfigure.storeKeysToCache(kases, keyCache);
+ Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache);
verify(keyCache, never()).store(any(Config.KASInfo.class));
}
@@ -780,14 +841,11 @@ void testStoreKeysToCache_WithKeys() {
.setUri("https://example.com/kas")
.build();
- // Add the KeyAccessServer to a list
- List kases = List.of(kas1);
-
// Call the method under test
- Autoconfigure.storeKeysToCache(kases, keyCache);
+ Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache);
// Verify that the key was stored in the cache
- Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1");
+ Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid");
assertNotNull(storedKASInfo);
assertEquals("https://example.com/kas", storedKASInfo.URL);
assertEquals("test-kid", storedKASInfo.KID);
@@ -826,21 +884,18 @@ void testStoreKeysToCache_MultipleKasEntries() {
.setUri("https://example.com/kas")
.build();
- // Add the KeyAccessServer to a list
- List kases = List.of(kas1);
-
// Call the method under test
- Autoconfigure.storeKeysToCache(kases, keyCache);
+ Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache);
// Verify that the key was stored in the cache
- Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1");
+ Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid");
assertNotNull(storedKASInfo);
assertEquals("https://example.com/kas", storedKASInfo.URL);
assertEquals("test-kid", storedKASInfo.KID);
assertEquals("ec:secp256r1", storedKASInfo.Algorithm);
assertEquals("public-key-pem", storedKASInfo.PublicKey);
- Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048");
+ Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2");
assertNotNull(storedKASInfo2);
assertEquals("https://example.com/kas", storedKASInfo2.URL);
assertEquals("test-kid-2", storedKASInfo2.KID);
@@ -849,7 +904,7 @@ void testStoreKeysToCache_MultipleKasEntries() {
}
GetAttributeValuesByFqnsResponse getResponseWithGrants(GetAttributeValuesByFqnsRequest req,
- List grants) {
+ List grants) {
GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder();
for (String v : req.getFqnsList()) {
@@ -901,13 +956,15 @@ void testKeyCacheFromGrants() {
AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class);
when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> {
- var request = (GetAttributeValuesByFqnsRequest)invocation.getArgument(0);
- return new UnaryBlockingCall(){
+ var request = (GetAttributeValuesByFqnsRequest) invocation.getArgument(0);
+ return new UnaryBlockingCall() {
@Override
public ResponseMessage execute() {
return new ResponseMessage.Success<>(getResponseWithGrants(request, List.of(kas1)), Collections.emptyMap(), Collections.emptyMap());
}
- @Override public void cancel() {
+
+ @Override
+ public void cancel() {
// not really calling anything
}
};
@@ -920,14 +977,14 @@ public ResponseMessage execute() {
assertThat(reasoner).isNotNull();
// Verify that the key was stored in the cache
- Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1");
+ Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid");
assertNotNull(storedKASInfo);
assertEquals("https://example.com/kas", storedKASInfo.URL);
assertEquals("test-kid", storedKASInfo.KID);
assertEquals("ec:secp256r1", storedKASInfo.Algorithm);
assertEquals("public-key-pem", storedKASInfo.PublicKey);
- Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048");
+ Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2");
assertNotNull(storedKASInfo2);
assertEquals("https://example.com/kas", storedKASInfo2.URL);
assertEquals("test-kid-2", storedKASInfo2.KID);
@@ -935,4 +992,155 @@ public ResponseMessage execute() {
assertEquals("public-key-pem-2", storedKASInfo2.PublicKey);
}
+ @Test
+ void testUsingBaseKeyWhenNoMappedKeysOrGrants() {
+ Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null);
+ SimpleKasKey key = SimpleKasKey.newBuilder()
+ .setKasUri("https://example.com/kas")
+ .setPublicKey(
+ SimpleKasPublicKey.newBuilder()
+ .setKid("thenewkid")
+ .setPem("anotherpem")
+ .setAlgorithm(Algorithm.ALGORITHM_EC_P521)
+ ).build();
+
+ var splits = granter.getSplits(
+ List.of("https://example.org/kas2"),
+ () -> {
+ throw new IllegalStateException("the plan should have a single element");
+ },
+ () -> Optional.of(key));
+ assertThat(splits).hasSize(1);
+ assertThat(splits.get(0)).isEqualTo(new KeySplitStep("https://example.com/kas", "", "thenewkid"));
+ }
+
+ @Test
+ void testUsingDefaultKasesWhenNothingElseProvided() {
+ Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null);
+ var counter = new AtomicInteger();
+ Supplier splitGen = () -> String.valueOf(counter.getAndIncrement());
+ var splits = granter.getSplits(
+ List.of("https://example.org/kas1", "https://example.org/kas2"),
+ splitGen,
+ Optional::empty);
+
+ assertThat(splits)
+ .hasSize(2)
+ .asList().containsExactly(
+ new KeySplitStep("https://example.org/kas1", "0", null),
+ new KeySplitStep("https://example.org/kas2", "1", null)
+ );
+ }
+
+ @Test
+ void createsGranterFromAttributeValues() {
+ // Arrange
+ Config.TDFConfig config = new Config.TDFConfig();
+ config.attributeValues = List.of(mockValueFor(spk2spk), mockValueFor(rel2gbr));
+
+ SDK.Services services = mock(SDK.Services.class);
+ SDK.KAS kas = mock(SDK.KAS.class);
+ when(services.kas()).thenReturn(kas);
+ when(services.attributes()).thenThrow(new IllegalStateException("should never use the attribute service when attributes are provided"));
+ when(kas.getKeyCache()).thenReturn(null); // No cache needed for this test
+
+ // Act
+ Autoconfigure.Granter granter = Autoconfigure.createGranter(services, config);
+
+ // Assert
+ assertThat(granter).isNotNull();
+ assertThat(granter.getPolicy()).hasSize(2);
+ assertThat(granter.getPolicy()).containsExactlyInAnyOrder(
+ new AttributeValueFQN("https://other.com/attr/specified/value/specked"),
+ new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR")
+ );
+ }
+
+ @Test
+ void createsGranterFromService() {
+ // Arrange
+ SDK.Services services = mock(SDK.Services.class);
+ SDK.KAS kas = mock(SDK.KAS.class);
+ AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class);
+
+ // Prepare a request and a mocked response
+ List policy = List.of(
+ new AttributeValueFQN("https://other.com/attr/specified/value/specked"),
+ new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR")
+ );
+
+ when(services.kas()).thenReturn(kas);
+ when(services.attributes()).thenReturn(attributesServiceClient);
+
+ // Mock the attribute service to return a response with the expected values
+ when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> {
+ GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder();
+ for (AttributeValueFQN fqn : policy) {
+ Value value = Value.newBuilder()
+ .setId(fqn.toString())
+ .setFqn(fqn.toString())
+ .build();
+ builder.putFqnAttributeValues(fqn.toString(),
+ GetAttributeValuesByFqnsResponse.AttributeAndValue.newBuilder()
+ .setValue(value)
+ .build());
+ }
+ return TestUtil.successfulUnaryCall(builder.build());
+ });
+
+ // Act
+ Autoconfigure.Granter granter = Autoconfigure.createGranter(services, new Config.TDFConfig() {{
+ attributeValues = null; // force use of service
+ attributes = policy;
+ }});
+
+ // Assert
+ assertThat(granter).isNotNull();
+ // The policy should be empty because attributeValues is null, but the test ensures the service is called
+ // If you want to check the service call, verify it:
+ verify(services).attributes();
+ }
+
+ @Test
+ void getSplits_usesAutoconfigurePlan_whenAutoconfigureTrue() {
+ var tdfConfig = new Config.TDFConfig();
+ tdfConfig.autoconfigure = true;
+ tdfConfig.kasInfoList = new ArrayList<>();
+ tdfConfig.splitPlan = null;
+
+ var kas = Mockito.mock(SDK.KAS.class);
+ Mockito.when(kas.getKeyCache()).thenReturn(new KASKeyCache());
+ Config.KASInfo kasInfo = new Config.KASInfo() {{
+ URL = "https://kas.example.com";
+ Algorithm = "ec:secp256r1";
+ KID = "kid";
+ }};
+ Mockito.when(kas.getPublicKey(any())).thenReturn(kasInfo);
+
+ var services = new FakeServicesBuilder().setKas(kas).build();
+
+ // Mock granterFactory to return a granter with a known split plan
+ var expectedSplit = new Autoconfigure.KeySplitStep("https://kas.example.com", "", "kid");
+ var granter = Mockito.mock(Autoconfigure.Granter.class);
+ Mockito.when(granter.getSplits(
+ Mockito.anyList(),
+ Mockito.any(),
+ Mockito.any()))
+ .thenReturn(List.of(expectedSplit));
+
+ BiFunction granterFactory =
+ (s, c) -> granter;
+
+ var planner = new Planner(tdfConfig, services, granterFactory);
+
+ // Act
+ var splits = planner.getSplits(tdfConfig);
+
+ // Assert
+ assertThat(splits).containsKey("");
+ assertThat(splits.get("")).hasSize(1);
+ assertThat(splits.get("").get(0).URL).isEqualTo("https://kas.example.com");
+ assertThat(splits.get("").get(0).KID).isEqualTo("kid");
+ assertThat(splits.get("").get(0).Algorithm).isEqualTo("ec:secp256r1");
+ }
}
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java
index b3573593..2851b22b 100644
--- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java
@@ -1,31 +1,34 @@
package io.opentdf.platform.sdk;
-import io.opentdf.platform.authorization.AuthorizationServiceClient;
-import io.opentdf.platform.policy.attributes.AttributesServiceClient;
-import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient;
-import io.opentdf.platform.policy.namespaces.NamespaceServiceClient;
-import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient;
-import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient;
+import io.opentdf.platform.authorization.AuthorizationServiceClientInterface;
+import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface;
+import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface;
+import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface;
+import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface;
+import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import java.util.Objects;
public class FakeServices implements SDK.Services {
- private final AuthorizationServiceClient authorizationService;
- private final AttributesServiceClient attributesService;
- private final NamespaceServiceClient namespaceService;
- private final SubjectMappingServiceClient subjectMappingService;
- private final ResourceMappingServiceClient resourceMappingService;
- private final KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub;
+ private final AuthorizationServiceClientInterface authorizationService;
+ private final AttributesServiceClientInterface attributesService;
+ private final NamespaceServiceClientInterface namespaceService;
+ private final SubjectMappingServiceClientInterface subjectMappingService;
+ private final ResourceMappingServiceClientInterface resourceMappingService;
+ private final KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub;
+ private final WellKnownServiceClientInterface wellKnownService;
private final SDK.KAS kas;
public FakeServices(
- AuthorizationServiceClient authorizationService,
- AttributesServiceClient attributesService,
- NamespaceServiceClient namespaceService,
- SubjectMappingServiceClient subjectMappingService,
- ResourceMappingServiceClient resourceMappingService,
- KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub,
+ AuthorizationServiceClientInterface authorizationService,
+ AttributesServiceClientInterface attributesService,
+ NamespaceServiceClientInterface namespaceService,
+ SubjectMappingServiceClientInterface subjectMappingService,
+ ResourceMappingServiceClientInterface resourceMappingService,
+ KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub,
+ WellKnownServiceClientInterface wellKnownServiceClient,
SDK.KAS kas) {
this.authorizationService = authorizationService;
this.attributesService = attributesService;
@@ -33,39 +36,45 @@ public FakeServices(
this.subjectMappingService = subjectMappingService;
this.resourceMappingService = resourceMappingService;
this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub;
+ this.wellKnownService = wellKnownServiceClient;
this.kas = kas;
}
@Override
- public AuthorizationServiceClient authorization() {
+ public AuthorizationServiceClientInterface authorization() {
return Objects.requireNonNull(authorizationService);
}
@Override
- public AttributesServiceClient attributes() {
+ public AttributesServiceClientInterface attributes() {
return Objects.requireNonNull(attributesService);
}
@Override
- public NamespaceServiceClient namespaces() {
+ public NamespaceServiceClientInterface namespaces() {
return Objects.requireNonNull(namespaceService);
}
@Override
- public SubjectMappingServiceClient subjectMappings() {
+ public SubjectMappingServiceClientInterface subjectMappings() {
return Objects.requireNonNull(subjectMappingService);
}
@Override
- public ResourceMappingServiceClient resourceMappings() {
+ public ResourceMappingServiceClientInterface resourceMappings() {
return Objects.requireNonNull(resourceMappingService);
}
@Override
- public KeyAccessServerRegistryServiceClient kasRegistry() {
+ public KeyAccessServerRegistryServiceClientInterface kasRegistry() {
return Objects.requireNonNull(keyAccessServerRegistryServiceFutureStub);
}
+ @Override
+ public WellKnownServiceClientInterface wellknown() {
+ return Objects.requireNonNull(wellKnownService);
+ }
+
@Override
public SDK.KAS kas() {
return Objects.requireNonNull(kas);
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java
index 2a80f53d..558aee3b 100644
--- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java
@@ -1,47 +1,54 @@
package io.opentdf.platform.sdk;
-import io.opentdf.platform.authorization.AuthorizationServiceClient;
-import io.opentdf.platform.policy.attributes.AttributesServiceClient;
-import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient;
-import io.opentdf.platform.policy.namespaces.NamespaceServiceClient;
-import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient;
-import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient;
+import io.opentdf.platform.authorization.AuthorizationServiceClientInterface;
+import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface;
+import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface;
+import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface;
+import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface;
+import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
public class FakeServicesBuilder {
- private AuthorizationServiceClient authorizationService;
- private AttributesServiceClient attributesService;
- private NamespaceServiceClient namespaceService;
- private SubjectMappingServiceClient subjectMappingService;
- private ResourceMappingServiceClient resourceMappingService;
- private KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub;
+ private AuthorizationServiceClientInterface authorizationService;
+ private AttributesServiceClientInterface attributesService;
+ private NamespaceServiceClientInterface namespaceService;
+ private SubjectMappingServiceClientInterface subjectMappingService;
+ private ResourceMappingServiceClientInterface resourceMappingService;
+ private KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub;
+ private WellKnownServiceClientInterface wellKnownServiceClient;
private SDK.KAS kas;
- public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClient authorizationService) {
+ public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClientInterface authorizationService) {
this.authorizationService = authorizationService;
return this;
}
- public FakeServicesBuilder setAttributesService(AttributesServiceClient attributesService) {
+ public FakeServicesBuilder setAttributesService(AttributesServiceClientInterface attributesService) {
this.attributesService = attributesService;
return this;
}
- public FakeServicesBuilder setNamespaceService(NamespaceServiceClient namespaceService) {
+ public FakeServicesBuilder setNamespaceService(NamespaceServiceClientInterface namespaceService) {
this.namespaceService = namespaceService;
return this;
}
- public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClient subjectMappingService) {
+ public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClientInterface subjectMappingService) {
this.subjectMappingService = subjectMappingService;
return this;
}
- public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClient resourceMappingService) {
+ public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClientInterface resourceMappingService) {
this.resourceMappingService = resourceMappingService;
return this;
}
- public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub) {
+ public FakeServicesBuilder setWellknownService(WellKnownServiceClientInterface wellKnownServiceClient) {
+ this.wellKnownServiceClient = wellKnownServiceClient;
+ return this;
+ }
+
+ public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub) {
this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub;
return this;
}
@@ -52,6 +59,7 @@ public FakeServicesBuilder setKas(SDK.KAS kas) {
}
public FakeServices build() {
- return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService, resourceMappingService, keyAccessServerRegistryServiceFutureStub, kas);
+ return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService,
+ resourceMappingService, keyAccessServerRegistryServiceFutureStub, wellKnownServiceClient, kas);
}
}
\ No newline at end of file
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java
index 5550678a..fdee682e 100644
--- a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java
@@ -35,7 +35,7 @@ void testStoreAndGet_WithinTimeLimit() {
kasKeyCache.store(kasInfo1);
// Retrieve the item within the time limit
- Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048");
+ Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1");
// Ensure the item was correctly retrieved
assertNotNull(result);
@@ -51,12 +51,24 @@ void testStoreAndGet_AfterTimeLimit() {
kasKeyCache.store(kasInfo1);
// Simulate time passing by modifying the timestamp directly
- KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048");
+ KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1");
TimeStampedKASInfo timeStampedKASInfo = new TimeStampedKASInfo(kasInfo1, LocalDateTime.now().minus(6, ChronoUnit.MINUTES));
kasKeyCache.cache.put(cacheKey, timeStampedKASInfo);
// Attempt to retrieve the item after the time limit
- Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048");
+ Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1");
+
+ // Ensure the item was not retrieved (it should have expired)
+ assertNull(result);
+ }
+
+ @Test
+ void testStoreAndGet_DifferentKIDs() {
+ // Store an item in the cache
+ kasKeyCache.store(kasInfo1);
+
+ // Attempt to retrieve the item with a different KID
+ Config.KASInfo result = kasKeyCache.get(kasInfo1.URL, kasInfo1.Algorithm, kasInfo1.KID + "different");
// Ensure the item was not retrieved (it should have expired)
assertNull(result);
@@ -72,7 +84,7 @@ void testStoreAndGet_WithNullAlgorithm() {
kasKeyCache.store(kasInfo1);
// Retrieve the item with a null algorithm
- Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null);
+ Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null, "kid1");
// Ensure the item was correctly retrieved
assertNotNull(result);
@@ -91,7 +103,7 @@ void testClearCache() {
kasKeyCache.clear();
// Attempt to retrieve the item after clearing the cache
- Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048");
+ Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1");
// Ensure the item was not retrieved (the cache should be empty)
assertNull(result);
@@ -104,8 +116,8 @@ void testStoreMultipleItemsAndGet() {
kasKeyCache.store(kasInfo2);
// Retrieve each item and ensure they were correctly stored and retrieved
- Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048");
- Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1");
+ Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1");
+ Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1", "kid2");
assertNotNull(result1);
assertEquals("https://example.com/kas1", result1.URL);
@@ -119,8 +131,8 @@ void testStoreMultipleItemsAndGet() {
@Test
void testEqualsAndHashCode() {
// Create two identical KASKeyRequest objects
- KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048");
- KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048");
+ KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1");
+ KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1");
// Ensure that equals and hashCode work as expected
assertEquals(keyRequest1, keyRequest2);
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java
index 5e428c79..a79a42a1 100644
--- a/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java
@@ -1,7 +1,7 @@
package io.opentdf.platform.sdk;
-import com.connectrpc.ResponseMessage;
-import com.connectrpc.UnaryBlockingCall;
+import com.google.protobuf.Struct;
+import com.google.protobuf.Value;
import io.opentdf.platform.policy.KeyAccessServer;
import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient;
import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest;
@@ -11,18 +11,20 @@
import java.nio.charset.StandardCharsets;
+import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.apache.commons.lang3.NotImplementedException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.nio.ByteBuffer;
-import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Base64;
-import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.Random;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@@ -45,20 +47,35 @@ public class NanoTDFTest {
"oVP7Vpcx\n" +
"-----END PRIVATE KEY-----";
+ private static final String BASE_PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\n" +
+ "MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEzcPM21p19N27oyMG6i3KrrDdgyiNDLgN\n" +
+ "rTJMvZRj3V/48ysDKw5jzngtzVLpkur/l5XBGjEiCr0aGlSo8Db0YVYFRm5g74P2\n" +
+ "nt8xOedHSPGYq0NMDt4Il6bt1rNUSVm+\n" +
+ "-----END PUBLIC KEY-----\n";
+
+ private static final String BASE_KEY_PRIVATE = "-----BEGIN PRIVATE KEY-----\n" +
+ "MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDA8zHdzkI8OhDojcRcQ\n" +
+ "xMqV3Gs+XGeQ/9SKbIRsyab0jJhMAEhRqzF3DQ4RSipIIX6hZANiAATNw8zbWnX0\n" +
+ "3bujIwbqLcqusN2DKI0MuA2tMky9lGPdX/jzKwMrDmPOeC3NUumS6v+XlcEaMSIK\n" +
+ "vRoaVKjwNvRhVgVGbmDvg/ae3zE550dI8ZirQ0wO3giXpu3Ws1RJWb4=\n" +
+ "-----END PRIVATE KEY-----\n";
+
private static final String KID = "r1";
+ private static final String BASE_KID = "basekid";
protected static KeyAccessServerRegistryServiceClient kasRegistryService;
protected static List registeredKases = List.of(
"https://api.example.com/kas",
"https://other.org/kas2",
"http://localhost:8181/kas",
- "https://localhost:8383/kas"
+ "https://localhost:8383/kas",
+ "https://api.kaswithbasekey.example.com"
);
protected static String platformUrl = "http://localhost:8080";
protected static SDK.KAS kas = new SDK.KAS() {
@Override
- public void close() throws Exception {
+ public void close() {
}
@Override
@@ -70,10 +87,16 @@ public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) {
@Override
public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) {
+ var k2 = kasInfo.clone();
+ if (Objects.equals(kasInfo.KID, BASE_KID)) {
+ assertThat(kasInfo.URL).isEqualTo("https://api.kaswithbasekey.example.com");
+ assertThat(kasInfo.Algorithm).isEqualTo("ec:secp384r1");
+ k2.PublicKey = BASE_PUBLIC_KEY;
+ return k2;
+ }
if (kasInfo.Algorithm != null && !"ec:secp256r1".equals(kasInfo.Algorithm)) {
throw new IllegalArgumentException("Unexpected algorithm: " + kasInfo);
}
- var k2 = kasInfo.clone();
k2.KID = KID;
k2.PublicKey = kasPublicKey;
return k2;
@@ -81,18 +104,12 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve)
@Override
public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy, KeyType sessionKeyType) {
- int index = Integer.parseInt(keyAccess.url);
- var decryptor = new AsymDecryption(keypairs.get(index).getPrivate());
- var bytes = Base64.getDecoder().decode(keyAccess.wrappedKey);
- try {
- return decryptor.decrypt(bytes);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ throw new NotImplementedException("no unwrapping ZTDFs here");
}
@Override
public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL) {
+ String key = Objects.equals(kasURL, "https://api.kaswithbasekey.example.com") ? BASE_KEY_PRIVATE : kasPrivateKey;
byte[] headerAsBytes = Base64.getDecoder().decode(header);
Header nTDFHeader = new Header(ByteBuffer.wrap(headerAsBytes));
@@ -102,7 +119,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas
// Generate symmetric key
byte[] symmetricKey = ECKeyPair.computeECDHKey(ECKeyPair.publicKeyFromPem(publicKeyAsPem),
- ECKeyPair.privateKeyFromPem(kasPrivateKey));
+ ECKeyPair.privateKeyFromPem(key));
// Generate HKDF key
MessageDigest digest;
@@ -112,8 +129,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas
throw new SDKException("error creating SHA-256 message digest", e);
}
byte[] hashOfSalt = digest.digest(NanoTDF.MAGIC_NUMBER_AND_VERSION);
- byte[] key = ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey);
- return key;
+ return ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey);
}
@Override
@@ -136,21 +152,9 @@ static void setupMocks() {
// Stub the listKeyAccessServers method
when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any()))
- .thenReturn(new UnaryBlockingCall<>() {
- @Override
- public ResponseMessage execute() {
- return new ResponseMessage.Success<>(mockResponse, Collections.emptyMap(), Collections.emptyMap());
- }
-
- @Override
- public void cancel() {
- // this never happens in tests
- }
- });
+ .thenReturn(TestUtil.successfulUnaryCall(mockResponse));
}
- private static ArrayList keypairs = new ArrayList<>();
-
@Test
void encryptionAndDecryptionWithValidKey() throws Exception {
var kasInfos = new ArrayList<>();
@@ -201,6 +205,39 @@ void encryptionAndDecryptionWithValidKey() throws Exception {
}
}
+ @Test
+ void encryptionAndDecryptWithBaseKey() throws Exception {
+ var baseKeyJson = "{\"kas_url\":\"https://api.kaswithbasekey.example.com\",\"public_key\":{\"algorithm\":\"ALGORITHM_EC_P384\",\"kid\":\"thebasekid\",\"pem\": \"" + BASE_PUBLIC_KEY + "\"}}";
+ var val = Value.newBuilder().setStringValue(baseKeyJson).build();
+ var config = Struct.newBuilder().putFields("base_key", val).build();
+ WellKnownServiceClientInterface wellknown = mock(WellKnownServiceClientInterface.class);
+ GetWellKnownConfigurationResponse response = GetWellKnownConfigurationResponse.newBuilder().setConfiguration(config).build();
+
+ when(wellknown.getWellKnownConfigurationBlocking(any(), any())).thenReturn(TestUtil.successfulUnaryCall(response));
+
+ Config.NanoTDFConfig nanoConfig = Config.newNanoTDFConfig(
+ Config.witDataAttributes("https://example.com/attr/Classification/value/S",
+ "https://example.com/attr/Classification/value/X")
+ );
+
+ String plainText = "Virtru!!";
+ ByteBuffer byteBuffer = ByteBuffer.wrap(plainText.getBytes());
+ ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream();
+
+ NanoTDF nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).setWellknownService(wellknown).build());
+ nanoTDF.createNanoTDF(byteBuffer, tdfOutputStream, nanoConfig);
+
+ byte[] nanoTDFBytes = tdfOutputStream.toByteArray();
+ ByteArrayOutputStream plainTextStream = new ByteArrayOutputStream();
+ nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).build());
+ nanoTDF.readNanoTDF(ByteBuffer.wrap(nanoTDFBytes), plainTextStream, platformUrl);
+
+ String out = new String(plainTextStream.toByteArray(), StandardCharsets.UTF_8);
+ assertThat(out).isEqualTo(plainText);
+ // KAS KID
+ assertThat(new String(nanoTDFBytes, StandardCharsets.UTF_8)).contains(KID);
+ }
+
void runBasicTest(String kasUrl, boolean allowed, KeyAccessServerRegistryServiceClient kasReg, NanoTDFReaderConfig decryptConfig) throws Exception {
var kasInfos = new ArrayList<>();
var kasInfo = new Config.KASInfo();
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java
new file mode 100644
index 00000000..117db1a7
--- /dev/null
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java
@@ -0,0 +1,269 @@
+package io.opentdf.platform.sdk;
+
+import com.google.protobuf.Struct;
+import com.google.protobuf.Value;
+import io.opentdf.platform.policy.Algorithm;
+import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
+import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+
+class PlannerTest {
+
+ @Test
+ void fetchBaseKey() {
+ var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class);
+ var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{\"algorithm\":\"ALGORITHM_RSA_2048\",\"kid\":\"thekid\",\"pem\": \"thepem\"}}";
+ var val = Value.newBuilder().setStringValue(baseKeyJson).build();
+ var config = Struct.newBuilder().putFields("base_key", val).build();
+ var response = GetWellKnownConfigurationResponse
+ .newBuilder()
+ .setConfiguration(config)
+ .build();
+
+ Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap()))
+ .thenReturn(TestUtil.successfulUnaryCall(response));
+
+
+ var baseKey = Planner.fetchBaseKey(wellknownService);
+ assertThat(baseKey).isNotEmpty();
+ var simpleKasKey = baseKey.get();
+ assertThat(simpleKasKey.getKasUri()).isEqualTo("https://example.com/base_key");
+ assertThat(simpleKasKey.getPublicKey().getAlgorithm()).isEqualTo(Algorithm.ALGORITHM_RSA_2048);
+ assertThat(simpleKasKey.getPublicKey().getKid()).isEqualTo("thekid");
+ assertThat(simpleKasKey.getPublicKey().getPem()).isEqualTo("thepem");
+ }
+
+ @Test
+ void fetchBaseKeyWithNoBaseKey() {
+ var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class);
+ var response = GetWellKnownConfigurationResponse
+ .newBuilder()
+ .setConfiguration(Struct.newBuilder().build())
+ .build();
+
+ Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap()))
+ .thenReturn(TestUtil.successfulUnaryCall(response));
+
+ var baseKey = Planner.fetchBaseKey(wellknownService);
+ assertThat(baseKey).isEmpty();
+ }
+
+ @Test
+ void fetchBaseKeyWithMissingFields() {
+ var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class);
+ // Missing 'kid', 'pem', and 'algorithm' in public_key
+ var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{}}";
+ var val = Value.newBuilder().setStringValue(baseKeyJson).build();
+ var config = Struct.newBuilder().putFields("base_key", val).build();
+ var response = GetWellKnownConfigurationResponse
+ .newBuilder()
+ .setConfiguration(config)
+ .build();
+
+ Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap()))
+ .thenReturn(TestUtil.successfulUnaryCall(response));
+
+ var baseKey = Planner.fetchBaseKey(wellknownService);
+ assertThat(baseKey).isEmpty();
+ }
+
+ @Test
+ void generatePlanFromProvidedKases() {
+ var kas1 = new Config.KASInfo();
+ kas1.URL = "https://kas1.example.com";
+ kas1.KID = "kid1";
+ kas1.Algorithm = "rsa:2048";
+
+ var kas2 = new Config.KASInfo();
+ kas2.URL = "https://kas2.example.com";
+ kas2.KID = "kid2";
+ kas2.Algorithm = "ec:secp256";
+
+ var tdfConfig = new Config.TDFConfig();
+ tdfConfig.kasInfoList.add(kas1);
+ tdfConfig.kasInfoList.add(kas2);
+
+ var planner = new Planner(tdfConfig, new FakeServicesBuilder().build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); });
+ List splitPlan = planner.generatePlanFromProvidedKases(tdfConfig.kasInfoList);
+
+ assertThat(splitPlan).asList().hasSize(2);
+ assertThat(splitPlan.get(0).kas).isEqualTo("https://kas1.example.com");
+ assertThat(splitPlan.get(0).kid).isEqualTo("kid1");
+
+ assertThat(splitPlan.get(1).kas).isEqualTo("https://kas2.example.com");
+ assertThat(splitPlan.get(1).kid).isEqualTo("kid2");
+
+ assertThat(splitPlan.get(0).splitID).isNotEqualTo(splitPlan.get(1).splitID);
+ }
+
+ @Test
+ void testFillingInKeysWithAutoConfigure() {
+ var kas = Mockito.mock(SDK.KAS.class);
+ Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> {
+ Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class);
+ var ret = new Config.KASInfo();
+ ret.URL = kasInfo.URL;
+ if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) {
+ ret.PublicKey = "pem1";
+ ret.Algorithm = "rsa:2048";
+ ret.KID = "kid1";
+ } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) {
+ ret.PublicKey = "pem2";
+ ret.Algorithm = "ec:secp256r1";
+ ret.KID = "kid2";
+ } else if (Objects.equals(kasInfo.URL, "https://kas3.example.com")) {
+ ret.PublicKey = "pem3";
+ ret.Algorithm = "rsa:4096";
+ ret.KID = "kid3";
+ } else {
+ throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL);
+ }
+ return ret;
+ });
+ var tdfConfig = new Config.TDFConfig();
+ tdfConfig.autoconfigure = true;
+ tdfConfig.wrappingKeyType = KeyType.RSA2048Key;
+ tdfConfig.kasInfoList = List.of(
+ new Config.KASInfo() {{
+ URL = "https://kas4.example.com";
+ KID = "kid4";
+ Algorithm = "ec:secp384r1";
+ PublicKey = "pem4";
+ }}
+ );
+ var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); });
+ var plan = List.of(
+ new Autoconfigure.KeySplitStep("https://kas1.example.com", "split1", null),
+ new Autoconfigure.KeySplitStep("https://kas4.example.com", "split1", "kid4"),
+ new Autoconfigure.KeySplitStep("https://kas2.example.com", "split2", "kid2"),
+ new Autoconfigure.KeySplitStep("https://kas3.example.com", "split2", "kid3")
+ );
+ Map> filledInPlan = planner.resolveKeys(plan);
+ assertThat(filledInPlan.keySet().stream().collect(Collectors.toList())).asList().containsExactlyInAnyOrder("split1", "split2");
+ assertThat(filledInPlan.get("split1")).asList().hasSize(2);
+ var kasInfo1 = filledInPlan.get("split1").stream().filter(k -> "kid1".equals(k.KID)).findFirst().get();
+ assertThat(kasInfo1.URL).isEqualTo("https://kas1.example.com");
+ assertThat(kasInfo1.Algorithm).isEqualTo("rsa:2048");
+ assertThat(kasInfo1.PublicKey).isEqualTo("pem1");
+ var kasInfo4 = filledInPlan.get("split1").stream().filter(k -> "kid4".equals(k.KID)).findFirst().get();
+ assertThat(kasInfo4.URL).isEqualTo("https://kas4.example.com");
+ assertThat(kasInfo4.Algorithm).isEqualTo("ec:secp384r1");
+ assertThat(kasInfo4.PublicKey).isEqualTo("pem4");
+
+ assertThat(filledInPlan.get("split2")).asList().hasSize(2);
+ var kasInfo2 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid2".equals(kasInfo.KID)).findFirst().get();
+ assertThat(kasInfo2.URL).isEqualTo("https://kas2.example.com");
+ assertThat(kasInfo2.Algorithm).isEqualTo("ec:secp256r1");
+ assertThat(kasInfo2.PublicKey).isEqualTo("pem2");
+ var kasInfo3 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid3".equals(kasInfo.KID)).findFirst().get();
+ assertThat(kasInfo3.URL).isEqualTo("https://kas3.example.com");
+ assertThat(kasInfo3.Algorithm).isEqualTo("rsa:4096");
+ assertThat(kasInfo3.PublicKey).isEqualTo("pem3");
+ }
+
+ @Test
+ void returnsOnlyDefaultKasesIfPresent() {
+ var kas1 = new Config.KASInfo();
+ kas1.URL = "https://kas1.example.com";
+ kas1.Default = true;
+
+ var kas2 = new Config.KASInfo();
+ kas2.URL = "https://kas2.example.com";
+ kas2.Default = false;
+
+ var kas3 = new Config.KASInfo();
+ kas3.URL = "https://kas3.example.com";
+ kas3.Default = true;
+
+ var config = new Config.TDFConfig();
+ config.kasInfoList.addAll(List.of(kas1, kas2, kas3));
+
+ List result = Planner.defaultKases(config);
+
+ Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas3.example.com");
+ }
+
+ @Test
+ void returnsAllKasesIfNoDefault() {
+ var kas1 = new Config.KASInfo();
+ kas1.URL = "https://kas1.example.com";
+ kas1.Default = false;
+
+ var kas2 = new Config.KASInfo();
+ kas2.URL = "https://kas2.example.com";
+ kas2.Default = null; // not set
+
+ var config = new Config.TDFConfig();
+ config.kasInfoList.addAll(List.of(kas1, kas2));
+
+ List result = Planner.defaultKases(config);
+ Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas2.example.com");
+ }
+
+ @Test
+ void returnsEmptyListIfNoKases() {
+ var config = new Config.TDFConfig();
+ List result = Planner.defaultKases(config);
+ Assertions.assertThat(result).isEmpty();
+ }
+
+ @Test
+ void usesProvidedSplitPlanWhenNotAutoconfigure() {
+ var kas = Mockito.mock(SDK.KAS.class);
+ Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> {
+ Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class);
+ var ret = new Config.KASInfo();
+ ret.URL = kasInfo.URL;
+ if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) {
+ ret.PublicKey = "pem1";
+ ret.Algorithm = "rsa:2048";
+ ret.KID = "kid1";
+ } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) {
+ ret.PublicKey = "pem2";
+ ret.Algorithm = "ec:secp256r1";
+ ret.KID = "kid2";
+ } else {
+ throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL);
+ }
+ return ret;
+ });
+ // Arrange
+ var kas1 = new Config.KASInfo();
+ kas1.URL = "https://kas1.example.com";
+ kas1.KID = "kid1";
+ kas1.Algorithm = "rsa:2048";
+
+ var kas2 = new Config.KASInfo();
+ kas2.URL = "https://kas2.example.com";
+ kas2.KID = "kid2";
+ kas2.Algorithm = "ec:secp256";
+
+ var splitStep1 = new Autoconfigure.KeySplitStep(kas1.URL, "split1", kas1.KID);
+ var splitStep2 = new Autoconfigure.KeySplitStep(kas2.URL, "split2", kas2.KID);
+
+ var tdfConfig = new Config.TDFConfig();
+ tdfConfig.autoconfigure = false;
+ tdfConfig.kasInfoList.add(kas1);
+ tdfConfig.kasInfoList.add(kas2);
+ tdfConfig.splitPlan = List.of(splitStep1, splitStep2);
+
+ var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); });
+
+ // Act
+ Map> splits = planner.getSplits(tdfConfig);
+
+ // Assert
+ Assertions.assertThat(splits).hasSize(2);
+ Assertions.assertThat(splits.get("split1")).extracting("URL").containsExactly("https://kas1.example.com");
+ Assertions.assertThat(splits.get("split2")).extracting("URL").containsExactly("https://kas2.example.com");
+ }
+}
\ No newline at end of file
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java
new file mode 100644
index 00000000..53076007
--- /dev/null
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java
@@ -0,0 +1,22 @@
+package io.opentdf.platform.sdk;
+
+import com.connectrpc.ResponseMessage;
+import com.connectrpc.UnaryBlockingCall;
+
+import java.util.Collections;
+
+public class TestUtil {
+ static UnaryBlockingCall successfulUnaryCall(T result) {
+ return new UnaryBlockingCall() {
+ @Override
+ public ResponseMessage execute() {
+ return new ResponseMessage.Success<>(result, Collections.emptyMap(), Collections.emptyMap());
+ }
+
+ @Override
+ public void cancel() {
+ // in tests we don't need to preserve server resources, so no-op
+ }
+ };
+ }
+}