Skip to content

Commit f7edc74

Browse files
committed
fix: parameters and cost scaling
1 parent 8e2b650 commit f7edc74

File tree

4 files changed

+24
-15
lines changed

4 files changed

+24
-15
lines changed

examples/sampling_c3/anything/parameters/sampling_c3_options.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ warm_start: false
1111
end_on_qp_step: false
1212
solve_time_filter_alpha: 0.95
1313
publish_frequency: 0
14-
penalize_changes_in_u_across_solves: false # Penalize (u-u_prev) instead of u.
14+
penalize_changes_in_u_across_solves: true # Penalize (u-u_prev) instead of u.
1515
num_friction_directions: 2
1616

1717
N: 5

examples/sampling_c3/anything/parameters/sampling_c3plus_options.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ scale_lcs: true
1212
end_on_qp_step: false
1313
solve_time_filter_alpha: 0.95
1414
publish_frequency: 0
15-
penalize_input_change: false # Penalize (u-u_prev) instead of u.
15+
penalize_input_change: true # Penalize (u-u_prev) instead of u.
1616
num_friction_directions: 2
1717
spring_stiffness: 0.0 # Not used in C3+.
18-
final_augmented_cost_scaling: 1000.0
1918

2019
N: 7
2120
gamma: 1.0 # discount factor on MPC costs
@@ -88,7 +87,7 @@ q_vector: [0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 150, 150, 120, 0.1, 0.1, 0.1, 0
8887
r_vector: [0.01, 0.01, 1]
8988

9089
# Penalty on matching projected variables
91-
g_x: [950, 950, 950, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
90+
g_x: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
9291
g_u: [0, 0, 0]
9392
g_gamma_list: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
9493
g_lambda_n_list: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
@@ -124,7 +123,7 @@ q_vector_position: [0.01, 0.01, 0.01, 0, 0, 0, 0, 200, 200, 120, 0, 0, 0, 0, 200
124123
r_vector_position: [0.01, 0.01, 1]
125124

126125
# Penalty on matching projected variables
127-
g_x_position: [950, 950, 950, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
126+
g_x_position: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
128127
g_u_position: [0, 0, 0]
129128
g_gamma_position_list: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
130129
g_lambda_n_position_list: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
@@ -156,7 +155,7 @@ u_eta_position_list: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
156155
# use_predicted_x0_reset_mechanism
157156
dt: 0 # instead: planning_dt_pose, planning_dt_position
158157
# solve_dt: 0 # unused
159-
dt_cost: 0
158+
dt_cost: 0
160159
mu: [] # instead based on indexing into mu_per_pair_type
161160
num_contacts: 0 # instead based on summing index of resolve_contacts_to_lists
162161
# Instead for the below, index into their _list versions.
@@ -177,3 +176,5 @@ u_eta_slack: []
177176
u_eta_n: []
178177
u_eta_t: []
179178
u_eta: []
179+
final_augmented_cost_contact_scaling: 1000
180+
final_augmented_cost_contact_indices: [0, 1, 2, 3]

examples/sampling_c3/multiyaml_rewrite.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,6 @@ def update_c3_options(is_c3_plus, samp_c3_options_yaml_path):
249249

250250
samp_c3_options_yaml["q_vector_position"] = q_vector_position
251251

252-
samp_c3_options_yaml["g_x"] = (
253-
[950] * 3 + [1] * (7 * num_objects) + [0.1] * (3 + 6 * num_objects)
254-
)
255252
samp_c3_options_yaml["g_gamma_list"] = [
256253
[1] * calculate_contacts(num_objects, include_walls * num_objects)
257254
]
@@ -263,6 +260,10 @@ def update_c3_options(is_c3_plus, samp_c3_options_yaml_path):
263260
]
264261

265262
if is_c3_plus:
263+
n_x = 6 + (13 * num_objects)
264+
n_lambda = 4 * calculate_contacts(num_objects, include_walls * num_objects) - 2 * sum(samp_c3_options_yaml["resolve_as_planar_contacts_list"])
265+
n_u = 3
266+
samp_c3_options_yaml["g_x"] = [0] * (3 + 7 * num_objects + 3 + 6 * num_objects)
266267
samp_c3_options_yaml["g_eta_slack_list"] = [
267268
[1] * calculate_contacts(num_objects, include_walls * num_objects)
268269
]
@@ -278,7 +279,12 @@ def update_c3_options(is_c3_plus, samp_c3_options_yaml_path):
278279
samp_c3_options_yaml["g_lambda_list"] = [
279280
[2] * (4 * calculate_contacts(num_objects, include_walls * num_objects))
280281
]
282+
samp_c3_options_yaml["final_augmented_cost_contact_scaling"] = 1000
283+
samp_c3_options_yaml["final_augmented_cost_contact_indices"] = [i for i in range(4)]
281284
else:
285+
samp_c3_options_yaml["g_x"] = (
286+
[950] * 3 + [1] * (7 * num_objects) + [0.1] * (3 + 6 * num_objects)
287+
)
282288
samp_c3_options_yaml["g_lambda_list"] = [
283289
[0.05] * (4 * calculate_contacts(num_objects, include_walls * num_objects))
284290
]
@@ -336,6 +342,9 @@ def update_c3_options(is_c3_plus, samp_c3_options_yaml_path):
336342
]
337343

338344
if is_c3_plus:
345+
n_x = 6 + (13 * num_objects)
346+
n_lambda = 4 * calculate_contacts(num_objects, include_walls * num_objects) - 2 * sum(samp_c3_options_yaml["resolve_as_planar_contacts_list"])
347+
n_u = 3
339348
samp_c3_options_yaml["g_eta_slack_position_list"] = [
340349
[1] * calculate_contacts(num_objects, include_walls * num_objects)
341350
]
@@ -351,14 +360,11 @@ def update_c3_options(is_c3_plus, samp_c3_options_yaml_path):
351360
samp_c3_options_yaml["g_eta_position_list"] = [
352361
[1] * (4 * calculate_contacts(num_objects, include_walls * num_objects))
353362
]
354-
samp_c3_options_yaml["g_x_position"] = (
355-
[950] * 3 + [1] * (7 * num_objects) + [0.1] * (3 + 6 * num_objects)
356-
)
357363
else:
358364
samp_c3_options_yaml["g_lambda_position_list"] = [
359365
[0.005] * (4 * calculate_contacts(num_objects, include_walls * num_objects))
360366
]
361-
samp_c3_options_yaml["g_x_position"] = [0] * (6 + (13 * num_objects))
367+
samp_c3_options_yaml["g_x_position"] = [0] * (6 + (13 * num_objects))
362368

363369
samp_c3_options_yaml["u_gamma_position_list"] = [
364370
[1] * calculate_contacts(num_objects, include_walls * num_objects)
@@ -551,7 +557,6 @@ def main():
551557

552558
min_max_z = min(max_zs_world)
553559

554-
555560
# Update c3_options
556561
is_c3_plus = "plus" in samp_c3_options_yaml_path
557562
update_c3_options(is_c3_plus, samp_c3_options_yaml_path)

examples/sampling_c3/parameter_headers/sampling_c3_options.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,10 @@ struct SamplingC3Options : C3Options, LCSFactoryOptions {
402402
c3_options->M = M;
403403
c3_options->qp_projection_alpha = qp_projection_alpha;
404404
c3_options->qp_projection_scaling = qp_projection_scaling;
405-
c3_options->final_augmented_cost_scaling = final_augmented_cost_scaling;
405+
c3_options->final_augmented_cost_contact_scaling =
406+
final_augmented_cost_contact_scaling;
407+
c3_options->final_augmented_cost_contact_indices =
408+
final_augmented_cost_contact_indices;
406409

407410
lcs_factory_options->contact_model = contact_model;
408411
lcs_factory_options->num_contacts = num_contacts;

0 commit comments

Comments
 (0)