Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,15 @@ private static AssignmentPlan mergePlans(
return finalPlanBuilder.build();
}

private static void copyAssignments(
AssignmentPlan source,
AssignmentPlan.Builder dest,
Map<String, AssignmentPlan.Node> originalNodeById
) {
for (AssignmentPlan.Deployment m : source.deployments()) {
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
dest.accountMemory(m, originalNode, requiredMemory);
}
/**
* Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
*/
static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
for (AssignmentPlan.Deployment deployment : source.deployments()) {
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, calling assignedModelToNode is enough to correctly account for allocated memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test for the copyAssignments method in TrainedModelAssignmentRebalancerTests

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,43 @@ public int compareTo(AssignmentPlan o) {
return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
}

/**
* Checks whether all deployments in the current {@link AssignmentPlan} have at least as many
* allocations as currently assigned.
*/
public boolean satisfiesCurrentAssignments() {
return deployments().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
}

/**
* Checks whether the current assignments for a given {@link Deployment} meet its allocation requirements.
*
* It ensures that the total number of allocations assigned to the deployment across all nodes is
* at least equal to the deployment's current assigned allocations.
*/
private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
if (m.currentAllocationsByNodeId().isEmpty()) {
return true;
}
Map<Node, Integer> nodeAssignments = assignments.get(m);
int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
return currentAllocations >= m.getCurrentAssignedAllocations();
int inPlanAssignedAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
return inPlanAssignedAllocations >= m.getCurrentAssignedAllocations();
}

public boolean satisfiesAllocations(Deployment m) {
return remainingModelAllocations.getOrDefault(m, 0) == 0;
/**
* Checks if the current assignments satisfy the deployment's allocation requirements.
* @param deployment the deployment to check
* @return true if the current assignments satisfy the deployment's allocation requirements, false otherwise
*/
public boolean satisfiesAllocations(Deployment deployment) {
return remainingModelAllocations.getOrDefault(deployment, 0) == 0;
}

/**
* Checks if the current assignments satisfy all deployments' allocation requirements. This means that
* each deployment has no remaining allocations left to assign.
* @return true if the current assignments satisfy the deployments' allocation requirements, false otherwise
*/
public boolean satisfiesAllModels() {
return deployments().stream().allMatch(this::satisfiesAllocations);
}
Expand Down Expand Up @@ -424,8 +444,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
if (allocations <= 0) {
return this;
}
if (/*isAlreadyAssigned(deployment, node) == false
&&*/ requiredMemory > remainingNodeMemory.get(node)) {
if (requiredMemory > remainingNodeMemory.get(node)) {
throw new IllegalArgumentException(
"not enough memory on node ["
+ node.id()
Expand All @@ -450,7 +469,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
);
}

assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
accountMemory(deployment, node, requiredMemory);

if (deployment.priority == Priority.NORMAL) {
Expand All @@ -461,23 +480,10 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
}

private int getAssignedAllocations(Deployment deployment, Node node) {
int currentAllocations = getCurrentAllocations(deployment, node);
int assignmentAllocations = assignments.get(deployment).get(node);
return currentAllocations + assignmentAllocations;
}

private static int getCurrentAllocations(Deployment m, Node n) {
return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
}

public void accountMemory(Deployment m, Node n) {
// TODO (#101612) remove or refactor unused method
long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
accountMemory(m, n, requiredMemory);
return assignments.get(deployment).get(node);
}

public void accountMemory(Deployment m, Node n, long requiredMemory) {
// TODO (#101612) computation of required memory should be done internally
remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.deploymentId() + "]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,26 @@ public AssignmentPlan computePlan() {
return computePlan(true);
}

public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
/**
* Computes an {@link AssignmentPlan} for the given nodes and deployments.
* If {@code tryAssigningAllPreviouslyAllocatedModels} is true, then the plan will
* attempt to assign at least one allocation to previously assigned models.
* Otherwise, it will only ensure that deployments assigned to existing nodes will preserve at least one allocation
*
* @param tryAssigningAllPreviouslyAllocatedModels whether to do the best effort assigning previously assigned models somewhere
* with at least one allocation
* @return the computed assignment plan
*/
public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
logger.debug(() -> format("Computing plan for nodes = %s; deployments = %s", nodes, deployments));

AssignmentPlan bestPlan;
AssignmentPlan planSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() == false && tryAssigningPreviouslyAssignedModels) {
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || tryAssigningAllPreviouslyAllocatedModels == false) {
bestPlan = planSatisfyingCurrentAssignments;
} else {
// try to reuse any deployment that would otherwise drop to zero allocations
AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated =
solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
logger.debug(
Expand All @@ -82,28 +95,37 @@ public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels)
? planSatisfyingCurrentAssignments
: planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
}
} else {
bestPlan = planSatisfyingCurrentAssignments;
}

logger.debug(() -> "Best plan =\n" + bestPlan.prettyPrint());
logger.debug(() -> prettyPrintOverallStats(bestPlan));
return bestPlan;
}

/**
* Computes the best assignment plan from two strategies:
* 1. Preserving one allocation on current assignments, which is the most flexible
* 2. Preserving all allocations on current assignments, which is more conservative
* @return the best assignment plan
*/
private AssignmentPlan solveSatisfyingCurrentAssignments() {
AssignmentPlan bestPlan;
// First solve preserving one allocation per assignment because that is most flexible
AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {

if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
// If the plan satisfies all models, then we can use it as is
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
// If in the new assignment plan, some deployments have fewer allocations than in the current assignments,
// try explicitly preserving all allocations on current assignments.
bestPlan = solvePreservingAllAllocationsOnCurrentAssignments();
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() == false) {
} else {
// Choose the best strategy according to {@link AssignmentPlan#computeQuality(AssignmentPlan)}
AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
bestPlan = planKeepingAllAllocationsOnCurrentAssignments.compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0
? planKeepingAllAllocationsOnCurrentAssignments
: planKeepingOneAllocationOnCurrentAssignments;
} else {
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
}
return bestPlan;
}
Expand All @@ -120,7 +142,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
1,
m.threadsPerAllocation(),
// don't rely on the current allocation
new HashMap<>(),
Map.of(),
m.maxAssignedAllocations(),
m.getAdaptiveAllocationsSettings(),
m.perDeploymentMemoryBytes(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,36 +182,39 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
plan = preserveAllAllocations.mergePreservedAllocations(plan);
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
}

private AssignmentPlan swapOriginalModelsInPlan(
/**
* The method is responsible for reconstructing an AssignmentPlan
* by replacing the deployments and nodes in the given plan with their original counterparts.
* This ensures that the final plan uses the original objects while preserving the assignments
* and memory accounting from the input plan.
*
* @param plan AssignmentPlan to reconstruct with original models and nodes
* @param allNodes List of all nodes in the system, used to find original nodes
* @param planDeployments List of deployments in the plan, not the original deployments
* @return final plan with original models and nodes swapped in
*/
private AssignmentPlan swapOriginalDeploymentsInPlan(
AssignmentPlan plan,
List<Node> allNodes,
List<AssignmentPlan.Deployment> planDeployments
) {
final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
.collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, Function.identity()));
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
for (AssignmentPlan.Deployment m : planDeployments) {
AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.deploymentId());
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);

for (AssignmentPlan.Deployment planDeployment : planDeployments) {
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId());
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
Node originalNode = originalNodeById.get(assignment.getKey().id());
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
);
planBuilder.accountMemory(m, originalNode, requiredMemory);
}
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, calling assignedModelToNode is enough to correctly account for allocated memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make AssignmentPlan::accountMemory() a private method so it is not accidentally called by one of the planners

}
}
return planBuilder.build();
return finalPlanBuilder.build();
}

private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.job.NodeLoad;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.anEmptyMap;
Expand Down Expand Up @@ -1127,6 +1130,74 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() {
assertThat(assignment.getReason().isPresent(), is(false));
}

public void testCopyAssignments() {
// Create test nodes
AssignmentPlan.Node node1 = new AssignmentPlan.Node("node-1", ByteSizeValue.ofGb(1).getBytes(), 4);
AssignmentPlan.Node node2 = new AssignmentPlan.Node("node-2", ByteSizeValue.ofGb(1).getBytes(), 8);
List<AssignmentPlan.Node> nodes = List.of(node1, node2);

// Create test deployments
AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
"deployment-1",
"model-1",
ByteSizeValue.ofMb(100).getBytes(),
2,
1,
Map.of(),
0,
null,
Priority.NORMAL,
0,
0
);
AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
"deployment-2",
"model-2",
ByteSizeValue.ofMb(100).getBytes(),
1,
2,
Map.of(),
0,
null,
Priority.LOW,
0,
0
);
List<AssignmentPlan.Deployment> deployments = List.of(deployment1, deployment2);

// Create source plan and assign models to nodes
AssignmentPlan.Builder sourceBuilder = AssignmentPlan.builder(nodes, deployments);
sourceBuilder.assignModelToNode(deployment1, node1, 1);
sourceBuilder.assignModelToNode(deployment1, node2, 1);
sourceBuilder.assignModelToNode(deployment2, node2, 1);
AssignmentPlan source = sourceBuilder.build();

// Create destination plan
AssignmentPlan.Builder dest = AssignmentPlan.builder(nodes, deployments);

// Create map of node IDs to original nodes
Map<String, AssignmentPlan.Node> originalNodeById = nodes.stream()
.collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));

// Call copyAssignments
TrainedModelAssignmentRebalancer.copyAssignments(source, dest, originalNodeById);

// Build the destination plan
AssignmentPlan result = dest.build();

// Verify assignments
Optional<Map<AssignmentPlan.Node, Integer>> deployment1Assignments = result.assignments(deployment1);
assertThat(deployment1Assignments.isPresent(), is(true));
assertThat(deployment1Assignments.get().size(), equalTo(2));
assertThat(deployment1Assignments.get().get(node1), equalTo(1));
assertThat(deployment1Assignments.get().get(node2), equalTo(1));

Optional<Map<AssignmentPlan.Node, Integer>> deployment2Assignments = result.assignments(deployment2);
assertThat(deployment2Assignments.isPresent(), is(true));
assertThat(deployment2Assignments.get().size(), equalTo(1));
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
}

private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
return lowPriorityParams(deploymentId, deploymentId, modelSize);
}
Expand Down
Loading