diff --git a/backend/src/apiserver/common/config.go b/backend/src/apiserver/common/config.go index 273b16790c1..05aa2d5310c 100644 --- a/backend/src/apiserver/common/config.go +++ b/backend/src/apiserver/common/config.go @@ -35,6 +35,7 @@ const ( MetadataTLSEnabled string = "METADATA_TLS_ENABLED" CaBundleSecretName string = "CABUNDLE_SECRET_NAME" RequireNamespaceForPipelines string = "REQUIRE_NAMESPACE_FOR_PIPELINES" + CompiledPipelineSpecPatch string = "COMPILED_PIPELINE_SPEC_PATCH" ) func IsPipelineVersionUpdatedByDefault() bool { @@ -142,3 +143,7 @@ func GetMetadataTLSEnabled() bool { func GetCaBundleSecretName() string { return GetStringConfigWithDefault(CaBundleSecretName, "") } + +func GetCompiledPipelineSpecPatch() string { + return GetStringConfigWithDefault(CompiledPipelineSpecPatch, "{}") +} diff --git a/backend/src/v2/compiler/argocompiler/argo.go b/backend/src/v2/compiler/argocompiler/argo.go index 2ed9e135fac..207feaa9236 100644 --- a/backend/src/v2/compiler/argocompiler/argo.go +++ b/backend/src/v2/compiler/argocompiler/argo.go @@ -31,7 +31,6 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" k8score "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" k8sres "k8s.io/apimachinery/pkg/api/resource" k8smeta "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -185,8 +184,17 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S // compile err = compiler.Accept(job, kubernetesSpec, c) + if err != nil { + return nil, err + } + + // Apply any workflow spec patches from environment variable + patchJSON := common.GetCompiledPipelineSpecPatch() + if err := c.ApplyWorkflowSpecPatch(patchJSON); err != nil { + return nil, fmt.Errorf("failed to apply workflow spec patch: %w", err) + } - return c.wf, err + return c.wf, nil } func retrieveLastValidString(s string) string { @@ -534,7 +542,7 @@ func GetWorkspacePVC( } } - quantity, err := resource.ParseQuantity(sizeStr) + quantity, err := k8sres.ParseQuantity(sizeStr) if err != nil { return k8score.PersistentVolumeClaim{}, fmt.Errorf("invalid size value for workspace PVC: %v", err) } @@ -542,7 +550,7 @@ func GetWorkspacePVC( return k8score.PersistentVolumeClaim{}, fmt.Errorf("negative size value for workspace PVC: %v", sizeStr) } if pvcSpec.Resources.Requests == nil { - pvcSpec.Resources.Requests = make(map[k8score.ResourceName]resource.Quantity) + pvcSpec.Resources.Requests = make(map[k8score.ResourceName]k8sres.Quantity) } pvcSpec.Resources.Requests[k8score.ResourceStorage] = quantity diff --git a/backend/src/v2/compiler/argocompiler/spec_patch.go b/backend/src/v2/compiler/argocompiler/spec_patch.go index 06be1963ea8..af1c79f2c75 100644 --- a/backend/src/v2/compiler/argocompiler/spec_patch.go +++ b/backend/src/v2/compiler/argocompiler/spec_patch.go @@ -15,7 +15,13 @@ package argocompiler import ( + "encoding/json" + "fmt" + + wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/structpb" + "k8s.io/apimachinery/pkg/util/strategicpatch" ) func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *structpb.Struct) error { @@ -25,3 +31,52 @@ func (c *workflowCompiler) AddKubernetesSpec(name string, kubernetesSpec *struct } return nil } + +// ApplyWorkflowSpecPatch applies a JSON patch to the compiled workflow specification. +// It validates the JSON and applies it using Kubernetes strategic merge patch. +// Only the workflow's "spec" field can be patched for security reasons. +func (c *workflowCompiler) ApplyWorkflowSpecPatch(patchJSON string) error { + if c.wf == nil { + return fmt.Errorf("workflow is nil") + } + + // Check for empty patch string + if patchJSON == "" { + log.Debug("Empty workflow spec patch string provided, skipping patching") + return nil + } + + log.Debug("Applying workflow spec patch") + + // Validate that the patch is valid JSON by attempting to unmarshal it + var specPatchValidation map[string]interface{} + if err := json.Unmarshal([]byte(patchJSON), &specPatchValidation); err != nil { + return fmt.Errorf("invalid JSON in COMPILED_PIPELINE_SPEC_PATCH: %w", err) + } + + // Check if the patch is empty (no fields to patch) + if len(specPatchValidation) == 0 { + log.Debug("Empty workflow spec patch provided, skipping patching") + return nil + } + + // Convert the current workflow spec to JSON + originalSpecJSON, err := json.Marshal(c.wf.Spec) + if err != nil { + return fmt.Errorf("failed to marshal workflow spec to JSON: %w", err) + } + + // Apply the strategic merge patch to the spec directly + patchedSpecJSON, err := strategicpatch.StrategicMergePatch(originalSpecJSON, []byte(patchJSON), wfapi.WorkflowSpec{}) + if err != nil { + return fmt.Errorf("failed to apply strategic merge patch to workflow spec: %w", err) + } + + // Unmarshal the patched spec back into the workflow + if err := json.Unmarshal(patchedSpecJSON, &c.wf.Spec); err != nil { + return fmt.Errorf("failed to unmarshal patched workflow spec: %w", err) + } + + log.Debug("Successfully applied workflow spec patch") + return nil +} diff --git a/backend/src/v2/compiler/argocompiler/spec_patch_test.go b/backend/src/v2/compiler/argocompiler/spec_patch_test.go new file mode 100644 index 00000000000..836a0cb2399 --- /dev/null +++ b/backend/src/v2/compiler/argocompiler/spec_patch_test.go @@ -0,0 +1,248 @@ +// Copyright 2025 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package argocompiler + +import ( + "testing" + + wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestApplyWorkflowSpecPatch(t *testing.T) { + tests := []struct { + name string + patchJSON string + expectError bool + validateFunc func(t *testing.T, wf *wfapi.Workflow) + }{ + { + name: "empty patch should skip patching", + patchJSON: "{}", + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + // Should remain unchanged + assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName) + }, + }, + { + name: "empty string should skip patching", + patchJSON: "", + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + // Should remain unchanged + assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName) + }, + }, + { + name: "patch service account name", + patchJSON: `{"serviceAccountName": "custom-sa"}`, + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + assert.Equal(t, "custom-sa", wf.Spec.ServiceAccountName) + }, + }, + { + name: "patch node selector", + patchJSON: `{"nodeSelector": {"node-type": "gpu", "zone": "us-west1"}}`, + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + assert.Equal(t, "gpu", wf.Spec.NodeSelector["node-type"]) + assert.Equal(t, "us-west1", wf.Spec.NodeSelector["zone"]) + }, + }, + { + name: "patch multiple fields", + patchJSON: `{ + "serviceAccountName": "gpu-runner", + "nodeSelector": {"accelerator": "nvidia-tesla-k80"}, + "tolerations": [ + {"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"} + ] + }`, + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + assert.Equal(t, "gpu-runner", wf.Spec.ServiceAccountName) + assert.Equal(t, "nvidia-tesla-k80", wf.Spec.NodeSelector["accelerator"]) + assert.Len(t, wf.Spec.Tolerations, 1) + assert.Equal(t, "nvidia.com/gpu", wf.Spec.Tolerations[0].Key) + assert.Equal(t, corev1.TolerationOperator("Exists"), wf.Spec.Tolerations[0].Operator) + assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect) + }, + }, + { + name: "patch pod metadata", + patchJSON: `{ + "podMetadata": { + "labels": {"env": "prod", "team": "ml"}, + "annotations": {"monitoring": "enabled"} + } + }`, + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + assert.NotNil(t, wf.Spec.PodMetadata) + assert.Equal(t, "prod", wf.Spec.PodMetadata.Labels["env"]) + assert.Equal(t, "ml", wf.Spec.PodMetadata.Labels["team"]) + assert.Equal(t, "enabled", wf.Spec.PodMetadata.Annotations["monitoring"]) + }, + }, + { + name: "merge with existing pod metadata", + patchJSON: `{ + "podMetadata": { + "labels": {"new-label": "new-value"} + } + }`, + expectError: false, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + assert.NotNil(t, wf.Spec.PodMetadata) + // Should merge with existing labels + assert.Equal(t, "existing-value", wf.Spec.PodMetadata.Labels["existing-label"]) + assert.Equal(t, "new-value", wf.Spec.PodMetadata.Labels["new-label"]) + }, + }, + { + name: "invalid JSON should return error", + patchJSON: `{"invalid": json}`, + expectError: true, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + // Should remain unchanged + assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName) + }, + }, + { + name: "malformed JSON should return error", + patchJSON: `{incomplete`, + expectError: true, + validateFunc: func(t *testing.T, wf *wfapi.Workflow) { + // Should remain unchanged + assert.Equal(t, "original-sa", wf.Spec.ServiceAccountName) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test workflow with some initial values + wf := &wfapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-workflow", + }, + Spec: wfapi.WorkflowSpec{ + ServiceAccountName: "original-sa", + PodMetadata: &wfapi.Metadata{ + Labels: map[string]string{ + "existing-label": "existing-value", + }, + }, + }, + } + + // Create a workflow compiler with the test workflow + compiler := &workflowCompiler{ + wf: wf, + } + + // Apply the patch + err := compiler.ApplyWorkflowSpecPatch(tt.patchJSON) + + // Check error expectation + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Run validation function + if tt.validateFunc != nil { + tt.validateFunc(t, wf) + } + }) + } +} + +func TestApplyWorkflowSpecPatch_NilWorkflow(t *testing.T) { + compiler := &workflowCompiler{ + wf: nil, + } + + err := compiler.ApplyWorkflowSpecPatch(`{"serviceAccountName": "test"}`) + assert.Error(t, err) + assert.Contains(t, err.Error(), "workflow is nil") +} + +func TestApplyWorkflowSpecPatch_ComplexPatch(t *testing.T) { + // Test a more complex patch with various field types + wf := &wfapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Name: "complex-test-workflow", + }, + Spec: wfapi.WorkflowSpec{ + ServiceAccountName: "default", + }, + } + + compiler := &workflowCompiler{ + wf: wf, + } + + complexPatch := `{ + "serviceAccountName": "complex-sa", + "activeDeadlineSeconds": 3600, + "parallelism": 5, + "nodeSelector": { + "kubernetes.io/arch": "amd64", + "node-pool": "compute" + }, + "tolerations": [ + { + "key": "dedicated", + "operator": "Equal", + "value": "ml-workload", + "effect": "NoSchedule" + } + ], + "securityContext": { + "runAsUser": 1000, + "runAsGroup": 1000, + "fsGroup": 1000 + }, + "hostNetwork": false, + "dnsPolicy": "ClusterFirst" + }` + + err := compiler.ApplyWorkflowSpecPatch(complexPatch) + assert.NoError(t, err) + + // Validate all patched fields + assert.Equal(t, "complex-sa", wf.Spec.ServiceAccountName) + assert.Equal(t, int64(3600), *wf.Spec.ActiveDeadlineSeconds) + assert.Equal(t, int64(5), *wf.Spec.Parallelism) + assert.Equal(t, "amd64", wf.Spec.NodeSelector["kubernetes.io/arch"]) + assert.Equal(t, "compute", wf.Spec.NodeSelector["node-pool"]) + assert.Len(t, wf.Spec.Tolerations, 1) + assert.Equal(t, "dedicated", wf.Spec.Tolerations[0].Key) + assert.Equal(t, corev1.TolerationOperator("Equal"), wf.Spec.Tolerations[0].Operator) + assert.Equal(t, "ml-workload", wf.Spec.Tolerations[0].Value) + assert.Equal(t, corev1.TaintEffectNoSchedule, wf.Spec.Tolerations[0].Effect) + assert.NotNil(t, wf.Spec.SecurityContext) + assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsUser) + assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.RunAsGroup) + assert.Equal(t, int64(1000), *wf.Spec.SecurityContext.FSGroup) + assert.Equal(t, false, *wf.Spec.HostNetwork) + assert.Equal(t, corev1.DNSClusterFirst, *wf.Spec.DNSPolicy) +} diff --git a/manifests/kustomize/base/pipeline/ml-pipeline-apiserver-deployment.yaml b/manifests/kustomize/base/pipeline/ml-pipeline-apiserver-deployment.yaml index cd9f3dd2962..66787a81835 100644 --- a/manifests/kustomize/base/pipeline/ml-pipeline-apiserver-deployment.yaml +++ b/manifests/kustomize/base/pipeline/ml-pipeline-apiserver-deployment.yaml @@ -126,6 +126,9 @@ spec: value: ghcr.io/kubeflow/kfp-driver:2.14.3 - name: V2_LAUNCHER_IMAGE value: ghcr.io/kubeflow/kfp-launcher:2.14.3 + # JSON patch to apply to compiled workflow specifications + - name: COMPILED_PIPELINE_SPEC_PATCH + value: "{}" image: ghcr.io/kubeflow/kfp-api-server:dummy imagePullPolicy: IfNotPresent name: ml-pipeline-api-server