Skip to content

Commit f49bfdf

Browse files
committed
will be merged: fix kms address parsing, add shuffle
1 parent 394df55 commit f49bfdf

File tree

2 files changed

+107
-17
lines changed

2 files changed

+107
-17
lines changed

kms.go

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"crypto/aes"
66
"io/ioutil"
7+
"math/rand"
78

89
"encoding/base64"
910
"encoding/json"
@@ -51,38 +52,74 @@ func (c *Client) kmsAuth(url string) error {
5152
return nil
5253
}
5354

54-
func kmsSplit(keyProvider string) []string {
55-
return strings.Split(keyProvider, ",")
56-
}
55+
// parse uri like kms://[email protected];kms02.example.com:9600/kms
56+
func kmsParseProviderUri(uri string) ([]string, error) {
57+
original_uri := uri
5758

58-
// kmsUrl parse KeyProviderUri to list of URL's
59-
func (c *Client) kmsUrl(einfo *hdfs.FileEncryptionInfoProto) ([]string, error) {
60-
defaults, err := c.fetchDefaults()
61-
if err != nil {
62-
return nil, err
63-
}
64-
65-
uri := defaults.GetKeyProviderUri()
6659
if uri == "" {
67-
return nil, errors.New("KeyProviderUri not configured on server")
60+
return nil, errors.New("KeyProviderUri empty. not configured on server ?")
6861
}
6962

7063
var urls []string
7164
var proto string
7265
if strings.HasPrefix(uri, kmsSchemeHTTPS) {
7366
proto = "https://"
74-
urls = kmsSplit(uri[len(kmsSchemeHTTPS):])
67+
uri = uri[len(kmsSchemeHTTPS):]
7568
}
7669
if proto == "" && strings.HasPrefix(uri, kmsSchemeHTTP) {
7770
proto = "http://"
78-
urls = kmsSplit(uri[len(kmsSchemeHTTP):])
71+
uri = uri[len(kmsSchemeHTTP):]
7972
}
8073
if proto == "" {
81-
return nil, fmt.Errorf("not supported scheme %v", uri)
74+
return nil, fmt.Errorf("not supported uri %v", original_uri)
75+
}
76+
77+
port := ":9600" // default kms port
78+
path := "" // default path
79+
80+
parts := strings.Split(uri, ";")
81+
for i, s := range parts {
82+
path_index := strings.Index(s, "/")
83+
if path_index > -1 {
84+
path = s[path_index:]
85+
s = s[:path_index]
86+
}
87+
port_index := strings.Index(s, ":")
88+
if port_index > -1 {
89+
port = s[port_index:]
90+
s = s[:port_index]
91+
}
92+
if (path_index > -1 || port_index > -1) && i+1 != len(parts) {
93+
return nil, fmt.Errorf("bad uri: %v", original_uri)
94+
}
95+
urls = append(urls, proto+s)
8296
}
8397

8498
for i := range urls {
85-
urls[i] = proto + urls[i] + "/v1/keyversion/" + url.QueryEscape(*einfo.EzKeyVersionName) + "/_eek?eek_op=decrypt"
99+
urls[i] += port
100+
urls[i] += path
101+
}
102+
103+
return urls, nil
104+
}
105+
106+
// kmsUrl parse KeyProviderUri to list of URL's
107+
func (c *Client) kmsUrl(einfo *hdfs.FileEncryptionInfoProto) ([]string, error) {
108+
defaults, err := c.fetchDefaults()
109+
if err != nil {
110+
return nil, err
111+
}
112+
113+
urls, err := kmsParseProviderUri(defaults.GetKeyProviderUri())
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
// Reorder urls. Simple method to round robin calls across em.
119+
rand.Shuffle(len(urls), func(i, j int) { urls[i], urls[j] = urls[j], urls[i] })
120+
121+
for i := range urls {
122+
urls[i] = urls[i] + "/v1/keyversion/" + url.QueryEscape(*einfo.EzKeyVersionName) + "/_eek?eek_op=decrypt"
86123
}
87124

88125
return urls, nil
@@ -156,7 +193,7 @@ func (c *Client) kmsGetKey(einfo *hdfs.FileEncryptionInfoProto) (*TransparentEnc
156193

157194
urls, err := c.kmsUrl(einfo)
158195
if err != nil {
159-
return nil, err
196+
return nil, errors.Wrap(err, "fail to get KMS address")
160197
}
161198

162199
requestBody, err := json.Marshal(map[string]string{

kms_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package hdfs
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestKmsParseProviderUri(t *testing.T) {
10+
assert.Equal(t, nil, nil)
11+
12+
urls, err := kmsParseProviderUri("")
13+
assert.Error(t, err)
14+
15+
urls, err = kmsParseProviderUri("http")
16+
assert.Error(t, err)
17+
18+
urls, err = kmsParseProviderUri("kms://https@localhost:9600/kms")
19+
assert.NoError(t, err)
20+
assert.Equal(t, 1, len(urls))
21+
assert.Equal(t, "https://localhost:9600/kms", urls[0])
22+
23+
urls, err = kmsParseProviderUri("kms://[email protected]:9600;kms02.example.com")
24+
assert.Error(t, err)
25+
26+
urls, err = kmsParseProviderUri("kms://[email protected]/kms;kms02.example.com")
27+
assert.Error(t, err)
28+
29+
urls, err = kmsParseProviderUri("kms://[email protected];kms02.example.com:9600/kms")
30+
assert.NoError(t, err)
31+
assert.Equal(t, 2, len(urls))
32+
assert.Equal(t, "http://kms01.example.com:9600/kms", urls[0])
33+
assert.Equal(t, "http://kms02.example.com:9600/kms", urls[1])
34+
35+
urls, err = kmsParseProviderUri("kms://[email protected];kms02.example.com/kms")
36+
assert.NoError(t, err)
37+
assert.Equal(t, 2, len(urls))
38+
assert.Equal(t, "http://kms01.example.com:9600/kms", urls[0])
39+
assert.Equal(t, "http://kms02.example.com:9600/kms", urls[1])
40+
41+
urls, err = kmsParseProviderUri("kms://[email protected];kms02.example.com:9600")
42+
assert.NoError(t, err)
43+
assert.Equal(t, 2, len(urls))
44+
assert.Equal(t, "http://kms01.example.com:9600", urls[0])
45+
assert.Equal(t, "http://kms02.example.com:9600", urls[1])
46+
47+
urls, err = kmsParseProviderUri("kms://[email protected];kms02.example.com;kms03.example.com")
48+
assert.NoError(t, err)
49+
assert.Equal(t, 3, len(urls))
50+
assert.Equal(t, "http://kms01.example.com:9600", urls[0])
51+
assert.Equal(t, "http://kms02.example.com:9600", urls[1])
52+
assert.Equal(t, "http://kms03.example.com:9600", urls[2])
53+
}

0 commit comments

Comments
 (0)