diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java index 54e121dee2..6d10533104 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java @@ -192,6 +192,18 @@ public interface ConstraintFactory { */ @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, @NonNull BiJoiner... joiners); + // ************************************************************************ + // staticData + //************************************************************************ + + /** + * Computes and caches the tuples that would be produced by the given stream. + * As this is cached, it is vital the stream does not reference any variables + * (genuine or otherwise). + */ + @NonNull Stream_ + staticData(StaticDataSupplier<@NonNull Stream_> staticDataSupplier); + // ************************************************************************ // from* (deprecated) // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataFactory.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataFactory.java new file mode 100644 index 0000000000..6f8763ad50 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataFactory.java @@ -0,0 +1,7 @@ +package ai.timefold.solver.core.api.score.stream; + +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; + +public interface StaticDataFactory { + UniConstraintStream forEachUnfiltered(Class sourceClass); +} diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataSupplier.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataSupplier.java new file mode 100644 index 0000000000..afbe6d9d01 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/StaticDataSupplier.java @@ -0,0 +1,8 @@ +package ai.timefold.solver.core.api.score.stream; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public interface StaticDataSupplier { + Stream_ get(StaticDataFactory dataFactory); +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java index 38b65835ae..04a766af15 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java @@ -3,43 +3,49 @@ import java.util.IdentityHashMap; import java.util.Map; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode.LifecycleOperation; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; public abstract class AbstractSession { private final NodeNetwork nodeNetwork; - private final Map, AbstractForEachUniNode[]> insertEffectiveClassToNodeArrayMap; - private final Map, AbstractForEachUniNode[]> updateEffectiveClassToNodeArrayMap; - private final Map, AbstractForEachUniNode[]> retractEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> insertEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> updateEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> retractEffectiveClassToNodeArrayMap; + private final BavetRootNode[] settleNodes; protected AbstractSession(NodeNetwork nodeNetwork) { this.nodeNetwork = nodeNetwork; this.insertEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); this.updateEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); this.retractEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); + this.settleNodes = nodeNetwork.getAllTupleSourceRootNodes() + .filter(node -> node.supports(LifecycleOperation.SETTLE)) + .toArray(BavetRootNode[]::new); } public final void insert(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.INSERT)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.INSERT)) { node.insert(fact); } } @SuppressWarnings("unchecked") - private AbstractForEachUniNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { + private BavetRootNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { var effectiveClassToNodeArrayMap = switch (lifecycleOperation) { case INSERT -> insertEffectiveClassToNodeArrayMap; case UPDATE -> updateEffectiveClassToNodeArrayMap; case RETRACT -> retractEffectiveClassToNodeArrayMap; + case SETTLE -> + throw new IllegalArgumentException("impossible state: findNodes should not be called for settle nodes"); }; // Map.computeIfAbsent() would have created lambdas on the hot path, this will not. var nodeArray = effectiveClassToNodeArrayMap.get(factClass); if (nodeArray == null) { - nodeArray = nodeNetwork.getForEachNodes(factClass) + nodeArray = nodeNetwork.getTupleSourceRootNodes(factClass) .filter(node -> node.supports(lifecycleOperation)) - .toArray(AbstractForEachUniNode[]::new); + .toArray(BavetRootNode[]::new); effectiveClassToNodeArrayMap.put(factClass, nodeArray); } return nodeArray; @@ -47,19 +53,22 @@ private AbstractForEachUniNode[] findNodes(Class factClass, Lifecycle public final void update(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.UPDATE)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.UPDATE)) { node.update(fact); } } public final void retract(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.RETRACT)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.RETRACT)) { node.retract(fact); } } public void settle() { + for (var node : settleNodes) { + node.settle(); + } nodeNetwork.settle(); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java index 62077a3ed4..ef6ff7ba7e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java @@ -7,8 +7,8 @@ import java.util.stream.Stream; import ai.timefold.solver.core.api.domain.solution.PlanningSolution; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.Propagator; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; /** * Represents Bavet's network of nodes, specific to a particular session. @@ -19,7 +19,7 @@ * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; * propagation needs to happen in this order. */ -public record NodeNetwork(Map, List>> declaredClassToNodeMap, +public record NodeNetwork(Map, List>> declaredClassToNodeMap, Propagator[][] layeredNodes) { public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]); @@ -32,14 +32,21 @@ public int layerCount() { return layeredNodes.length; } - public Stream> getForEachNodes(Class factClass) { + public Stream> getAllTupleSourceRootNodes() { // The node needs to match the fact, or the node needs to be applicable to the entire solution. // The latter is for FromSolution nodes. return declaredClassToNodeMap.entrySet() .stream() - .filter(entry -> factClass == PlanningSolution.class || entry.getKey().isAssignableFrom(factClass)) - .map(Map.Entry::getValue) - .flatMap(List::stream); + .flatMap(entry -> entry.getValue().stream()); + } + + public Stream> getTupleSourceRootNodes(Class factClass) { + // The node needs to match the fact, or the node needs to be applicable to the entire solution. + // The latter is for FromSolution nodes. + return declaredClassToNodeMap.entrySet() + .stream() + .flatMap(entry -> entry.getValue().stream()) + .filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass)); } public void settle() { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/StaticDataBiNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/StaticDataBiNode.java new file mode 100644 index 0000000000..813e7699fe --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/StaticDataBiNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.bi; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractStaticDataNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class StaticDataBiNode extends AbstractStaticDataNode> { + private final int outputStoreSize; + + public StaticDataBiNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected BiTuple remapTuple(BiTuple tuple) { + return new BiTuple<>(tuple.factA, tuple.factB, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java index 4bb364e006..fda966c91f 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java @@ -16,7 +16,6 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; public abstract class AbstractNodeBuildHelper { @@ -47,7 +46,7 @@ public void addNode(AbstractNode node, Stream_ creator) { public void addNode(AbstractNode node, Stream_ creator, Stream_ parent) { reversedNodeList.add(node); nodeCreatorMap.put(node, creator); - if (!(node instanceof AbstractForEachUniNode)) { + if (!(node instanceof BavetRootNode)) { if (parent == null) { throw new IllegalStateException("Impossible state: The node (%s) has no parent (%s)." .formatted(node, parent)); @@ -148,7 +147,7 @@ public AbstractNode findParentNode(Stream_ childNodeCreator) { } public static NodeNetwork buildNodeNetwork(List nodeList, - Map, List>> declaredClassToNodeMap) { + Map, List>> declaredClassToNodeMap) { var layerMap = new TreeMap>(); for (var node : nodeList) { layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>()) @@ -206,7 +205,7 @@ public > List long determineLayerIndex(AbstractNode node, AbstractNodeBuildHelper buildHelper) { - if (node instanceof AbstractForEachUniNode) { // ForEach nodes, and only they, are in layer 0. + if (node instanceof BavetRootNode) { // TupleSourceRoot nodes, and only they, are in layer 0. return 0; } else if (node instanceof AbstractTwoInputNode joinNode) { var nodeCreator = (BavetStreamBinaryOperation) buildHelper.getNodeCreatingStream(joinNode); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractStaticDataNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractStaticDataNode.java new file mode 100644 index 0000000000..b4bf7c8474 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractStaticDataNode.java @@ -0,0 +1,181 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState; +import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; +import ai.timefold.solver.core.impl.util.CollectionUtils; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public abstract class AbstractStaticDataNode extends AbstractNode + implements BavetRootNode { + private final StaticPropagationQueue propagationQueue; + private final Map> tupleMap = new IdentityHashMap<>(1000); + private final NodeNetwork innerNodeNetwork; + private final RecordingTupleLifecycle recordingTupleNode; + private final Class[] sourceClasses; + private final Set queuedInsertSet = CollectionUtils.newIdentityHashSet(32); + private final Set queuedUpdateSet = CollectionUtils.newIdentityHashSet(32); + private final Set queuedRetractSet = CollectionUtils.newIdentityHashSet(32); + + protected AbstractStaticDataNode(NodeNetwork innerNodeNetwork, + RecordingTupleLifecycle recordingTupleNode, + TupleLifecycle nextNodesTupleLifecycle, + Class[] sourceClasses) { + this.innerNodeNetwork = innerNodeNetwork; + this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); + this.sourceClasses = sourceClasses; + this.recordingTupleNode = recordingTupleNode; + } + + @Override + public final Propagator getPropagator() { + return propagationQueue; + } + + @Override + public final boolean allowsInstancesOf(Class clazz) { + for (var sourceClass : sourceClasses) { + if (sourceClass.isAssignableFrom(clazz)) { + return true; + } + } + return false; + } + + @Override + public final Class[] getSourceClasses() { + return sourceClasses; + } + + @Override + public final boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { + return true; + } + + @Override + public final void insert(Object a) { + // do not remove a retract of the same fact (a fact was updated) + queuedInsertSet.add(a); + } + + @Override + public final void update(Object a) { + queuedUpdateSet.add(a); + } + + private void updateExisting(@Nullable Object a, Tuple_ tuple) { + var state = tuple.state; + if (state.isDirty()) { + if (state == TupleState.DYING || state == TupleState.ABORTING) { + throw new IllegalStateException("The fact (%s) was retracted, so it cannot update." + .formatted(a)); + } + // CREATING or UPDATING is ignored; it's already in the queue. + } else { + propagationQueue.update(tuple); + } + } + + @Override + public final void retract(Object a) { + // remove an insert then retract (a fact was inserted but retracted before settling) + // do not remove a retract then insert (a fact was updated) + if (!queuedInsertSet.remove(a)) { + queuedRetractSet.add(a); + } + } + + @Override + public final void settle() { + if (!queuedRetractSet.isEmpty() || !queuedInsertSet.isEmpty()) { + invalidateCache(); + queuedUpdateSet.removeAll(queuedRetractSet); + queuedUpdateSet.removeAll(queuedInsertSet); + // Do not remove queued retracts from inserts; if a fact property + // change, there will be both a retract and insert for that fact + queuedRetractSet.forEach(this::retractFromInnerNodeNetwork); + queuedInsertSet.forEach(this::insertIntoInnerNodeNetwork); + queuedRetractSet.clear(); + queuedInsertSet.clear(); + + // settle the inner node network, so the inserts/retracts do not interfere + // with the recording of the first object's tuples + innerNodeNetwork.settle(); + recalculateTuples(); + } + for (var updatedObject : queuedUpdateSet) { + for (var updatedTuple : tupleMap.get(updatedObject)) { + updateExisting(updatedObject, updatedTuple); + } + } + queuedUpdateSet.clear(); + } + + private void insertNew(Tuple_ tuple) { + var state = tuple.state; + if (state != TupleState.CREATING) { + propagationQueue.insert(tuple); + } + } + + private void retractExisting(Tuple_ tuple) { + var state = tuple.state; + if (state.isDirty()) { + if (state == TupleState.DYING || state == TupleState.ABORTING) { + // We already retracted this tuple from another list, so we + // don't need to do anything + return; + } + propagationQueue.retract(tuple, state == TupleState.CREATING ? TupleState.ABORTING : TupleState.DYING); + } else { + propagationQueue.retract(tuple, TupleState.DYING); + } + } + + private void insertIntoInnerNodeNetwork(Object toInsert) { + tupleMap.put(toInsert, new ArrayList<>()); + innerNodeNetwork.getTupleSourceRootNodes(toInsert.getClass()) + .forEach(node -> ((AbstractForEachUniNode) node).insert(toInsert)); + } + + private void retractFromInnerNodeNetwork(Object toRetract) { + tupleMap.remove(toRetract); + innerNodeNetwork.getTupleSourceRootNodes(toRetract.getClass()) + .forEach(node -> ((AbstractForEachUniNode) node).retract(toRetract)); + } + + private void invalidateCache() { + tupleMap.values().stream().flatMap(List::stream).forEach(this::retractExisting); + recordingTupleNode.tupleRecorder().reset(); + } + + private void recalculateTuples() { + var recorder = recordingTupleNode.tupleRecorder(); + for (var mappedTupleEntry : tupleMap.entrySet()) { + mappedTupleEntry.getValue().clear(); + var invalidated = mappedTupleEntry.getKey(); + recorder.recordingInto(mappedTupleEntry.getValue(), this::remapTuple, () -> { + // Do a fake update on the object and settle the network; this will update precisely the + // tuples mapped to this node, which will then be recorded + innerNodeNetwork.getTupleSourceRootNodes(invalidated.getClass()) + .forEach(node -> ((AbstractForEachUniNode) node).update(invalidated)); + innerNodeNetwork.settle(); + }); + } + tupleMap.values().stream().flatMap(List::stream).forEach(this::insertNew); + } + + protected abstract Tuple_ remapTuple(Tuple_ tuple); +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java new file mode 100644 index 0000000000..ba90ad9857 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java @@ -0,0 +1,57 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public interface BavetRootNode { + void insert(A a); + + void update(A a); + + void retract(A a); + + void settle(); + + boolean allowsInstancesOf(Class clazz); + + Class[] getSourceClasses(); + + /** + * Determines if this node supports the given lifecycle operation. + * Unsupported nodes will not be called during that lifecycle operation. + * + * @param lifecycleOperation the lifecycle operation to check + * @return {@code true} if the given lifecycle operation is supported; otherwise, {@code false}. + */ + boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation); + + /** + * Represents the various lifecycle operations that can be performed + * on tuples within a node in Bavet. + */ + enum LifecycleOperation { + /** + * Represents the operation of inserting a new tuple into the node. + * This operation is typically performed when a new fact is added to the working solution + * and needs to be propagated through the node network. + */ + INSERT, + /** + * Represents the operation of updating an existing tuple within the node. + * This operation is typically triggered when a fact in the working solution + * is modified, requiring the corresponding tuple to be updated and its changes + * propagated through the node network. + */ + UPDATE, + /** + * Represents the operation of retracting or removing an existing tuple from the node. + * This operation is typically used when a fact is removed from the working solution + * and its corresponding tuple needs to be removed from the node network. + */ + RETRACT, + /** + * Represents the operation of recalculating the score, just prior to all queued operations being propagated. + */ + SETTLE + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java new file mode 100644 index 0000000000..2a9a57bb69 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java @@ -0,0 +1,36 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import java.util.IdentityHashMap; +import java.util.List; +import java.util.function.UnaryOperator; + +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; + +public final class TupleRecorder { + List recordedTupleList; + UnaryOperator mapper; + IdentityHashMap inputTupleToOutputTuple = new IdentityHashMap<>(); + + public void reset() { + inputTupleToOutputTuple.clear(); + } + + public void recordingInto(List recordedTupleList, UnaryOperator mapper, + Runnable runner) { + this.recordedTupleList = recordedTupleList; + this.mapper = mapper; + runner.run(); + this.recordedTupleList = null; + this.mapper = null; + } + + public boolean isRecording() { + return recordedTupleList != null; + } + + public void recordTuple(Tuple_ tuple) { + if (recordedTupleList != null) { + recordedTupleList.add(inputTupleToOutputTuple.computeIfAbsent(tuple, mapper)); + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java new file mode 100644 index 0000000000..38197da584 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java @@ -0,0 +1,26 @@ +package ai.timefold.solver.core.impl.bavet.common.tuple; + +import ai.timefold.solver.core.impl.bavet.common.TupleRecorder; + +public record RecordingTupleLifecycle(TupleRecorder tupleRecorder) + implements + TupleLifecycle { + + @Override + public void insert(Tuple_ tuple) { + if (tupleRecorder.isRecording()) { + throw new IllegalStateException("Impossible state: tuple %s was inserted during recording".formatted(tuple)); + } + } + + @Override + public void update(Tuple_ tuple) { + tupleRecorder.recordTuple(tuple); + } + + @Override + public void retract(Tuple_ tuple) { + // Not illegal; a filter can retract a never inserted tuple on update, + // since it does not remember what tuples it accepted to save memory + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java index a17752a3db..1d2490fee7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java @@ -7,6 +7,7 @@ import ai.timefold.solver.core.api.function.QuadPredicate; import ai.timefold.solver.core.api.function.TriPredicate; +import ai.timefold.solver.core.impl.bavet.common.TupleRecorder; public interface TupleLifecycle { @@ -51,6 +52,10 @@ static TupleLifecycle> conditionally(TupleLifecycle< tuple -> predicate.test(tuple.factA, tuple.factB, tuple.factC, tuple.factD)); } + static TupleLifecycle recording() { + return new RecordingTupleLifecycle<>(new TupleRecorder<>()); + } + void insert(Tuple_ tuple); void update(Tuple_ tuple); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/StaticDataQuadNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/StaticDataQuadNode.java new file mode 100644 index 0000000000..874bd63bf1 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/StaticDataQuadNode.java @@ -0,0 +1,29 @@ +package ai.timefold.solver.core.impl.bavet.quad; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractStaticDataNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.QuadTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class StaticDataQuadNode extends AbstractStaticDataNode> { + private final int outputStoreSize; + + public StaticDataQuadNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected QuadTuple remapTuple(QuadTuple tuple) { + return new QuadTuple<>(tuple.factA, tuple.factB, tuple.factC, tuple.factD, + outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/StaticDataTriNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/StaticDataTriNode.java new file mode 100644 index 0000000000..8ba46d346f --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/StaticDataTriNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.tri; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractStaticDataNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TriTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class StaticDataTriNode extends AbstractStaticDataNode> { + private final int outputStoreSize; + + public StaticDataTriNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected TriTuple remapTuple(TriTuple tuple) { + return new TriTuple<>(tuple.factA, tuple.factB, tuple.factC, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java index 510b3586cc..a5d893db3a 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java @@ -4,6 +4,7 @@ import java.util.Map; import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.Propagator; import ai.timefold.solver.core.impl.bavet.common.StaticPropagationQueue; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; @@ -24,6 +25,7 @@ @NullMarked public abstract sealed class AbstractForEachUniNode extends AbstractNode + implements BavetRootNode permits ForEachFilteredUniNode, ForEachUnfilteredUniNode { private final Class forEachClass; @@ -38,6 +40,17 @@ protected AbstractForEachUniNode(Class forEachClass, TupleLifecycle(nextNodesTupleLifecycle); } + @Override + public boolean allowsInstancesOf(Class clazz) { + return forEachClass.isAssignableFrom(clazz); + } + + @Override + public Class[] getSourceClasses() { + return new Class[] { forEachClass }; + } + + @Override public void insert(@Nullable A a) { var tuple = new UniTuple<>(a, outputStoreSize); var old = tupleMap.put(a, tuple); @@ -48,8 +61,6 @@ public void insert(@Nullable A a) { propagationQueue.insert(tuple); } - public abstract void update(@Nullable A a); - protected final void updateExisting(@Nullable A a, UniTuple tuple) { var state = tuple.state; if (state.isDirty()) { @@ -63,6 +74,7 @@ protected final void updateExisting(@Nullable A a, UniTuple tuple) { } } + @Override public void retract(@Nullable A a) { var tuple = tupleMap.remove(a); if (tuple == null) { @@ -85,6 +97,11 @@ protected void retractExisting(@Nullable A a, UniTuple tuple) { } } + @Override + public final void settle() { + // We don't need to do any operations + } + @Override public Propagator getPropagator() { return propagationQueue; @@ -94,45 +111,10 @@ public final Class getForEachClass() { return forEachClass; } - /** - * Determines if this node supports the given lifecycle operation. - * Unsupported nodes will not be called during that lifecycle operation. - * - * @param lifecycleOperation the lifecycle operation to check - * @return {@code true} if the given lifecycle operation is supported; otherwise, {@code false}. - */ - public abstract boolean supports(LifecycleOperation lifecycleOperation); - @Override public final String toString() { return "%s(%s)" .formatted(getClass().getSimpleName(), forEachClass.getSimpleName()); } - /** - * Represents the various lifecycle operations that can be performed - * on tuples within a node in Bavet. - */ - public enum LifecycleOperation { - /** - * Represents the operation of inserting a new tuple into the node. - * This operation is typically performed when a new fact is added to the working solution - * and needs to be propagated through the node network. - */ - INSERT, - /** - * Represents the operation of updating an existing tuple within the node. - * This operation is typically triggered when a fact in the working solution - * is modified, requiring the corresponding tuple to be updated and its changes - * propagated through the node network. - */ - UPDATE, - /** - * Represents the operation of retracting or removing an existing tuple from the node. - * This operation is typically used when a fact is removed from the working solution - * and its corresponding tuple needs to be removed from the node network. - */ - RETRACT - } - } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java index 9514b2cf39..9f7b7fdc3c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java @@ -3,6 +3,7 @@ import java.util.Objects; import java.util.function.Predicate; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -51,7 +52,7 @@ public void retract(@Nullable A a) { } @Override - public boolean supports(LifecycleOperation lifecycleOperation) { + public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java index 3660cad1e9..f08516bd1c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java @@ -1,5 +1,6 @@ package ai.timefold.solver.core.impl.bavet.uni; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -26,7 +27,7 @@ public void update(@Nullable A a) { } @Override - public boolean supports(LifecycleOperation lifecycleOperation) { + public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/StaticDataUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/StaticDataUniNode.java new file mode 100644 index 0000000000..91345d4819 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/StaticDataUniNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.uni; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractStaticDataNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class StaticDataUniNode extends AbstractStaticDataNode> { + private final int outputStoreSize; + + public StaticDataUniNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected UniTuple remapTuple(UniTuple tuple) { + return new UniTuple<>(tuple.factA, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java index 6d0a4f103d..879f50cffc 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java @@ -9,6 +9,7 @@ import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; @@ -38,7 +39,7 @@ public DatasetSession buildSession(SessionContext context) private NodeNetwork buildNodeNetwork(Set> enumeratingStreamSet, DataNodeBuildHelper buildHelper, Consumer nodeNetworkVisualizationConsumer) { - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(enumeratingStreamSet, buildHelper, AbstractEnumeratingStream::buildNode, node -> { if (!(node instanceof AbstractForEachUniNode forEachUniNode)) { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java index 3b64467bca..a3f46ef2d6 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java @@ -1,6 +1,6 @@ package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni; -import static ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode.LifecycleOperation; +import static ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; import java.util.Objects; import java.util.Set; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java index 3cd43f8132..9e16407b75 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java @@ -180,11 +180,17 @@ public void afterProblemFactAdded(Object problemFact) { super.afterProblemFactAdded(problemFact); } - // public void beforeProblemPropertyChanged(Object problemFactOrEntity) // Do nothing + @Override + public void beforeProblemPropertyChanged(Object problemFactOrEntity) { + // Since this is called when a fact (not a variable) changes, + // we need to retract and reinsert to update cached static data + super.beforeProblemPropertyChanged(problemFactOrEntity); + session.retract(problemFactOrEntity); + } @Override public void afterProblemPropertyChanged(Object problemFactOrEntity) { - session.update(problemFactOrEntity); + session.insert(problemFactOrEntity); super.afterProblemPropertyChanged(problemFactOrEntity); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java index 11720db829..f14788de11 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java @@ -7,7 +7,9 @@ import java.util.function.Function; import java.util.function.Predicate; +import ai.timefold.solver.core.api.score.stream.ConstraintStream; import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.StaticDataSupplier; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; import ai.timefold.solver.core.config.solver.EnvironmentMode; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; @@ -15,8 +17,20 @@ import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.bi.BavetAbstractBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.bi.BavetStaticDataBiConstraintStream; import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.quad.BavetAbstractQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.quad.BavetStaticDataQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.tri.BavetAbstractTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.tri.BavetStaticDataTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetAbstractUniConstraintStream; import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetForEachUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetStaticDataUniConstraintStream; import ai.timefold.solver.core.impl.score.stream.common.ForEachFilteringCriteria; import ai.timefold.solver.core.impl.score.stream.common.InnerConstraintFactory; import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; @@ -108,18 +122,23 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe private @NonNull UniConstraintStream forEachForCriteria(@NonNull Class sourceClass, ForEachFilteringCriteria criteria) { + return forEachForCriteria(sourceClass, criteria, RetrievalSemantics.STANDARD); + } + + private @NonNull UniConstraintStream forEachForCriteria(@NonNull Class sourceClass, + ForEachFilteringCriteria criteria, RetrievalSemantics retrievalSemantics) { assertValidFromType(sourceClass); var entityDescriptor = solutionDescriptor.findEntityDescriptor(sourceClass); if (entityDescriptor == null || criteria == ForEachFilteringCriteria.ALL) { // Not genuine or shadow entity, or filtering was not requested; no need for filtering. - return share(new BavetForEachUniConstraintStream<>(this, sourceClass, null, RetrievalSemantics.STANDARD)); + return share(new BavetForEachUniConstraintStream<>(this, sourceClass, null, retrievalSemantics)); } var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor(); if (listVariableDescriptor == null || !listVariableDescriptor.acceptsValueType(sourceClass)) { // No applicable list variable; don't need to check inverse relationships. return share(new BavetForEachUniConstraintStream<>(this, sourceClass, new ForEachFilteringCriteriaPredicateFunction<>(entityDescriptor, criteria), - RetrievalSemantics.STANDARD)); + retrievalSemantics)); } var entityClass = listVariableDescriptor.getEntityDescriptor().getEntityClass(); if (entityClass == sourceClass) { @@ -138,7 +157,7 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe } else { // We have the inverse relation variable, so we can read its value directly. return share(new BavetForEachUniConstraintStream<>(this, sourceClass, new ForEachFilteringCriteriaPredicateFunction<>(entityDescriptor, criteria), - RetrievalSemantics.STANDARD)); + retrievalSemantics)); } } @@ -157,6 +176,10 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL); } + @NonNull UniConstraintStream forEachUnfilteredStatic(@NonNull Class sourceClass) { + return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL, RetrievalSemantics.STATIC); + } + // Required for node sharing, since using a lambda will create different instances private record PredicateSupplier( Predicate suppliedPredicate) implements Function, Predicate> { @@ -179,6 +202,39 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe } } + @Override + @SuppressWarnings("unchecked") + public @NonNull Stream_ + staticData(StaticDataSupplier staticDataSupplier) { + var bavetStream = Objects.requireNonNull(staticDataSupplier.get(new BavetStaticDataFactory<>(this))); + // TODO: Use switch here in JDK 21 + if (bavetStream instanceof BavetAbstractUniConstraintStream uniStream) { + var out = new BavetStaticDataUniConstraintStream<>(this, + (BavetAbstractUniConstraintStream) uniStream); + return (Stream_) share(new BavetAftBridgeUniConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractBiConstraintStream biStream) { + var out = new BavetStaticDataBiConstraintStream<>(this, + (BavetAbstractBiConstraintStream) biStream); + return (Stream_) share(new BavetAftBridgeBiConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractTriConstraintStream triStream) { + var out = new BavetStaticDataTriConstraintStream<>(this, + (BavetAbstractTriConstraintStream) triStream); + return (Stream_) share(new BavetAftBridgeTriConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractQuadConstraintStream quadStream) { + var out = new BavetStaticDataQuadConstraintStream<>(this, + (BavetAbstractQuadConstraintStream) quadStream); + return (Stream_) share(new BavetAftBridgeQuadConstraintStream<>(this, out), + out::setAftBridge); + } else { + throw new IllegalStateException( + "impossible state: the supplier (%s) returned a stream (%s) that not an instance of any Bavet ConstraintStream" + .formatted(staticDataSupplier, bavetStream)); + } + } + @Override public @NonNull UniConstraintStream fromUnfiltered(@NonNull Class fromClass) { assertValidFromType(fromClass); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java index afeb0562f6..c38a493cde 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java @@ -15,6 +15,7 @@ import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.bavet.visual.NodeGraph; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; @@ -124,23 +125,33 @@ private static > NodeNetwork buildNodeNe AbstractScoreInliner scoreInliner, Consumer nodeNetworkVisualizationConsumer) { var buildHelper = new ConstraintNodeBuildHelper<>(consistencyTracker, constraintStreamSet, scoreInliner); - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(constraintStreamSet, buildHelper, BavetAbstractConstraintStream::buildNode, node -> { - if (!(node instanceof AbstractForEachUniNode forEachUniNode)) { + if (!(node instanceof BavetRootNode tupleSourceRoot)) { return; } - var forEachClass = forEachUniNode.getForEachClass(); - var forEachUniNodeList = - declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); - if (forEachUniNodeList.size() == 3) { - // Each class can have at most three forEach nodes: one including everything, one including consistent + null vars, the last consistent + no null vars. - throw new IllegalStateException( - "Impossible state: For class (%s) there are already 3 nodes (%s), not adding another (%s)." - .formatted(forEachClass, forEachUniNodeList, forEachUniNode)); + + if (tupleSourceRoot instanceof AbstractForEachUniNode forEachUniNode) { + var forEachClass = forEachUniNode.getForEachClass(); + var forEachUniNodeList = + declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); + if (forEachUniNodeList.stream().filter(sourceNode -> sourceNode instanceof AbstractForEachUniNode) + .count() == 3) { + // Each class can have at most three forEach nodes: one including everything, one including consistent + null vars, the last consistent + no null vars. + throw new IllegalStateException( + "Impossible state: For class (%s) there are already 3 nodes (%s), not adding another (%s)." + .formatted(forEachClass, forEachUniNodeList, forEachUniNode)); + } + forEachUniNodeList.add(forEachUniNode); + } else { + for (var sourceClass : tupleSourceRoot.getSourceClasses()) { + var forEachUniNodeList = + declaredClassToNodeMap.computeIfAbsent(sourceClass, k -> new ArrayList<>(2)); + forEachUniNodeList.add(tupleSourceRoot); + } } - forEachUniNodeList.add(forEachUniNode); }); if (nodeNetworkVisualizationConsumer != null) { var constraintSet = scoreInliner.getConstraints(); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java new file mode 100644 index 0000000000..a7b5feed35 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java @@ -0,0 +1,12 @@ +package ai.timefold.solver.core.impl.score.stream.bavet; + +import ai.timefold.solver.core.api.score.stream.StaticDataFactory; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; + +public record BavetStaticDataFactory( + BavetConstraintFactory constraintFactory) implements StaticDataFactory { + @Override + public UniConstraintStream forEachUnfiltered(Class sourceClass) { + return constraintFactory.forEachUnfilteredStatic(sourceClass); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java new file mode 100644 index 0000000000..60556e6d19 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java @@ -0,0 +1,21 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.bi; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingBiConstraintStream extends BavetAbstractBiConstraintStream { + protected BavetRecordingBiConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetStaticDataBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetStaticDataBiConstraintStream.java new file mode 100644 index 0000000000..478be44354 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetStaticDataBiConstraintStream.java @@ -0,0 +1,51 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.bi; + +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.bi.StaticDataBiNode; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetStaticDataBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetStaticDataBiConstraintStream extends BavetAbstractBiConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingStaticConstraintStream; + private BavetAftBridgeBiConstraintStream aftStream; + + public BavetStaticDataBiConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream staticConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingStaticConstraintStream = new BavetRecordingBiConstraintStream<>(constraintFactory, + staticConstraintStream); + staticConstraintStream.getChildStreamList().add(recordingStaticConstraintStream); + } + + public void setAftBridge(BavetAftBridgeBiConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var staticDataBuildHelper = new BavetStaticDataBuildHelper>(recordingStaticConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new StaticDataBiNode<>(staticDataBuildHelper.getNodeNetwork(), + staticDataBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + staticDataBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetStaticDataBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetStaticDataBuildHelper.java new file mode 100644 index 0000000000..f28c7e1023 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetStaticDataBuildHelper.java @@ -0,0 +1,81 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.common; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; +import ai.timefold.solver.core.impl.score.buildin.SimpleScoreDefinition; +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; + +public final class BavetStaticDataBuildHelper { + private final NodeNetwork nodeNetwork; + private final RecordingTupleLifecycle recordingTupleLifecycle; + private final Class[] sourceClasses; + + public BavetStaticDataBuildHelper(BavetAbstractConstraintStream staticConstraintStream) { + var streamList = new ArrayList>(); + var queue = new ArrayDeque>(); + queue.addLast(staticConstraintStream); + + while (!queue.isEmpty()) { + var current = queue.pollFirst(); + streamList.add(current); + if (current instanceof BavetConstraintStreamBinaryOperation binaryOperation) { + queue.addLast((BavetAbstractConstraintStream) binaryOperation.getLeftParent()); + queue.addLast((BavetAbstractConstraintStream) binaryOperation.getRightParent()); + } else { + if (current.getParent() != null) { + queue.addLast(current.getParent()); + } + } + } + Collections.reverse(streamList); + var streamSet = new LinkedHashSet<>(streamList); + + var buildHelper = new ConstraintNodeBuildHelper<>(new ConsistencyTracker<>(), streamSet, + AbstractScoreInliner.buildScoreInliner(new SimpleScoreDefinition(), Collections.emptyMap(), + ConstraintMatchPolicy.DISABLED)); + + var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var nodeList = buildHelper.buildNodeList(streamSet, buildHelper, + BavetAbstractConstraintStream::buildNode, + node -> { + if (!(node instanceof BavetRootNode sourceRootNode)) { + return; + } + var nodeSourceClasses = sourceRootNode.getSourceClasses(); + for (Class nodeSourceClass : nodeSourceClasses) { + var sourceNodeList = declaredClassToNodeMap.computeIfAbsent(nodeSourceClass, k -> new ArrayList<>(2)); + sourceNodeList.add(sourceRootNode); + } + }); + + this.nodeNetwork = AbstractNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap); + this.recordingTupleLifecycle = + (RecordingTupleLifecycle) buildHelper.getAggregatedTupleLifecycle(List.of(staticConstraintStream)); + this.sourceClasses = declaredClassToNodeMap.keySet().toArray(new Class[0]); + } + + public NodeNetwork getNodeNetwork() { + return nodeNetwork; + } + + public RecordingTupleLifecycle getRecordingTupleLifecycle() { + return recordingTupleLifecycle; + } + + public Class[] getSourceClasses() { + return sourceClasses; + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java new file mode 100644 index 0000000000..ffbaaf6dba --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java @@ -0,0 +1,22 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.quad; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingQuadConstraintStream + extends BavetAbstractQuadConstraintStream { + protected BavetRecordingQuadConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetStaticDataQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetStaticDataQuadConstraintStream.java new file mode 100644 index 0000000000..911ed8d0b5 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetStaticDataQuadConstraintStream.java @@ -0,0 +1,52 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.quad; + +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.bi.StaticDataBiNode; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetStaticDataBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetStaticDataQuadConstraintStream + extends BavetAbstractQuadConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingStaticConstraintStream; + private BavetAftBridgeQuadConstraintStream aftStream; + + public BavetStaticDataQuadConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream staticConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingStaticConstraintStream = new BavetRecordingQuadConstraintStream<>(constraintFactory, + staticConstraintStream); + staticConstraintStream.getChildStreamList().add(recordingStaticConstraintStream); + } + + public void setAftBridge(BavetAftBridgeQuadConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var staticDataBuildHelper = new BavetStaticDataBuildHelper>(recordingStaticConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new StaticDataBiNode<>(staticDataBuildHelper.getNodeNetwork(), + staticDataBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + staticDataBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java new file mode 100644 index 0000000000..23c93eb3cd --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java @@ -0,0 +1,22 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.tri; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingTriConstraintStream + extends BavetAbstractTriConstraintStream { + protected BavetRecordingTriConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetStaticDataTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetStaticDataTriConstraintStream.java new file mode 100644 index 0000000000..f95612e91a --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetStaticDataTriConstraintStream.java @@ -0,0 +1,51 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.tri; + +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.bi.StaticDataBiNode; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetStaticDataBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetStaticDataTriConstraintStream extends BavetAbstractTriConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingStaticConstraintStream; + private BavetAftBridgeTriConstraintStream aftStream; + + public BavetStaticDataTriConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream staticConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingStaticConstraintStream = new BavetRecordingTriConstraintStream<>(constraintFactory, + staticConstraintStream); + staticConstraintStream.getChildStreamList().add(recordingStaticConstraintStream); + } + + public void setAftBridge(BavetAftBridgeTriConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var staticDataBuildHelper = new BavetStaticDataBuildHelper>(recordingStaticConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new StaticDataBiNode<>(staticDataBuildHelper.getNodeNetwork(), + staticDataBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + staticDataBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java new file mode 100644 index 0000000000..4302c7bbed --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java @@ -0,0 +1,21 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.uni; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingUniConstraintStream extends BavetAbstractUniConstraintStream { + protected BavetRecordingUniConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetStaticDataUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetStaticDataUniConstraintStream.java new file mode 100644 index 0000000000..edf9fd27dc --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetStaticDataUniConstraintStream.java @@ -0,0 +1,51 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.uni; + +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.bavet.uni.StaticDataUniNode; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetStaticDataBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetStaticDataUniConstraintStream extends BavetAbstractUniConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingStaticConstraintStream; + private BavetAftBridgeUniConstraintStream aftStream; + + public BavetStaticDataUniConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream staticConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingStaticConstraintStream = new BavetRecordingUniConstraintStream<>(constraintFactory, + staticConstraintStream); + staticConstraintStream.getChildStreamList().add(recordingStaticConstraintStream); + } + + public void setAftBridge(BavetAftBridgeUniConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var staticDataBuildHelper = new BavetStaticDataBuildHelper>(recordingStaticConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new StaticDataUniNode<>(staticDataBuildHelper.getNodeNetwork(), + staticDataBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + staticDataBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java index fa19cc252c..8001c3f686 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java @@ -2,12 +2,13 @@ import ai.timefold.solver.core.api.domain.variable.PlanningVariable; import ai.timefold.solver.core.api.score.stream.ConstraintFactory; +import ai.timefold.solver.core.api.score.stream.StaticDataFactory; /** * Determines the behavior of joins and conditional propagation * based on whether they are coming off of a constraint stream started by - * either {@link ConstraintFactory#from(Class)} - * or {@link ConstraintFactory#forEach(Class)} + * either {@link ConstraintFactory#from(Class)}, {@link ConstraintFactory#forEach(Class)}, + * or {@link StaticDataFactory#forEachUnfiltered(Class)} * family of methods. * *

@@ -29,6 +30,15 @@ public enum RetrievalSemantics { * Applies when the stream comes off of a {@link ConstraintFactory#forEach(Class)} family of methods. */ STANDARD, + + /** + * Joins and conditional propagation always include entities with null planning variables, + * regardless of whether their planning variables allow unassigned values. + *

+ * Applies when the stream comes off of a {@link StaticDataFactory#forEachUnfiltered(Class)} family of methods. + */ + STATIC, + /** * Joins include entities with null planning variables if these variables allow unassigned values. * Conditional propagation always includes entities with null planning variables, diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java index e3125ff6d3..34c8c71ab1 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.bi; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Arrays; import java.util.Collection; @@ -45,53 +43,55 @@ static BiFunction> createDefaultIndictedObjectsMappin @Override default @NonNull TriConstraintStream join(@NonNull Class otherClass, TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifExists(@NonNull Class otherClass, TriJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull TriJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java index 51911c1737..816d6e2662 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java @@ -44,43 +44,45 @@ static QuadFunction> createDefaultIndicte @Override default @NonNull QuadConstraintStream ifExists(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java index 6d975c6cf5..c531cdf8c7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.tri; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Arrays; import java.util.Collection; @@ -46,53 +44,55 @@ static TriFunction> createDefaultIndictedObject @Override default @NonNull QuadConstraintStream join(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifExists(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java index 65efad8fe9..12932d3685 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.uni; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Collection; import java.util.Collections; @@ -47,11 +45,11 @@ static Function> createDefaultIndictedObjectsMapping() { @Override default @NonNull BiConstraintStream join(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } /** @@ -66,42 +64,44 @@ static Function> createDefaultIndictedObjectsMapping() { @Override default @NonNull UniConstraintStream ifExists(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case STATIC -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java index 0549b286ce..cf964a5b85 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java @@ -52,6 +52,7 @@ import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValueGroup; import org.junit.jupiter.api.TestTemplate; +import org.mockito.Mockito; public abstract class AbstractBiConstraintStreamTest extends AbstractConstraintStreamTest implements ConstraintStreamFunctionalTest { @@ -3306,4 +3307,84 @@ public void joinerEqualsAndSameness() { assertMatch(entity3, entity2)); } + @TestTemplate + public void staticData_join_filter_map_entity_right() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.staticData(data -> data.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntity.class) + .filter((value, entity) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup)) + .filter((value, entity) -> entity.getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(value1, entity1), + assertMatch(value2, entity1), + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2), + assertMatch(value1, entity3), + assertMatch(value2, entity3)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2), + assertMatch(value1, entity3), + assertMatch(value2, entity3)); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamTest.java index f525f0af14..85a376e4a6 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamTest.java @@ -72,6 +72,7 @@ import ai.timefold.solver.core.testdomain.unassignedvar.TestdataAllowsUnassignedSolution; import org.junit.jupiter.api.TestTemplate; +import org.mockito.Mockito; public abstract class AbstractUniConstraintStreamTest extends AbstractConstraintStreamTest @@ -3814,6 +3815,238 @@ public void fromIncludesNullWhenAllowsUnassigned() { assertMatch(entityWithValue)); } + @TestTemplate + public void staticData_filter_entity() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + solution.getEntityGroupList().add(entityGroup); + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, solution.getFirstValue())); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, solution.getFirstValue()); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + solution.getFirstValue()); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.staticData(data -> data.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() == entityGroup)) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(new TestdataLavishValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + } + + @TestTemplate + public void staticData_join_filter_map_entity_left() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + record EntityValuePair(TestdataLavishEntity entity, TestdataLavishValue value) { + } + var scoreDirector = + buildScoreDirector(factory -> factory.staticData(data -> data.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, value) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup) + .map(EntityValuePair::new)) + .filter(pair -> pair.entity.getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(entity1, value1)), + assertMatch(new EntityValuePair(entity1, value2)), + assertMatch(new EntityValuePair(entity2, value1)), + assertMatch(new EntityValuePair(entity2, value2))); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(entity2, value1)), + assertMatch(new EntityValuePair(entity2, value2))); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(entity2, value1)), + assertMatch(new EntityValuePair(entity2, value2)), + assertMatch(new EntityValuePair(entity3, value1)), + assertMatch(new EntityValuePair(entity3, value2))); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(entity2, value1)), + assertMatch(new EntityValuePair(entity2, value2))); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(entity2, value1)), + assertMatch(new EntityValuePair(entity2, value2)), + assertMatch(new EntityValuePair(entity3, value1)), + assertMatch(new EntityValuePair(entity3, value2))); + } + + @TestTemplate + public void staticData_join_filter_map_entity_right() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + record EntityValuePair(TestdataLavishValue value, TestdataLavishEntity entity) { + } + var scoreDirector = + buildScoreDirector(factory -> factory.staticData(data -> data.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntity.class) + .filter((value, entity) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup) + .map(EntityValuePair::new)) + .filter(pair -> pair.entity.getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(value1, entity1)), + assertMatch(new EntityValuePair(value2, entity1)), + assertMatch(new EntityValuePair(value1, entity2)), + assertMatch(new EntityValuePair(value2, entity2))); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(value1, entity2)), + assertMatch(new EntityValuePair(value2, entity2))); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(value1, entity2)), + assertMatch(new EntityValuePair(value2, entity2)), + assertMatch(new EntityValuePair(value1, entity3)), + assertMatch(new EntityValuePair(value2, entity3))); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(value1, entity2)), + assertMatch(new EntityValuePair(value2, entity2))); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(new EntityValuePair(value1, entity2)), + assertMatch(new EntityValuePair(value2, entity2)), + assertMatch(new EntityValuePair(value1, entity3)), + assertMatch(new EntityValuePair(value2, entity3))); + } + @TestTemplate public void constraintProvidedFromUnknownPackage() throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { diff --git a/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java b/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java index 8e269e8122..d54825bc73 100644 --- a/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java +++ b/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java @@ -3,6 +3,8 @@ import java.util.Objects; import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.api.score.stream.ConstraintStream; +import ai.timefold.solver.core.api.score.stream.StaticDataSupplier; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.score.stream.common.InnerConstraintFactory; @@ -43,6 +45,11 @@ public SolutionDescriptor getSolutionDescriptor() { throw new UnsupportedOperationException(); } + @Override + public @NonNull Stream_ staticData(StaticDataSupplier stream) { + throw new UnsupportedOperationException(); + } + @Override public @NonNull UniConstraintStream from(@NonNull Class fromClass) { throw new UnsupportedOperationException();