Skip to content

Commit 3d48575

Browse files
cdtwiggfacebook-github-bot
authored andcommitted
Python bindings for VertexVertexDistanceErrorFunction. (#578)
Summary: Pull Request resolved: #578 ### Added Python Bindings for VertexVertexDistanceErrorFunction #### Summary This diff adds Python bindings for the `VertexVertexDistanceErrorFunction` class, allowing it to be used directly from Python. This change enables the use of the `VertexVertexDistanceErrorFunction` in Python, expanding the functionality of the `pymomentum` library. #### Example Use Case With this change, you can now create and use `VertexVertexDistanceErrorFunction` objects directly from Python, like this: ```python import pymomentum error_function = pymomentum.solver2.VertexVertexDistanceErrorFunction() ``` Reviewed By: jeongseok-meta Differential Revision: D82848959 fbshipit-source-id: d36b4dc85b5d3578dd8f67066ed3022717822995
1 parent 286a0ee commit 3d48575

File tree

2 files changed

+298
-0
lines changed

2 files changed

+298
-0
lines changed

pymomentum/solver2/solver2_error_functions.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <momentum/character_solver/skeleton_error_function.h>
2525
#include <momentum/character_solver/state_error_function.h>
2626
#include <momentum/character_solver/vertex_projection_error_function.h>
27+
#include <momentum/character_solver/vertex_vertex_distance_error_function.h>
2728
#include <pymomentum/solver2/solver2_utility.h>
2829

2930
#include <fmt/format.h>
@@ -1238,6 +1239,162 @@ distance is greater than zero (ie. the point being above).)")
12381239
py::arg("name") = std::optional<std::vector<std::string>>{});
12391240
}
12401241

1242+
void defVertexVertexDistanceErrorFunction(py::module_& m) {
1243+
py::class_<mm::VertexVertexDistanceConstraintT<float>>(
1244+
m, "VertexVertexDistanceConstraint")
1245+
.def(
1246+
"__repr__",
1247+
[](const mm::VertexVertexDistanceConstraintT<float>& self) {
1248+
return fmt::format(
1249+
"VertexVertexDistanceConstraint(vertex_index1={}, vertex_index2={}, weight={}, target_distance={})",
1250+
self.vertexIndex1,
1251+
self.vertexIndex2,
1252+
self.weight,
1253+
self.targetDistance);
1254+
})
1255+
.def_readonly(
1256+
"vertex_index1",
1257+
&mm::VertexVertexDistanceConstraintT<float>::vertexIndex1,
1258+
"The index of the first vertex")
1259+
.def_readonly(
1260+
"vertex_index2",
1261+
&mm::VertexVertexDistanceConstraintT<float>::vertexIndex2,
1262+
"The index of the second vertex")
1263+
.def_readonly(
1264+
"weight",
1265+
&mm::VertexVertexDistanceConstraintT<float>::weight,
1266+
"The weight of the constraint")
1267+
.def_readonly(
1268+
"target_distance",
1269+
&mm::VertexVertexDistanceConstraintT<float>::targetDistance,
1270+
"The target distance between the two vertices");
1271+
1272+
py::class_<
1273+
mm::VertexVertexDistanceErrorFunctionT<float>,
1274+
mm::SkeletonErrorFunction,
1275+
std::shared_ptr<mm::VertexVertexDistanceErrorFunctionT<float>>>(
1276+
m,
1277+
"VertexVertexDistanceErrorFunction",
1278+
R"(Error function that minimizes the distance between pairs of vertices on the character's mesh.
1279+
1280+
This is useful for constraints where you want to maintain specific distances between
1281+
different parts of the character mesh, such as keeping fingers at a certain distance
1282+
or maintaining the width of body parts.)")
1283+
.def(
1284+
"__repr__",
1285+
[](const mm::VertexVertexDistanceErrorFunctionT<float>& self) {
1286+
return fmt::format(
1287+
"VertexVertexDistanceErrorFunction(weight={}, num_constraints={})",
1288+
self.getWeight(),
1289+
self.numConstraints());
1290+
})
1291+
.def(
1292+
py::init<>(
1293+
[](const mm::Character& character, float weight)
1294+
-> std::shared_ptr<
1295+
mm::VertexVertexDistanceErrorFunctionT<float>> {
1296+
validateWeight(weight, "weight");
1297+
auto result = std::make_shared<
1298+
mm::VertexVertexDistanceErrorFunctionT<float>>(character);
1299+
result->setWeight(weight);
1300+
return result;
1301+
}),
1302+
R"(Initialize a VertexVertexDistanceErrorFunction.
1303+
1304+
:param character: The character to use.
1305+
:param weight: The weight applied to the error function.)",
1306+
py::keep_alive<1, 2>(),
1307+
py::arg("character"),
1308+
py::kw_only(),
1309+
py::arg("weight") = 1.0f)
1310+
.def(
1311+
"add_constraint",
1312+
[](mm::VertexVertexDistanceErrorFunctionT<float>& self,
1313+
int vertexIndex1,
1314+
int vertexIndex2,
1315+
float weight,
1316+
float targetDistance) {
1317+
validateVertexIndex(
1318+
vertexIndex1, "vertex_index1", self.getCharacter());
1319+
validateVertexIndex(
1320+
vertexIndex2, "vertex_index2", self.getCharacter());
1321+
validateWeight(weight, "weight");
1322+
self.addConstraint(
1323+
vertexIndex1, vertexIndex2, weight, targetDistance);
1324+
},
1325+
R"(Adds a vertex-to-vertex distance constraint to the error function.
1326+
1327+
:param vertex_index1: The index of the first vertex.
1328+
:param vertex_index2: The index of the second vertex.
1329+
:param weight: The weight of the constraint.
1330+
:param target_distance: The desired distance between the two vertices.)",
1331+
py::arg("vertex_index1"),
1332+
py::arg("vertex_index2"),
1333+
py::arg("weight"),
1334+
py::arg("target_distance"))
1335+
.def(
1336+
"add_constraints",
1337+
[](mm::VertexVertexDistanceErrorFunctionT<float>& self,
1338+
const py::array_t<int>& vertexIndex1,
1339+
const py::array_t<int>& vertexIndex2,
1340+
const py::array_t<float>& weight,
1341+
const py::array_t<float>& targetDistance) {
1342+
ArrayShapeValidator validator;
1343+
const int nConsIdx = -1;
1344+
validator.validate(
1345+
vertexIndex1, "vertex_index1", {nConsIdx}, {"n_cons"});
1346+
validateVertexIndex(
1347+
vertexIndex1, "vertex_index1", self.getCharacter());
1348+
validator.validate(
1349+
vertexIndex2, "vertex_index2", {nConsIdx}, {"n_cons"});
1350+
validateVertexIndex(
1351+
vertexIndex2, "vertex_index2", self.getCharacter());
1352+
validator.validate(weight, "weight", {nConsIdx}, {"n_cons"});
1353+
validateWeights(weight, "weight");
1354+
validator.validate(
1355+
targetDistance, "target_distance", {nConsIdx}, {"n_cons"});
1356+
1357+
auto vertexIndex1Acc = vertexIndex1.unchecked<1>();
1358+
auto vertexIndex2Acc = vertexIndex2.unchecked<1>();
1359+
auto weightAcc = weight.unchecked<1>();
1360+
auto targetDistanceAcc = targetDistance.unchecked<1>();
1361+
1362+
py::gil_scoped_release release;
1363+
1364+
for (py::ssize_t i = 0; i < vertexIndex1.shape(0); ++i) {
1365+
self.addConstraint(
1366+
vertexIndex1Acc(i),
1367+
vertexIndex2Acc(i),
1368+
weightAcc(i),
1369+
targetDistanceAcc(i));
1370+
}
1371+
},
1372+
R"(Adds multiple vertex-to-vertex distance constraints to the error function.
1373+
1374+
:param vertex_index1: A numpy array of indices for the first vertices.
1375+
:param vertex_index2: A numpy array of indices for the second vertices.
1376+
:param weight: A numpy array of weights for the constraints.
1377+
:param target_distance: A numpy array of desired distances between vertex pairs.)",
1378+
py::arg("vertex_index1"),
1379+
py::arg("vertex_index2"),
1380+
py::arg("weight"),
1381+
py::arg("target_distance"))
1382+
.def(
1383+
"clear_constraints",
1384+
&mm::VertexVertexDistanceErrorFunctionT<float>::clearConstraints,
1385+
"Clears all vertex-to-vertex distance constraints from the error function.")
1386+
.def_property_readonly(
1387+
"constraints",
1388+
[](const mm::VertexVertexDistanceErrorFunctionT<float>& self) {
1389+
return self.getConstraints();
1390+
},
1391+
"Returns the list of vertex-to-vertex distance constraints.")
1392+
.def(
1393+
"num_constraints",
1394+
&mm::VertexVertexDistanceErrorFunctionT<float>::numConstraints,
1395+
"Returns the number of constraints.");
1396+
}
1397+
12411398
} // namespace
12421399

12431400
void addErrorFunctions(py::module_& m) {
@@ -2237,6 +2394,9 @@ rotation matrix to a target rotation.)")
22372394

22382395
// Vertex Projection error function
22392396
defVertexProjectionErrorFunction(m);
2397+
2398+
// Vertex-to-vertex distance error function
2399+
defVertexVertexDistanceErrorFunction(m);
22402400
}
22412401

22422402
} // namespace pymomentum

pymomentum/test/test_solver2.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,144 @@ def test_vertex_sequence_error_function(self) -> None:
14481448
self.assertIn("VertexSequenceErrorFunction", repr_str)
14491449
self.assertIn("num_constraints=2", repr_str)
14501450

1451+
def test_vertex_vertex_distance_constraint(self) -> None:
1452+
"""Test VertexVertexDistanceErrorFunction to ensure vertices are pulled to target distance."""
1453+
1454+
# Create a test character
1455+
character = pym_geometry.create_test_character(num_joints=4)
1456+
1457+
n_params = character.parameter_transform.size
1458+
1459+
# Ensure repeatability in the rng:
1460+
torch.manual_seed(0)
1461+
model_params_init = torch.zeros(n_params, dtype=torch.float32)
1462+
1463+
# Choose two vertices to constrain - use vertices that are initially far apart
1464+
vertex_index1 = 0
1465+
vertex_index2 = character.mesh.vertices.shape[0] - 1 # Last vertex
1466+
target_distance = 0.5 # Target distance between the two vertices
1467+
weight = 1.0
1468+
1469+
# Get initial positions of the vertices
1470+
skel_state_init = pym_geometry.model_parameters_to_skeleton_state(
1471+
character, model_params_init
1472+
)
1473+
initial_mesh = character.skin_points(skel_state_init)
1474+
initial_pos1 = initial_mesh[vertex_index1, :3]
1475+
initial_pos2 = initial_mesh[vertex_index2, :3]
1476+
initial_distance = torch.norm(initial_pos2 - initial_pos1).item()
1477+
1478+
# Create VertexVertexDistanceErrorFunction
1479+
vertex_distance_error = pym_solver2.VertexVertexDistanceErrorFunction(character)
1480+
1481+
# Test basic properties
1482+
self.assertEqual(vertex_distance_error.num_constraints(), 0)
1483+
self.assertEqual(len(vertex_distance_error.constraints), 0)
1484+
1485+
# Add a single constraint
1486+
vertex_distance_error.add_constraint(
1487+
vertex_index1=vertex_index1,
1488+
vertex_index2=vertex_index2,
1489+
weight=weight,
1490+
target_distance=target_distance,
1491+
)
1492+
1493+
# Verify constraint was added
1494+
self.assertEqual(vertex_distance_error.num_constraints(), 1)
1495+
self.assertEqual(len(vertex_distance_error.constraints), 1)
1496+
1497+
constraint = vertex_distance_error.constraints[0]
1498+
self.assertEqual(constraint.vertex_index1, vertex_index1)
1499+
self.assertEqual(constraint.vertex_index2, vertex_index2)
1500+
self.assertAlmostEqual(constraint.weight, weight)
1501+
self.assertAlmostEqual(constraint.target_distance, target_distance)
1502+
1503+
# Create solver function with the vertex distance error
1504+
solver_function = pym_solver2.SkeletonSolverFunction(
1505+
character, [vertex_distance_error]
1506+
)
1507+
1508+
# Set solver options
1509+
solver_options = pym_solver2.GaussNewtonSolverOptions()
1510+
solver_options.max_iterations = 100
1511+
solver_options.regularization = 1e-5
1512+
1513+
# Create and run the solver
1514+
solver = pym_solver2.GaussNewtonSolver(solver_function, solver_options)
1515+
model_params_final = solver.solve(model_params_init.numpy())
1516+
1517+
# Convert final model parameters to skeleton state
1518+
skel_state_final = pym_geometry.model_parameters_to_skeleton_state(
1519+
character, torch.from_numpy(model_params_final)
1520+
)
1521+
1522+
# Compute final mesh and vertex positions
1523+
final_mesh = character.skin_points(skel_state_final)
1524+
final_pos1 = final_mesh[vertex_index1, :3]
1525+
final_pos2 = final_mesh[vertex_index2, :3]
1526+
final_distance = torch.norm(final_pos2 - final_pos1).item()
1527+
1528+
# Assert that the final distance is close to the target distance
1529+
self.assertAlmostEqual(
1530+
final_distance,
1531+
target_distance,
1532+
delta=1e-3,
1533+
msg=f"Final distance {final_distance} does not match target {target_distance}",
1534+
)
1535+
1536+
# Verify that the distance actually changed from the initial distance
1537+
self.assertNotAlmostEqual(
1538+
initial_distance,
1539+
final_distance,
1540+
delta=1e-1,
1541+
msg=f"Distance did not change significantly from initial {initial_distance} to final {final_distance}",
1542+
)
1543+
1544+
# Test multiple constraints using add_constraints
1545+
vertex_distance_error.clear_constraints()
1546+
self.assertEqual(vertex_distance_error.num_constraints(), 0)
1547+
1548+
# Add multiple constraints
1549+
vertex_indices1 = np.array([0, 1], dtype=np.int32)
1550+
vertex_indices2 = np.array([2, 3], dtype=np.int32)
1551+
weights = np.array([1.0, 2.0], dtype=np.float32)
1552+
target_distances = np.array([0.3, 0.7], dtype=np.float32)
1553+
1554+
vertex_distance_error.add_constraints(
1555+
vertex_index1=vertex_indices1,
1556+
vertex_index2=vertex_indices2,
1557+
weight=weights,
1558+
target_distance=target_distances,
1559+
)
1560+
1561+
# Verify multiple constraints were added
1562+
self.assertEqual(vertex_distance_error.num_constraints(), 2)
1563+
constraints = vertex_distance_error.constraints
1564+
self.assertEqual(len(constraints), 2)
1565+
1566+
# Check first constraint
1567+
self.assertEqual(constraints[0].vertex_index1, 0)
1568+
self.assertEqual(constraints[0].vertex_index2, 2)
1569+
self.assertAlmostEqual(constraints[0].weight, 1.0)
1570+
self.assertAlmostEqual(constraints[0].target_distance, 0.3)
1571+
1572+
# Check second constraint
1573+
self.assertEqual(constraints[1].vertex_index1, 1)
1574+
self.assertEqual(constraints[1].vertex_index2, 3)
1575+
self.assertAlmostEqual(constraints[1].weight, 2.0)
1576+
self.assertAlmostEqual(constraints[1].target_distance, 0.7)
1577+
1578+
# Test string representation
1579+
repr_str = repr(vertex_distance_error)
1580+
self.assertIn("VertexVertexDistanceErrorFunction", repr_str)
1581+
self.assertIn("num_constraints=2", repr_str)
1582+
1583+
# Test constraint string representation
1584+
constraint_repr = repr(constraints[0])
1585+
self.assertIn("VertexVertexDistanceConstraint", constraint_repr)
1586+
self.assertIn("vertex_index1=0", constraint_repr)
1587+
self.assertIn("vertex_index2=2", constraint_repr)
1588+
14511589
def test_weight_validation(self) -> None:
14521590
"""Test that error functions throw ValueError when negative weights are passed."""
14531591

0 commit comments

Comments
 (0)