Skip to content

Commit 63b9ecd

Browse files
committed
[CP-SAT] tweak and improve code
1 parent 8ff5dbe commit 63b9ecd

File tree

5 files changed

+59
-57
lines changed

5 files changed

+59
-57
lines changed

ortools/sat/integer.cc

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,8 @@ int IntegerTrail::FindTrailIndexOfVarBefore(IntegerVariable var,
983983
int IntegerTrail::FindLowestTrailIndexThatExplainBound(
984984
IntegerLiteral i_lit) const {
985985
DCHECK_LE(i_lit.bound, var_lbs_[i_lit.var]);
986-
if (i_lit.bound <= LevelZeroLowerBound(i_lit.var)) return -1;
986+
DCHECK(!IsTrueAtLevelZero(i_lit));
987+
987988
int trail_index = var_trail_index_[i_lit.var];
988989

989990
// Check the validity of the cached index and use it if possible. This caching
@@ -1003,6 +1004,7 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound(
10031004

10041005
int prev_trail_index = trail_index;
10051006
while (true) {
1007+
++work_done_in_explain_lower_than_;
10061008
if (trail_index >= var_trail_index_cache_threshold_) {
10071009
var_trail_index_cache_[i_lit.var] = trail_index;
10081010
}
@@ -1171,10 +1173,9 @@ std::vector<Literal>* IntegerTrail::InitializeConflict(
11711173
lazy_reasons_.back().Explain(conflict, &tmp_queue_);
11721174
} else {
11731175
conflict->assign(literals_reason.begin(), literals_reason.end());
1174-
const int num_vars = var_lbs_.size();
11751176
for (const IntegerLiteral& literal : bounds_reason) {
1176-
const int trail_index = FindLowestTrailIndexThatExplainBound(literal);
1177-
if (trail_index >= num_vars) tmp_queue_.push_back(trail_index);
1177+
if (IsTrueAtLevelZero(literal)) continue;
1178+
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal));
11781179
}
11791180
}
11801181
return conflict;
@@ -1553,9 +1554,8 @@ bool IntegerTrail::EnqueueInternal(
15531554
// efficiency and a potential smaller reason.
15541555
auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason,
15551556
integer_reason);
1556-
{
1557-
const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason);
1558-
if (trail_index >= 0) tmp_queue_.push_back(trail_index);
1557+
if (!IsTrueAtLevelZero(ub_reason)) {
1558+
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(ub_reason));
15591559
}
15601560
MergeReasonIntoInternal(conflict, NextConflictId());
15611561
return false;
@@ -1771,12 +1771,10 @@ absl::Span<const int> IntegerTrail::Dependencies(int reason_index) const {
17711771

17721772
int new_size = 0;
17731773
int* data = trail_index_reason_buffer_.data() + start;
1774-
const int num_vars = var_lbs_.size();
17751774
for (int i = start; i < end; ++i) {
1776-
const int dep =
1777-
FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]);
1778-
if (dep >= num_vars) {
1779-
data[new_size++] = dep;
1775+
const IntegerLiteral to_explain = bounds_reason_buffer_[i];
1776+
if (!IsTrueAtLevelZero(to_explain)) {
1777+
data[new_size++] = FindLowestTrailIndexThatExplainBound(to_explain);
17801778
}
17811779
}
17821780
cached_sizes_[reason_index] = new_size;
@@ -1818,14 +1816,10 @@ std::vector<Literal> IntegerTrail::ReasonFor(IntegerLiteral literal) const {
18181816
void IntegerTrail::MergeReasonInto(absl::Span<const IntegerLiteral> literals,
18191817
std::vector<Literal>* output) const {
18201818
DCHECK(tmp_queue_.empty());
1821-
const int num_vars = var_lbs_.size();
18221819
for (const IntegerLiteral& literal : literals) {
18231820
if (literal.IsAlwaysTrue()) continue;
1824-
const int trail_index = FindLowestTrailIndexThatExplainBound(literal);
1825-
1826-
// Any indices lower than that means that there is no reason needed.
1827-
// Note that it is important for size to be signed because of -1 indices.
1828-
if (trail_index >= num_vars) tmp_queue_.push_back(trail_index);
1821+
if (IsTrueAtLevelZero(literal)) continue;
1822+
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal));
18291823
}
18301824
return MergeReasonIntoInternal(output, -1);
18311825
}

ortools/sat/integer.h

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ class IntegerTrail final : public SatPropagator {
523523
// Returns the current value (if known) of an IntegerLiteral.
524524
bool IntegerLiteralIsTrue(IntegerLiteral l) const;
525525
bool IntegerLiteralIsFalse(IntegerLiteral l) const;
526+
bool IsTrueAtLevelZero(IntegerLiteral l) const;
526527

527528
// Returns globally valid lower/upper bound on the given integer variable.
528529
IntegerValue LevelZeroLowerBound(IntegerVariable var) const;
@@ -796,39 +797,38 @@ class IntegerTrail final : public SatPropagator {
796797
void AddAllGreaterThanConstantReason(absl::Span<AffineExpression> exprs,
797798
IntegerValue target_min,
798799
std::vector<int>* indices) const {
799-
int64_t num_processed = 0;
800+
constexpr int64_t check_period = 1e6;
801+
int64_t limit_check = work_done_in_explain_lower_than_ + check_period;
800802
for (const AffineExpression& expr : exprs) {
801803
if (expr.IsConstant()) {
802804
DCHECK_GE(expr.constant, target_min);
803805
continue;
804806
}
805807
DCHECK_NE(expr.var, kNoIntegerVariable);
808+
const IntegerLiteral to_explain = expr.GreaterOrEqual(target_min);
809+
if (IsTrueAtLevelZero(to_explain)) continue;
806810

807811
// On large routing problems, we can spend a lot of time in this loop.
808-
// We check the time limit every 5 processed expressions.
809-
if (++num_processed % 5 == 0 && time_limit_->LimitReached()) return;
812+
if (work_done_in_explain_lower_than_ > limit_check) {
813+
limit_check = work_done_in_explain_lower_than_ + check_period;
814+
if (time_limit_->LimitReached()) return;
815+
}
810816

811817
// Skip if we already have an explanation for expr >= target_min. Note
812818
// that we already do that while processing the returned indices, so this
813819
// mainly save a FindLowestTrailIndexThatExplainBound() call per skipped
814820
// indices, which can still be costly.
815821
{
816-
const int index = tmp_var_to_trail_index_in_queue_[expr.var];
822+
const int index = tmp_var_to_trail_index_in_queue_[to_explain.var];
817823
if (index == std::numeric_limits<int>::max()) continue;
818-
if (index > 0 &&
819-
expr.ValueAt(integer_trail_[index].bound) >= target_min) {
824+
if (index > 0 && integer_trail_[index].bound >= to_explain.bound) {
820825
has_dependency_ = true;
821826
continue;
822827
}
823828
}
824829

825830
// We need to find the index that explain the bound.
826-
// Note that this will skip if the condition is true at level zero.
827-
const int index =
828-
FindLowestTrailIndexThatExplainBound(expr.GreaterOrEqual(target_min));
829-
if (index >= 0) {
830-
indices->push_back(index);
831-
}
831+
indices->push_back(FindLowestTrailIndexThatExplainBound(to_explain));
832832
}
833833
}
834834

@@ -885,8 +885,8 @@ class IntegerTrail final : public SatPropagator {
885885
int64_t conflict_id) const;
886886

887887
// Returns the lowest trail index of a TrailEntry that can be used to explain
888-
// the given IntegerLiteral. The literal must be currently true (CHECKed).
889-
// Returns -1 if the explanation is trivial.
888+
// the given IntegerLiteral. The literal must be currently true but not true
889+
// at level zero (DCHECKed).
890890
int FindLowestTrailIndexThatExplainBound(IntegerLiteral i_lit) const;
891891

892892
// This must be called before Dependencies() or AppendLiteralsReason().
@@ -1033,6 +1033,8 @@ class IntegerTrail final : public SatPropagator {
10331033
std::vector<SparseBitset<IntegerVariable>*> watchers_;
10341034
std::vector<ReversibleInterface*> reversible_classes_;
10351035

1036+
mutable int64_t work_done_in_explain_lower_than_ = 0;
1037+
10361038
mutable Domain temp_domain_;
10371039
DelayedRootLevelDeduction* delayed_to_fix_;
10381040
IntegerDomains* domains_;
@@ -1417,6 +1419,10 @@ inline bool IntegerTrail::IntegerLiteralIsFalse(IntegerLiteral l) const {
14171419
return l.bound > UpperBound(l.var);
14181420
}
14191421

1422+
inline bool IntegerTrail::IsTrueAtLevelZero(IntegerLiteral l) const {
1423+
return l.bound <= LevelZeroLowerBound(l.var);
1424+
}
1425+
14201426
// The level zero bounds are stored at the beginning of the trail and they also
14211427
// serves as sentinels. Their index match the variables index.
14221428
inline IntegerValue IntegerTrail::LevelZeroLowerBound(

ortools/sat/integer_base.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,6 @@ IntegerValue BestBinaryRelationBounds::GetUpperBound(
214214
return kMaxIntegerValue;
215215
}
216216

217-
// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically
218-
// get the better function, and it documents when we have canonicalized
219-
// expression.
220-
IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized(
221-
LinearExpression2 expr) const {
222-
DCHECK_EQ(expr.DivideByGcd(), 1);
223-
DCHECK(expr.IsCanonicalized());
224-
const bool negated = expr.NegateForCanonicalization();
225-
const auto it = best_bounds_.find(expr);
226-
if (it != best_bounds_.end()) {
227-
const auto [known_lb, known_ub] = it->second;
228-
if (negated) {
229-
return -known_lb;
230-
} else {
231-
return known_ub;
232-
}
233-
}
234-
return kMaxIntegerValue;
235-
}
236-
237217
std::vector<std::pair<LinearExpression2, IntegerValue>>
238218
BestBinaryRelationBounds::GetSortedNonTrivialUpperBounds() const {
239219
std::vector<std::pair<LinearExpression2, IntegerValue>> root_relations_sorted;

ortools/sat/integer_base.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,28 @@ std::ostream& operator<<(std::ostream& os, const ValueLiteralPair& p);
559559
DEFINE_STRONG_INDEX_TYPE(IntervalVariable);
560560
const IntervalVariable kNoIntervalVariable(-1);
561561

562+
// This functions appears in hot spot, and so it is important to inline it.
563+
//
564+
// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically
565+
// get the better function, and it documents when we have canonicalized
566+
// expression.
567+
inline IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized(
568+
LinearExpression2 expr) const {
569+
DCHECK_EQ(expr.DivideByGcd(), 1);
570+
DCHECK(expr.IsCanonicalized());
571+
const bool negated = expr.NegateForCanonicalization();
572+
const auto it = best_bounds_.find(expr);
573+
if (it != best_bounds_.end()) {
574+
const auto [known_lb, known_ub] = it->second;
575+
if (negated) {
576+
return -known_lb;
577+
} else {
578+
return known_ub;
579+
}
580+
}
581+
return kMaxIntegerValue;
582+
}
583+
562584
// ============================================================================
563585
// Implementation.
564586
// ============================================================================
@@ -599,8 +621,8 @@ inline IntegerLiteral AffineExpression::GreaterOrEqual(
599621
: IntegerLiteral::FalseLiteral();
600622
}
601623
DCHECK_GT(coeff, 0);
602-
return IntegerLiteral::GreaterOrEqual(var,
603-
CeilRatio(bound - constant, coeff));
624+
return IntegerLiteral::GreaterOrEqual(
625+
var, coeff == 1 ? bound - constant : CeilRatio(bound - constant, coeff));
604626
}
605627

606628
// var * coeff + constant <= bound.
@@ -610,7 +632,8 @@ inline IntegerLiteral AffineExpression::LowerOrEqual(IntegerValue bound) const {
610632
: IntegerLiteral::FalseLiteral();
611633
}
612634
DCHECK_GT(coeff, 0);
613-
return IntegerLiteral::LowerOrEqual(var, FloorRatio(bound - constant, coeff));
635+
return IntegerLiteral::LowerOrEqual(
636+
var, coeff == 1 ? bound - constant : FloorRatio(bound - constant, coeff));
614637
}
615638

616639
} // namespace sat

ortools/sat/precedences.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,8 +1943,7 @@ IntegerValue Linear2Bounds::NonTrivialUpperBoundForGcd1(
19431943
}
19441944
DCHECK_NE(expr.coeffs[1], 0);
19451945
DCHECK_EQ(1, expr.DivideByGcd());
1946-
IntegerValue ub = kMaxIntegerValue;
1947-
ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(expr));
1946+
IntegerValue ub = root_level_bounds_->GetUpperBoundNoTrail(expr);
19481947
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr));
19491948
ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr));
19501949
return ub;

0 commit comments

Comments
 (0)