Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 203 additions & 97 deletions src/deeprm/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import tqdm

from deeprm.inference import inference_preprocess_python as stream_prep
from deeprm.inference.inference_dataloader import load_dataset
from deeprm.inference.pileup_deeprm import main as pileup_main
from deeprm.utils import check_deps
Expand All @@ -37,7 +38,39 @@ def add_arguments(parser: argparse.ArgumentParser):
Returns:
None
"""
parser.add_argument("--input", "-i", dest="data", type=str, required=True, help="Data path")
parser.add_argument("--input", "-i", dest="data", type=str, default=None, help="Preprocessed data path")
parser.add_argument(
"--preprocess-mode",
choices=["disk", "stream"],
default="disk",
help="disk: read from --input directory, stream: preprocess on-the-fly from --pod5.",
)
parser.add_argument("--pod5", type=str, default=None, help="POD5 path for --preprocess-mode stream")
parser.add_argument(
"--stream-save-dir",
type=str,
default=None,
help="Optional output directory to also persist streamed preprocessed chunks as .npz",
)
# Streaming preprocess knobs (mirrors preprocess CLI).
parser.add_argument(
"--prep-thread",
type=int,
default=max(1, int((os.cpu_count() or 1) * 0.95)),
help="Prep threads",
)
parser.add_argument("--prep-qcut", type=int, default=0, help="Prep BQ cutoff")
parser.add_argument("--prep-chunk", type=int, default=16000, help="Prep chunk size")
parser.add_argument("--prep-max-token-len", type=int, default=200, help="Prep max token length")
parser.add_argument("--prep-sampling", type=int, default=6, help="Prep sampling rate")
parser.add_argument("--prep-boi", type=str, default="A", help="Prep base of interest")
parser.add_argument("--prep-kmer-len", type=int, default=5, help="Prep k-mer length")
parser.add_argument("--prep-cb-len", type=int, default=21, help="Prep context block length")
parser.add_argument("--prep-bam-thread", type=int, default=4, help="Prep BAM thread per process")
parser.add_argument("--prep-process-once", type=int, default=1000, help="Prep reads per processing batch")
parser.add_argument("--prep-dwell-shift", type=int, default=10, help="Prep dwell shift")
parser.add_argument("--prep-sig-window", type=int, default=5, help="Prep signal window")
parser.add_argument("--prep-label-div", type=int, default=10**9, help="Prep label divisor")
parser.add_argument("--bam", "-b", type=str, required=True, help="BAM file path")
parser.add_argument("--output", "-o", type=str, required=True, help="Output path")
parser.add_argument("--model", "-m", type=str, default=None, help="Model path")
Expand Down Expand Up @@ -79,6 +112,18 @@ def _validate_args(args: argparse.Namespace) -> None:
raise ValueError("--threshold must satisfy 0 <= threshold < 1.")
if args.epsilon <= 0:
raise ValueError("--epsilon must be positive.")
if args.preprocess_mode == "disk":
if not args.data:
raise ValueError("--input is required when --preprocess-mode=disk.")
if args.data.endswith("/"):
args.data = args.data[:-1]
if not os.path.isdir(args.data):
raise ValueError("Invalid data path. It should be a directory containing data files.")
else:
if not args.pod5:
raise ValueError("--pod5 is required when --preprocess-mode=stream.")
if args.resume:
raise ValueError("--resume is only supported when --preprocess-mode=disk.")


def _normalize_gpu_config(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -127,15 +172,11 @@ def main(args: argparse.Namespace):
args.model = os.path.join(deeprm_root, "weight", "deeprm_weights.pt")
if not args.model.endswith(".pt"):
raise ValueError("Invalid model path. It should be a .pt file.")
if args.data.endswith("/"):
args.data = args.data[:-1]
if not os.path.isdir(args.data):
raise ValueError("Invalid data path. It should be a directory containing data files.")

_validate_args(args)
_normalize_gpu_config(args)

output = f"{args.output}/{os.path.basename(args.data)}"
dataset_name = os.path.basename(args.data) if args.preprocess_mode == "disk" else "streamed-preprocess"
output = f"{args.output}/{dataset_name}"
if len(args.postfix) > 0:
output = f"{output}-{args.postfix}"
args.output = output
Expand Down Expand Up @@ -177,13 +218,166 @@ def run_inference(args):
log.info(f"Model path: {args.model}")
log.info(f"Output directory: {args.output}")

if args.preprocess_mode == "stream":
run_inference_stream(args)
return None

if args.num_gpu > 0:
mp.spawn(inference_worker, nprocs=args.num_gpu, args=(vars(args),), join=True)
else:
inference_worker(0, vars(args))
return None


def _run_chunk_inference(args_dict, rank, device, model, chunk_idx, chunk_data_cpu):
n_rows = int(chunk_data_cpu["label_id"].shape[0])
if n_rows == 0:
log.warning("Skipping empty input shard at rank=%d chunk=%d.", rank, chunk_idx)
return False

batch_splits = {k: torch.split(v, args_dict["batch"]) for k, v in chunk_data_cpu.items()}
n_batches = len(batch_splits["label_id"])

pred_buffer = []
label_id_buffer = []
read_id_buffer = []
pred_parts = []
label_parts = []
read_parts = []

for bidx in range(n_batches):
batch_data_cpu = {k: v[bidx] for k, v in batch_splits.items()}
batch_data_gpu = to_device(batch_data_cpu, device)
pred = model(*batch_data_gpu)
if args_dict["output_id"] is not None:
pred = pred[args_dict["output_id"]]
pred_buffer.append(pred.detach().cpu().numpy())
label_id_buffer.append(batch_data_cpu["label_id"].detach().cpu().numpy())
read_id_buffer.append(batch_data_cpu["read_id"].detach().cpu().numpy())

if len(pred_buffer) >= args_dict["flush"]:
_flush_cpu_buffers(pred_buffer, label_id_buffer, read_id_buffer, pred_parts, label_parts, read_parts)

_flush_cpu_buffers(pred_buffer, label_id_buffer, read_id_buffer, pred_parts, label_parts, read_parts)
if not pred_parts:
log.warning("No predictions were produced for rank=%d chunk=%d.", rank, chunk_idx)
return False

preds = pred_parts[0] if len(pred_parts) == 1 else np.concatenate(pred_parts, axis=0)
label_ids = label_parts[0] if len(label_parts) == 1 else np.concatenate(label_parts, axis=0)
read_ids = read_parts[0] if len(read_parts) == 1 else np.concatenate(read_parts, axis=0)

out_path = f"{args_dict['output']}/inference_{rank}_{chunk_idx}.npz"
np.savez_compressed(out_path, label_id=label_ids, read_id=read_ids, pred=preds)
return True


def _init_model_for_device(args_dict, rank):
use_gpu = args_dict["num_gpu"] > 0
device = torch.device(f"cuda:{args_dict['gpu_pool'][rank]}") if use_gpu else torch.device("cpu")
if use_gpu:
map_location = {"cuda:0": str(device)}
else:
map_location = "cpu"
save_dict = torch.load(args_dict["model"], map_location=map_location, weights_only=False)
model_config = save_dict["model_config"]
if args_dict["model_type"] is not None:
model_config["model"] = args_dict["model_type"]

dwell_bq_dim = 3
TransformerModel = importlib.import_module(f"deeprm.model.{model_config['model']}").TransformerModel
model = TransformerModel(
d_model=model_config["enc_dim"],
n_heads=model_config["head"],
d_ff=model_config["lin_dim"],
n_layers=model_config["enc_layer"],
lin_depth=model_config["lin_layer"],
t_act=model_config["t_act"],
lin_act=model_config["lin_act"],
encoder_dropout=model_config["enc_dropout"],
lin_dropout=model_config["lin_dropout"],
kmer_size=model_config["kmer_size"],
signal_size=model_config["signal_size"],
block_len=model_config["block_len"],
seq_len=model_config["seq_len"],
signal_stride=model_config["signal_stride"],
dwell_bq_dim=dwell_bq_dim,
)
model.load_state_dict(state_dict=save_dict["model_state_dict"], strict=False)
save_dict.clear()
model.to(device)
model.eval()
return model, model_config, device


def run_inference_stream(args):
"""Run on-the-fly preprocessing + inference without intermediate disk writes."""
if args.num_gpu > 1:
raise ValueError("Stream mode currently supports CPU or a single GPU. Use --gpu 1 or --gpu-pool <one id>.")

args_dict = vars(args)
model, _, device = _init_model_for_device(args_dict, rank=0)

prep_ns = argparse.Namespace(
pod5=args.pod5,
bam=args.bam,
output=args.stream_save_dir or args.output,
thread=args.prep_thread,
qcut=args.prep_qcut,
chunk=args.prep_chunk,
max_token_len=args.prep_max_token_len,
sampling=args.prep_sampling,
boi=args.prep_boi,
kmer_len=args.prep_kmer_len,
cb_len=args.prep_cb_len,
bam_thread=args.prep_bam_thread,
process_once=args.prep_process_once,
dwell_shift=args.prep_dwell_shift,
sig_window=args.prep_sig_window,
label_div=args.prep_label_div,
filter_flag=276,
)

chunk_idx = 0
processed_any = False

def _emit_chunk(chunk_arrays):
nonlocal chunk_idx, processed_any
chunk_data_cpu = {
"read_id": torch.as_tensor(chunk_arrays["read_id"]),
"label_id": torch.as_tensor(chunk_arrays["label_id"]),
"segment_len": torch.as_tensor(chunk_arrays["segment_len_arr"], dtype=torch.int32),
"signal_token": torch.as_tensor(chunk_arrays["signal_token"], dtype=torch.float32),
"kmer_token": torch.as_tensor(chunk_arrays["kmer_token"], dtype=torch.int32),
"dwell_bq_token": torch.as_tensor(
np.stack(
(chunk_arrays["dwell_motor_token"], chunk_arrays["dwell_pore_token"], chunk_arrays["bq_token"]),
axis=-1,
),
dtype=torch.float32,
),
}
processed_any = _run_chunk_inference(args_dict, 0, device, model, chunk_idx, chunk_data_cpu) or processed_any
chunk_idx += 1

amp_ctx = autocast(enabled=device.type == "cuda", cache_enabled=device.type == "cuda", device_type="cuda")
if device.type != "cuda":
amp_ctx = nullcontext()

with amp_ctx:
with torch.no_grad():
stream_prep.stream(
prep_ns,
emit_chunk=_emit_chunk,
write_to_disk=args.stream_save_dir is not None,
disk_output_path=args.stream_save_dir,
)

if not processed_any:
log.warning("No streamed chunks were processed. No inference outputs were written.")
return None


def _discover_resume_point(output_dir: str, rank: int) -> int:
pattern = os.path.join(output_dir, f"inference_{rank}_*.npz")
paths = glob.glob(pattern)
Expand Down Expand Up @@ -230,49 +424,12 @@ def inference_worker(rank, args_dict):
Returns:
None
"""
use_gpu = args_dict["num_gpu"] > 0
device = torch.device(f"cuda:{args_dict['gpu_pool'][rank]}") if use_gpu else torch.device("cpu")

if use_gpu:
map_location = {"cuda:0": str(device)}
else:
map_location = "cpu"
save_dict = torch.load(args_dict["model"], map_location=map_location, weights_only=False)
model_config = save_dict["model_config"]

if args_dict["model_type"] is not None:
model_config["model"] = args_dict["model_type"]

dwell_bq_dim = 3
TransformerModel = importlib.import_module(f"deeprm.model.{model_config['model']}").TransformerModel

model = TransformerModel(
d_model=model_config["enc_dim"],
n_heads=model_config["head"],
d_ff=model_config["lin_dim"],
n_layers=model_config["enc_layer"],
lin_depth=model_config["lin_layer"],
t_act=model_config["t_act"],
lin_act=model_config["lin_act"],
encoder_dropout=model_config["enc_dropout"],
lin_dropout=model_config["lin_dropout"],
kmer_size=model_config["kmer_size"],
signal_size=model_config["signal_size"],
block_len=model_config["block_len"],
seq_len=model_config["seq_len"],
signal_stride=model_config["signal_stride"],
dwell_bq_dim=dwell_bq_dim,
)
model, model_config, device = _init_model_for_device(args_dict, rank)

if rank == 0:
total_params = sum(parameter.numel() for parameter in model.parameters())
log.info("Model parameters: %s", f"{total_params:,}")

model.load_state_dict(state_dict=save_dict["model_state_dict"], strict=False)
save_dict.clear()
model.to(device)
model.eval()

resume_from = _discover_resume_point(args_dict["output"], rank) if args_dict["resume"] else 0
if resume_from > 0:
log.info("Rank %d resuming from input shard index %d.", rank, resume_from)
Expand Down Expand Up @@ -355,58 +512,7 @@ def inference_loop(args_dict, rank, device, model, data_loader, start_index=0):
log.warning("Skipping empty input shard at rank=%d chunk=%d.", rank, chunk_idx)
continue

processed_any = True
batch_splits = {k: torch.split(v, args_dict["batch"]) for k, v in chunk_data_cpu.items()}
n_batches = len(batch_splits["label_id"])

pred_buffer = []
label_id_buffer = []
read_id_buffer = []
pred_parts = []
label_parts = []
read_parts = []

for bidx in range(n_batches):
batch_data_cpu = {k: v[bidx] for k, v in batch_splits.items()}
batch_data_gpu = to_device(batch_data_cpu, device)
pred = model(*batch_data_gpu)

if args_dict["output_id"] is not None:
pred = pred[args_dict["output_id"]]

pred_buffer.append(pred.detach().cpu().numpy())
label_id_buffer.append(batch_data_cpu["label_id"].detach().cpu().numpy())
read_id_buffer.append(batch_data_cpu["read_id"].detach().cpu().numpy())

if len(pred_buffer) >= args_dict["flush"]:
_flush_cpu_buffers(
pred_buffer,
label_id_buffer,
read_id_buffer,
pred_parts,
label_parts,
read_parts,
)

_flush_cpu_buffers(
pred_buffer,
label_id_buffer,
read_id_buffer,
pred_parts,
label_parts,
read_parts,
)

if not pred_parts:
log.warning("No predictions were produced for rank=%d chunk=%d.", rank, chunk_idx)
continue

preds = pred_parts[0] if len(pred_parts) == 1 else np.concatenate(pred_parts, axis=0)
label_ids = label_parts[0] if len(label_parts) == 1 else np.concatenate(label_parts, axis=0)
read_ids = read_parts[0] if len(read_parts) == 1 else np.concatenate(read_parts, axis=0)

out_path = f"{args_dict['output']}/inference_{rank}_{chunk_idx}.npz"
np.savez_compressed(out_path, label_id=label_ids, read_id=read_ids, pred=preds)
processed_any = _run_chunk_inference(args_dict, rank, device, model, chunk_idx, chunk_data_cpu) or processed_any

if not processed_any:
log.warning("No input shards were processed for rank=%d. No inference outputs were written.", rank)
Expand Down
Loading
Loading