@@ -11,6 +11,7 @@ import (
1111 "github.com/NexusGPU/tensor-fusion/internal/constants"
1212 "github.com/NexusGPU/tensor-fusion/internal/gpuallocator"
1313 "github.com/NexusGPU/tensor-fusion/internal/gpuallocator/filter"
14+ "github.com/NexusGPU/tensor-fusion/internal/utils"
1415 "github.com/samber/lo/mutable"
1516 corev1 "k8s.io/api/core/v1"
1617 errors "k8s.io/apimachinery/pkg/api/errors"
@@ -277,12 +278,15 @@ func (e *NodeExpander) simulateSchedulingWithoutGPU(ctx context.Context, pod *co
277278 return nil , fmt .Errorf ("pod labels is nil, pod: %s" , pod .Name )
278279 }
279280
280- // Disable the tensor fusion label to simulate scheduling without GPU plugins
281+ // Disable the tensor fusion component label to simulate scheduling without GPU plugins
281282 // NOTE: must apply patch after `go mod vendor`, FindNodesThatFitPod is not exported from Kubernetes
282283 // Run `git apply ./patches/scheduler-sched-one.patch` once or `bash scripts/patch-scheduler.sh`
283- pod .Labels [constants .TensorFusionEnabledLabelKey ] = constants .FalseStringValue
284+ if ! utils .IsTensorFusionPod (pod ) {
285+ return nil , fmt .Errorf ("pod to check expansion is not a tensor fusion worker pod: %s" , pod .Name )
286+ }
287+ delete (pod .Labels , constants .LabelComponent )
284288 scheduleResult , _ , err := e .scheduler .FindNodesThatFitPod (ctx , fwkInstance , state , pod )
285- pod .Labels [constants .TensorFusionEnabledLabelKey ] = constants .TrueStringValue
289+ pod .Labels [constants .LabelComponent ] = constants .ComponentWorker
286290 if len (scheduleResult ) == 0 {
287291 return nil , err
288292 }
@@ -382,32 +386,34 @@ func (e *NodeExpander) checkGPUFitForNewNode(pod *corev1.Pod, gpus []*tfv1.GPU)
382386
383387func (e * NodeExpander ) createGPUNodeClaim (ctx context.Context , pod * corev1.Pod , preparedNode * corev1.Node ) error {
384388 owners := preparedNode .GetOwnerReferences ()
389+ isKarpenterNodeClaim := false
390+ isGPUNodeClaim := false
385391 controlledBy := & metav1.OwnerReference {}
386392 for _ , owner := range owners {
387- if owner .Controller != nil && * owner .Controller {
388- controlledBy = & owner
393+ controlledBy = & owner
394+ // Karpenter owner reference is not controller reference
395+ if owner .Kind == constants .KarpenterNodeClaimKind {
396+ isKarpenterNodeClaim = true
397+ break
398+ } else if owner .Kind == tfv1 .GPUNodeClaimKind {
399+ isGPUNodeClaim = true
389400 break
390401 }
391402 }
392- if controlledBy . Kind == "" {
393- e .logger .Info ("node is not owned by any provisioner, skip expansion" , "node" , preparedNode .Name )
403+ if ! isKarpenterNodeClaim && ! isGPUNodeClaim {
404+ e .logger .Info ("node is not owned by any known provisioner, skip expansion" , "node" , preparedNode .Name )
394405 return nil
395406 }
396407 e .logger .Info ("start expanding node from existing template node" , "tmplNode" , preparedNode .Name )
397-
398- switch controlledBy .Kind {
399- case constants .KarpenterNodeClaimKind :
408+ if isKarpenterNodeClaim {
400409 // Check if controllerMeta's parent is GPUNodeClaim using unstructured object
401410 return e .handleKarpenterNodeClaim (ctx , pod , preparedNode , controlledBy )
402- case tfv1 . GPUNodeClaimKind :
411+ } else if isGPUNodeClaim {
403412 // Running in Provisioning mode, clone the parent GPUNodeClaim and apply
404413 e .logger .Info ("node is controlled by GPUNodeClaim, cloning another to expand node" , "tmplNode" , preparedNode .Name )
405414 return e .cloneGPUNodeClaim (ctx , pod , preparedNode , controlledBy )
406- default :
407- e .logger .Info ("node is not controlled by any known provisioner, skip expansion" , "tmplNode" , preparedNode .Name ,
408- "controller" , controlledBy .Kind )
409- return nil
410415 }
416+ return nil
411417}
412418
413419// handleKarpenterNodeClaim handles the case where the controller is a Karpenter NodeClaim
@@ -424,8 +430,12 @@ func (e *NodeExpander) handleKarpenterNodeClaim(ctx context.Context, pod *corev1
424430 // Check if the NodeClaim has owner references
425431 ownerRefs := nodeClaim .GetOwnerReferences ()
426432 var nodeClaimParent * metav1.OwnerReference
433+ hasNodePoolParent := false
427434
428435 for _ , owner := range ownerRefs {
436+ if owner .Kind == constants .KarpenterNodePoolKind {
437+ hasNodePoolParent = true
438+ }
429439 if owner .Controller != nil && * owner .Controller {
430440 nodeClaimParent = & owner
431441 break
@@ -437,13 +447,13 @@ func (e *NodeExpander) handleKarpenterNodeClaim(ctx context.Context, pod *corev1
437447 e .logger .Info ("NodeClaim parent is GPUNodeClaim, cloning another to expand node" ,
438448 "nodeClaimName" , controlledBy .Name , "gpuNodeClaimParent" , nodeClaimParent .Name )
439449 return e .cloneGPUNodeClaim (ctx , pod , preparedNode , nodeClaimParent )
440- } else if nodeClaimParent != nil {
441- // No GPUNodeClaim parent , create karpenter NodeClaim directly with special label identifier
450+ } else if hasNodePoolParent {
451+ // owned by Karpenter node pool , create NodeClaim directly with special label identifier
442452 e .logger .Info ("NodeClaim owned by Karpenter Pool, creating Karpenter NodeClaim to expand node" ,
443453 "nodeClaimName" , controlledBy .Name )
444454 return e .createKarpenterNodeClaimDirect (ctx , pod , preparedNode , nodeClaim )
445455 } else {
446- return fmt .Errorf ("NodeClaim has no parent, can not expand node, should not happen" )
456+ return fmt .Errorf ("NodeClaim has no valid parent, can not expand node, should not happen" )
447457 }
448458}
449459
0 commit comments