Skip to content

Commit 79a2dfe

Browse files
committed
feat: add extensible transform API for custom field transformations
Signed-off-by: Evgenii Orlov <[email protected]>
1 parent 8da9902 commit 79a2dfe

File tree

3 files changed

+334
-9
lines changed

3 files changed

+334
-9
lines changed

transform/canonical.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ import (
2020
"github.com/compose-spec/compose-go/v2/tree"
2121
)
2222

23-
type transformFunc func(data any, p tree.Path, ignoreParseError bool) (any, error)
23+
// Func is a function that can transform data at a specific path
24+
type Func func(data any, p tree.Path, ignoreParseError bool) (any, error)
2425

25-
var transformers = map[tree.Path]transformFunc{}
26+
var transformers = map[tree.Path]Func{}
2627

2728
func init() {
2829
transformers["services.*"] = transformService

transform/defaults.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ import (
2020
"github.com/compose-spec/compose-go/v2/tree"
2121
)
2222

23-
var defaultValues = map[tree.Path]transformFunc{}
23+
// DefaultValues contains the default value transformers for compose fields
24+
var DefaultValues = map[tree.Path]Func{}
2425

2526
func init() {
26-
defaultValues["services.*.build"] = defaultBuildContext
27-
defaultValues["services.*.secrets.*"] = defaultSecretMount
28-
defaultValues["services.*.ports.*"] = portDefaults
29-
defaultValues["services.*.deploy.resources.reservations.devices.*"] = deviceRequestDefaults
30-
defaultValues["services.*.gpus.*"] = deviceRequestDefaults
27+
DefaultValues["services.*.build"] = defaultBuildContext
28+
DefaultValues["services.*.secrets.*"] = defaultSecretMount
29+
DefaultValues["services.*.ports.*"] = portDefaults
30+
DefaultValues["services.*.deploy.resources.reservations.devices.*"] = deviceRequestDefaults
31+
DefaultValues["services.*.gpus.*"] = deviceRequestDefaults
32+
}
33+
34+
// RegisterDefaultValue registers a custom transformer for the given path pattern
35+
func RegisterDefaultValue(path string, transformer Func) {
36+
DefaultValues[tree.Path(path)] = transformer
3137
}
3238

3339
// SetDefaultValues transforms a compose model to set default values to missing attributes
@@ -40,7 +46,7 @@ func SetDefaultValues(yaml map[string]any) (map[string]any, error) {
4046
}
4147

4248
func setDefaults(data any, p tree.Path) (any, error) {
43-
for pattern, transformer := range defaultValues {
49+
for pattern, transformer := range DefaultValues {
4450
if p.Matches(pattern) {
4551
t, err := transformer(data, p, false)
4652
if err != nil {

transform/defaults_test.go

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package transform
18+
19+
import (
20+
"fmt"
21+
"strings"
22+
"testing"
23+
24+
"github.com/compose-spec/compose-go/v2/tree"
25+
"go.yaml.in/yaml/v3"
26+
"gotest.tools/v3/assert"
27+
)
28+
29+
// exampleEnhancedPortDefaults is an example of an enhanced port defaults function
30+
// that supports additional protocols and domain names in the published field.
31+
// This function can be used as a replacement for the default portDefaults
32+
// by registering it with RegisterDefaultValue.
33+
//
34+
// Example usage:
35+
//
36+
// RegisterDefaultValue("services.*.ports.*", exampleEnhancedPortDefaults)
37+
//
38+
// This function supports:
39+
// - Additional protocols: "http", "https", "tcp", "udp"
40+
// - Domain names in the published field (e.g., "example.com:80")
41+
// - All existing functionality of the original portDefaults
42+
func exampleEnhancedPortDefaults(data any, _ tree.Path, _ bool) (any, error) {
43+
switch v := data.(type) {
44+
case map[string]any:
45+
// Set default protocol if not specified
46+
if _, ok := v["protocol"]; !ok {
47+
v["protocol"] = "tcp"
48+
}
49+
50+
// Set default mode if not specified
51+
if _, ok := v["mode"]; !ok {
52+
v["mode"] = "ingress"
53+
}
54+
55+
// Enhanced protocol handling with app_protocol
56+
protocol, _ := v["protocol"].(string)
57+
switch protocol {
58+
case "http":
59+
if _, ok := v["app_protocol"]; !ok {
60+
v["app_protocol"] = "http1.1"
61+
}
62+
case "https":
63+
if _, ok := v["app_protocol"]; !ok {
64+
v["app_protocol"] = "http2"
65+
}
66+
}
67+
68+
// Auto-generate port name based on target port
69+
if _, ok := v["name"]; !ok {
70+
if target, ok := v["target"].(int); ok {
71+
switch target {
72+
case 80:
73+
v["name"] = "http"
74+
case 443:
75+
v["name"] = "https"
76+
case 3306:
77+
v["name"] = "mysql"
78+
case 5432:
79+
v["name"] = "postgres"
80+
case 6379:
81+
v["name"] = "redis"
82+
case 27017:
83+
v["name"] = "mongodb"
84+
default:
85+
v["name"] = fmt.Sprintf("port-%d", target)
86+
}
87+
}
88+
}
89+
90+
// Handle domain names in published field
91+
if published, ok := v["published"].(string); ok {
92+
// Check if published contains a domain name, very KISS check for this example
93+
if strings.Contains(published, ".") && strings.Contains(published, ":") {
94+
parts := strings.SplitN(published, ":", 2)
95+
if len(parts) == 2 {
96+
v["x-published-domain"] = parts[0]
97+
v["x-published-port"] = parts[1]
98+
}
99+
}
100+
}
101+
102+
// Normalize host_ip shortcuts
103+
if hostIP, ok := v["host_ip"].(string); ok {
104+
switch hostIP {
105+
case "localhost":
106+
v["host_ip"] = "127.0.0.1"
107+
case "*":
108+
v["host_ip"] = "0.0.0.0"
109+
}
110+
}
111+
112+
// Add monitoring metadata for common ports
113+
if target, ok := v["target"].(int); ok {
114+
switch target {
115+
case 80, 443, 8080, 8443:
116+
v["x-metrics-enabled"] = true
117+
v["x-metrics-path"] = "/metrics"
118+
case 9090: // Prometheus
119+
v["x-metrics-enabled"] = true
120+
v["x-metrics-type"] = "prometheus"
121+
}
122+
}
123+
124+
return v, nil
125+
default:
126+
return data, nil
127+
}
128+
}
129+
130+
func TestRegisterDefaultValue(t *testing.T) {
131+
// Save original transformers, so as not to break possible other tests
132+
originalTransformers := make(map[tree.Path]Func)
133+
for k, v := range DefaultValues {
134+
originalTransformers[k] = v
135+
}
136+
t.Cleanup(func() {
137+
DefaultValues = originalTransformers
138+
})
139+
140+
// Register the function
141+
RegisterDefaultValue("services.*.ports.*", exampleEnhancedPortDefaults)
142+
143+
testCases := []struct {
144+
name string
145+
inputYAML string
146+
expectedYAML string
147+
}{
148+
{
149+
name: "basic port with defaults and auto-generated name",
150+
inputYAML: `
151+
services:
152+
web:
153+
ports:
154+
- target: 80
155+
`,
156+
expectedYAML: `
157+
services:
158+
web:
159+
ports:
160+
- target: 80
161+
protocol: tcp
162+
mode: ingress
163+
name: http
164+
x-metrics-enabled: true
165+
x-metrics-path: /metrics
166+
`,
167+
},
168+
{
169+
name: "port with https protocol and app_protocol",
170+
inputYAML: `
171+
services:
172+
web:
173+
ports:
174+
- target: 443
175+
protocol: https
176+
`,
177+
expectedYAML: `
178+
services:
179+
web:
180+
ports:
181+
- target: 443
182+
protocol: https
183+
app_protocol: http2
184+
mode: ingress
185+
name: https
186+
x-metrics-enabled: true
187+
x-metrics-path: /metrics
188+
`,
189+
},
190+
{
191+
name: "port with domain name in published field",
192+
inputYAML: `
193+
services:
194+
web:
195+
ports:
196+
- target: 80
197+
published: "example.com:8080"
198+
`,
199+
expectedYAML: `
200+
services:
201+
web:
202+
ports:
203+
- target: 80
204+
published: "example.com:8080"
205+
protocol: tcp
206+
mode: ingress
207+
name: http
208+
x-published-domain: example.com
209+
x-published-port: "8080"
210+
x-metrics-enabled: true
211+
x-metrics-path: /metrics
212+
`,
213+
},
214+
{
215+
name: "database port with auto-generated name",
216+
inputYAML: `
217+
services:
218+
db:
219+
ports:
220+
- target: 3306
221+
`,
222+
expectedYAML: `
223+
services:
224+
db:
225+
ports:
226+
- target: 3306
227+
protocol: tcp
228+
mode: ingress
229+
name: mysql
230+
`,
231+
},
232+
{
233+
name: "host_ip normalization",
234+
inputYAML: `
235+
services:
236+
web:
237+
ports:
238+
- target: 8080
239+
host_ip: localhost
240+
- target: 8081
241+
host_ip: "*"
242+
`,
243+
expectedYAML: `
244+
services:
245+
web:
246+
ports:
247+
- target: 8080
248+
host_ip: "127.0.0.1"
249+
protocol: tcp
250+
mode: ingress
251+
name: port-8080
252+
x-metrics-enabled: true
253+
x-metrics-path: /metrics
254+
- target: 8081
255+
host_ip: "0.0.0.0"
256+
protocol: tcp
257+
mode: ingress
258+
name: port-8081
259+
`,
260+
},
261+
{
262+
name: "prometheus port with monitoring metadata",
263+
inputYAML: `
264+
services:
265+
prometheus:
266+
ports:
267+
- target: 9090
268+
`,
269+
expectedYAML: `
270+
services:
271+
prometheus:
272+
ports:
273+
- target: 9090
274+
protocol: tcp
275+
mode: ingress
276+
name: port-9090
277+
x-metrics-enabled: true
278+
x-metrics-type: prometheus
279+
`,
280+
},
281+
{
282+
name: "http protocol with app_protocol",
283+
inputYAML: `
284+
services:
285+
web:
286+
ports:
287+
- target: 3000
288+
protocol: http
289+
`,
290+
expectedYAML: `
291+
services:
292+
web:
293+
ports:
294+
- target: 3000
295+
protocol: http
296+
app_protocol: http1.1
297+
mode: ingress
298+
name: port-3000
299+
`,
300+
},
301+
}
302+
303+
for _, tc := range testCases {
304+
t.Run(tc.name, func(t *testing.T) {
305+
var input map[string]any
306+
err := yaml.Unmarshal([]byte(tc.inputYAML), &input)
307+
assert.NilError(t, err)
308+
309+
var expected map[string]any
310+
err = yaml.Unmarshal([]byte(tc.expectedYAML), &expected)
311+
assert.NilError(t, err)
312+
313+
result, err := SetDefaultValues(input)
314+
assert.NilError(t, err)
315+
assert.DeepEqual(t, result, expected)
316+
})
317+
}
318+
}

0 commit comments

Comments
 (0)