Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>
#include <ranges>
#include <span>
#include <tuple>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -90,4 +91,142 @@ class QuadraticModelNode : public ScalarOutputMixin<ArrayNode> {
QuadraticModel quadratic_model_;
};

/// Defines a Zephyr lattice.
/// See https://docs.dwavequantum.com/en/latest/concepts/index.html#term-Zephyr
struct ZephyrLattice {
/// A single Zephyr cell.
constexpr ZephyrLattice() noexcept : ZephyrLattice(1) {}

/// A Zephyr lattice with grid parameter `m` and tile parameter `t`.
explicit(false) constexpr ZephyrLattice(ssize_t m, ssize_t t = 4) : m(m), t(t) {
if (m <= 0) throw std::invalid_argument("m must be positive");
if (t <= 0) throw std::invalid_argument("t must be positive");
}

/// Two Zephyr graphs are equivalent if they have the same grid and tile
/// parameters.
constexpr friend bool operator==(const ZephyrLattice& lhs, const ZephyrLattice& rhs) noexcept {
return lhs.m == rhs.m && lhs.t == rhs.t;
}

/// Return a vector of edges in the Zephyr lattice, sorted lexicographically.
std::vector<std::tuple<int, int>> edges() const;

/// The number of edges in a Zephyr lattice
constexpr ssize_t num_edges() const {
assert(m > 0 && "m must be positive");
assert(t > 0 && "t must be positive");
if (m == 1) return 2 * t * (8 * t + 3);
return 2 * t * ((8 * t + 8) * m * m - 2 * m - 3);
}

/// The number of nodes in a Zephyr lattice
constexpr ssize_t num_nodes() const {
assert(m > 0 && "m must be positive");
assert(t > 0 && "t must be positive");
return 4 * t * m * (2 * m + 1);
}

/// The grid parameter.
ssize_t m;

/// The tile parameter.
ssize_t t;
};

/// The lattice types we currently support.
template <typename T>
concept Lattice = std::same_as<T, ZephyrLattice>;

/// A node representing a quadratic model with linear and quadratic biases
/// structured according to the given lattice.
template <Lattice LatticeType>
class LatticeNode : public ScalarOutputMixin<ArrayNode> {
public:
LatticeNode() = delete;

/// Return a LatticeNode with all zero biases.
LatticeNode(ArrayNode* x_ptr, LatticeType lattice);

/// Construct a LatticeNode with the biases provided by the linear and quadratic functions.
LatticeNode(ArrayNode* x_ptr, LatticeType lattice,
std::function<double(int)> linear, // function to get the linear biases
std::function<double(int, int)> quadratic); // function to get the quadratic biases

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Node::initialize_state()
void initialize_state(State& state) const override;

/// Return a reference to the lattice structure of the node.
const LatticeType& lattice() const noexcept { return lattice_; }

template<class... Args>
static ssize_t lattice_num_edges(Args... args) {
return LatticeType(args...).num_edges();
}

template<class... Args>
static ssize_t lattice_num_nodes(Args... args) {
return LatticeType(args...).num_nodes();
}

/// Get the linear bias associated with `u`. Returns `0` if `u` is out-of-bounds.
double linear(int u) const noexcept;

/// @copydoc Node::propagate()
void propagate(State& state) const override;

/// Get the linear bias associated with `u` and `v`.
/// Returns `0` if `u` or `v` are out of bounds or if they have no interaction.
double quadratic(int u, int v) const noexcept;

/// @copydoc Node::revert()
void revert(State& state) const override;

private:
struct StateData;

ArrayNode* x_ptr_;

LatticeType lattice_;

struct neighbor {
neighbor() noexcept : neighbor(-1, 0) {}
neighbor(int v) noexcept : neighbor(v, 0) {}
neighbor(int v, double bias) noexcept : v(v), bias(bias) {}

// We want to be able to sort neighborhoods, so make neighbor weakly ordered
friend bool operator==(const neighbor& lhs, const neighbor& rhs) noexcept {
return lhs.v == rhs.v;
}
friend bool operator!=(const neighbor& lhs, const neighbor& rhs) noexcept {
return lhs.v != rhs.v;
}
friend std::weak_ordering operator<=>(const neighbor& lhs, const neighbor& rhs) noexcept {
return lhs.v <=> rhs.v;
}

int v; // the variable index
double bias; // the quadratic bias
};

struct neighborhood {
std::vector<neighbor> neighbors;
double bias; // the linear bias
};

// Store the graph in an adjacency format
std::vector<neighborhood> adj_;
};

using ZephyrNode = LatticeNode<ZephyrLattice>;

} // namespace dwave::optimization
7 changes: 7 additions & 0 deletions dwave/optimization/libcpp/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from libcpp.functional cimport function

# As of Cython 3.0.8 these are not in Cython's libcpp

cdef extern from "<functional>" namespace "std" nogil:
# We just do the overloads we need
function[double(int)] bind_front(double(void*, int), void*)
function[double(int, int)] bind_front(double(void*, int, int), void*)

cdef extern from "<span>" namespace "std" nogil:
cdef cppclass span[T]:
ctypedef size_t size_type
Expand Down
8 changes: 8 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ cdef extern from "dwave-optimization/nodes/quadratic_model.hpp" namespace "dwave
cdef cppclass QuadraticModelNode(ArrayNode):
QuadraticModel* get_quadratic_model()

cdef cppclass ZephyrNode(ArrayNode):
@staticmethod
Py_ssize_t lattice_num_nodes(Py_ssize_t m)
@staticmethod
Py_ssize_t lattice_num_edges(Py_ssize_t m)
double linear(int v)
double quadratic(int u, int v)


cdef extern from "dwave-optimization/nodes/testing.hpp" namespace "dwave::optimization" nogil:
cdef cppclass ArrayValidationNode(Node):
Expand Down
Loading