@@ -482,6 +482,86 @@ func extendPodSpecPatch(
482482 podSpec .Containers [0 ].VolumeMounts = append (podSpec .Containers [0 ].VolumeMounts , emptyDirVolumeMount )
483483 }
484484
485+ // Get node affinity information
486+ if nodeAffinityTerms := kubernetesExecutorConfig .GetNodeAffinity (); len (nodeAffinityTerms ) > 0 {
487+ var requiredTerms []k8score.NodeSelectorTerm
488+ var preferredTerms []k8score.PreferredSchedulingTerm
489+
490+ for i , nodeAffinityTerm := range nodeAffinityTerms {
491+ if nodeAffinityTerm .GetNodeAffinityJson () == nil &&
492+ len (nodeAffinityTerm .GetMatchExpressions ()) == 0 &&
493+ len (nodeAffinityTerm .GetMatchFields ()) == 0 {
494+ glog .Warningf ("NodeAffinityTerm %d is empty, skipping" , i )
495+ continue
496+ }
497+ if nodeAffinityTerm .GetNodeAffinityJson () != nil {
498+ var k8sNodeAffinity json.RawMessage
499+ err := resolveK8sJsonParameter (ctx , opts , dag , pipeline , mlmd ,
500+ nodeAffinityTerm .GetNodeAffinityJson (), inputParams , & k8sNodeAffinity )
501+ if err != nil {
502+ return fmt .Errorf ("failed to resolve node affinity json: %w" , err )
503+ }
504+
505+ var nodeAffinity k8score.NodeAffinity
506+ if err := json .Unmarshal (k8sNodeAffinity , & nodeAffinity ); err != nil {
507+ return fmt .Errorf ("failed to unmarshal node affinity json: %w" , err )
508+ }
509+
510+ if nodeAffinity .RequiredDuringSchedulingIgnoredDuringExecution != nil {
511+ requiredTerms = append (requiredTerms , nodeAffinity .RequiredDuringSchedulingIgnoredDuringExecution .NodeSelectorTerms ... )
512+ }
513+ preferredTerms = append (preferredTerms , nodeAffinity .PreferredDuringSchedulingIgnoredDuringExecution ... )
514+ } else {
515+ nodeSelectorTerm := k8score.NodeSelectorTerm {}
516+
517+ for _ , expr := range nodeAffinityTerm .GetMatchExpressions () {
518+ nodeSelectorRequirement := k8score.NodeSelectorRequirement {
519+ Key : expr .GetKey (),
520+ Operator : k8score .NodeSelectorOperator (expr .GetOperator ()),
521+ Values : expr .GetValues (),
522+ }
523+ nodeSelectorTerm .MatchExpressions = append (nodeSelectorTerm .MatchExpressions , nodeSelectorRequirement )
524+ }
525+
526+ for _ , field := range nodeAffinityTerm .GetMatchFields () {
527+ nodeSelectorRequirement := k8score.NodeSelectorRequirement {
528+ Key : field .GetKey (),
529+ Operator : k8score .NodeSelectorOperator (field .GetOperator ()),
530+ Values : field .GetValues (),
531+ }
532+ nodeSelectorTerm .MatchFields = append (nodeSelectorTerm .MatchFields , nodeSelectorRequirement )
533+ }
534+
535+ if nodeAffinityTerm .Weight != nil {
536+ preferredTerms = append (preferredTerms , k8score.PreferredSchedulingTerm {
537+ Weight : * nodeAffinityTerm .Weight ,
538+ Preference : nodeSelectorTerm ,
539+ })
540+ glog .V (4 ).Infof ("Added preferred node affinity: %+v" , nodeSelectorTerm )
541+ } else {
542+ requiredTerms = append (requiredTerms , nodeSelectorTerm )
543+ glog .V (4 ).Infof ("Added required node affinity: %+v" , nodeSelectorTerm )
544+ }
545+
546+ }
547+ }
548+
549+ if len (requiredTerms ) > 0 || len (preferredTerms ) > 0 {
550+ if podSpec .Affinity == nil {
551+ podSpec .Affinity = & k8score.Affinity {}
552+ }
553+ podSpec .Affinity .NodeAffinity = & k8score.NodeAffinity {}
554+ if len (requiredTerms ) > 0 {
555+ podSpec .Affinity .NodeAffinity .RequiredDuringSchedulingIgnoredDuringExecution = & k8score.NodeSelector {
556+ NodeSelectorTerms : requiredTerms ,
557+ }
558+ }
559+ if len (preferredTerms ) > 0 {
560+ podSpec .Affinity .NodeAffinity .PreferredDuringSchedulingIgnoredDuringExecution = preferredTerms
561+ }
562+ }
563+ }
564+
485565 return nil
486566}
487567
0 commit comments