diff --git a/sam_audio/processor.py b/sam_audio/processor.py index a85a1974..79356445 100644 --- a/sam_audio/processor.py +++ b/sam_audio/processor.py @@ -119,8 +119,8 @@ def process_anchors(self, anchors: Optional[list[list[Anchor]]]): anchor_ids = pad_sequence( ids, batch_first=True, padding_value=anchor_dict[""] ) - self.anchor_ids = anchor_ids - self.anchor_alignment = anchor_alignment + self.anchor_ids = anchor_ids.to(self.audios.device) + self.anchor_alignment = anchor_alignment.to(self.audios.device) self.anchors = anchors