Skip to content

Commit c22999a

Browse files
committed
feat: add extensible transform API for custom field transformations
Signed-off-by: Evgenii Orlov <[email protected]>
1 parent 8da9902 commit c22999a

File tree

3 files changed

+337
-9
lines changed

3 files changed

+337
-9
lines changed

transform/canonical.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ import (
2020
"github.com/compose-spec/compose-go/v2/tree"
2121
)
2222

23-
type transformFunc func(data any, p tree.Path, ignoreParseError bool) (any, error)
23+
// Func is a function that can transform data at a specific path
24+
type Func func(data any, p tree.Path, ignoreParseError bool) (any, error)
2425

25-
var transformers = map[tree.Path]transformFunc{}
26+
var transformers = map[tree.Path]Func{}
2627

2728
func init() {
2829
transformers["services.*"] = transformService

transform/defaults.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ import (
2020
"github.com/compose-spec/compose-go/v2/tree"
2121
)
2222

23-
var defaultValues = map[tree.Path]transformFunc{}
23+
// DefaultValues contains the default value transformers for compose fields
24+
var DefaultValues = map[tree.Path]Func{}
2425

2526
func init() {
26-
defaultValues["services.*.build"] = defaultBuildContext
27-
defaultValues["services.*.secrets.*"] = defaultSecretMount
28-
defaultValues["services.*.ports.*"] = portDefaults
29-
defaultValues["services.*.deploy.resources.reservations.devices.*"] = deviceRequestDefaults
30-
defaultValues["services.*.gpus.*"] = deviceRequestDefaults
27+
DefaultValues["services.*.build"] = defaultBuildContext
28+
DefaultValues["services.*.secrets.*"] = defaultSecretMount
29+
DefaultValues["services.*.ports.*"] = portDefaults
30+
DefaultValues["services.*.deploy.resources.reservations.devices.*"] = deviceRequestDefaults
31+
DefaultValues["services.*.gpus.*"] = deviceRequestDefaults
32+
}
33+
34+
// RegisterDefaultValue registers a custom transformer for the given path pattern
35+
func RegisterDefaultValue(path string, transformer Func) {
36+
DefaultValues[tree.Path(path)] = transformer
3137
}
3238

3339
// SetDefaultValues transforms a compose model to set default values to missing attributes
@@ -40,7 +46,7 @@ func SetDefaultValues(yaml map[string]any) (map[string]any, error) {
4046
}
4147

4248
func setDefaults(data any, p tree.Path) (any, error) {
43-
for pattern, transformer := range defaultValues {
49+
for pattern, transformer := range DefaultValues {
4450
if p.Matches(pattern) {
4551
t, err := transformer(data, p, false)
4652
if err != nil {

transform/defaults_test.go

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/*
2+
Copyright 2020 The Compose Specification Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package transform
18+
19+
import (
20+
"fmt"
21+
"strings"
22+
"testing"
23+
24+
"github.com/compose-spec/compose-go/v2/tree"
25+
"go.yaml.in/yaml/v3"
26+
"gotest.tools/v3/assert"
27+
)
28+
29+
// enhancedPortDefaults is an example of an enhanced port defaults function
30+
// that supports additional protocols and domain names in the published field.
31+
// This function can be used as a replacement for the default portDefaults
32+
// by registering it with RegisterDefaultValue.
33+
//
34+
// Example usage:
35+
//
36+
// RegisterDefaultValue("services.*.ports.*", enhancedPortDefaults)
37+
//
38+
// This function supports:
39+
// - Additional protocols: "http", "https", "tcp", "udp"
40+
// - Domain names in the published field (e.g., "example.com:80")
41+
// - All existing functionality of the original portDefaults
42+
func exampleEnhancedPortDefaults(data any, _ tree.Path, _ bool) (any, error) {
43+
switch v := data.(type) {
44+
case map[string]any:
45+
// Set default protocol if not specified
46+
if _, ok := v["protocol"]; !ok {
47+
v["protocol"] = "tcp"
48+
}
49+
50+
// Set default mode if not specified
51+
if _, ok := v["mode"]; !ok {
52+
v["mode"] = "ingress"
53+
}
54+
55+
// Enhanced protocol handling with app_protocol
56+
protocol, _ := v["protocol"].(string)
57+
switch protocol {
58+
case "http":
59+
if _, ok := v["app_protocol"]; !ok {
60+
v["app_protocol"] = "http1.1"
61+
}
62+
case "https":
63+
if _, ok := v["app_protocol"]; !ok {
64+
v["app_protocol"] = "http2"
65+
}
66+
}
67+
68+
// Auto-generate port name based on target port
69+
if _, ok := v["name"]; !ok {
70+
if target, ok := v["target"].(int); ok {
71+
switch target {
72+
case 80:
73+
v["name"] = "http"
74+
case 443:
75+
v["name"] = "https"
76+
case 3306:
77+
v["name"] = "mysql"
78+
case 5432:
79+
v["name"] = "postgres"
80+
case 6379:
81+
v["name"] = "redis"
82+
case 27017:
83+
v["name"] = "mongodb"
84+
default:
85+
v["name"] = fmt.Sprintf("port-%d", target)
86+
}
87+
}
88+
}
89+
90+
// Handle domain names in published field
91+
if published, ok := v["published"].(string); ok {
92+
// Check if published contains a domain name (simple check for illustration)
93+
if strings.Contains(published, ".") && strings.Contains(published, ":") {
94+
// Extract domain and port
95+
parts := strings.SplitN(published, ":", 2)
96+
if len(parts) == 2 {
97+
// Store domain info in custom fields
98+
v["x-published-domain"] = parts[0]
99+
v["x-published-port"] = parts[1]
100+
}
101+
}
102+
}
103+
104+
// Normalize host_ip shortcuts
105+
if hostIP, ok := v["host_ip"].(string); ok {
106+
switch hostIP {
107+
case "localhost":
108+
v["host_ip"] = "127.0.0.1"
109+
case "*":
110+
v["host_ip"] = "0.0.0.0"
111+
}
112+
}
113+
114+
// Add monitoring metadata for common ports
115+
if target, ok := v["target"].(int); ok {
116+
switch target {
117+
case 80, 443, 8080, 8443:
118+
v["x-metrics-enabled"] = true
119+
v["x-metrics-path"] = "/metrics"
120+
case 9090: // Prometheus
121+
v["x-metrics-enabled"] = true
122+
v["x-metrics-type"] = "prometheus"
123+
}
124+
}
125+
126+
return v, nil
127+
default:
128+
return data, nil
129+
}
130+
}
131+
132+
func TestRegisterDefaultValue(t *testing.T) {
133+
// Save original transformers, so as not to break possible other tests
134+
originalTransformers := make(map[tree.Path]Func)
135+
for k, v := range DefaultValues {
136+
originalTransformers[k] = v
137+
}
138+
t.Cleanup(func() {
139+
DefaultValues = originalTransformers
140+
})
141+
142+
// Register the enhanced port defaults
143+
RegisterDefaultValue("services.*.ports.*", exampleEnhancedPortDefaults)
144+
145+
// Test with various port configurations
146+
testCases := []struct {
147+
name string
148+
inputYAML string
149+
expectedYAML string
150+
}{
151+
{
152+
name: "basic port with defaults and auto-generated name",
153+
inputYAML: `
154+
services:
155+
web:
156+
ports:
157+
- target: 80
158+
`,
159+
expectedYAML: `
160+
services:
161+
web:
162+
ports:
163+
- target: 80
164+
protocol: tcp
165+
mode: ingress
166+
name: http
167+
x-metrics-enabled: true
168+
x-metrics-path: /metrics
169+
`,
170+
},
171+
{
172+
name: "port with https protocol and app_protocol",
173+
inputYAML: `
174+
services:
175+
web:
176+
ports:
177+
- target: 443
178+
protocol: https
179+
`,
180+
expectedYAML: `
181+
services:
182+
web:
183+
ports:
184+
- target: 443
185+
protocol: https
186+
app_protocol: http2
187+
mode: ingress
188+
name: https
189+
x-metrics-enabled: true
190+
x-metrics-path: /metrics
191+
`,
192+
},
193+
{
194+
name: "port with domain name in published field",
195+
inputYAML: `
196+
services:
197+
web:
198+
ports:
199+
- target: 80
200+
published: "example.com:8080"
201+
`,
202+
expectedYAML: `
203+
services:
204+
web:
205+
ports:
206+
- target: 80
207+
published: "example.com:8080"
208+
protocol: tcp
209+
mode: ingress
210+
name: http
211+
x-published-domain: example.com
212+
x-published-port: "8080"
213+
x-metrics-enabled: true
214+
x-metrics-path: /metrics
215+
`,
216+
},
217+
{
218+
name: "database port with auto-generated name",
219+
inputYAML: `
220+
services:
221+
db:
222+
ports:
223+
- target: 3306
224+
`,
225+
expectedYAML: `
226+
services:
227+
db:
228+
ports:
229+
- target: 3306
230+
protocol: tcp
231+
mode: ingress
232+
name: mysql
233+
`,
234+
},
235+
{
236+
name: "host_ip normalization",
237+
inputYAML: `
238+
services:
239+
web:
240+
ports:
241+
- target: 8080
242+
host_ip: localhost
243+
- target: 8081
244+
host_ip: "*"
245+
`,
246+
expectedYAML: `
247+
services:
248+
web:
249+
ports:
250+
- target: 8080
251+
host_ip: "127.0.0.1"
252+
protocol: tcp
253+
mode: ingress
254+
name: port-8080
255+
x-metrics-enabled: true
256+
x-metrics-path: /metrics
257+
- target: 8081
258+
host_ip: "0.0.0.0"
259+
protocol: tcp
260+
mode: ingress
261+
name: port-8081
262+
`,
263+
},
264+
{
265+
name: "prometheus port with monitoring metadata",
266+
inputYAML: `
267+
services:
268+
prometheus:
269+
ports:
270+
- target: 9090
271+
`,
272+
expectedYAML: `
273+
services:
274+
prometheus:
275+
ports:
276+
- target: 9090
277+
protocol: tcp
278+
mode: ingress
279+
name: port-9090
280+
x-metrics-enabled: true
281+
x-metrics-type: prometheus
282+
`,
283+
},
284+
{
285+
name: "http protocol with app_protocol",
286+
inputYAML: `
287+
services:
288+
web:
289+
ports:
290+
- target: 3000
291+
protocol: http
292+
`,
293+
expectedYAML: `
294+
services:
295+
web:
296+
ports:
297+
- target: 3000
298+
protocol: http
299+
app_protocol: http1.1
300+
mode: ingress
301+
name: port-3000
302+
`,
303+
},
304+
}
305+
306+
for _, tc := range testCases {
307+
t.Run(tc.name, func(t *testing.T) {
308+
var input map[string]any
309+
err := yaml.Unmarshal([]byte(tc.inputYAML), &input)
310+
assert.NilError(t, err)
311+
312+
var expected map[string]any
313+
err = yaml.Unmarshal([]byte(tc.expectedYAML), &expected)
314+
assert.NilError(t, err)
315+
316+
result, err := SetDefaultValues(input)
317+
assert.NilError(t, err)
318+
assert.DeepEqual(t, result, expected)
319+
})
320+
}
321+
}

0 commit comments

Comments
 (0)