Skip to content

Commit 71fbdd5

Browse files
Simplify PlotCallback
1 parent 08e709c commit 71fbdd5

File tree

9 files changed

+25
-23
lines changed

9 files changed

+25
-23
lines changed

trajopt/include/trajopt/plot_callback.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Returns a callback function suitable for an Optimizer.
1212
This callback will plot the trajectory (with translucent copies of the robot) as
1313
well as all of the Cost and Constraint functions with plot methods
1414
*/
15-
sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter);
15+
sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter);
16+
1617
/**
1718
* @brief Returns a callback suitable for an optimizer but does not require the problem
1819
* @param plotter

trajopt/src/plot_callback.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,17 @@ void PlotCosts(const tesseract_visualization::Visualization::Ptr& plotter,
4747
plotter->waitForInput();
4848
}
4949

50-
sco::Optimizer::Callback PlotCallback(TrajOptProb& prob, const tesseract_visualization::Visualization::Ptr& plotter)
50+
sco::Optimizer::Callback PlotCallback(const tesseract_visualization::Visualization::Ptr& plotter)
5151
{
52-
return [&prob, plotter](sco::OptProb*, sco::OptResults& results) {
53-
auto state_solver = prob.GetEnv()->getStateSolver();
52+
return [plotter](sco::OptProb* prob, sco::OptResults& results) {
53+
auto& trajopt_prob = dynamic_cast<TrajOptProb&>(*prob);
54+
auto state_solver = trajopt_prob.GetEnv()->getStateSolver();
5455
PlotCosts(plotter,
5556
*state_solver,
56-
prob.GetKin()->getJointNames(),
57-
std::ref(prob.getCosts()),
58-
prob.getConstraints(),
59-
std::ref(prob.GetVars()),
57+
trajopt_prob.GetKin()->getJointNames(),
58+
std::ref(trajopt_prob.getCosts()),
59+
trajopt_prob.getConstraints(),
60+
std::ref(trajopt_prob.GetVars()),
6061
results);
6162
};
6263
}

trajopt/src/problem_description.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ TrajOptResult::Ptr OptimizeProblem(const TrajOptProb::Ptr& prob,
394394
param.improve_ratio_threshold = .2;
395395
param.initial_merit_error_coeff = 20;
396396
if (plotter)
397-
opt.addCallback(PlotCallback(*prob, plotter));
397+
opt.addCallback(PlotCallback(plotter));
398398
opt.initialize(trajToDblVec(prob->GetInitTraj()));
399399
opt.optimize();
400400
return std::make_shared<TrajOptResult>(opt.results(), *prob);

trajopt/test/cast_cost_attached_unit.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ TEST_F(CastAttachedTest, LinkWithGeom) // NOLINT
126126

127127
sco::BasicTrustRegionSQP opt(prob);
128128
if (plotting)
129-
opt.addCallback(PlotCallback(*prob, plotter_));
129+
opt.addCallback(PlotCallback(plotter_));
130130
opt.initialize(trajToDblVec(prob->GetInitTraj()));
131131
opt.optimize();
132132

@@ -176,7 +176,7 @@ TEST_F(CastAttachedTest, LinkWithoutGeom) // NOLINT
176176

177177
sco::BasicTrustRegionSQP opt(prob);
178178
if (plotting)
179-
opt.addCallback(PlotCallback(*prob, plotter_));
179+
opt.addCallback(PlotCallback(plotter_));
180180
opt.initialize(trajToDblVec(prob->GetInitTraj()));
181181
opt.optimize();
182182

trajopt/test/cast_cost_octomap_unit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ TEST_F(CastOctomapTest, boxes) // NOLINT
125125

126126
sco::BasicTrustRegionSQP opt(prob);
127127
if (plotting)
128-
opt.addCallback(PlotCallback(*prob, plotter_));
128+
opt.addCallback(PlotCallback(plotter_));
129129
opt.initialize(trajToDblVec(prob->GetInitTraj()));
130130
opt.optimize();
131131

trajopt/test/cast_cost_unit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ TEST_F(CastTest, boxes) // NOLINT
8686

8787
sco::BasicTrustRegionSQP opt(prob);
8888
if (plotting)
89-
opt.addCallback(PlotCallback(*prob, plotter_));
89+
opt.addCallback(PlotCallback(plotter_));
9090
opt.initialize(trajToDblVec(prob->GetInitTraj()));
9191
opt.optimize();
9292

trajopt/test/cast_cost_world_unit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ TEST_F(CastWorldTest, boxes) // NOLINT
109109

110110
sco::BasicTrustRegionSQP opt(prob);
111111
if (plotting)
112-
opt.addCallback(PlotCallback(*prob, plotter_));
112+
opt.addCallback(PlotCallback(plotter_));
113113
opt.initialize(trajToDblVec(prob->GetInitTraj()));
114114
opt.optimize();
115115

trajopt/test/joint_costs_unit.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ TEST_F(CostsTest, equality_jointPos) // NOLINT
107107
sco::BasicTrustRegionSQP opt(prob);
108108
if (plotting)
109109
{
110-
opt.addCallback(PlotCallback(*prob, plotter_));
110+
opt.addCallback(PlotCallback(plotter_));
111111
}
112112

113113
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -214,7 +214,7 @@ TEST_F(CostsTest, inequality_jointPos) // NOLINT
214214
sco::BasicTrustRegionSQP opt(prob);
215215
if (plotting)
216216
{
217-
opt.addCallback(PlotCallback(*prob, plotter_));
217+
opt.addCallback(PlotCallback(plotter_));
218218
}
219219

220220
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -309,7 +309,7 @@ TEST_F(CostsTest, equality_jointVel) // NOLINT
309309
sco::BasicTrustRegionSQP opt(prob);
310310
if (plotting)
311311
{
312-
opt.addCallback(PlotCallback(*prob, plotter_));
312+
opt.addCallback(PlotCallback(plotter_));
313313
}
314314

315315
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -416,7 +416,7 @@ TEST_F(CostsTest, inequality_jointVel) // NOLINT
416416
sco::BasicTrustRegionSQP opt(prob);
417417
if (plotting)
418418
{
419-
opt.addCallback(PlotCallback(*prob, plotter_));
419+
opt.addCallback(PlotCallback(plotter_));
420420
}
421421

422422
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -514,7 +514,7 @@ TEST_F(CostsTest, equality_jointVel_time) // NOLINT
514514
sco::BasicTrustRegionSQP opt(prob);
515515
if (plotting)
516516
{
517-
opt.addCallback(PlotCallback(*prob, plotter_));
517+
opt.addCallback(PlotCallback(plotter_));
518518
}
519519

520520
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -627,7 +627,7 @@ TEST_F(CostsTest, inequality_jointVel_time) // NOLINT
627627
sco::BasicTrustRegionSQP opt(prob);
628628
if (plotting)
629629
{
630-
opt.addCallback(PlotCallback(*prob, plotter_));
630+
opt.addCallback(PlotCallback(plotter_));
631631
}
632632

633633
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -722,7 +722,7 @@ TEST_F(CostsTest, equality_jointAcc) // NOLINT
722722
sco::BasicTrustRegionSQP opt(prob);
723723
if (plotting)
724724
{
725-
opt.addCallback(PlotCallback(*prob, plotter_));
725+
opt.addCallback(PlotCallback(plotter_));
726726
}
727727

728728
opt.initialize(trajToDblVec(prob->GetInitTraj()));
@@ -831,7 +831,7 @@ TEST_F(CostsTest, inequality_jointAcc) // NOLINT
831831
sco::BasicTrustRegionSQP opt(prob);
832832
if (plotting)
833833
{
834-
opt.addCallback(PlotCallback(*prob, plotter_));
834+
opt.addCallback(PlotCallback(plotter_));
835835
}
836836

837837
opt.initialize(trajToDblVec(prob->GetInitTraj()));

trajopt/test/simple_collision_unit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ TEST_F(SimpleCollisionTest, spheres) // NOLINT
8686

8787
sco::BasicTrustRegionSQP opt(prob);
8888
if (plotting)
89-
opt.addCallback(PlotCallback(*prob, plotter_));
89+
opt.addCallback(PlotCallback(plotter_));
9090
opt.initialize(trajToDblVec(prob->GetInitTraj()));
9191
opt.optimize();
9292

0 commit comments

Comments
 (0)