Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ public interface ConstraintFactory {
*/
<A> @NonNull BiConstraintStream<A, A> forEachUniquePair(@NonNull Class<A> sourceClass, @NonNull BiJoiner<A, A>... joiners);

// ************************************************************************
// staticData
//************************************************************************

/**
* Computes and caches the tuples that would be produced by the given stream.
* As this is cached, it is vital the stream does not reference any variables
* (genuine or otherwise).
*/
<Stream_ extends ConstraintStream> @NonNull Stream_
staticData(StaticDataSupplier<@NonNull Stream_> staticDataSupplier);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: naming


// ************************************************************************
// from* (deprecated)
// ************************************************************************
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ai.timefold.solver.core.api.score.stream;

import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;

public interface StaticDataFactory {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: naming

<A> UniConstraintStream<A> forEachUnfiltered(Class<A> sourceClass);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package ai.timefold.solver.core.api.score.stream;

import org.jspecify.annotations.NullMarked;

@NullMarked
public interface StaticDataSupplier<Stream_ extends ConstraintStream> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: naming

Stream_ get(StaticDataFactory dataFactory);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,72 @@
import java.util.IdentityHashMap;
import java.util.Map;

import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode.LifecycleOperation;
import ai.timefold.solver.core.impl.bavet.common.BavetRootNode;
import ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation;

public abstract class AbstractSession {

private final NodeNetwork nodeNetwork;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> insertEffectiveClassToNodeArrayMap;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> updateEffectiveClassToNodeArrayMap;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> retractEffectiveClassToNodeArrayMap;
private final Map<Class<?>, BavetRootNode<Object>[]> insertEffectiveClassToNodeArrayMap;
private final Map<Class<?>, BavetRootNode<Object>[]> updateEffectiveClassToNodeArrayMap;
private final Map<Class<?>, BavetRootNode<Object>[]> retractEffectiveClassToNodeArrayMap;
private final BavetRootNode<Object>[] settleNodes;

protected AbstractSession(NodeNetwork nodeNetwork) {
this.nodeNetwork = nodeNetwork;
this.insertEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
this.updateEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
this.retractEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
this.settleNodes = nodeNetwork.getAllTupleSourceRootNodes()
.filter(node -> node.supports(LifecycleOperation.SETTLE))
.toArray(BavetRootNode[]::new);
}

public final void insert(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass, LifecycleOperation.INSERT)) {
for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.INSERT)) {
node.insert(fact);
}
}

@SuppressWarnings("unchecked")
private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass, LifecycleOperation lifecycleOperation) {
private BavetRootNode<Object>[] findNodes(Class<?> factClass, LifecycleOperation lifecycleOperation) {
var effectiveClassToNodeArrayMap = switch (lifecycleOperation) {
case INSERT -> insertEffectiveClassToNodeArrayMap;
case UPDATE -> updateEffectiveClassToNodeArrayMap;
case RETRACT -> retractEffectiveClassToNodeArrayMap;
case SETTLE ->
throw new IllegalArgumentException("impossible state: findNodes should not be called for settle nodes");
};
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
if (nodeArray == null) {
nodeArray = nodeNetwork.getForEachNodes(factClass)
nodeArray = nodeNetwork.getTupleSourceRootNodes(factClass)
.filter(node -> node.supports(lifecycleOperation))
.toArray(AbstractForEachUniNode[]::new);
.toArray(BavetRootNode[]::new);
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
}
return nodeArray;
}

public final void update(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass, LifecycleOperation.UPDATE)) {
for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.UPDATE)) {
node.update(fact);
}
}

public final void retract(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass, LifecycleOperation.RETRACT)) {
for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.RETRACT)) {
node.retract(fact);
}
}

public void settle() {
for (var node : settleNodes) {
node.settle();
}
nodeNetwork.settle();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import java.util.stream.Stream;

import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.impl.bavet.common.BavetRootNode;
import ai.timefold.solver.core.impl.bavet.common.Propagator;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;

/**
* Represents Bavet's network of nodes, specific to a particular session.
Expand All @@ -19,7 +19,7 @@
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
* propagation needs to happen in this order.
*/
public record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap,
public record NodeNetwork(Map<Class<?>, List<BavetRootNode<?>>> declaredClassToNodeMap,
Propagator[][] layeredNodes) {

public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);
Expand All @@ -32,14 +32,21 @@ public int layerCount() {
return layeredNodes.length;
}

public Stream<AbstractForEachUniNode<?>> getForEachNodes(Class<?> factClass) {
public Stream<BavetRootNode<?>> getAllTupleSourceRootNodes() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the type changed, the method name should change as well. In this case, probably getRootNodes() would suffice.

// The node needs to match the fact, or the node needs to be applicable to the entire solution.
// The latter is for FromSolution nodes.
return declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> factClass == PlanningSolution.class || entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.flatMap(List::stream);
.flatMap(entry -> entry.getValue().stream());
}

public Stream<BavetRootNode<?>> getTupleSourceRootNodes(Class<?> factClass) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dtto.

// The node needs to match the fact, or the node needs to be applicable to the entire solution.
// The latter is for FromSolution nodes.
return declaredClassToNodeMap.entrySet()
.stream()
.flatMap(entry -> entry.getValue().stream())
.filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass));
}

public void settle() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ai.timefold.solver.core.impl.bavet.bi;

import ai.timefold.solver.core.impl.bavet.NodeNetwork;
import ai.timefold.solver.core.impl.bavet.common.AbstractStaticDataNode;
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;

import org.jspecify.annotations.NullMarked;

@NullMarked
public final class StaticDataBiNode<A, B> extends AbstractStaticDataNode<BiTuple<A, B>> {
private final int outputStoreSize;

public StaticDataBiNode(NodeNetwork nodeNetwork,
RecordingTupleLifecycle<BiTuple<A, B>> recordingTupleNode,
int outputStoreSize,
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
Class<?>[] sourceClasses) {
super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses);
this.outputStoreSize = outputStoreSize;
}

@Override
protected BiTuple<A, B> remapTuple(BiTuple<A, B> tuple) {
return new BiTuple<>(tuple.factA, tuple.factB, outputStoreSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;

public abstract class AbstractNodeBuildHelper<Stream_ extends BavetStream> {

Expand Down Expand Up @@ -47,7 +46,7 @@ public void addNode(AbstractNode node, Stream_ creator) {
public void addNode(AbstractNode node, Stream_ creator, Stream_ parent) {
reversedNodeList.add(node);
nodeCreatorMap.put(node, creator);
if (!(node instanceof AbstractForEachUniNode<?>)) {
if (!(node instanceof BavetRootNode<?>)) {
if (parent == null) {
throw new IllegalStateException("Impossible state: The node (%s) has no parent (%s)."
.formatted(node, parent));
Expand Down Expand Up @@ -148,7 +147,7 @@ public AbstractNode findParentNode(Stream_ childNodeCreator) {
}

public static NodeNetwork buildNodeNetwork(List<AbstractNode> nodeList,
Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap) {
Map<Class<?>, List<BavetRootNode<?>>> declaredClassToNodeMap) {
var layerMap = new TreeMap<Long, List<Propagator>>();
for (var node : nodeList) {
layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>())
Expand Down Expand Up @@ -206,7 +205,7 @@ public <BuildHelper_ extends AbstractNodeBuildHelper<Stream_>> List<AbstractNode
@SuppressWarnings("unchecked")
private static <Stream_ extends BavetStream> long determineLayerIndex(AbstractNode node,
AbstractNodeBuildHelper<Stream_> buildHelper) {
if (node instanceof AbstractForEachUniNode<?>) { // ForEach nodes, and only they, are in layer 0.
if (node instanceof BavetRootNode<?>) { // TupleSourceRoot nodes, and only they, are in layer 0.
return 0;
} else if (node instanceof AbstractTwoInputNode<?, ?> joinNode) {
var nodeCreator = (BavetStreamBinaryOperation<?>) buildHelper.getNodeCreatingStream(joinNode);
Expand Down
Loading
Loading