diff --git a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java index 0f0ab735..f9479577 100644 --- a/examples/src/main/java/io/opentdf/platform/GetEntitlements.java +++ b/examples/src/main/java/io/opentdf/platform/GetEntitlements.java @@ -7,7 +7,6 @@ import io.opentdf.platform.sdk.*; import java.util.Collections; -import java.util.concurrent.ExecutionException; import java.util.List; diff --git a/sdk/pom.xml b/sdk/pom.xml index 795d34ba..a17f99ef 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -15,7 +15,7 @@ 2.1.0 0.7.2 4.12.0 - protocol/go/v0.3.0 + protocol/go/v0.5.0 diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java b/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java index 71445f69..96067c07 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AesGcm.java @@ -152,10 +152,9 @@ public byte[] encrypt(byte[] iv, int authTagLen, byte[] plaintext, int offset, i System.arraycopy(iv, 0, cipherTextWithNonce, 0, iv.length); System.arraycopy(cipherText, 0, cipherTextWithNonce, iv.length, cipherText.length); return cipherTextWithNonce; - } catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidAlgorithmParameterException e) { - throw new RuntimeException("error gcm decrypt", e); - } catch (InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) { - throw new RuntimeException("error gcm decrypt", e); + } catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidAlgorithmParameterException | + InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) { + throw new RuntimeException("error gcm encrypt", e); } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java index b60d0725..8a50cc03 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Autoconfigure.java @@ -1,12 +1,13 @@ package io.opentdf.platform.sdk; import com.connectrpc.ResponseMessageKt; +import io.opentdf.platform.policy.Algorithm; import io.opentdf.platform.policy.Attribute; import io.opentdf.platform.policy.AttributeRuleTypeEnum; import io.opentdf.platform.policy.AttributeValueSelector; -import io.opentdf.platform.policy.KasPublicKey; import io.opentdf.platform.policy.KasPublicKeyAlgEnum; import io.opentdf.platform.policy.KeyAccessServer; +import io.opentdf.platform.policy.SimpleKasKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest; @@ -14,6 +15,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; @@ -27,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.StringJoiner; import java.util.function.Supplier; @@ -53,47 +56,53 @@ class RuleType { * This class includes functionality to create granter instances based on * attributes either from a list of attribute values or from a service. */ -public class Autoconfigure { +class Autoconfigure { - public static Logger logger = LoggerFactory.getLogger(Autoconfigure.class); + private static Logger logger = LoggerFactory.getLogger(Autoconfigure.class); - public static class KeySplitStep { - public String kas; - public String splitID; + private Autoconfigure() { + // Prevent instantiation, this class is a utility class that is only used statically + } + + static class KeySplitStep { + final String kas; + final String splitID; + final String kid; - public KeySplitStep(String kas, String splitId) { - this.kas = kas; - this.splitID = splitId; + KeySplitStep(String kas, String splitId) { + this(kas, splitId, null); } - @Override - public String toString() { - return "KeySplitStep{kas=" + this.kas + ", splitID=" + this.splitID + "}"; + KeySplitStep(String kas, String splitId, @Nullable String kid) { + this.kas = Objects.requireNonNull(kas); + this.splitID = Objects.requireNonNull(splitId); + this.kid = kid; } @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || !(obj instanceof KeySplitStep)) { - return false; - } - KeySplitStep ss = (KeySplitStep) obj; - if ((this.kas.equals(ss.kas)) && (this.splitID.equals(ss.splitID))) { - return true; - } - return false; + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + KeySplitStep that = (KeySplitStep) o; + return Objects.equals(kas, that.kas) && Objects.equals(splitID, that.splitID) && Objects.equals(kid, that.kid); } @Override public int hashCode() { - return Objects.hash(kas, splitID); + return Objects.hash(kas, splitID, kid); + } + + @Override + public String toString() { + return "KeySplitStep{" + + "kas='" + kas + '\'' + + ", splitID='" + splitID + '\'' + + ", kid='" + kid + '\'' + + '}'; } } // Utility class for an attribute name FQN. - public static class AttributeNameFQN { + static class AttributeNameFQN { private final String url; private final String key; @@ -156,7 +165,7 @@ public String name() throws AutoConfigureException { } // Utility class for an attribute value FQN. - public static class AttributeValueFQN { + static class AttributeValueFQN { private final String url; private final String key; @@ -203,11 +212,11 @@ public int hashCode() { return Objects.hash(key); } - public String getKey() { + String getKey() { return key; } - public String authority() { + String authority() { Pattern pattern = Pattern.compile("^(https?://[\\w./-]+)/attr/\\S*/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -216,7 +225,7 @@ public String authority() { return matcher.group(1); } - public AttributeNameFQN prefix() throws AutoConfigureException { + AttributeNameFQN prefix() throws AutoConfigureException { Pattern pattern = Pattern.compile("^(https?://[\\w./-]+/attr/\\S*)/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -225,7 +234,7 @@ public AttributeNameFQN prefix() throws AutoConfigureException { return new AttributeNameFQN(matcher.group(1)); } - public String value() { + String value() { Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/\\S*/value/(\\S*)$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -238,7 +247,7 @@ public String value() { } } - public String name() { + String name() { Pattern pattern = Pattern.compile("^https?://[\\w./-]+/attr/(\\S*)/value/\\S*$"); Matcher matcher = pattern.matcher(url); if (!matcher.find()) { @@ -246,13 +255,13 @@ public String name() { } try { return URLDecoder.decode(matcher.group(1), StandardCharsets.UTF_8.name()); - } catch (UnsupportedEncodingException | IllegalArgumentException e) { - throw new RuntimeException("invalid attributeInstance", e); + } catch (UnsupportedEncodingException | IllegalArgumentException e) { + throw new RuntimeException("illegal attribute instance", e); } } } - public static class KeyAccessGrant { + static class KeyAccessGrant { public Attribute attr; public List kases; @@ -266,40 +275,103 @@ public KeyAccessGrant(Attribute attr, List kases) { static class Granter { private final List policy; private final Map grants = new HashMap<>(); + private final Map> mappedKeys = new HashMap<>(); + private boolean hasGrants = false; + private boolean hasMappedKeys = false; - public Granter(List policy) { + Granter(List policy) { this.policy = policy; } - public Map getGrants() { - return new HashMap(grants); + Map getGrants() { + return new HashMap<>(grants); } - public List getPolicy() { + List getPolicy() { return policy; } - public void addGrant(AttributeValueFQN fqn, String kas, Attribute attr) { - grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(kas); - } + boolean addAllGrants(AttributeValueFQN fqn, List granted, List mapped, Attribute attr) { + boolean foundMappedKey = false; + for (var mappedKey: mapped) { + foundMappedKey = true; + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(mappedKey)); + grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(mappedKey.getKasUri()); + } - public void addAllGrants(AttributeValueFQN fqn, List gs, Attribute attr) { - if (gs.isEmpty()) { - grants.putIfAbsent(fqn.key, new KeyAccessGrant(attr, new ArrayList<>())); - } else { - for (KeyAccessServer g : gs) { - if (g != null) { - addGrant(fqn, g.getUri(), attr); + if (foundMappedKey) { + hasMappedKeys = true; + return true; + } + + boolean foundGrantedKey = false; + for (var grantedKey: granted) { + foundGrantedKey = true; + grants.computeIfAbsent(fqn.key, k -> new KeyAccessGrant(attr, new ArrayList<>())).kases.add(grantedKey.getUri()); + if (!grantedKey.getKasKeysList().isEmpty()) { + for (var kas : grantedKey.getKasKeysList()) { + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(Config.KASInfo.fromSimpleKasKey(kas)); } + continue; } + var cachedGrantKeys = grantedKey.getPublicKey().getCached().getKeysList(); + + if (logger.isDebugEnabled()) { + logger.debug("found {} keys cached in policy service", cachedGrantKeys.size()); + } + + for (var cachedGrantKey: cachedGrantKeys) { + var mappedKey = new Config.KASInfo(); + mappedKey.URL = grantedKey.getUri(); + mappedKey.KID = cachedGrantKey.getKid(); + mappedKey.Algorithm = Autoconfigure.algProto2String(cachedGrantKey.getAlg()); + mappedKey.PublicKey = cachedGrantKey.getPem(); + mappedKey.Default = false; + mappedKeys.computeIfAbsent(fqn.key, k -> new ArrayList<>()).add(mappedKey); + } + } + + if (!grants.containsKey(fqn.key)) { + grants.put(fqn.key, new KeyAccessGrant(attr, new ArrayList<>())); + } + + if (foundGrantedKey) { + hasGrants = true; } + return foundGrantedKey; } - public KeyAccessGrant byAttribute(AttributeValueFQN fqn) { + KeyAccessGrant byAttribute(AttributeValueFQN fqn) { return grants.get(fqn.key); } - public List plan(List defaultKas, Supplier genSplitID) + List getSplits(List defaultKases, Supplier genSplitID, Supplier> baseKeySupplier) throws AutoConfigureException { + if (hasMappedKeys) { + logger.debug("generating plan from mapped keys"); + return planFromAttributes(genSplitID); + } + if (hasGrants) { + logger.debug("generating plan from grants"); + return plan(genSplitID); + } + + var baseKey = baseKeySupplier.get(); + if (baseKey.isPresent()) { + var key = baseKey.get(); + String kas = key.getKasUri(); + String splitID = ""; + String kid = key.getPublicKey().getKid(); + return Collections.singletonList(new KeySplitStep(kas, splitID, kid)); + } + + logger.warn("no grants or mapped keys found, generating plan from default KASes. this is deprecated"); + // this is a little bit weird because we don't take into account the KIDs here. This is the way + // that it works in the go SDK but it seems a bit odd + return generatePlanFromDefaultKases(defaultKases, genSplitID); + } + + @Nonnull + List plan(Supplier genSplitID) throws AutoConfigureException { AttributeBooleanExpression b = constructAttributeBoolean(); BooleanKeyExpression k = insertKeysForAttribute(b); @@ -310,18 +382,7 @@ public List plan(List defaultKas, Supplier genSpli k = k.reduce(); int l = k.size(); if (l == 0) { - // default behavior: split key across all default KAS - if (defaultKas.isEmpty()) { - throw new AutoConfigureException("no default KAS specified; required for grantless plans"); - } else if (defaultKas.size() == 1) { - return Collections.singletonList(new KeySplitStep(defaultKas.get(0), "")); - } else { - List result = new ArrayList<>(); - for (String kas : defaultKas) { - result.add(new KeySplitStep(kas, genSplitID.get())); - } - return result; - } + throw new AutoConfigureException("generated an empty plan"); } List steps = new ArrayList<>(); @@ -334,6 +395,45 @@ public List plan(List defaultKas, Supplier genSpli return steps; } + @Nonnull + List planFromAttributes(Supplier genSplitID) + throws AutoConfigureException { + AttributeBooleanExpression b = constructAttributeBoolean(); + BooleanKeyExpression k = assignKeysTo(b); + if (k == null) { + throw new AutoConfigureException("Error assigning keys to attribute"); + } + + k = k.reduce(); + int l = k.size(); + if (l == 0) { + return Collections.emptyList(); + } + + List steps = new ArrayList<>(); + for (KeyClause v : k.values) { + String splitID = (l > 1) ? genSplitID.get() : ""; + for (PublicKeyInfo o : v.values) { + steps.add(new KeySplitStep(o.kas, splitID, o.kid)); + } + } + return steps; + } + + static List generatePlanFromDefaultKases(List defaultKas, Supplier genSplitID) { + if (defaultKas.isEmpty()) { + throw new AutoConfigureException("no default KAS specified; required for grantless plans"); + } else if (defaultKas.size() == 1) { + return Collections.singletonList(new KeySplitStep(defaultKas.get(0), "")); + } else { + List result = new ArrayList<>(); + for (String kas : defaultKas) { + result.add(new KeySplitStep(kas, genSplitID.get())); + } + return result; + } + } + BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws AutoConfigureException { List kcs = new ArrayList<>(e.must.size()); @@ -361,13 +461,53 @@ BooleanKeyExpression insertKeysForAttribute(AttributeBooleanExpression e) throws logger.warn("Unknown attribute rule type: " + clause); } - KeyClause kc = new KeyClause(op, kcv); - kcs.add(kc); + kcs.add(new KeyClause(op, kcv)); } return new BooleanKeyExpression(kcs); } + BooleanKeyExpression assignKeysTo(AttributeBooleanExpression e) { + var keyClauses = new ArrayList(); + for (var clause : e.must) { + ArrayList keys = new ArrayList<>(); + if (clause.values.isEmpty()) { + logger.warn("No values found for attribute {}", clause.def.getFqn()); + continue; + } + for (var value : clause.values) { + var mapped = mappedKeys.get(value.key); + if (mapped == null) { + logger.warn("No keys found for attribute value {}", value); + continue; + } + for (var kasInfo : mapped) { + if (kasInfo.URL == null || kasInfo.URL.isEmpty()) { + logger.warn("No KAS URL found for attribute value {}", value); + continue; + } + keys.add(new PublicKeyInfo(kasInfo.URL, kasInfo.KID)); + } + } + + String op = ruleToOperator(clause.def.getRule()); + if (op.equals(RuleType.UNSPECIFIED)) { + logger.warn("Unknown attribute rule type {}", op); + } + + keyClauses.add(new KeyClause(op, keys)); + } + + return new BooleanKeyExpression(keyClauses); + } + + /** + * Constructs an AttributeBooleanExpression from the policy, splitting each attribute + * into its own clause. Each clause contains the attribute definition and a list of + * values. + * @return + * @throws AutoConfigureException + */ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureException { Map prefixes = new HashMap<>(); List sortedPrefixes = new ArrayList<>(); @@ -378,7 +518,7 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep clause.values.add(aP); } else if (byAttribute(aP) != null) { var x = new SingleAttributeClause(byAttribute(aP).attr, - new ArrayList(Arrays.asList(aP))); + new ArrayList<>(Arrays.asList(aP))); prefixes.put(a.getKey(), x); sortedPrefixes.add(a.getKey()); } @@ -391,39 +531,6 @@ AttributeBooleanExpression constructAttributeBoolean() throws AutoConfigureExcep return new AttributeBooleanExpression(must); } - static class AttributeMapping { - - private Map dict; - - public AttributeMapping() { - this.dict = new HashMap<>(); - } - - public void put(Attribute ad) throws AutoConfigureException { - if (this.dict == null) { - this.dict = new HashMap<>(); - } - - AttributeNameFQN prefix = new AttributeNameFQN(ad.getFqn()); - - if (this.dict.containsKey(prefix)) { - throw new AutoConfigureException("Attribute prefix already found: [" + prefix.toString() + "]"); - } - - this.dict.put(prefix, ad); - } - - public Attribute get(AttributeNameFQN prefix) throws AutoConfigureException { - Attribute ad = this.dict.get(prefix); - if (ad == null) { - throw new AutoConfigureException("Unknown attribute type: [" + prefix.toString() + "], not in [" - + this.dict.keySet().toString() + "]"); - } - return ad; - } - - } - static class SingleAttributeClause { private Attribute def; @@ -435,9 +542,9 @@ public SingleAttributeClause(Attribute def, List values) { } } - class AttributeBooleanExpression { + static class AttributeBooleanExpression { - private List must; + private final List must; public AttributeBooleanExpression(List must) { this.must = must; @@ -478,25 +585,65 @@ public String toString() { } - public class PublicKeyInfo { - private String kas; + static class PublicKeyInfo implements Comparable { + final String kas; + final String kid; + + PublicKeyInfo(String kas) { + this(kas, null); + } - public PublicKeyInfo(String kas) { + PublicKeyInfo(String kas, String kid) { this.kas = kas; + this.kid = kid; } - public String getKas() { + String getKas() { return kas; } - public void setKas(String kas) { - this.kas = kas; + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + PublicKeyInfo that = (PublicKeyInfo) o; + return Objects.equals(kas, that.kas) && Objects.equals(kid, that.kid); + } + + @Override + public int hashCode() { + return Objects.hash(kas, kid); + } + + @Override + public String toString() { + return "PublicKeyInfo{" + + "kas='" + kas + '\'' + + ", kid='" + kid + '\'' + + '}'; + } + + @Override + public int compareTo(PublicKeyInfo o) { + if (this.kas.equals(o.kas)) { + if (this.kid == null && o.kid == null) { + return 0; + } + if (this.kid == null) { + return -1; + } + if (o.kid == null) { + return 1; + } + return this.kid.compareTo(o.kid); + } else { + return this.kas.compareTo(o.kas); + } } } - public class KeyClause { - private String operator; - private List values; + static class KeyClause { + private final String operator; + private final List values; public KeyClause(String operator, List values) { this.operator = operator; @@ -531,8 +678,8 @@ public String toString() { } } - public class BooleanKeyExpression { - private List values; + static class BooleanKeyExpression { + private final List values; public BooleanKeyExpression(List values) { this.values = values; @@ -572,7 +719,7 @@ public BooleanKeyExpression reduce() { continue; } Disjunction terms = new Disjunction(); - terms.add(k.getKas()); + terms.add(k); if (!within(conjunction, terms)) { conjunction.add(terms); } @@ -584,25 +731,22 @@ public BooleanKeyExpression reduce() { } List newValues = new ArrayList<>(); - for (List d : conjunction) { + for (List d : conjunction) { List pki = new ArrayList<>(); - for (String k : d) { - pki.add(new PublicKeyInfo(k)); - } + pki.addAll(d); newValues.add(new KeyClause(RuleType.ANY_OF, pki)); } return new BooleanKeyExpression(newValues); } public Disjunction sortedNoDupes(List l) { - Set set = new HashSet<>(); + Set set = new HashSet<>(); Disjunction list = new Disjunction(); for (PublicKeyInfo e : l) { - String kas = e.getKas(); - if (!kas.equals(RuleType.EMPTY_TERM) && !set.contains(kas)) { - set.add(kas); - list.add(kas); + if (!Objects.equals(e.getKas(), RuleType.EMPTY_TERM) && !set.contains(e)) { + set.add(e); + list.add(e); } } @@ -612,7 +756,7 @@ public Disjunction sortedNoDupes(List l) { } - class Disjunction extends ArrayList { + static class Disjunction extends ArrayList { public boolean less(Disjunction r) { int m = Math.min(this.size(), r.size()); @@ -683,7 +827,7 @@ public static String ruleToOperator(AttributeRuleTypeEnum e) { // Given a policy (list of data attributes or tags), // get a set of grants from attribute values to KASes. // Unlike `NewGranterFromService`, this works offline. - public static Granter newGranterFromAttributes(Value... attrValues) throws AutoConfigureException { + static Granter newGranterFromAttributes(KASKeyCache keyCache, Value... attrValues) throws AutoConfigureException { var attrsAndValues = Arrays.stream(attrValues).map(v -> { if (!v.hasAttribute()) { throw new AutoConfigureException("tried to use an attribute that is not initialized"); @@ -694,11 +838,11 @@ public static Granter newGranterFromAttributes(Value... attrValues) throws AutoC .build(); }).collect(Collectors.toList()); - return getGranter(null, attrsAndValues); + return getGranter(keyCache, attrsAndValues); } // Gets a list of directory of KAS grants for a list of attribute FQNs - public static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException { + static Granter newGranterFromService(AttributesServiceClientInterface as, KASKeyCache keyCache, AttributeValueFQN... fqns) throws AutoConfigureException { GetAttributeValuesByFqnsRequest request = GetAttributeValuesByFqnsRequest.newBuilder() .addAllFqns(Arrays.stream(fqns).map(AttributeValueFQN::toString).collect(Collectors.toList())) .setWithValue(AttributeValueSelector.newBuilder().setWithKeyAccessGrants(true).build()) @@ -711,73 +855,65 @@ public static Granter newGranterFromService(AttributesServiceClientInterface as, return getGranter(keyCache, new ArrayList<>(av.getFqnAttributeValuesMap().values())); } - private static List getGrants(GetAttributeValuesByFqnsResponse.AttributeAndValue attributeAndValue) { - var val = attributeAndValue.getValue(); - var attribute = attributeAndValue.getAttribute(); - if (!val.getGrantsList().isEmpty()) { - if (logger.isDebugEnabled()) { - logger.debug("adding grants from attribute value [{}]: {}", val.getFqn(), val.getGrantsList().stream().map(KeyAccessServer::getUri).collect(Collectors.toList())); - } - return val.getGrantsList(); - } else if (!attribute.getGrantsList().isEmpty()) { - var attributeGrants = attribute.getGrantsList(); - if (logger.isDebugEnabled()) { - logger.debug("adding grants from attribute [{}]: {}", attribute.getFqn(), attributeGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList())); - } - return attributeGrants; - } else if (!attribute.getNamespace().getGrantsList().isEmpty()) { - var nsGrants = attribute.getNamespace().getGrantsList(); - if (logger.isDebugEnabled()) { - logger.debug("adding grants from namespace [{}]: [{}]", attribute.getNamespace().getName(), nsGrants.stream().map(KeyAccessServer::getId).collect(Collectors.toList())); - } - return nsGrants; - } else { - // this is needed to mark the fact that we have an empty - if (logger.isDebugEnabled()) { - logger.debug("didn't find any grants on value, attribute, or namespace for attribute value [{}]", val.getFqn()); - } - return Collections.emptyList(); + static Autoconfigure.Granter createGranter(SDK.Services services, Config.TDFConfig tdfConfig) { + Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>()); + if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) { + granter = Autoconfigure.newGranterFromAttributes(services.kas().getKeyCache(), tdfConfig.attributeValues.toArray(new Value[0])); + } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) { + granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(), + tdfConfig.attributes.toArray(new Autoconfigure.AttributeValueFQN[0])); } - + return granter; } private static Granter getGranter(@Nullable KASKeyCache keyCache, List values) { - Granter grants = new Granter(values.stream().map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue).map(Value::getFqn).map(AttributeValueFQN::new).collect(Collectors.toList())); + List attributeValues = values.stream() + .map(GetAttributeValuesByFqnsResponse.AttributeAndValue::getValue) + .map(Value::getFqn) + .map(AttributeValueFQN::new) + .collect(Collectors.toList()); + Granter grants = new Granter(attributeValues); for (var attributeAndValue: values) { - var attributeGrants = getGrants(attributeAndValue); String fqnstr = attributeAndValue.getValue().getFqn(); AttributeValueFQN fqn = new AttributeValueFQN(fqnstr); - grants.addAllGrants(fqn, attributeGrants, attributeAndValue.getAttribute()); - if (keyCache != null) { - storeKeysToCache(attributeGrants, keyCache); + + var value = attributeAndValue.getValue(); + var attribute = attributeAndValue.getAttribute(); + var namespace = attribute.getNamespace(); + + if (grants.addAllGrants(fqn, value.getGrantsList(), value.getKasKeysList(), attribute)) { + storeKeysToCache(value.getGrantsList(), value.getKasKeysList(), keyCache); + continue; + } + if (grants.addAllGrants(fqn, attribute.getGrantsList(), attribute.getKasKeysList(), attribute)) { + storeKeysToCache(attribute.getGrantsList(), attribute.getKasKeysList(), keyCache); + continue; + } + if (grants.addAllGrants(fqn, namespace.getGrantsList(), namespace.getKasKeysList(), attribute)) { + storeKeysToCache(namespace.getGrantsList(), namespace.getKasKeysList(), keyCache); } } return grants; } - - static void storeKeysToCache(List kases, KASKeyCache keyCache) { - for (KeyAccessServer kas : kases) { - List keys = kas.getPublicKey().getCached().getKeysList(); - if (keys.isEmpty()) { - logger.debug("No cached key in policy service for KAS: " + kas.getUri()); - continue; - } - for (KasPublicKey ki : keys) { - Config.KASInfo kasInfo = new Config.KASInfo(); - kasInfo.URL = kas.getUri(); - kasInfo.KID = ki.getKid(); - kasInfo.Algorithm = algProto2String(ki.getAlg()); - kasInfo.PublicKey = ki.getPem(); - keyCache.store(kasInfo); - } + static void storeKeysToCache(List kases, List kasKeys, KASKeyCache keyCache) { + if (keyCache == null) { + return; + } + for (var kas : kases) { + Config.KASInfo.fromKeyAccessServer(kas).forEach(keyCache::store); } + kasKeys.stream().map(Config.KASInfo::fromSimpleKasKey).forEach(keyCache::store); + } + + static String algProto2String(Algorithm e) { + return KeyType.fromAlgorithm(e).getCurveName(); } - private static String algProto2String(KasPublicKeyAlgEnum e) { + static String algProto2String(KasPublicKeyAlgEnum e) { switch (e) { case KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1: return "ec:secp256r1"; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java index 0a1f4a46..34e4fa27 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java @@ -1,13 +1,20 @@ package io.opentdf.platform.sdk; +import io.opentdf.platform.policy.KeyAccessServer; +import io.opentdf.platform.policy.SimpleKasKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.net.URI; import java.net.URISyntaxException; import java.util.*; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static io.opentdf.platform.sdk.Autoconfigure.algProto2String; /** * Configuration class for setting various configurations related to TDF. @@ -22,6 +29,7 @@ public class Config { public static final String KAS_PUBLIC_KEY_PATH = "/kas_public_key"; public static final String DEFAULT_MIME_TYPE = "application/octet-stream"; public static final int MAX_COLLECTION_ITERATION = (1 << 24) - 1; + private static Logger logger = LoggerFactory.getLogger(Config.class); public enum TDFFormat { JSONFormat, @@ -33,8 +41,6 @@ public enum IntegrityAlgorithm { GMAC } - public static final int K_HTTP_OK = 200; - public static class KASInfo implements Cloneable { public String URL; public String PublicKey; @@ -71,6 +77,36 @@ public String toString() { } return sb.append("}").toString(); } + + public static List fromKeyAccessServer(KeyAccessServer kas) { + var keys = kas.getPublicKey().getCached().getKeysList(); + if (keys.isEmpty()) { + logger.warn("Invalid KAS key mapping for kas [{}]: publicKey is empty", kas.getUri()); + return Collections.emptyList(); + } + return keys.stream().flatMap(ki -> { + if (ki.getPem().isEmpty()) { + logger.warn("Invalid KAS key mapping for kas [{}]: publicKey PEM is empty", kas.getUri()); + return Stream.empty(); + } + Config.KASInfo kasInfo = new Config.KASInfo(); + kasInfo.URL = kas.getUri(); + kasInfo.KID = ki.getKid(); + kasInfo.Algorithm = algProto2String(ki.getAlg()); + kasInfo.PublicKey = ki.getPem(); + return Stream.of(kasInfo); + }).collect(Collectors.toList()); + } + + public static KASInfo fromSimpleKasKey(SimpleKasKey ki) { + Config.KASInfo kasInfo = new Config.KASInfo(); + kasInfo.URL = ki.getKasUri(); + kasInfo.KID = ki.getPublicKey().getKid(); + kasInfo.Algorithm = algProto2String(ki.getPublicKey().getAlgorithm()); + kasInfo.PublicKey = ki.getPublicKey().getPem(); + + return kasInfo; + } } public static class AssertionVerificationKeys { diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java b/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java index ce000d54..bcb8b208 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/ECCMode.java @@ -1,5 +1,7 @@ package io.opentdf.platform.sdk; +import java.util.Objects; + public class ECCMode { private ECCModeStruct data; @@ -115,6 +117,18 @@ public static int getECCompressedPubKeySize(NanoTDFType.ECCurve curve) { } } + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ECCMode eccMode = (ECCMode) o; + return Objects.equals(data, eccMode.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(data); + } + private class ECCModeStruct { int curveMode; int unused; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java b/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java index 53095bd7..88680efc 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/ECKeyPair.java @@ -197,18 +197,12 @@ public static byte[] compressECPublickey(String pemECPubKey) { PemObject pemObject = pemReader.readPemObject(); PublicKey pubKey = ecKeyFac.generatePublic(new X509EncodedKeySpec(pemObject.getContent())); return ((ECPublicKey) pubKey).getQ().getEncoded(true); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e); - } catch (IOException e) { - throw new RuntimeException(e); - } catch (InvalidKeySpecException e) { - throw new RuntimeException(e); - } catch (NoSuchProviderException e) { + } catch (NoSuchAlgorithmException | IOException | InvalidKeySpecException | NoSuchProviderException e) { throw new RuntimeException(e); } } - public static String publicKeyFromECPoint(byte[] ecPoint, String curveName) { + public static String publicKeyFromECPoint(byte[] ecPoint, String curveName) throws RuntimeException { try { // Create EC Public key ECNamedCurveParameterSpec ecSpec = ECNamedCurveTable.getParameterSpec(curveName); @@ -274,7 +268,7 @@ public static byte[] computeECDHKey(ECPublicKey publicKey, ECPrivateKey privateK } public static byte[] calculateHKDF(byte[] salt, byte[] secret) { - byte[] key = new byte[secret.length]; + byte[] key = new byte[32]; HKDFParameters params = new HKDFParameters(secret, salt, null); HKDFBytesGenerator hkdf = new HKDFBytesGenerator(SHA256Digest.newInstance()); diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java index bafc3f2c..bfa1ffdd 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java @@ -85,7 +85,7 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) @Override public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { - Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm); + Config.KASInfo cachedValue = this.kasKeyCache.get(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID); if (cachedValue != null) { return cachedValue; } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java index 5879dd05..75bae93c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASKeyCache.java @@ -7,6 +7,7 @@ import java.time.temporal.ChronoUnit; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Class representing a cache for KAS (Key Access Server) information. @@ -24,14 +25,14 @@ public void clear() { this.cache = new HashMap<>(); } - public Config.KASInfo get(String url, String algorithm) { - log.debug("retrieving kasinfo for url = [{}], algorithm = [{}]", url, algorithm); - KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm); + public Config.KASInfo get(String url, String algorithm, String kid) { + log.debug("retrieving kasinfo for url = [{}], algorithm = [{}], kid = [{}]", url, algorithm, kid); + KASKeyRequest cacheKey = new KASKeyRequest(url, algorithm, kid); LocalDateTime now = LocalDateTime.now(); TimeStampedKASInfo cachedValue = cache.get(cacheKey); if (cachedValue == null) { - log.debug("didn't find kasinfo for url = [{}], algorithm = [{}]", url, algorithm); + log.debug("didn't find kasinfo for key= [{}]", cacheKey); return null; } @@ -49,7 +50,7 @@ public Config.KASInfo get(String url, String algorithm) { public void store(Config.KASInfo kasInfo) { log.debug("storing kasInfo into the cache {}", kasInfo); - KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm); + KASKeyRequest cacheKey = new KASKeyRequest(kasInfo.URL, kasInfo.Algorithm, kasInfo.KID); cache.put(cacheKey, new TimeStampedKASInfo(kasInfo, LocalDateTime.now())); } } @@ -85,30 +86,34 @@ public TimeStampedKASInfo(Config.KASInfo kasInfo, LocalDateTime timestamp) { class KASKeyRequest { private String url; private String algorithm; + private String kid; - public KASKeyRequest(String url, String algorithm) { + public KASKeyRequest(String url, String algorithm, String kid) { this.url = url; this.algorithm = algorithm; + this.kid = kid; } - // Override equals and hashCode to ensure proper functioning of the HashMap @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof KASKeyRequest)) return false; + if (o == null || getClass() != o.getClass()) return false; KASKeyRequest that = (KASKeyRequest) o; - if (algorithm == null){ - return url.equals(that.url); - } - return url.equals(that.url) && algorithm.equals(that.algorithm); + return Objects.equals(url, that.url) && Objects.equals(algorithm, that.algorithm) && Objects.equals(kid, that.kid); } @Override public int hashCode() { - int result = 31 * url.hashCode(); - if (algorithm != null) { - result = result + algorithm.hashCode(); - } - return result; + return Objects.hash(url, algorithm, kid); + } + + @Override + public String toString() { + return "KASKeyRequest{" + + "url='" + url + '\'' + + ", algorithm='" + algorithm + '\'' + + ", kid='" + kid + '\'' + + '}'; } + + // Override equals and hashCode to ensure proper functioning of the HashMap } \ No newline at end of file diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java index 9c5cf010..3b49668b 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java @@ -1,5 +1,9 @@ package io.opentdf.platform.sdk; +import io.opentdf.platform.policy.Algorithm; + +import java.util.Set; + public enum KeyType { RSA2048Key("rsa:2048"), EC256Key("ec:secp256r1"), @@ -39,7 +43,24 @@ public static KeyType fromString(String keyType) { throw new IllegalArgumentException("No enum constant for key type: " + keyType); } + public static KeyType fromAlgorithm(Algorithm a) { + switch (a) { + case ALGORITHM_RSA_2048: + return RSA2048Key; + case ALGORITHM_EC_P256: + return EC256Key; + case ALGORITHM_EC_P384: + return EC384Key; + case ALGORITHM_EC_P521: + return EC521Key; + default: + throw new IllegalArgumentException("Unsupported algorithm: " + a); + } + } + + private static final Set EC_KEY_TYPES = Set.of(EC256Key, EC384Key, EC521Key); + public boolean isEc() { - return this != RSA2048Key; + return EC_KEY_TYPES.contains(this); } } \ No newline at end of file diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java index e0b41549..9b798962 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/NanoTDF.java @@ -16,6 +16,8 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; +import org.bouncycastle.jcajce.provider.asymmetric.ec.KeyFactorySpi; import org.bouncycastle.jce.interfaces.ECPublicKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,6 +67,17 @@ public InvalidNanoTDFConfig(String errorMessage) { } } + private static Optional getBaseKey(WellKnownServiceClientInterface wellKnownService) { + var key = Planner.fetchBaseKey(wellKnownService); + key.ifPresent(k -> { + if (!KeyType.fromAlgorithm(k.getPublicKey().getAlgorithm()).isEc()) { + throw new SDKException(String.format("base key is not an EC key, cannot create NanoTDF using a key of type %s", + k.getPublicKey().getAlgorithm())); + } + }); + return key.map(Config.KASInfo::fromSimpleKasKey); + } + private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) throws InvalidNanoTDFConfig, UnsupportedNanoTDFFeature { if (nanoTDFConfig.collectionConfig.useCollection) { Config.HeaderInfo headerInfo = nanoTDFConfig.collectionConfig.getHeaderInfo(); @@ -74,22 +87,19 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro } Gson gson = new GsonBuilder().create(); - if (nanoTDFConfig.kasInfoList.isEmpty()) { - throw new InvalidNanoTDFConfig("kas url is missing"); - } + Optional maybeKas = getKasInfo(nanoTDFConfig).or(() -> NanoTDF.getBaseKey(services.wellknown())); - Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0); - String url = kasInfo.URL; - if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) { - logger.info("no public key provided for KAS at {}, retrieving", url); - kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getEllipticCurveType()); + if (maybeKas.isEmpty()) { + throw new SDKException("no KAS info provided and couldn't get base key, cannot create NanoTDF"); } + var kasInfo = maybeKas.get(); + // Kas url resource locator - ResourceLocator kasURL = new ResourceLocator(nanoTDFConfig.kasInfoList.get(0).URL, kasInfo.KID); + ResourceLocator kasURL = new ResourceLocator(kasInfo.URL, kasInfo.KID); assert kasURL.getIdentifier() != null : "Identifier in ResourceLocator cannot be null"; - ECKeyPair keyPair = new ECKeyPair(nanoTDFConfig.eccMode.getCurveName(), ECKeyPair.ECAlgorithm.ECDSA); + ECKeyPair keyPair = new ECKeyPair(kasInfo.Algorithm, ECKeyPair.ECAlgorithm.ECDSA); // Generate symmetric key ECPublicKey kasPublicKey = ECKeyPair.publicKeyFromPem(kasInfo.PublicKey); @@ -138,7 +148,12 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro // Create header byte[] compressedPubKey = keyPair.compressECPublickey(); Header header = new Header(); - header.setECCMode(nanoTDFConfig.eccMode); + var mode = new ECCMode(); + mode.setEllipticCurve(Enum.valueOf(NanoTDFType.ECCurve.class, keyPair.curveName())); + if (logger.isWarnEnabled() && !nanoTDFConfig.eccMode.equals(mode)) { + logger.warn("ECC mode provided in NanoTDFConfig: {}, ECC mode from key: {}", nanoTDFConfig.eccMode, mode); + } + header.setECCMode(mode); header.setPayloadConfig(nanoTDFConfig.config); header.setEphemeralKey(compressedPubKey); header.setKasLocator(kasURL); @@ -152,6 +167,21 @@ private Config.HeaderInfo getHeaderInfo(Config.NanoTDFConfig nanoTDFConfig) thro return headerInfo; } + private Optional getKasInfo(Config.NanoTDFConfig nanoTDFConfig) { + if (nanoTDFConfig.kasInfoList.isEmpty()) { + logger.debug("no kas info provided in NanoTDFConfig"); + return Optional.empty(); + } + + Config.KASInfo kasInfo = nanoTDFConfig.kasInfoList.get(0); + String url = kasInfo.URL; + if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) { + logger.info("no public key provided for KAS at {}, retrieving", url); + kasInfo = services.kas().getECPublicKey(kasInfo, nanoTDFConfig.eccMode.getEllipticCurveType()); + } + return Optional.of(kasInfo); + } + public int createNanoTDF(ByteBuffer data, OutputStream outputStream, Config.NanoTDFConfig nanoTDFConfig) throws SDKException, IOException { int nanoTDFSize = 0; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java new file mode 100644 index 00000000..a40a1ddd --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Planner.java @@ -0,0 +1,186 @@ +package io.opentdf.platform.sdk; + +import com.connectrpc.ConnectException; +import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; +import com.google.gson.annotations.SerializedName; +import io.opentdf.platform.policy.Algorithm; +import io.opentdf.platform.policy.SimpleKasKey; +import io.opentdf.platform.policy.SimpleKasPublicKey; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.function.BiFunction; + + +public class Planner { + private static final String BASE_KEY = "base_key"; + private final Config.TDFConfig tdfConfig; + private final SDK.Services services; + private final BiFunction granterFactory; + + + private static final Logger logger = LoggerFactory.getLogger(Planner.class); + + public Planner(Config.TDFConfig config, SDK.Services services, BiFunction granterFactory) { + this.tdfConfig = Objects.requireNonNull(config); + this.services = Objects.requireNonNull(services); + this.granterFactory = granterFactory; + } + + private static String getUUID() { + return UUID.randomUUID().toString(); + } + + Map> getSplits(Config.TDFConfig tdfConfig) { + List splitPlan; + if (tdfConfig.autoconfigure) { + if (tdfConfig.splitPlan != null && !tdfConfig.splitPlan.isEmpty()) { + throw new IllegalArgumentException("cannot use autoconfigure with a split plan provided in the TDFConfig"); + } + splitPlan = getAutoconfigurePlan(services, tdfConfig); + } else if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) { + splitPlan = generatePlanFromProvidedKases(tdfConfig.kasInfoList); + } else { + splitPlan = tdfConfig.splitPlan; + } + + if (tdfConfig.kasInfoList.isEmpty() && splitPlan.isEmpty()) { + throw new SDK.KasInfoMissing("kas information is missing, no key access template specified or inferred"); + } + return resolveKeys(splitPlan); + } + + private List getAutoconfigurePlan(SDK.Services services, Config.TDFConfig tdfConfig) { + Autoconfigure.Granter granter = granterFactory.apply(services, tdfConfig); + return granter.getSplits(defaultKases(tdfConfig), Planner::getUUID, () -> Planner.fetchBaseKey(services.wellknown())); + } + + + List generatePlanFromProvidedKases(List kases) { + if (kases.size() == 1) { + var kasInfo = kases.get(0); + return Collections.singletonList(new Autoconfigure.KeySplitStep(kasInfo.URL, "", kasInfo.KID)); + } + List splitPlan = new ArrayList<>(); + for (var kasInfo : kases) { + splitPlan.add(new Autoconfigure.KeySplitStep(kasInfo.URL, getUUID(), kasInfo.KID)); + } + return splitPlan; + } + + static Optional fetchBaseKey(WellKnownServiceClientInterface wellknown) { + var responseMessage = wellknown + .getWellKnownConfigurationBlocking(GetWellKnownConfigurationRequest.getDefaultInstance(), Collections.emptyMap()) + .execute(); + GetWellKnownConfigurationResponse response; + try { + response = RequestHelper.getOrThrow(responseMessage); + } catch (ConnectException e) { + throw new SDKException("unable to retrieve base key from well known endpoint", e); + } + + String baseKeyJson; + try { + baseKeyJson = response + .getConfiguration() + .getFieldsOrThrow(BASE_KEY) + .getStringValue(); + } catch (IllegalArgumentException e) { + logger.info( "no `" + BASE_KEY + "` found in well known configuration.", e); + return Optional.empty(); + } + + BaseKey baseKey; + try { + baseKey = gson.fromJson(baseKeyJson, BaseKey.class); + } catch (JsonSyntaxException e) { + throw new SDKException("base key in well known configuration is malformed [" + baseKeyJson + "]", e); + } + + if (baseKey == null || baseKey.kasUrl == null || baseKey.publicKey == null || baseKey.publicKey.kid == null || baseKey.publicKey.pem == null || baseKey.publicKey.algorithm == null) { + logger.error("base key in well known configuration is missing required fields [{}]. base key will not be used", baseKeyJson); + return Optional.empty(); + } + + return Optional.of(SimpleKasKey.newBuilder() + .setKasUri(baseKey.kasUrl) + .setPublicKey( + SimpleKasPublicKey.newBuilder() + .setKid(baseKey.publicKey.kid) + .setAlgorithm(baseKey.publicKey.algorithm) + .setPem(baseKey.publicKey.pem) + .build()) + .build()); + } + + private static final Gson gson = new Gson(); + + private static class BaseKey { + @SerializedName("kas_url") + String kasUrl; + + @SerializedName("public_key") + Key publicKey; + + private static class Key { + String kid; + String pem; + Algorithm algorithm; + } + } + + Map> resolveKeys(List splitPlan) { + Map> conjunction = new HashMap<>(); + var latestKASInfo = new HashMap(); + // Seed anything passed in manually + for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { + if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { + latestKASInfo.put(kasInfo.URL, kasInfo); + } + } + + for (Autoconfigure.KeySplitStep splitInfo: splitPlan) { + // Public key was passed in with kasInfoList + // TODO First look up in attribute information / add to split plan? + Config.KASInfo ki = latestKASInfo.get(splitInfo.kas); + if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank() || (splitInfo.kid != null && !splitInfo.kid.equals(ki.KID))) { + logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas); + var getKI = new Config.KASInfo(); + getKI.URL = splitInfo.kas; + getKI.Algorithm = tdfConfig.wrappingKeyType.toString(); + getKI.KID = splitInfo.kid; + getKI = services.kas().getPublicKey(getKI); + latestKASInfo.put(splitInfo.kas, getKI); + ki = getKI; + } + conjunction.computeIfAbsent(splitInfo.splitID, s -> new ArrayList<>()).add(ki); + } + return conjunction; + } + + static List defaultKases(Config.TDFConfig config) { + List allk = new ArrayList<>(); + List defk = new ArrayList<>(); + + for (Config.KASInfo kasInfo : config.kasInfoList) { + if (kasInfo.Default != null && kasInfo.Default) { + defk.add(kasInfo.URL); + } else if (defk.isEmpty()) { + allk.add(kasInfo.URL); + } + } + return defk.isEmpty() ? allk : defk; + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java index 9195fca5..a0db542d 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java @@ -9,6 +9,7 @@ import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import javax.net.ssl.TrustManager; import java.io.IOException; @@ -75,6 +76,8 @@ public interface Services extends AutoCloseable { KeyAccessServerRegistryServiceClientInterface kasRegistry(); + WellKnownServiceClientInterface wellknown(); + KAS kas(); } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index dd83f75a..03b1943c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -33,6 +33,7 @@ import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClient; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import nl.altindag.ssl.SSLFactory; import nl.altindag.ssl.pem.util.PemUtils; import okhttp3.OkHttpClient; @@ -251,6 +252,7 @@ ServicesAndInternals buildServices() { var resourceMappingService = new ResourceMappingServiceClient(client); var authorizationService = new AuthorizationServiceClient(client); var kasRegistryService = new KeyAccessServerRegistryServiceClient(client); + var wellKnownService = new WellKnownServiceClient(client); var services = new SDK.Services() { @Override @@ -290,6 +292,11 @@ public KeyAccessServerRegistryServiceClient kasRegistry() { return kasRegistryService; } + @Override + public WellKnownServiceClientInterface wellknown() { + return wellKnownService; + } + @Override public SDK.KAS kas() { return kasClient; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index fe547caa..cfe75d26 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -5,11 +5,8 @@ import com.google.gson.GsonBuilder; import com.nimbusds.jose.*; -import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest; import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersResponse; -import io.opentdf.platform.sdk.Config.TDFConfig; -import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; import io.opentdf.platform.sdk.Config.KASInfo; import org.apache.commons.codec.DecoderException; @@ -142,7 +139,7 @@ private PolicyObject createPolicyObject(List at private static final Base64.Encoder encoder = Base64.getEncoder(); - private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { + private void prepareManifest(Config.TDFConfig tdfConfig, Map> splits) { manifest.tdfVersion = tdfConfig.renderVersionInfoInManifest ? TDF_VERSION : null; manifest.encryptionInformation.keyAccessType = kSplitKeyType; manifest.encryptionInformation.keyAccessObj = new ArrayList<>(); @@ -150,60 +147,12 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { PolicyObject policyObject = createPolicyObject(tdfConfig.attributes); String base64PolicyObject = encoder .encodeToString(gson.toJson(policyObject).getBytes(StandardCharsets.UTF_8)); - Map latestKASInfo = new HashMap<>(); - if (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty()) { - // Default split plan: Split keys across all KASes - List splitPlan = new ArrayList<>(tdfConfig.kasInfoList.size()); - int i = 0; - for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { - Autoconfigure.KeySplitStep step = new Autoconfigure.KeySplitStep(kasInfo.URL, ""); - if (tdfConfig.kasInfoList.size() > 1) { - step.splitID = String.format("s-%d", i++); - } - splitPlan.add(step); - if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { - latestKASInfo.put(kasInfo.URL, kasInfo); - } - } - tdfConfig.splitPlan = splitPlan; - } - // Seed anything passed in manually - for (Config.KASInfo kasInfo : tdfConfig.kasInfoList) { - if (kasInfo.PublicKey != null && !kasInfo.PublicKey.isEmpty()) { - latestKASInfo.put(kasInfo.URL, kasInfo); - } - } - // split plan: restructure by conjunctions - Map> conjunction = new HashMap<>(); - List splitIDs = new ArrayList<>(); - - for (Autoconfigure.KeySplitStep splitInfo : tdfConfig.splitPlan) { - // Public key was passed in with kasInfoList - // TODO First look up in attribute information / add to split plan? - Config.KASInfo ki = latestKASInfo.get(splitInfo.kas); - if (ki == null || ki.PublicKey == null || ki.PublicKey.isBlank()) { - logger.info("no public key provided for KAS at {}, retrieving", splitInfo.kas); - var getKI = new Config.KASInfo(); - getKI.URL = splitInfo.kas; - getKI.Algorithm = tdfConfig.wrappingKeyType.toString(); - getKI = kas.getPublicKey(getKI); - latestKASInfo.put(splitInfo.kas, getKI); - ki = getKI; - } - if (conjunction.containsKey(splitInfo.splitID)) { - conjunction.get(splitInfo.splitID).add(ki); - } else { - List newList = new ArrayList<>(); - newList.add(ki); - conjunction.put(splitInfo.splitID, newList); - splitIDs.add(splitInfo.splitID); - } - } + List symKeys = new ArrayList<>(splits.size()); + for (var split : splits.entrySet()) { + String splitID = split.getKey(); - List symKeys = new ArrayList<>(splitIDs.size()); - for (String splitID : splitIDs) { // Symmetric key byte[] symKey = new byte[GCM_KEY_SIZE]; sRandom.nextBytes(symKey); @@ -230,7 +179,8 @@ private void prepareManifest(Config.TDFConfig tdfConfig, SDK.KAS kas) { encryptedMetadata = encoder.encodeToString(metadata.getBytes(StandardCharsets.UTF_8)); } - for (Config.KASInfo kasInfo : conjunction.get(splitID)) { + List kasInfos = split.getValue(); + for (Config.KASInfo kasInfo : kasInfos) { if (kasInfo.PublicKey == null || kasInfo.PublicKey.isEmpty()) { throw new SDK.KasPublicKeyMissing("Kas public key is missing in kas information list"); } @@ -301,7 +251,6 @@ private String createRSAWrappedKey(Config.KASInfo kasInfo, byte[] symKey) { } } - private static final Base64.Decoder decoder = Base64.getDecoder(); public static class Reader { @@ -326,7 +275,6 @@ public Manifest getManifest() { this.aesGcm = new AesGcm(payloadKey); this.payloadKey = payloadKey; this.unencryptedMetadata = unencryptedMetadata; - } public void readPayload(OutputStream outputStream) throws SDK.SegmentSignatureMismatch, IOException { @@ -400,35 +348,11 @@ private static byte[] calculateSignature(byte[] data, byte[] secret, Config.Inte } TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFConfig tdfConfig) throws SDKException, IOException { - - if (tdfConfig.autoconfigure) { - Autoconfigure.Granter granter = new Autoconfigure.Granter(new ArrayList<>()); - if (tdfConfig.attributeValues != null && !tdfConfig.attributeValues.isEmpty()) { - granter = Autoconfigure.newGranterFromAttributes(tdfConfig.attributeValues.toArray(new Value[0])); - } else if (tdfConfig.attributes != null && !tdfConfig.attributes.isEmpty()) { - granter = Autoconfigure.newGranterFromService(services.attributes(), services.kas().getKeyCache(), - tdfConfig.attributes.toArray(new AttributeValueFQN[0])); - } - - if (granter == null) { - throw new AutoConfigureException("Failed to create Granter"); // Replace with appropriate error handling - } - - List dk = defaultKases(tdfConfig); - tdfConfig.splitPlan = granter.plan(dk, () -> UUID.randomUUID().toString()); - - if (tdfConfig.splitPlan == null) { - throw new AutoConfigureException("Failed to generate Split Plan"); // Replace with appropriate error - // handling - } - } - - if (tdfConfig.kasInfoList.isEmpty() && (tdfConfig.splitPlan == null || tdfConfig.splitPlan.isEmpty())) { - throw new SDK.KasInfoMissing("kas information is missing, no key access template specified or inferred"); - } + Planner planner = new Planner(tdfConfig, services, Autoconfigure::createGranter); + Map> splits = planner.getSplits(tdfConfig); TDFObject tdfObject = new TDFObject(); - tdfObject.prepareManifest(tdfConfig, services.kas()); + tdfObject.prepareManifest(tdfConfig, splits); long encryptedSegmentSize = tdfConfig.defaultSegmentSize + kGcmIvSize + kAesBlockSize; TDFWriter tdfWriter = new TDFWriter(outputStream); @@ -561,22 +485,6 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo return tdfObject; } - static List defaultKases(TDFConfig config) { - List allk = new ArrayList<>(); - List defk = new ArrayList<>(); - - for (KASInfo kasInfo : config.kasInfoList) { - if (kasInfo.Default != null && kasInfo.Default) { - defk.add(kasInfo.URL); - } else if (defk.isEmpty()) { - allk.add(kasInfo.URL); - } - } - if (defk.isEmpty()) { - return allk; - } - return defk; - } Reader loadTDF(SeekableByteChannel tdf, String platformUrl) throws SDKException, IOException { return loadTDF(tdf, Config.newTDFReaderConfig(), platformUrl); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java index 59cd0912..d37f2727 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/AutoconfigureTest.java @@ -1,18 +1,8 @@ package io.opentdf.platform.sdk; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import com.connectrpc.ResponseMessage; import com.connectrpc.UnaryBlockingCall; +import io.opentdf.platform.policy.Algorithm; import io.opentdf.platform.policy.Attribute; import io.opentdf.platform.policy.AttributeRuleTypeEnum; import io.opentdf.platform.policy.KasPublicKey; @@ -21,16 +11,17 @@ import io.opentdf.platform.policy.KeyAccessServer; import io.opentdf.platform.policy.Namespace; import io.opentdf.platform.policy.PublicKey; +import io.opentdf.platform.policy.SimpleKasKey; +import io.opentdf.platform.policy.SimpleKasPublicKey; import io.opentdf.platform.policy.Value; import io.opentdf.platform.policy.attributes.AttributesServiceClient; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest; import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse; import io.opentdf.platform.sdk.Autoconfigure.AttributeValueFQN; +import io.opentdf.platform.sdk.Autoconfigure.Granter; import io.opentdf.platform.sdk.Autoconfigure.Granter.AttributeBooleanExpression; import io.opentdf.platform.sdk.Autoconfigure.Granter.BooleanKeyExpression; import io.opentdf.platform.sdk.Autoconfigure.KeySplitStep; -import io.opentdf.platform.sdk.Autoconfigure.Granter; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -39,10 +30,26 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class AutoconfigureTest { @@ -56,6 +63,16 @@ public class AutoconfigureTest { public static final String SPECIFIED_KAS = "https://attr.kas.com/"; public static final String EVEN_MORE_SPECIFIC_KAS = "https://value.kas.com/"; private static final String NAMESPACE_KAS = "https://namespace.kas.com/"; + private static final SimpleKasKey NAMESPACE_KAS_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("namespacekey").setKid("namespacekeykid").build() + ).build(); + private static final SimpleKasKey ATTRIBUTE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("attrpem").setKid("attrkeykid").build() + ).build(); + private static final SimpleKasKey VALUE_KEY = SimpleKasKey.newBuilder().setKasUri("https://mapped.example.com").setKasId("mapped").setPublicKey( + SimpleKasPublicKey.newBuilder().setAlgorithm(Algorithm.ALGORITHM_EC_P521).setPem("valuepem").setKid("valuekeykid").build() + ).build(); + private static Autoconfigure.AttributeNameFQN UNMAPPED; private static Autoconfigure.AttributeNameFQN SPKSPECKED; private static Autoconfigure.AttributeNameFQN SPKUNSPECKED; @@ -64,6 +81,8 @@ public class AutoconfigureTest { private static Autoconfigure.AttributeNameFQN REL; private static Autoconfigure.AttributeNameFQN UNSPECKED; private static Autoconfigure.AttributeNameFQN SPECKED; + private static Autoconfigure.AttributeNameFQN MAPPED; + private static Autoconfigure.AttributeNameFQN SPKMAPPED; private static Autoconfigure.AttributeValueFQN clsA; private static Autoconfigure.AttributeValueFQN clsS; @@ -84,6 +103,9 @@ public class AutoconfigureTest { private static Autoconfigure.AttributeValueFQN spk2spk2uns; private static Autoconfigure.AttributeValueFQN spk2spk2spk; + private static Autoconfigure.AttributeValueFQN mp2uns2uns; + private static Autoconfigure.AttributeValueFQN mp2uns2mp; + @BeforeAll public static void setup() throws AutoConfigureException { // Initialize the FQNs (Fully Qualified Names) @@ -92,8 +114,11 @@ public static void setup() throws AutoConfigureException { REL = new Autoconfigure.AttributeNameFQN("https://virtru.com/attr/Releasable%20To"); UNSPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/unspecified"); SPECKED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/specified"); + MAPPED = new Autoconfigure.AttributeNameFQN("https://other.com/attr/mapped"); + UNMAPPED = new Autoconfigure.AttributeNameFQN("https://mapped.com/attr/unspecified"); SPKUNSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/unspecified"); SPKSPECKED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/specified"); + SPKMAPPED = new Autoconfigure.AttributeNameFQN("https://hasgrants.com/attr/mapped"); clsA = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Allowed"); clsS = new Autoconfigure.AttributeValueFQN("https://virtru.com/attr/Classification/value/Secret"); @@ -117,6 +142,9 @@ public static void setup() throws AutoConfigureException { spk2uns2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/unspecified/value/specked"); spk2spk2uns = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/unspecked"); spk2spk2spk = new Autoconfigure.AttributeValueFQN("https://hasgrants.com/attr/specified/value/specked"); + + mp2uns2uns = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/unspecked"); + mp2uns2mp = new Autoconfigure.AttributeValueFQN("https://mapped.com/attr/unspecified/value/mapped"); } private static String spongeCase(String s) { @@ -181,6 +209,7 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) { Namespace ns1 = Namespace.newBuilder().setId("v").setName("virtru.com").setFqn("https://virtru.com").build(); Namespace ns2 = Namespace.newBuilder().setId("o").setName("other.com").setFqn("https://other.com").build(); Namespace ns3 = Namespace.newBuilder().setId("h").setName("hasgrants.com").addGrants(KeyAccessServer.newBuilder().setUri(NAMESPACE_KAS).build()).setFqn("https://hasgrants.com").build(); + Namespace ns4 = Namespace.newBuilder().setId("m").setName("mapped.com").addKasKeys(NAMESPACE_KAS_KEY).build(); String key = fqn.getKey(); if (key.equals(CLS.getKey())) { @@ -216,6 +245,17 @@ private Attribute mockAttributeFor(Autoconfigure.AttributeNameFQN fqn) { .setName("unspecified").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) .setName(fqn.toString()) .build(); + } else if (key.equals(MAPPED.getKey())) { + return Attribute.newBuilder().setId(MAPPED.getKey()).setNamespace(ns4) + .setName("mapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + .setKasKeys(0, ATTRIBUTE_KEY) + .setName(fqn.toString()) + .build(); + } else if (key.equals(UNMAPPED.getKey())) { + return Attribute.newBuilder().setId(UNMAPPED.getKey()).setNamespace(ns4) + .setName("unmapped attribute").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + .setName(fqn.toString()) + .build(); } throw new IllegalArgumentException("Key not recognized: " + key); @@ -287,7 +327,13 @@ private Value mockValueFor(Autoconfigure.AttributeValueFQN fqn) throws AutoConfi p = p.toBuilder().addGrants(KeyAccessServer.newBuilder().setUri(EVEN_MORE_SPECIFIC_KAS).build()) .build(); } + } else if (Objects.equals(UNMAPPED.getKey(), an.getKey())) { + if (fqn.value().equalsIgnoreCase("mapped")) { + p = p.toBuilder().addKasKeys(VALUE_KEY) + .build(); + } } + return p; } @@ -368,7 +414,7 @@ public void testConfigurationServicePutGet() { for (ConfigurationTestCase tc : testCases) { assertDoesNotThrow(() -> { List v = valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0])); - Granter grants = Autoconfigure.newGranterFromAttributes(v.toArray(new Value[0])); + Granter grants = Autoconfigure.newGranterFromAttributes(null, v.toArray(new Value[0])); assertThat(grants).isNotNull(); assertThat(grants.getGrants()).hasSize(tc.getSize()); assertThat(policyToStringKeys(tc.getPolicy())).containsAll(grants.getGrants().keySet()); @@ -451,7 +497,7 @@ public void testReasonerConstructAttributeBoolean() { new KeySplitStep(KAS_US_HCS, "2"), new KeySplitStep(KAS_US_SA, "3")))); for (ReasonerTestCase tc : testCases) { - Granter reasoner = Autoconfigure.newGranterFromAttributes( + Granter reasoner = Autoconfigure.newGranterFromAttributes(null, valuesToPolicy(tc.getPolicy().toArray(new AttributeValueFQN[0])).toArray(new Value[0])); assertThat(reasoner).isNotNull(); @@ -467,15 +513,33 @@ public void testReasonerConstructAttributeBoolean() { var wrapper = new Object() { int i = 0; }; - List plan = reasoner.plan(tc.getDefaults(), () -> { - return String.valueOf(wrapper.i++ + 1); - } - - ); - assertThat(plan).isEqualTo(tc.getPlan()); + List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty); + assertThat(plan) + .as(tc.name) + .isEqualTo(tc.getPlan()); } } + @Test + void testUsingAttributeMappedAtNamespace() { + Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), mockValueFor(mp2uns2uns)); + var counter = new AtomicInteger(0); + var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty); + assertThat(splitPlan).isEqualTo(List.of(new KeySplitStep("https://mapped.example.com", "", NAMESPACE_KAS_KEY.getPublicKey().getKid()))); + } + + @Test + void testUsingAttributeMappedAtMultiplePlaces() { + var attributes = new Value[]{mockValueFor(mp2uns2uns), mockValueFor(mp2uns2mp)}; + Granter granter = Autoconfigure.newGranterFromAttributes(new KASKeyCache(), attributes); + var counter = new AtomicInteger(0); + var splitPlan = granter.getSplits(Collections.emptyList(), () -> Integer.toString(counter.getAndIncrement()), Optional::empty); + assertThat(splitPlan).isEqualTo(List.of( + new KeySplitStep(NAMESPACE_KAS_KEY.getKasUri(), "0", NAMESPACE_KAS_KEY.getPublicKey().getKid()), + new KeySplitStep(VALUE_KEY.getKasUri(), "0", VALUE_KEY.getPublicKey().getKid()) + )); + } + GetAttributeValuesByFqnsResponse getResponse(GetAttributeValuesByFqnsRequest req) { GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); @@ -547,7 +611,7 @@ public void testReasonerSpecificity() { List.of(KAS_US), List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), new ReasonerTestCase( - "uns.uns & uns.spk => spk", + "uns.uns & spk.spk => spk", List.of(uns2uns, spk2spk), List.of(KAS_US), List.of(new KeySplitStep(EVEN_MORE_SPECIFIC_KAS, ""))), @@ -598,12 +662,11 @@ public void cancel() { var wrapper = new Object() { int i = 0; }; - List plan = reasoner.plan(tc.getDefaults(), () -> { - return String.valueOf(wrapper.i++ + 1); - } + List plan = reasoner.getSplits(tc.getDefaults(), () -> String.valueOf(wrapper.i++ + 1), Optional::empty); + assertThat(plan) + .as(tc.name) + .hasSameElementsAs(tc.getPlan()); - ); - assertThat(plan).hasSameElementsAs(tc.getPlan()); } } @@ -691,7 +754,7 @@ private static class ReasonerTestCase { private final List plan; ReasonerTestCase(String name, List policy, List defaults, String ats, String keyed, - String reduced, List plan) { + String reduced, List plan) { this.name = name; this.policy = policy; this.defaults = defaults; @@ -744,13 +807,11 @@ public List getPlan() { void testStoreKeysToCache_NoKeys() { KASKeyCache keyCache = Mockito.mock(KASKeyCache.class); KeyAccessServer kas1 = KeyAccessServer.newBuilder().setPublicKey( - PublicKey.newBuilder().setCached( - KasPublicKeySet.newBuilder())) + PublicKey.newBuilder().setCached( + KasPublicKeySet.newBuilder())) .build(); - List kases = List.of(kas1); - - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); verify(keyCache, never()).store(any(Config.KASInfo.class)); } @@ -780,14 +841,11 @@ void testStoreKeysToCache_WithKeys() { .setUri("https://example.com/kas") .build(); - // Add the KeyAccessServer to a list - List kases = List.of(kas1); - // Call the method under test - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); @@ -826,21 +884,18 @@ void testStoreKeysToCache_MultipleKasEntries() { .setUri("https://example.com/kas") .build(); - // Add the KeyAccessServer to a list - List kases = List.of(kas1); - // Call the method under test - Autoconfigure.storeKeysToCache(kases, keyCache); + Autoconfigure.storeKeysToCache(List.of(kas1), Collections.emptyList(), keyCache); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); assertEquals("ec:secp256r1", storedKASInfo.Algorithm); assertEquals("public-key-pem", storedKASInfo.PublicKey); - Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048"); + Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2"); assertNotNull(storedKASInfo2); assertEquals("https://example.com/kas", storedKASInfo2.URL); assertEquals("test-kid-2", storedKASInfo2.KID); @@ -849,7 +904,7 @@ void testStoreKeysToCache_MultipleKasEntries() { } GetAttributeValuesByFqnsResponse getResponseWithGrants(GetAttributeValuesByFqnsRequest req, - List grants) { + List grants) { GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); for (String v : req.getFqnsList()) { @@ -901,13 +956,15 @@ void testKeyCacheFromGrants() { AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class); when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> { - var request = (GetAttributeValuesByFqnsRequest)invocation.getArgument(0); - return new UnaryBlockingCall(){ + var request = (GetAttributeValuesByFqnsRequest) invocation.getArgument(0); + return new UnaryBlockingCall() { @Override public ResponseMessage execute() { return new ResponseMessage.Success<>(getResponseWithGrants(request, List.of(kas1)), Collections.emptyMap(), Collections.emptyMap()); } - @Override public void cancel() { + + @Override + public void cancel() { // not really calling anything } }; @@ -920,14 +977,14 @@ public ResponseMessage execute() { assertThat(reasoner).isNotNull(); // Verify that the key was stored in the cache - Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1"); + Config.KASInfo storedKASInfo = keyCache.get("https://example.com/kas", "ec:secp256r1", "test-kid"); assertNotNull(storedKASInfo); assertEquals("https://example.com/kas", storedKASInfo.URL); assertEquals("test-kid", storedKASInfo.KID); assertEquals("ec:secp256r1", storedKASInfo.Algorithm); assertEquals("public-key-pem", storedKASInfo.PublicKey); - Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048"); + Config.KASInfo storedKASInfo2 = keyCache.get("https://example.com/kas", "rsa:2048", "test-kid-2"); assertNotNull(storedKASInfo2); assertEquals("https://example.com/kas", storedKASInfo2.URL); assertEquals("test-kid-2", storedKASInfo2.KID); @@ -935,4 +992,155 @@ public ResponseMessage execute() { assertEquals("public-key-pem-2", storedKASInfo2.PublicKey); } + @Test + void testUsingBaseKeyWhenNoMappedKeysOrGrants() { + Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null); + SimpleKasKey key = SimpleKasKey.newBuilder() + .setKasUri("https://example.com/kas") + .setPublicKey( + SimpleKasPublicKey.newBuilder() + .setKid("thenewkid") + .setPem("anotherpem") + .setAlgorithm(Algorithm.ALGORITHM_EC_P521) + ).build(); + + var splits = granter.getSplits( + List.of("https://example.org/kas2"), + () -> { + throw new IllegalStateException("the plan should have a single element"); + }, + () -> Optional.of(key)); + assertThat(splits).hasSize(1); + assertThat(splits.get(0)).isEqualTo(new KeySplitStep("https://example.com/kas", "", "thenewkid")); + } + + @Test + void testUsingDefaultKasesWhenNothingElseProvided() { + Autoconfigure.Granter granter = Autoconfigure.newGranterFromAttributes(null); + var counter = new AtomicInteger(); + Supplier splitGen = () -> String.valueOf(counter.getAndIncrement()); + var splits = granter.getSplits( + List.of("https://example.org/kas1", "https://example.org/kas2"), + splitGen, + Optional::empty); + + assertThat(splits) + .hasSize(2) + .asList().containsExactly( + new KeySplitStep("https://example.org/kas1", "0", null), + new KeySplitStep("https://example.org/kas2", "1", null) + ); + } + + @Test + void createsGranterFromAttributeValues() { + // Arrange + Config.TDFConfig config = new Config.TDFConfig(); + config.attributeValues = List.of(mockValueFor(spk2spk), mockValueFor(rel2gbr)); + + SDK.Services services = mock(SDK.Services.class); + SDK.KAS kas = mock(SDK.KAS.class); + when(services.kas()).thenReturn(kas); + when(services.attributes()).thenThrow(new IllegalStateException("should never use the attribute service when attributes are provided")); + when(kas.getKeyCache()).thenReturn(null); // No cache needed for this test + + // Act + Autoconfigure.Granter granter = Autoconfigure.createGranter(services, config); + + // Assert + assertThat(granter).isNotNull(); + assertThat(granter.getPolicy()).hasSize(2); + assertThat(granter.getPolicy()).containsExactlyInAnyOrder( + new AttributeValueFQN("https://other.com/attr/specified/value/specked"), + new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR") + ); + } + + @Test + void createsGranterFromService() { + // Arrange + SDK.Services services = mock(SDK.Services.class); + SDK.KAS kas = mock(SDK.KAS.class); + AttributesServiceClient attributesServiceClient = mock(AttributesServiceClient.class); + + // Prepare a request and a mocked response + List policy = List.of( + new AttributeValueFQN("https://other.com/attr/specified/value/specked"), + new AttributeValueFQN("https://virtru.com/attr/Releasable%20To/value/GBR") + ); + + when(services.kas()).thenReturn(kas); + when(services.attributes()).thenReturn(attributesServiceClient); + + // Mock the attribute service to return a response with the expected values + when(attributesServiceClient.getAttributeValuesByFqnsBlocking(any(), any())).thenAnswer(invocation -> { + GetAttributeValuesByFqnsResponse.Builder builder = GetAttributeValuesByFqnsResponse.newBuilder(); + for (AttributeValueFQN fqn : policy) { + Value value = Value.newBuilder() + .setId(fqn.toString()) + .setFqn(fqn.toString()) + .build(); + builder.putFqnAttributeValues(fqn.toString(), + GetAttributeValuesByFqnsResponse.AttributeAndValue.newBuilder() + .setValue(value) + .build()); + } + return TestUtil.successfulUnaryCall(builder.build()); + }); + + // Act + Autoconfigure.Granter granter = Autoconfigure.createGranter(services, new Config.TDFConfig() {{ + attributeValues = null; // force use of service + attributes = policy; + }}); + + // Assert + assertThat(granter).isNotNull(); + // The policy should be empty because attributeValues is null, but the test ensures the service is called + // If you want to check the service call, verify it: + verify(services).attributes(); + } + + @Test + void getSplits_usesAutoconfigurePlan_whenAutoconfigureTrue() { + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = true; + tdfConfig.kasInfoList = new ArrayList<>(); + tdfConfig.splitPlan = null; + + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getKeyCache()).thenReturn(new KASKeyCache()); + Config.KASInfo kasInfo = new Config.KASInfo() {{ + URL = "https://kas.example.com"; + Algorithm = "ec:secp256r1"; + KID = "kid"; + }}; + Mockito.when(kas.getPublicKey(any())).thenReturn(kasInfo); + + var services = new FakeServicesBuilder().setKas(kas).build(); + + // Mock granterFactory to return a granter with a known split plan + var expectedSplit = new Autoconfigure.KeySplitStep("https://kas.example.com", "", "kid"); + var granter = Mockito.mock(Autoconfigure.Granter.class); + Mockito.when(granter.getSplits( + Mockito.anyList(), + Mockito.any(), + Mockito.any())) + .thenReturn(List.of(expectedSplit)); + + BiFunction granterFactory = + (s, c) -> granter; + + var planner = new Planner(tdfConfig, services, granterFactory); + + // Act + var splits = planner.getSplits(tdfConfig); + + // Assert + assertThat(splits).containsKey(""); + assertThat(splits.get("")).hasSize(1); + assertThat(splits.get("").get(0).URL).isEqualTo("https://kas.example.com"); + assertThat(splits.get("").get(0).KID).isEqualTo("kid"); + assertThat(splits.get("").get(0).Algorithm).isEqualTo("ec:secp256r1"); + } } diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java index b3573593..2851b22b 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServices.java @@ -1,31 +1,34 @@ package io.opentdf.platform.sdk; -import io.opentdf.platform.authorization.AuthorizationServiceClient; -import io.opentdf.platform.policy.attributes.AttributesServiceClient; -import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; -import io.opentdf.platform.policy.namespaces.NamespaceServiceClient; -import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient; -import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient; +import io.opentdf.platform.authorization.AuthorizationServiceClientInterface; +import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; +import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface; +import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; +import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; +import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import java.util.Objects; public class FakeServices implements SDK.Services { - private final AuthorizationServiceClient authorizationService; - private final AttributesServiceClient attributesService; - private final NamespaceServiceClient namespaceService; - private final SubjectMappingServiceClient subjectMappingService; - private final ResourceMappingServiceClient resourceMappingService; - private final KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub; + private final AuthorizationServiceClientInterface authorizationService; + private final AttributesServiceClientInterface attributesService; + private final NamespaceServiceClientInterface namespaceService; + private final SubjectMappingServiceClientInterface subjectMappingService; + private final ResourceMappingServiceClientInterface resourceMappingService; + private final KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub; + private final WellKnownServiceClientInterface wellKnownService; private final SDK.KAS kas; public FakeServices( - AuthorizationServiceClient authorizationService, - AttributesServiceClient attributesService, - NamespaceServiceClient namespaceService, - SubjectMappingServiceClient subjectMappingService, - ResourceMappingServiceClient resourceMappingService, - KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub, + AuthorizationServiceClientInterface authorizationService, + AttributesServiceClientInterface attributesService, + NamespaceServiceClientInterface namespaceService, + SubjectMappingServiceClientInterface subjectMappingService, + ResourceMappingServiceClientInterface resourceMappingService, + KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub, + WellKnownServiceClientInterface wellKnownServiceClient, SDK.KAS kas) { this.authorizationService = authorizationService; this.attributesService = attributesService; @@ -33,39 +36,45 @@ public FakeServices( this.subjectMappingService = subjectMappingService; this.resourceMappingService = resourceMappingService; this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub; + this.wellKnownService = wellKnownServiceClient; this.kas = kas; } @Override - public AuthorizationServiceClient authorization() { + public AuthorizationServiceClientInterface authorization() { return Objects.requireNonNull(authorizationService); } @Override - public AttributesServiceClient attributes() { + public AttributesServiceClientInterface attributes() { return Objects.requireNonNull(attributesService); } @Override - public NamespaceServiceClient namespaces() { + public NamespaceServiceClientInterface namespaces() { return Objects.requireNonNull(namespaceService); } @Override - public SubjectMappingServiceClient subjectMappings() { + public SubjectMappingServiceClientInterface subjectMappings() { return Objects.requireNonNull(subjectMappingService); } @Override - public ResourceMappingServiceClient resourceMappings() { + public ResourceMappingServiceClientInterface resourceMappings() { return Objects.requireNonNull(resourceMappingService); } @Override - public KeyAccessServerRegistryServiceClient kasRegistry() { + public KeyAccessServerRegistryServiceClientInterface kasRegistry() { return Objects.requireNonNull(keyAccessServerRegistryServiceFutureStub); } + @Override + public WellKnownServiceClientInterface wellknown() { + return Objects.requireNonNull(wellKnownService); + } + @Override public SDK.KAS kas() { return Objects.requireNonNull(kas); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java index 2a80f53d..558aee3b 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/FakeServicesBuilder.java @@ -1,47 +1,54 @@ package io.opentdf.platform.sdk; -import io.opentdf.platform.authorization.AuthorizationServiceClient; -import io.opentdf.platform.policy.attributes.AttributesServiceClient; -import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; -import io.opentdf.platform.policy.namespaces.NamespaceServiceClient; -import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClient; -import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClient; +import io.opentdf.platform.authorization.AuthorizationServiceClientInterface; +import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; +import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface; +import io.opentdf.platform.policy.namespaces.NamespaceServiceClientInterface; +import io.opentdf.platform.policy.resourcemapping.ResourceMappingServiceClientInterface; +import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceClientInterface; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; public class FakeServicesBuilder { - private AuthorizationServiceClient authorizationService; - private AttributesServiceClient attributesService; - private NamespaceServiceClient namespaceService; - private SubjectMappingServiceClient subjectMappingService; - private ResourceMappingServiceClient resourceMappingService; - private KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub; + private AuthorizationServiceClientInterface authorizationService; + private AttributesServiceClientInterface attributesService; + private NamespaceServiceClientInterface namespaceService; + private SubjectMappingServiceClientInterface subjectMappingService; + private ResourceMappingServiceClientInterface resourceMappingService; + private KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub; + private WellKnownServiceClientInterface wellKnownServiceClient; private SDK.KAS kas; - public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClient authorizationService) { + public FakeServicesBuilder setAuthorizationService(AuthorizationServiceClientInterface authorizationService) { this.authorizationService = authorizationService; return this; } - public FakeServicesBuilder setAttributesService(AttributesServiceClient attributesService) { + public FakeServicesBuilder setAttributesService(AttributesServiceClientInterface attributesService) { this.attributesService = attributesService; return this; } - public FakeServicesBuilder setNamespaceService(NamespaceServiceClient namespaceService) { + public FakeServicesBuilder setNamespaceService(NamespaceServiceClientInterface namespaceService) { this.namespaceService = namespaceService; return this; } - public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClient subjectMappingService) { + public FakeServicesBuilder setSubjectMappingService(SubjectMappingServiceClientInterface subjectMappingService) { this.subjectMappingService = subjectMappingService; return this; } - public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClient resourceMappingService) { + public FakeServicesBuilder setResourceMappingService(ResourceMappingServiceClientInterface resourceMappingService) { this.resourceMappingService = resourceMappingService; return this; } - public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClient keyAccessServerRegistryServiceFutureStub) { + public FakeServicesBuilder setWellknownService(WellKnownServiceClientInterface wellKnownServiceClient) { + this.wellKnownServiceClient = wellKnownServiceClient; + return this; + } + + public FakeServicesBuilder setKeyAccessServerRegistryService(KeyAccessServerRegistryServiceClientInterface keyAccessServerRegistryServiceFutureStub) { this.keyAccessServerRegistryServiceFutureStub = keyAccessServerRegistryServiceFutureStub; return this; } @@ -52,6 +59,7 @@ public FakeServicesBuilder setKas(SDK.KAS kas) { } public FakeServices build() { - return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService, resourceMappingService, keyAccessServerRegistryServiceFutureStub, kas); + return new FakeServices(authorizationService, attributesService, namespaceService, subjectMappingService, + resourceMappingService, keyAccessServerRegistryServiceFutureStub, wellKnownServiceClient, kas); } } \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java index 5550678a..fdee682e 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/KASKeyCacheTest.java @@ -35,7 +35,7 @@ void testStoreAndGet_WithinTimeLimit() { kasKeyCache.store(kasInfo1); // Retrieve the item within the time limit - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure the item was correctly retrieved assertNotNull(result); @@ -51,12 +51,24 @@ void testStoreAndGet_AfterTimeLimit() { kasKeyCache.store(kasInfo1); // Simulate time passing by modifying the timestamp directly - KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); + KASKeyRequest cacheKey = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); TimeStampedKASInfo timeStampedKASInfo = new TimeStampedKASInfo(kasInfo1, LocalDateTime.now().minus(6, ChronoUnit.MINUTES)); kasKeyCache.cache.put(cacheKey, timeStampedKASInfo); // Attempt to retrieve the item after the time limit - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); + + // Ensure the item was not retrieved (it should have expired) + assertNull(result); + } + + @Test + void testStoreAndGet_DifferentKIDs() { + // Store an item in the cache + kasKeyCache.store(kasInfo1); + + // Attempt to retrieve the item with a different KID + Config.KASInfo result = kasKeyCache.get(kasInfo1.URL, kasInfo1.Algorithm, kasInfo1.KID + "different"); // Ensure the item was not retrieved (it should have expired) assertNull(result); @@ -72,7 +84,7 @@ void testStoreAndGet_WithNullAlgorithm() { kasKeyCache.store(kasInfo1); // Retrieve the item with a null algorithm - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", null, "kid1"); // Ensure the item was correctly retrieved assertNotNull(result); @@ -91,7 +103,7 @@ void testClearCache() { kasKeyCache.clear(); // Attempt to retrieve the item after clearing the cache - Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); + Config.KASInfo result = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure the item was not retrieved (the cache should be empty) assertNull(result); @@ -104,8 +116,8 @@ void testStoreMultipleItemsAndGet() { kasKeyCache.store(kasInfo2); // Retrieve each item and ensure they were correctly stored and retrieved - Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048"); - Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1"); + Config.KASInfo result1 = kasKeyCache.get("https://example.com/kas1", "rsa:2048", "kid1"); + Config.KASInfo result2 = kasKeyCache.get("https://example.com/kas2", "ec:secp256r1", "kid2"); assertNotNull(result1); assertEquals("https://example.com/kas1", result1.URL); @@ -119,8 +131,8 @@ void testStoreMultipleItemsAndGet() { @Test void testEqualsAndHashCode() { // Create two identical KASKeyRequest objects - KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); - KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048"); + KASKeyRequest keyRequest1 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); + KASKeyRequest keyRequest2 = new KASKeyRequest("https://example.com/kas1", "rsa:2048", "kid1"); // Ensure that equals and hashCode work as expected assertEquals(keyRequest1, keyRequest2); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java index 5e428c79..a79a42a1 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/NanoTDFTest.java @@ -1,7 +1,7 @@ package io.opentdf.platform.sdk; -import com.connectrpc.ResponseMessage; -import com.connectrpc.UnaryBlockingCall; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; import io.opentdf.platform.policy.KeyAccessServer; import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest; @@ -11,18 +11,20 @@ import java.nio.charset.StandardCharsets; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; import org.apache.commons.io.output.ByteArrayOutputStream; +import org.apache.commons.lang3.NotImplementedException; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.nio.ByteBuffer; -import java.security.KeyPair; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Base64; -import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Random; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -45,20 +47,35 @@ public class NanoTDFTest { "oVP7Vpcx\n" + "-----END PRIVATE KEY-----"; + private static final String BASE_PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\n" + + "MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEzcPM21p19N27oyMG6i3KrrDdgyiNDLgN\n" + + "rTJMvZRj3V/48ysDKw5jzngtzVLpkur/l5XBGjEiCr0aGlSo8Db0YVYFRm5g74P2\n" + + "nt8xOedHSPGYq0NMDt4Il6bt1rNUSVm+\n" + + "-----END PUBLIC KEY-----\n"; + + private static final String BASE_KEY_PRIVATE = "-----BEGIN PRIVATE KEY-----\n" + + "MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDA8zHdzkI8OhDojcRcQ\n" + + "xMqV3Gs+XGeQ/9SKbIRsyab0jJhMAEhRqzF3DQ4RSipIIX6hZANiAATNw8zbWnX0\n" + + "3bujIwbqLcqusN2DKI0MuA2tMky9lGPdX/jzKwMrDmPOeC3NUumS6v+XlcEaMSIK\n" + + "vRoaVKjwNvRhVgVGbmDvg/ae3zE550dI8ZirQ0wO3giXpu3Ws1RJWb4=\n" + + "-----END PRIVATE KEY-----\n"; + private static final String KID = "r1"; + private static final String BASE_KID = "basekid"; protected static KeyAccessServerRegistryServiceClient kasRegistryService; protected static List registeredKases = List.of( "https://api.example.com/kas", "https://other.org/kas2", "http://localhost:8181/kas", - "https://localhost:8383/kas" + "https://localhost:8383/kas", + "https://api.kaswithbasekey.example.com" ); protected static String platformUrl = "http://localhost:8080"; protected static SDK.KAS kas = new SDK.KAS() { @Override - public void close() throws Exception { + public void close() { } @Override @@ -70,10 +87,16 @@ public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { @Override public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) { + var k2 = kasInfo.clone(); + if (Objects.equals(kasInfo.KID, BASE_KID)) { + assertThat(kasInfo.URL).isEqualTo("https://api.kaswithbasekey.example.com"); + assertThat(kasInfo.Algorithm).isEqualTo("ec:secp384r1"); + k2.PublicKey = BASE_PUBLIC_KEY; + return k2; + } if (kasInfo.Algorithm != null && !"ec:secp256r1".equals(kasInfo.Algorithm)) { throw new IllegalArgumentException("Unexpected algorithm: " + kasInfo); } - var k2 = kasInfo.clone(); k2.KID = KID; k2.PublicKey = kasPublicKey; return k2; @@ -81,18 +104,12 @@ public KASInfo getECPublicKey(Config.KASInfo kasInfo, NanoTDFType.ECCurve curve) @Override public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy, KeyType sessionKeyType) { - int index = Integer.parseInt(keyAccess.url); - var decryptor = new AsymDecryption(keypairs.get(index).getPrivate()); - var bytes = Base64.getDecoder().decode(keyAccess.wrappedKey); - try { - return decryptor.decrypt(bytes); - } catch (Exception e) { - throw new RuntimeException(e); - } + throw new NotImplementedException("no unwrapping ZTDFs here"); } @Override public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL) { + String key = Objects.equals(kasURL, "https://api.kaswithbasekey.example.com") ? BASE_KEY_PRIVATE : kasPrivateKey; byte[] headerAsBytes = Base64.getDecoder().decode(header); Header nTDFHeader = new Header(ByteBuffer.wrap(headerAsBytes)); @@ -102,7 +119,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas // Generate symmetric key byte[] symmetricKey = ECKeyPair.computeECDHKey(ECKeyPair.publicKeyFromPem(publicKeyAsPem), - ECKeyPair.privateKeyFromPem(kasPrivateKey)); + ECKeyPair.privateKeyFromPem(key)); // Generate HKDF key MessageDigest digest; @@ -112,8 +129,7 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas throw new SDKException("error creating SHA-256 message digest", e); } byte[] hashOfSalt = digest.digest(NanoTDF.MAGIC_NUMBER_AND_VERSION); - byte[] key = ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey); - return key; + return ECKeyPair.calculateHKDF(hashOfSalt, symmetricKey); } @Override @@ -136,21 +152,9 @@ static void setupMocks() { // Stub the listKeyAccessServers method when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any())) - .thenReturn(new UnaryBlockingCall<>() { - @Override - public ResponseMessage execute() { - return new ResponseMessage.Success<>(mockResponse, Collections.emptyMap(), Collections.emptyMap()); - } - - @Override - public void cancel() { - // this never happens in tests - } - }); + .thenReturn(TestUtil.successfulUnaryCall(mockResponse)); } - private static ArrayList keypairs = new ArrayList<>(); - @Test void encryptionAndDecryptionWithValidKey() throws Exception { var kasInfos = new ArrayList<>(); @@ -201,6 +205,39 @@ void encryptionAndDecryptionWithValidKey() throws Exception { } } + @Test + void encryptionAndDecryptWithBaseKey() throws Exception { + var baseKeyJson = "{\"kas_url\":\"https://api.kaswithbasekey.example.com\",\"public_key\":{\"algorithm\":\"ALGORITHM_EC_P384\",\"kid\":\"thebasekid\",\"pem\": \"" + BASE_PUBLIC_KEY + "\"}}"; + var val = Value.newBuilder().setStringValue(baseKeyJson).build(); + var config = Struct.newBuilder().putFields("base_key", val).build(); + WellKnownServiceClientInterface wellknown = mock(WellKnownServiceClientInterface.class); + GetWellKnownConfigurationResponse response = GetWellKnownConfigurationResponse.newBuilder().setConfiguration(config).build(); + + when(wellknown.getWellKnownConfigurationBlocking(any(), any())).thenReturn(TestUtil.successfulUnaryCall(response)); + + Config.NanoTDFConfig nanoConfig = Config.newNanoTDFConfig( + Config.witDataAttributes("https://example.com/attr/Classification/value/S", + "https://example.com/attr/Classification/value/X") + ); + + String plainText = "Virtru!!"; + ByteBuffer byteBuffer = ByteBuffer.wrap(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + NanoTDF nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).setWellknownService(wellknown).build()); + nanoTDF.createNanoTDF(byteBuffer, tdfOutputStream, nanoConfig); + + byte[] nanoTDFBytes = tdfOutputStream.toByteArray(); + ByteArrayOutputStream plainTextStream = new ByteArrayOutputStream(); + nanoTDF = new NanoTDF(new FakeServicesBuilder().setKas(kas).setKeyAccessServerRegistryService(kasRegistryService).build()); + nanoTDF.readNanoTDF(ByteBuffer.wrap(nanoTDFBytes), plainTextStream, platformUrl); + + String out = new String(plainTextStream.toByteArray(), StandardCharsets.UTF_8); + assertThat(out).isEqualTo(plainText); + // KAS KID + assertThat(new String(nanoTDFBytes, StandardCharsets.UTF_8)).contains(KID); + } + void runBasicTest(String kasUrl, boolean allowed, KeyAccessServerRegistryServiceClient kasReg, NanoTDFReaderConfig decryptConfig) throws Exception { var kasInfos = new ArrayList<>(); var kasInfo = new Config.KASInfo(); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java new file mode 100644 index 00000000..117db1a7 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/PlannerTest.java @@ -0,0 +1,269 @@ +package io.opentdf.platform.sdk; + +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.opentdf.platform.policy.Algorithm; +import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse; +import io.opentdf.platform.wellknownconfiguration.WellKnownServiceClientInterface; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +class PlannerTest { + + @Test + void fetchBaseKey() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{\"algorithm\":\"ALGORITHM_RSA_2048\",\"kid\":\"thekid\",\"pem\": \"thepem\"}}"; + var val = Value.newBuilder().setStringValue(baseKeyJson).build(); + var config = Struct.newBuilder().putFields("base_key", val).build(); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(config) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isNotEmpty(); + var simpleKasKey = baseKey.get(); + assertThat(simpleKasKey.getKasUri()).isEqualTo("https://example.com/base_key"); + assertThat(simpleKasKey.getPublicKey().getAlgorithm()).isEqualTo(Algorithm.ALGORITHM_RSA_2048); + assertThat(simpleKasKey.getPublicKey().getKid()).isEqualTo("thekid"); + assertThat(simpleKasKey.getPublicKey().getPem()).isEqualTo("thepem"); + } + + @Test + void fetchBaseKeyWithNoBaseKey() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(Struct.newBuilder().build()) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isEmpty(); + } + + @Test + void fetchBaseKeyWithMissingFields() { + var wellknownService = Mockito.mock(WellKnownServiceClientInterface.class); + // Missing 'kid', 'pem', and 'algorithm' in public_key + var baseKeyJson = "{\"kas_url\":\"https://example.com/base_key\",\"public_key\":{}}"; + var val = Value.newBuilder().setStringValue(baseKeyJson).build(); + var config = Struct.newBuilder().putFields("base_key", val).build(); + var response = GetWellKnownConfigurationResponse + .newBuilder() + .setConfiguration(config) + .build(); + + Mockito.when(wellknownService.getWellKnownConfigurationBlocking(Mockito.any(), Mockito.anyMap())) + .thenReturn(TestUtil.successfulUnaryCall(response)); + + var baseKey = Planner.fetchBaseKey(wellknownService); + assertThat(baseKey).isEmpty(); + } + + @Test + void generatePlanFromProvidedKases() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.KID = "kid1"; + kas1.Algorithm = "rsa:2048"; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.KID = "kid2"; + kas2.Algorithm = "ec:secp256"; + + var tdfConfig = new Config.TDFConfig(); + tdfConfig.kasInfoList.add(kas1); + tdfConfig.kasInfoList.add(kas2); + + var planner = new Planner(tdfConfig, new FakeServicesBuilder().build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + List splitPlan = planner.generatePlanFromProvidedKases(tdfConfig.kasInfoList); + + assertThat(splitPlan).asList().hasSize(2); + assertThat(splitPlan.get(0).kas).isEqualTo("https://kas1.example.com"); + assertThat(splitPlan.get(0).kid).isEqualTo("kid1"); + + assertThat(splitPlan.get(1).kas).isEqualTo("https://kas2.example.com"); + assertThat(splitPlan.get(1).kid).isEqualTo("kid2"); + + assertThat(splitPlan.get(0).splitID).isNotEqualTo(splitPlan.get(1).splitID); + } + + @Test + void testFillingInKeysWithAutoConfigure() { + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> { + Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class); + var ret = new Config.KASInfo(); + ret.URL = kasInfo.URL; + if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) { + ret.PublicKey = "pem1"; + ret.Algorithm = "rsa:2048"; + ret.KID = "kid1"; + } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) { + ret.PublicKey = "pem2"; + ret.Algorithm = "ec:secp256r1"; + ret.KID = "kid2"; + } else if (Objects.equals(kasInfo.URL, "https://kas3.example.com")) { + ret.PublicKey = "pem3"; + ret.Algorithm = "rsa:4096"; + ret.KID = "kid3"; + } else { + throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL); + } + return ret; + }); + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = true; + tdfConfig.wrappingKeyType = KeyType.RSA2048Key; + tdfConfig.kasInfoList = List.of( + new Config.KASInfo() {{ + URL = "https://kas4.example.com"; + KID = "kid4"; + Algorithm = "ec:secp384r1"; + PublicKey = "pem4"; + }} + ); + var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + var plan = List.of( + new Autoconfigure.KeySplitStep("https://kas1.example.com", "split1", null), + new Autoconfigure.KeySplitStep("https://kas4.example.com", "split1", "kid4"), + new Autoconfigure.KeySplitStep("https://kas2.example.com", "split2", "kid2"), + new Autoconfigure.KeySplitStep("https://kas3.example.com", "split2", "kid3") + ); + Map> filledInPlan = planner.resolveKeys(plan); + assertThat(filledInPlan.keySet().stream().collect(Collectors.toList())).asList().containsExactlyInAnyOrder("split1", "split2"); + assertThat(filledInPlan.get("split1")).asList().hasSize(2); + var kasInfo1 = filledInPlan.get("split1").stream().filter(k -> "kid1".equals(k.KID)).findFirst().get(); + assertThat(kasInfo1.URL).isEqualTo("https://kas1.example.com"); + assertThat(kasInfo1.Algorithm).isEqualTo("rsa:2048"); + assertThat(kasInfo1.PublicKey).isEqualTo("pem1"); + var kasInfo4 = filledInPlan.get("split1").stream().filter(k -> "kid4".equals(k.KID)).findFirst().get(); + assertThat(kasInfo4.URL).isEqualTo("https://kas4.example.com"); + assertThat(kasInfo4.Algorithm).isEqualTo("ec:secp384r1"); + assertThat(kasInfo4.PublicKey).isEqualTo("pem4"); + + assertThat(filledInPlan.get("split2")).asList().hasSize(2); + var kasInfo2 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid2".equals(kasInfo.KID)).findFirst().get(); + assertThat(kasInfo2.URL).isEqualTo("https://kas2.example.com"); + assertThat(kasInfo2.Algorithm).isEqualTo("ec:secp256r1"); + assertThat(kasInfo2.PublicKey).isEqualTo("pem2"); + var kasInfo3 = filledInPlan.get("split2").stream().filter(kasInfo -> "kid3".equals(kasInfo.KID)).findFirst().get(); + assertThat(kasInfo3.URL).isEqualTo("https://kas3.example.com"); + assertThat(kasInfo3.Algorithm).isEqualTo("rsa:4096"); + assertThat(kasInfo3.PublicKey).isEqualTo("pem3"); + } + + @Test + void returnsOnlyDefaultKasesIfPresent() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.Default = true; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.Default = false; + + var kas3 = new Config.KASInfo(); + kas3.URL = "https://kas3.example.com"; + kas3.Default = true; + + var config = new Config.TDFConfig(); + config.kasInfoList.addAll(List.of(kas1, kas2, kas3)); + + List result = Planner.defaultKases(config); + + Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas3.example.com"); + } + + @Test + void returnsAllKasesIfNoDefault() { + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.Default = false; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.Default = null; // not set + + var config = new Config.TDFConfig(); + config.kasInfoList.addAll(List.of(kas1, kas2)); + + List result = Planner.defaultKases(config); + Assertions.assertThat(result).containsExactlyInAnyOrder("https://kas1.example.com", "https://kas2.example.com"); + } + + @Test + void returnsEmptyListIfNoKases() { + var config = new Config.TDFConfig(); + List result = Planner.defaultKases(config); + Assertions.assertThat(result).isEmpty(); + } + + @Test + void usesProvidedSplitPlanWhenNotAutoconfigure() { + var kas = Mockito.mock(SDK.KAS.class); + Mockito.when(kas.getPublicKey(Mockito.any())).thenAnswer(invocation -> { + Config.KASInfo kasInfo = invocation.getArgument(0, Config.KASInfo.class); + var ret = new Config.KASInfo(); + ret.URL = kasInfo.URL; + if (Objects.equals(kasInfo.URL, "https://kas1.example.com")) { + ret.PublicKey = "pem1"; + ret.Algorithm = "rsa:2048"; + ret.KID = "kid1"; + } else if (Objects.equals(kasInfo.URL, "https://kas2.example.com")) { + ret.PublicKey = "pem2"; + ret.Algorithm = "ec:secp256r1"; + ret.KID = "kid2"; + } else { + throw new IllegalArgumentException("Unexpected KAS URL: " + kasInfo.URL); + } + return ret; + }); + // Arrange + var kas1 = new Config.KASInfo(); + kas1.URL = "https://kas1.example.com"; + kas1.KID = "kid1"; + kas1.Algorithm = "rsa:2048"; + + var kas2 = new Config.KASInfo(); + kas2.URL = "https://kas2.example.com"; + kas2.KID = "kid2"; + kas2.Algorithm = "ec:secp256"; + + var splitStep1 = new Autoconfigure.KeySplitStep(kas1.URL, "split1", kas1.KID); + var splitStep2 = new Autoconfigure.KeySplitStep(kas2.URL, "split2", kas2.KID); + + var tdfConfig = new Config.TDFConfig(); + tdfConfig.autoconfigure = false; + tdfConfig.kasInfoList.add(kas1); + tdfConfig.kasInfoList.add(kas2); + tdfConfig.splitPlan = List.of(splitStep1, splitStep2); + + var planner = new Planner(tdfConfig, new FakeServicesBuilder().setKas(kas).build(), (ignore1, ignored2) -> { throw new IllegalArgumentException("no granter needed"); }); + + // Act + Map> splits = planner.getSplits(tdfConfig); + + // Assert + Assertions.assertThat(splits).hasSize(2); + Assertions.assertThat(splits.get("split1")).extracting("URL").containsExactly("https://kas1.example.com"); + Assertions.assertThat(splits.get("split2")).extracting("URL").containsExactly("https://kas2.example.com"); + } +} \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java new file mode 100644 index 00000000..53076007 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java @@ -0,0 +1,22 @@ +package io.opentdf.platform.sdk; + +import com.connectrpc.ResponseMessage; +import com.connectrpc.UnaryBlockingCall; + +import java.util.Collections; + +public class TestUtil { + static UnaryBlockingCall successfulUnaryCall(T result) { + return new UnaryBlockingCall() { + @Override + public ResponseMessage execute() { + return new ResponseMessage.Success<>(result, Collections.emptyMap(), Collections.emptyMap()); + } + + @Override + public void cancel() { + // in tests we don't need to preserve server resources, so no-op + } + }; + } +}