Skip to content

Commit 1701717

Browse files
committed
test: Add unit tests for Pipeline SpecPatch function
Signed-off-by: Giulio Frasca <[email protected]>
1 parent 902b586 commit 1701717

File tree

1 file changed

+248
-0
lines changed

1 file changed

+248
-0
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package argocompiler
16+
17+
import (
18+
"testing"
19+
20+
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
21+
"github.com/stretchr/testify/assert"
22+
corev1 "k8s.io/api/core/v1"
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24+
)
25+
26+
func TestApplyWorkflowSpecPatch(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
patchJSON string
30+
expectError bool
31+
validateFunc func(t *testing.T, wf *wfapi.Workflow)
32+
}{
33+
{
34+
name: "empty patch should skip patching",
35+
patchJSON: "{}",
36+
expectError: false,
37+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
38+
// Should remain unchanged
39+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
40+
},
41+
},
42+
{
43+
name: "empty string should skip patching",
44+
patchJSON: "",
45+
expectError: false,
46+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
47+
// Should remain unchanged
48+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
49+
},
50+
},
51+
{
52+
name: "patch service account name",
53+
patchJSON: `{"serviceAccountName": "custom-sa"}`,
54+
expectError: false,
55+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
56+
assert.Equal(t, "custom-sa", wf.Spec.ServiceAccountName)
57+
},
58+
},
59+
{
60+
name: "patch node selector",
61+
patchJSON: `{"nodeSelector": {"node-type": "gpu", "zone": "us-west1"}}`,
62+
expectError: false,
63+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
64+
assert.Equal(t, "gpu", wf.Spec.NodeSelector["node-type"])
65+
assert.Equal(t, "us-west1", wf.Spec.NodeSelector["zone"])
66+
},
67+
},
68+
{
69+
name: "patch multiple fields",
70+
patchJSON: `{
71+
"serviceAccountName": "gpu-runner",
72+
"nodeSelector": {"accelerator": "nvidia-tesla-k80"},
73+
"tolerations": [
74+
{"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"}
75+
]
76+
}`,
77+
expectError: false,
78+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
79+
assert.Equal(t, "gpu-runner", wf.Spec.ServiceAccountName)
80+
assert.Equal(t, "nvidia-tesla-k80", wf.Spec.NodeSelector["accelerator"])
81+
assert.Len(t, wf.Spec.Tolerations, 1)
82+
assert.Equal(t, "nvidia.com/gpu", wf.Spec.Tolerations[0].Key)
83+
assert.Equal(t, corev1.TolerationOperator("Exists"), wf.Spec.Tolerations[0].Operator)
84+
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
85+
},
86+
},
87+
{
88+
name: "patch pod metadata",
89+
patchJSON: `{
90+
"podMetadata": {
91+
"labels": {"env": "prod", "team": "ml"},
92+
"annotations": {"monitoring": "enabled"}
93+
}
94+
}`,
95+
expectError: false,
96+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
97+
assert.NotNil(t, wf.Spec.PodMetadata)
98+
assert.Equal(t, "prod", wf.Spec.PodMetadata.Labels["env"])
99+
assert.Equal(t, "ml", wf.Spec.PodMetadata.Labels["team"])
100+
assert.Equal(t, "enabled", wf.Spec.PodMetadata.Annotations["monitoring"])
101+
},
102+
},
103+
{
104+
name: "merge with existing pod metadata",
105+
patchJSON: `{
106+
"podMetadata": {
107+
"labels": {"new-label": "new-value"}
108+
}
109+
}`,
110+
expectError: false,
111+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
112+
assert.NotNil(t, wf.Spec.PodMetadata)
113+
// Should merge with existing labels
114+
assert.Equal(t, "existing-value", wf.Spec.PodMetadata.Labels["existing-label"])
115+
assert.Equal(t, "new-value", wf.Spec.PodMetadata.Labels["new-label"])
116+
},
117+
},
118+
{
119+
name: "invalid JSON should return error",
120+
patchJSON: `{"invalid": json}`,
121+
expectError: true,
122+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
123+
// Should remain unchanged
124+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
125+
},
126+
},
127+
{
128+
name: "malformed JSON should return error",
129+
patchJSON: `{incomplete`,
130+
expectError: true,
131+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
132+
// Should remain unchanged
133+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
134+
},
135+
},
136+
}
137+
138+
for _, tt := range tests {
139+
t.Run(tt.name, func(t *testing.T) {
140+
// Create a test workflow with some initial values
141+
wf := &wfapi.Workflow{
142+
ObjectMeta: metav1.ObjectMeta{
143+
Name: "test-workflow",
144+
},
145+
Spec: wfapi.WorkflowSpec{
146+
ServiceAccountName: "original-sa",
147+
PodMetadata: &wfapi.Metadata{
148+
Labels: map[string]string{
149+
"existing-label": "existing-value",
150+
},
151+
},
152+
},
153+
}
154+
155+
// Create a workflow compiler with the test workflow
156+
compiler := &workflowCompiler{
157+
wf: wf,
158+
}
159+
160+
// Apply the patch
161+
err := compiler.ApplyWorkflowSpecPatch(tt.patchJSON)
162+
163+
// Check error expectation
164+
if tt.expectError {
165+
assert.Error(t, err)
166+
} else {
167+
assert.NoError(t, err)
168+
}
169+
170+
// Run validation function
171+
if tt.validateFunc != nil {
172+
tt.validateFunc(t, wf)
173+
}
174+
})
175+
}
176+
}
177+
178+
func TestApplyWorkflowSpecPatch_NilWorkflow(t *testing.T) {
179+
compiler := &workflowCompiler{
180+
wf: nil,
181+
}
182+
183+
err := compiler.ApplyWorkflowSpecPatch(`{"serviceAccountName": "test"}`)
184+
assert.Error(t, err)
185+
assert.Contains(t, err.Error(), "workflow is nil")
186+
}
187+
188+
func TestApplyWorkflowSpecPatch_ComplexPatch(t *testing.T) {
189+
// Test a more complex patch with various field types
190+
wf := &wfapi.Workflow{
191+
ObjectMeta: metav1.ObjectMeta{
192+
Name: "complex-test-workflow",
193+
},
194+
Spec: wfapi.WorkflowSpec{
195+
ServiceAccountName: "default",
196+
},
197+
}
198+
199+
compiler := &workflowCompiler{
200+
wf: wf,
201+
}
202+
203+
complexPatch := `{
204+
"serviceAccountName": "complex-sa",
205+
"activeDeadlineSeconds": 3600,
206+
"parallelism": 5,
207+
"nodeSelector": {
208+
"kubernetes.io/arch": "amd64",
209+
"node-pool": "compute"
210+
},
211+
"tolerations": [
212+
{
213+
"key": "dedicated",
214+
"operator": "Equal",
215+
"value": "ml-workload",
216+
"effect": "NoSchedule"
217+
}
218+
],
219+
"securityContext": {
220+
"runAsUser": 1000,
221+
"runAsGroup": 1000,
222+
"fsGroup": 1000
223+
},
224+
"hostNetwork": false,
225+
"dnsPolicy": "ClusterFirst"
226+
}`
227+
228+
err := compiler.ApplyWorkflowSpecPatch(complexPatch)
229+
assert.NoError(t, err)
230+
231+
// Validate all patched fields
232+
assert.Equal(t, "complex-sa", wf.Spec.ServiceAccountName)
233+
assert.Equal(t, int64(3600), *wf.Spec.ActiveDeadlineSeconds)
234+
assert.Equal(t, int64(5), *wf.Spec.Parallelism)
235+
assert.Equal(t, "amd64", wf.Spec.NodeSelector["kubernetes.io/arch"])
236+
assert.Equal(t, "compute", wf.Spec.NodeSelector["node-pool"])
237+
assert.Len(t, wf.Spec.Tolerations, 1)
238+
assert.Equal(t, "dedicated", wf.Spec.Tolerations[0].Key)
239+
assert.Equal(t, corev1.TolerationOperator("Equal"), wf.Spec.Tolerations[0].Operator)
240+
assert.Equal(t, "ml-workload", wf.Spec.Tolerations[0].Value)
241+
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
242+
assert.NotNil(t, wf.Spec.SecurityContext)
243+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsUser)
244+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsGroup)
245+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.FSGroup)
246+
assert.Equal(t, false, *wf.Spec.HostNetwork)
247+
assert.Equal(t, corev1.DNSClusterFirst, *wf.Spec.DNSPolicy)
248+
}

0 commit comments

Comments
 (0)