Skip to content

Commit 5c0c47e

Browse files
authored
feat(backend): Patch pipeline spec from APIServer EnvVar (#12317)
* feat(backend): Pipeline specpatch from APIServer envvar - Allows a KFP admin to apply a custom specpatch to all incoming pipelines - Useful for installment-wide configuration, such as workflow ttl - Take in a JSON string provided via the environment variable `COMPILED_PIPELINE_SPEC_PATCH` Signed-off-by: Giulio Frasca <[email protected]> * test: Add unit tests for Pipeline SpecPatch function Signed-off-by: Giulio Frasca <[email protected]> * chore: linter fixes Signed-off-by: Giulio Frasca <[email protected]> --------- Signed-off-by: Giulio Frasca <[email protected]>
1 parent c9fba37 commit 5c0c47e

File tree

5 files changed

+323
-4
lines changed

5 files changed

+323
-4
lines changed

backend/src/apiserver/common/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const (
3535
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
3636
CaBundleSecretName string = "CABUNDLE_SECRET_NAME"
3737
RequireNamespaceForPipelines string = "REQUIRE_NAMESPACE_FOR_PIPELINES"
38+
CompiledPipelineSpecPatch string = "COMPILED_PIPELINE_SPEC_PATCH"
3839
)
3940

4041
func IsPipelineVersionUpdatedByDefault() bool {
@@ -142,3 +143,7 @@ func GetMetadataTLSEnabled() bool {
142143
func GetCaBundleSecretName() string {
143144
return GetStringConfigWithDefault(CaBundleSecretName, "")
144145
}
146+
147+
func GetCompiledPipelineSpecPatch() string {
148+
return GetStringConfigWithDefault(CompiledPipelineSpecPatch, "{}")
149+
}

backend/src/v2/compiler/argocompiler/argo.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import (
3131
"google.golang.org/protobuf/proto"
3232
"google.golang.org/protobuf/types/known/structpb"
3333
k8score "k8s.io/api/core/v1"
34-
"k8s.io/apimachinery/pkg/api/resource"
3534
k8sres "k8s.io/apimachinery/pkg/api/resource"
3635
k8smeta "k8s.io/apimachinery/pkg/apis/meta/v1"
3736
)
@@ -185,8 +184,17 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S
185184

186185
// compile
187186
err = compiler.Accept(job, kubernetesSpec, c)
187+
if err != nil {
188+
return nil, err
189+
}
190+
191+
// Apply any workflow spec patches from environment variable
192+
patchJSON := common.GetCompiledPipelineSpecPatch()
193+
if err := c.ApplyWorkflowSpecPatch(patchJSON); err != nil {
194+
return nil, fmt.Errorf("failed to apply workflow spec patch: %w", err)
195+
}
188196

189-
return c.wf, err
197+
return c.wf, nil
190198
}
191199

192200
func retrieveLastValidString(s string) string {
@@ -534,15 +542,15 @@ func GetWorkspacePVC(
534542
}
535543
}
536544

537-
quantity, err := resource.ParseQuantity(sizeStr)
545+
quantity, err := k8sres.ParseQuantity(sizeStr)
538546
if err != nil {
539547
return k8score.PersistentVolumeClaim{}, fmt.Errorf("invalid size value for workspace PVC: %v", err)
540548
}
541549
if quantity.Sign() < 0 {
542550
return k8score.PersistentVolumeClaim{}, fmt.Errorf("negative size value for workspace PVC: %v", sizeStr)
543551
}
544552
if pvcSpec.Resources.Requests == nil {
545-
pvcSpec.Resources.Requests = make(map[k8score.ResourceName]resource.Quantity)
553+
pvcSpec.Resources.Requests = make(map[k8score.ResourceName]k8sres.Quantity)
546554
}
547555
pvcSpec.Resources.Requests[k8score.ResourceStorage] = quantity
548556

backend/src/v2/compiler/argocompiler/spec_patch.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
package argocompiler
1616

1717
import (
18+
"encoding/json"
19+
"fmt"
20+
21+
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
22+
log "github.com/sirupsen/logrus"
1823
"google.golang.org/protobuf/types/known/structpb"
24+
"k8s.io/apimachinery/pkg/util/strategicpatch"
1925
)
2026

2127
func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *structpb.Struct) error {
@@ -25,3 +31,52 @@ func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *struct
2531
}
2632
return nil
2733
}
34+
35+
// ApplyWorkflowSpecPatch applies a JSON patch to the compiled workflow specification.
36+
// It validates the JSON and applies it using Kubernetes strategic merge patch.
37+
// Only the workflow's "spec" field can be patched for security reasons.
38+
func (c *workflowCompiler) ApplyWorkflowSpecPatch(patchJSON string) error {
39+
if c.wf == nil {
40+
return fmt.Errorf("workflow is nil")
41+
}
42+
43+
// Check for empty patch string
44+
if patchJSON == "" {
45+
log.Debug("Empty workflow spec patch string provided, skipping patching")
46+
return nil
47+
}
48+
49+
log.Debug("Applying workflow spec patch")
50+
51+
// Validate that the patch is valid JSON by attempting to unmarshal it
52+
var specPatchValidation map[string]interface{}
53+
if err := json.Unmarshal([]byte(patchJSON), &specPatchValidation); err != nil {
54+
return fmt.Errorf("invalid JSON in COMPILED_PIPELINE_SPEC_PATCH: %w", err)
55+
}
56+
57+
// Check if the patch is empty (no fields to patch)
58+
if len(specPatchValidation) == 0 {
59+
log.Debug("Empty workflow spec patch provided, skipping patching")
60+
return nil
61+
}
62+
63+
// Convert the current workflow spec to JSON
64+
originalSpecJSON, err := json.Marshal(c.wf.Spec)
65+
if err != nil {
66+
return fmt.Errorf("failed to marshal workflow spec to JSON: %w", err)
67+
}
68+
69+
// Apply the strategic merge patch to the spec directly
70+
patchedSpecJSON, err := strategicpatch.StrategicMergePatch(originalSpecJSON, []byte(patchJSON), wfapi.WorkflowSpec{})
71+
if err != nil {
72+
return fmt.Errorf("failed to apply strategic merge patch to workflow spec: %w", err)
73+
}
74+
75+
// Unmarshal the patched spec back into the workflow
76+
if err := json.Unmarshal(patchedSpecJSON, &c.wf.Spec); err != nil {
77+
return fmt.Errorf("failed to unmarshal patched workflow spec: %w", err)
78+
}
79+
80+
log.Debug("Successfully applied workflow spec patch")
81+
return nil
82+
}
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package argocompiler
16+
17+
import (
18+
"testing"
19+
20+
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
21+
"github.com/stretchr/testify/assert"
22+
corev1 "k8s.io/api/core/v1"
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24+
)
25+
26+
func TestApplyWorkflowSpecPatch(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
patchJSON string
30+
expectError bool
31+
validateFunc func(t *testing.T, wf *wfapi.Workflow)
32+
}{
33+
{
34+
name: "empty patch should skip patching",
35+
patchJSON: "{}",
36+
expectError: false,
37+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
38+
// Should remain unchanged
39+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
40+
},
41+
},
42+
{
43+
name: "empty string should skip patching",
44+
patchJSON: "",
45+
expectError: false,
46+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
47+
// Should remain unchanged
48+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
49+
},
50+
},
51+
{
52+
name: "patch service account name",
53+
patchJSON: `{"serviceAccountName": "custom-sa"}`,
54+
expectError: false,
55+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
56+
assert.Equal(t, "custom-sa", wf.Spec.ServiceAccountName)
57+
},
58+
},
59+
{
60+
name: "patch node selector",
61+
patchJSON: `{"nodeSelector": {"node-type": "gpu", "zone": "us-west1"}}`,
62+
expectError: false,
63+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
64+
assert.Equal(t, "gpu", wf.Spec.NodeSelector["node-type"])
65+
assert.Equal(t, "us-west1", wf.Spec.NodeSelector["zone"])
66+
},
67+
},
68+
{
69+
name: "patch multiple fields",
70+
patchJSON: `{
71+
"serviceAccountName": "gpu-runner",
72+
"nodeSelector": {"accelerator": "nvidia-tesla-k80"},
73+
"tolerations": [
74+
{"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"}
75+
]
76+
}`,
77+
expectError: false,
78+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
79+
assert.Equal(t, "gpu-runner", wf.Spec.ServiceAccountName)
80+
assert.Equal(t, "nvidia-tesla-k80", wf.Spec.NodeSelector["accelerator"])
81+
assert.Len(t, wf.Spec.Tolerations, 1)
82+
assert.Equal(t, "nvidia.com/gpu", wf.Spec.Tolerations[0].Key)
83+
assert.Equal(t, corev1.TolerationOperator("Exists"), wf.Spec.Tolerations[0].Operator)
84+
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
85+
},
86+
},
87+
{
88+
name: "patch pod metadata",
89+
patchJSON: `{
90+
"podMetadata": {
91+
"labels": {"env": "prod", "team": "ml"},
92+
"annotations": {"monitoring": "enabled"}
93+
}
94+
}`,
95+
expectError: false,
96+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
97+
assert.NotNil(t, wf.Spec.PodMetadata)
98+
assert.Equal(t, "prod", wf.Spec.PodMetadata.Labels["env"])
99+
assert.Equal(t, "ml", wf.Spec.PodMetadata.Labels["team"])
100+
assert.Equal(t, "enabled", wf.Spec.PodMetadata.Annotations["monitoring"])
101+
},
102+
},
103+
{
104+
name: "merge with existing pod metadata",
105+
patchJSON: `{
106+
"podMetadata": {
107+
"labels": {"new-label": "new-value"}
108+
}
109+
}`,
110+
expectError: false,
111+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
112+
assert.NotNil(t, wf.Spec.PodMetadata)
113+
// Should merge with existing labels
114+
assert.Equal(t, "existing-value", wf.Spec.PodMetadata.Labels["existing-label"])
115+
assert.Equal(t, "new-value", wf.Spec.PodMetadata.Labels["new-label"])
116+
},
117+
},
118+
{
119+
name: "invalid JSON should return error",
120+
patchJSON: `{"invalid": json}`,
121+
expectError: true,
122+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
123+
// Should remain unchanged
124+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
125+
},
126+
},
127+
{
128+
name: "malformed JSON should return error",
129+
patchJSON: `{incomplete`,
130+
expectError: true,
131+
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
132+
// Should remain unchanged
133+
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
134+
},
135+
},
136+
}
137+
138+
for _, tt := range tests {
139+
t.Run(tt.name, func(t *testing.T) {
140+
// Create a test workflow with some initial values
141+
wf := &wfapi.Workflow{
142+
ObjectMeta: metav1.ObjectMeta{
143+
Name: "test-workflow",
144+
},
145+
Spec: wfapi.WorkflowSpec{
146+
ServiceAccountName: "original-sa",
147+
PodMetadata: &wfapi.Metadata{
148+
Labels: map[string]string{
149+
"existing-label": "existing-value",
150+
},
151+
},
152+
},
153+
}
154+
155+
// Create a workflow compiler with the test workflow
156+
compiler := &workflowCompiler{
157+
wf: wf,
158+
}
159+
160+
// Apply the patch
161+
err := compiler.ApplyWorkflowSpecPatch(tt.patchJSON)
162+
163+
// Check error expectation
164+
if tt.expectError {
165+
assert.Error(t, err)
166+
} else {
167+
assert.NoError(t, err)
168+
}
169+
170+
// Run validation function
171+
if tt.validateFunc != nil {
172+
tt.validateFunc(t, wf)
173+
}
174+
})
175+
}
176+
}
177+
178+
func TestApplyWorkflowSpecPatch_NilWorkflow(t *testing.T) {
179+
compiler := &workflowCompiler{
180+
wf: nil,
181+
}
182+
183+
err := compiler.ApplyWorkflowSpecPatch(`{"serviceAccountName": "test"}`)
184+
assert.Error(t, err)
185+
assert.Contains(t, err.Error(), "workflow is nil")
186+
}
187+
188+
func TestApplyWorkflowSpecPatch_ComplexPatch(t *testing.T) {
189+
// Test a more complex patch with various field types
190+
wf := &wfapi.Workflow{
191+
ObjectMeta: metav1.ObjectMeta{
192+
Name: "complex-test-workflow",
193+
},
194+
Spec: wfapi.WorkflowSpec{
195+
ServiceAccountName: "default",
196+
},
197+
}
198+
199+
compiler := &workflowCompiler{
200+
wf: wf,
201+
}
202+
203+
complexPatch := `{
204+
"serviceAccountName": "complex-sa",
205+
"activeDeadlineSeconds": 3600,
206+
"parallelism": 5,
207+
"nodeSelector": {
208+
"kubernetes.io/arch": "amd64",
209+
"node-pool": "compute"
210+
},
211+
"tolerations": [
212+
{
213+
"key": "dedicated",
214+
"operator": "Equal",
215+
"value": "ml-workload",
216+
"effect": "NoSchedule"
217+
}
218+
],
219+
"securityContext": {
220+
"runAsUser": 1000,
221+
"runAsGroup": 1000,
222+
"fsGroup": 1000
223+
},
224+
"hostNetwork": false,
225+
"dnsPolicy": "ClusterFirst"
226+
}`
227+
228+
err := compiler.ApplyWorkflowSpecPatch(complexPatch)
229+
assert.NoError(t, err)
230+
231+
// Validate all patched fields
232+
assert.Equal(t, "complex-sa", wf.Spec.ServiceAccountName)
233+
assert.Equal(t, int64(3600), *wf.Spec.ActiveDeadlineSeconds)
234+
assert.Equal(t, int64(5), *wf.Spec.Parallelism)
235+
assert.Equal(t, "amd64", wf.Spec.NodeSelector["kubernetes.io/arch"])
236+
assert.Equal(t, "compute", wf.Spec.NodeSelector["node-pool"])
237+
assert.Len(t, wf.Spec.Tolerations, 1)
238+
assert.Equal(t, "dedicated", wf.Spec.Tolerations[0].Key)
239+
assert.Equal(t, corev1.TolerationOperator("Equal"), wf.Spec.Tolerations[0].Operator)
240+
assert.Equal(t, "ml-workload", wf.Spec.Tolerations[0].Value)
241+
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
242+
assert.NotNil(t, wf.Spec.SecurityContext)
243+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsUser)
244+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsGroup)
245+
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.FSGroup)
246+
assert.Equal(t, false, *wf.Spec.HostNetwork)
247+
assert.Equal(t, corev1.DNSClusterFirst, *wf.Spec.DNSPolicy)
248+
}

manifests/kustomize/base/pipeline/ml-pipeline-apiserver-deployment.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ spec:
126126
value: ghcr.io/kubeflow/kfp-driver:2.14.3
127127
- name: V2_LAUNCHER_IMAGE
128128
value: ghcr.io/kubeflow/kfp-launcher:2.14.3
129+
# JSON patch to apply to compiled workflow specifications
130+
- name: COMPILED_PIPELINE_SPEC_PATCH
131+
value: "{}"
129132
image: ghcr.io/kubeflow/kfp-api-server:dummy
130133
imagePullPolicy: IfNotPresent
131134
name: ml-pipeline-api-server

0 commit comments

Comments
 (0)