diff --git a/utils/load_data_sets.py b/utils/load_data_sets.py index 95f306f..806421e 100755 --- a/utils/load_data_sets.py +++ b/utils/load_data_sets.py @@ -88,6 +88,10 @@ def get_batch(self, batch_size): @staticmethod def from_positions_w_context(positions_w_context, is_test=False, extract_move_prob=False): positions, next_moves, results = zip(*positions_w_context) + + # Remove None types to prevent error in wrt_result + results = [i for i in results if i.result is not None] + extracted_features = bulk_extract_features(positions) if extract_move_prob: encoded_moves = np.asarray(next_moves)