Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/src/apiserver/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
CaBundleSecretName string = "CABUNDLE_SECRET_NAME"
RequireNamespaceForPipelines string = "REQUIRE_NAMESPACE_FOR_PIPELINES"
CompiledPipelineSpecPatch string = "COMPILED_PIPELINE_SPEC_PATCH"
)

func IsPipelineVersionUpdatedByDefault() bool {
Expand Down Expand Up @@ -142,3 +143,7 @@ func GetMetadataTLSEnabled() bool {
func GetCaBundleSecretName() string {
return GetStringConfigWithDefault(CaBundleSecretName, "")
}

func GetCompiledPipelineSpecPatch() string {
return GetStringConfigWithDefault(CompiledPipelineSpecPatch, "{}")
}
16 changes: 12 additions & 4 deletions backend/src/v2/compiler/argocompiler/argo.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
k8score "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
k8sres "k8s.io/apimachinery/pkg/api/resource"
k8smeta "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand Down Expand Up @@ -185,8 +184,17 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S

// compile
err = compiler.Accept(job, kubernetesSpec, c)
if err != nil {
return nil, err
}

// Apply any workflow spec patches from environment variable
patchJSON := common.GetCompiledPipelineSpecPatch()
if err := c.ApplyWorkflowSpecPatch(patchJSON); err != nil {
return nil, fmt.Errorf("failed to apply workflow spec patch: %w", err)
}

return c.wf, err
return c.wf, nil
}

func retrieveLastValidString(s string) string {
Expand Down Expand Up @@ -534,15 +542,15 @@ func GetWorkspacePVC(
}
}

quantity, err := resource.ParseQuantity(sizeStr)
quantity, err := k8sres.ParseQuantity(sizeStr)
if err != nil {
return k8score.PersistentVolumeClaim{}, fmt.Errorf("invalid size value for workspace PVC: %v", err)
}
if quantity.Sign() < 0 {
return k8score.PersistentVolumeClaim{}, fmt.Errorf("negative size value for workspace PVC: %v", sizeStr)
}
if pvcSpec.Resources.Requests == nil {
pvcSpec.Resources.Requests = make(map[k8score.ResourceName]resource.Quantity)
pvcSpec.Resources.Requests = make(map[k8score.ResourceName]k8sres.Quantity)
}
pvcSpec.Resources.Requests[k8score.ResourceStorage] = quantity

Expand Down
55 changes: 55 additions & 0 deletions backend/src/v2/compiler/argocompiler/spec_patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
package argocompiler

import (
"encoding/json"
"fmt"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/structpb"
"k8s.io/apimachinery/pkg/util/strategicpatch"
)

func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *structpb.Struct) error {
Expand All @@ -25,3 +31,52 @@ func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *struct
}
return nil
}

// ApplyWorkflowSpecPatch applies a JSON patch to the compiled workflow specification.
// It validates the JSON and applies it using Kubernetes strategic merge patch.
// Only the workflow's "spec" field can be patched for security reasons.
func (c *workflowCompiler) ApplyWorkflowSpecPatch(patchJSON string) error {
if c.wf == nil {
return fmt.Errorf("workflow is nil")
}

// Check for empty patch string
if patchJSON == "" {
log.Debug("Empty workflow spec patch string provided, skipping patching")
return nil
}

log.Debug("Applying workflow spec patch")

// Validate that the patch is valid JSON by attempting to unmarshal it
var specPatchValidation map[string]interface{}
if err := json.Unmarshal([]byte(patchJSON), &specPatchValidation); err != nil {
return fmt.Errorf("invalid JSON in COMPILED_PIPELINE_SPEC_PATCH: %w", err)
}

// Check if the patch is empty (no fields to patch)
if len(specPatchValidation) == 0 {
log.Debug("Empty workflow spec patch provided, skipping patching")
return nil
}

// Convert the current workflow spec to JSON
originalSpecJSON, err := json.Marshal(c.wf.Spec)
if err != nil {
return fmt.Errorf("failed to marshal workflow spec to JSON: %w", err)
}

// Apply the strategic merge patch to the spec directly
patchedSpecJSON, err := strategicpatch.StrategicMergePatch(originalSpecJSON, []byte(patchJSON), wfapi.WorkflowSpec{})
if err != nil {
return fmt.Errorf("failed to apply strategic merge patch to workflow spec: %w", err)
}

// Unmarshal the patched spec back into the workflow
if err := json.Unmarshal(patchedSpecJSON, &c.wf.Spec); err != nil {
return fmt.Errorf("failed to unmarshal patched workflow spec: %w", err)
}

log.Debug("Successfully applied workflow spec patch")
return nil
}
248 changes: 248 additions & 0 deletions backend/src/v2/compiler/argocompiler/spec_patch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright 2025 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package argocompiler

import (
"testing"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestApplyWorkflowSpecPatch(t *testing.T) {
tests := []struct {
name string
patchJSON string
expectError bool
validateFunc func(t *testing.T, wf *wfapi.Workflow)
}{
{
name: "empty patch should skip patching",
patchJSON: "{}",
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
// Should remain unchanged
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
},
},
{
name: "empty string should skip patching",
patchJSON: "",
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
// Should remain unchanged
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
},
},
{
name: "patch service account name",
patchJSON: `{"serviceAccountName": "custom-sa"}`,
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
assert.Equal(t, "custom-sa", wf.Spec.ServiceAccountName)
},
},
{
name: "patch node selector",
patchJSON: `{"nodeSelector": {"node-type": "gpu", "zone": "us-west1"}}`,
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
assert.Equal(t, "gpu", wf.Spec.NodeSelector["node-type"])
assert.Equal(t, "us-west1", wf.Spec.NodeSelector["zone"])
},
},
{
name: "patch multiple fields",
patchJSON: `{
"serviceAccountName": "gpu-runner",
"nodeSelector": {"accelerator": "nvidia-tesla-k80"},
"tolerations": [
{"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"}
]
}`,
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
assert.Equal(t, "gpu-runner", wf.Spec.ServiceAccountName)
assert.Equal(t, "nvidia-tesla-k80", wf.Spec.NodeSelector["accelerator"])
assert.Len(t, wf.Spec.Tolerations, 1)
assert.Equal(t, "nvidia.com/gpu", wf.Spec.Tolerations[0].Key)
assert.Equal(t, corev1.TolerationOperator("Exists"), wf.Spec.Tolerations[0].Operator)
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
},
},
{
name: "patch pod metadata",
patchJSON: `{
"podMetadata": {
"labels": {"env": "prod", "team": "ml"},
"annotations": {"monitoring": "enabled"}
}
}`,
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
assert.NotNil(t, wf.Spec.PodMetadata)
assert.Equal(t, "prod", wf.Spec.PodMetadata.Labels["env"])
assert.Equal(t, "ml", wf.Spec.PodMetadata.Labels["team"])
assert.Equal(t, "enabled", wf.Spec.PodMetadata.Annotations["monitoring"])
},
},
{
name: "merge with existing pod metadata",
patchJSON: `{
"podMetadata": {
"labels": {"new-label": "new-value"}
}
}`,
expectError: false,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
assert.NotNil(t, wf.Spec.PodMetadata)
// Should merge with existing labels
assert.Equal(t, "existing-value", wf.Spec.PodMetadata.Labels["existing-label"])
assert.Equal(t, "new-value", wf.Spec.PodMetadata.Labels["new-label"])
},
},
{
name: "invalid JSON should return error",
patchJSON: `{"invalid": json}`,
expectError: true,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
// Should remain unchanged
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
},
},
{
name: "malformed JSON should return error",
patchJSON: `{incomplete`,
expectError: true,
validateFunc: func(t *testing.T, wf *wfapi.Workflow) {
// Should remain unchanged
assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test workflow with some initial values
wf := &wfapi.Workflow{
ObjectMeta: metav1.ObjectMeta{
Name: "test-workflow",
},
Spec: wfapi.WorkflowSpec{
ServiceAccountName: "original-sa",
PodMetadata: &wfapi.Metadata{
Labels: map[string]string{
"existing-label": "existing-value",
},
},
},
}

// Create a workflow compiler with the test workflow
compiler := &workflowCompiler{
wf: wf,
}

// Apply the patch
err := compiler.ApplyWorkflowSpecPatch(tt.patchJSON)

// Check error expectation
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}

// Run validation function
if tt.validateFunc != nil {
tt.validateFunc(t, wf)
}
})
}
}

func TestApplyWorkflowSpecPatch_NilWorkflow(t *testing.T) {
compiler := &workflowCompiler{
wf: nil,
}

err := compiler.ApplyWorkflowSpecPatch(`{"serviceAccountName": "test"}`)
assert.Error(t, err)
assert.Contains(t, err.Error(), "workflow is nil")
}

func TestApplyWorkflowSpecPatch_ComplexPatch(t *testing.T) {
// Test a more complex patch with various field types
wf := &wfapi.Workflow{
ObjectMeta: metav1.ObjectMeta{
Name: "complex-test-workflow",
},
Spec: wfapi.WorkflowSpec{
ServiceAccountName: "default",
},
}

compiler := &workflowCompiler{
wf: wf,
}

complexPatch := `{
"serviceAccountName": "complex-sa",
"activeDeadlineSeconds": 3600,
"parallelism": 5,
"nodeSelector": {
"kubernetes.io/arch": "amd64",
"node-pool": "compute"
},
"tolerations": [
{
"key": "dedicated",
"operator": "Equal",
"value": "ml-workload",
"effect": "NoSchedule"
}
],
"securityContext": {
"runAsUser": 1000,
"runAsGroup": 1000,
"fsGroup": 1000
},
"hostNetwork": false,
"dnsPolicy": "ClusterFirst"
}`

err := compiler.ApplyWorkflowSpecPatch(complexPatch)
assert.NoError(t, err)

// Validate all patched fields
assert.Equal(t, "complex-sa", wf.Spec.ServiceAccountName)
assert.Equal(t, int64(3600), *wf.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int64(5), *wf.Spec.Parallelism)
assert.Equal(t, "amd64", wf.Spec.NodeSelector["kubernetes.io/arch"])
assert.Equal(t, "compute", wf.Spec.NodeSelector["node-pool"])
assert.Len(t, wf.Spec.Tolerations, 1)
assert.Equal(t, "dedicated", wf.Spec.Tolerations[0].Key)
assert.Equal(t, corev1.TolerationOperator("Equal"), wf.Spec.Tolerations[0].Operator)
assert.Equal(t, "ml-workload", wf.Spec.Tolerations[0].Value)
assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect)
assert.NotNil(t, wf.Spec.SecurityContext)
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsUser)
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsGroup)
assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.FSGroup)
assert.Equal(t, false, *wf.Spec.HostNetwork)
assert.Equal(t, corev1.DNSClusterFirst, *wf.Spec.DNSPolicy)
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ spec:
value: ghcr.io/kubeflow/kfp-driver:2.14.3
- name: V2_LAUNCHER_IMAGE
value: ghcr.io/kubeflow/kfp-launcher:2.14.3
# JSON patch to apply to compiled workflow specifications
- name: COMPILED_PIPELINE_SPEC_PATCH
value: "{}"
image: ghcr.io/kubeflow/kfp-api-server:dummy
imagePullPolicy: IfNotPresent
name: ml-pipeline-api-server
Expand Down
Loading