diff --git a/easyeditor/models/rome/rome_main.py b/easyeditor/models/rome/rome_main.py index fa63b15..f69bb3e 100644 --- a/easyeditor/models/rome/rome_main.py +++ b/easyeditor/models/rome/rome_main.py @@ -31,6 +31,7 @@ def set_min_to_zero(pos_matrix, matrix, possibility): indices_to_zero = np.argpartition(flattened_pos_matrix, num_zeros)[:num_zeros] flattened_matrix[indices_to_zero] = 0 modified_matrix = flattened_matrix.reshape(matrix.shape) + return modified_matrix def apply_rome_to_model( model: AutoModelForCausalLM,