Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 101 additions & 72 deletions src/scdori/_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self.dim_encoder1 = dim_encoder1
self.dim_encoder2 = dim_encoder2
self.batch_norm = batch_norm
self._cache = {} # Cache for storing computed values

# ENCODER for RNA
self.encoder_rna = nn.Sequential(
Expand Down Expand Up @@ -228,6 +229,7 @@ def forward(
num_cells,
batch_onehot,
phase="warmup_1",
use_cache=True,
):
"""
Forward pass through scDoRI, producing predictions for ATAC, TF, and RNA reconstructions (Phase 1), as well as GRN-based RNA predictions in GRN phase (Phase 2).
Expand All @@ -254,95 +256,122 @@ def forward(
phase : str, optional
Which training phase: "warmup_1", "warmup_2", or "grn".
If phase=="grn", the GRN-based RNA predictions are included.
use_cache : bool, optional
Whether to use cached values for computations. Only applies when phase=="grn".

Returns
-------
dict
A dictionary with the following keys:
- "theta": (B, num_topics), the softmaxed topic distribution.
- "mu_theta": (B, num_topics), raw topic logits.
- "preds_atac": (B, num_peaks), predicted peak accessibility.
- "preds_tf": (B, num_tfs), predicted TF expression.
- "mu_nb_tf": (B, num_tfs), TF negative binomial mean = preds_tf * TF library factor.
- "preds_rna": (B, num_genes), predicted RNA expression.
- "mu_nb_rna": (B, num_genes), RNA negative binomial mean = preds_rna * RNA library factor.
- "preds_rna_from_grn": (B, num_genes), optional GRN-based RNA predictions.
- "mu_nb_rna_grn": (B, num_genes), negative binomial mean of GRN-based RNA predictions.
- "library_factor_tf": (B, 1), predicted library factor for TF.
- "library_factor_rna": (B, 1), predicted library factor for RNA.

A dictionary containing model predictions and intermediate values.
"""
# 1) ENCODE => topic distribution
theta, mu_theta = self.encode(rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells)

# 2) ATAC decoding
batch_factor_atac = torch.mm(batch_onehot, self.atac_batch_factor)
preds_atac = torch.mm(theta, self.topic_peak_decoder) + batch_factor_atac
preds_atac = self.atac_batch_norm(preds_atac)
preds_atac = F.softmax(preds_atac, dim=-1)

# 3) TF decoding => library factor
batch_factor_tf = torch.mm(batch_onehot, self.tf_batch_factor)
tf_logits = torch.mm(theta, self.topic_tf_decoder) + batch_factor_tf
tf_logits = self.tf_batch_norm(tf_logits)
preds_tf = F.softmax(tf_logits, dim=-1)
# library MLP for TF
library_factor_tf = self.tf_library_factor(tf_input)
mu_nb_tf = preds_tf * library_factor_tf

# 4) RNA from ATAC => library factor
topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1)
topic_peak_min, _ = torch.min(topic_peak_denoised1, dim=0, keepdim=True)
topic_peak_max, _ = torch.max(topic_peak_denoised1, dim=0, keepdim=True)
topic_peak_denoised = (topic_peak_denoised1 - topic_peak_min) / (topic_peak_max - topic_peak_min + 1e-8)
gene_peak = (self.gene_peak_factor_learnt * self.gene_peak_factor_fixed).T
batch_factor_rna = torch.mm(batch_onehot, self.rna_batch_factor)
topicxgene = torch.mm(topic_peak_denoised, gene_peak)
rna_logits = torch.mm(theta, topicxgene) + batch_factor_rna
rna_logits = self.rna_batch_norm(rna_logits)
preds_rna = F.softmax(rna_logits, dim=-1)

topic_peak_denoised1 = nn.Softmax(dim=1)(self.topic_peak_decoder)

# library MLP for RNA
library_factor_rna = self.rna_library_factor(rna_input)
mu_nb_rna = preds_rna * library_factor_rna

# 5) GRN => preds_rna_from_grn if phase=="grn"
skip_computations = use_cache & (phase == "grn")
cache_key = f"{rna_input.shape[0]}_{atac_input.shape[0]}" # Use batch sizes as cache key

if skip_computations and cache_key in self._cache:
# Retrieve cached values
# Load detached tensors to avoid unnecessary computation
# and to save memory
cached = self._cache[cache_key]
theta = cached['theta'].detach()
mu_theta = cached['mu_theta'].detach()
preds_atac = cached['preds_atac'].detach()
preds_tf = cached['preds_tf'].detach()
mu_nb_tf = cached['mu_nb_tf'].detach()
preds_rna = cached['preds_rna'].detach()
mu_nb_rna = cached['mu_nb_rna'].detach()
topic_peak_denoised1 = cached['topic_peak_denoised1'].detach()
grn_atac_activator = cached['grn_atac_activator'].detach()
grn_atac_repressor = cached['grn_atac_repressor'].detach()
library_factor_tf = cached['library_factor_tf'].detach()
library_factor_rna = cached['library_factor_rna'].detach()
else:
# Compute all values
theta, mu_theta = self.encode(rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells)

# 2) ATAC decoding
batch_factor_atac = torch.mm(batch_onehot, self.atac_batch_factor)
preds_atac = torch.mm(theta, self.topic_peak_decoder) + batch_factor_atac
preds_atac = self.atac_batch_norm(preds_atac)
preds_atac = F.softmax(preds_atac, dim=-1)

# 3) TF decoding => library factor
batch_factor_tf = torch.mm(batch_onehot, self.tf_batch_factor)
tf_logits = torch.mm(theta, self.topic_tf_decoder) + batch_factor_tf
tf_logits = self.tf_batch_norm(tf_logits)
preds_tf = F.softmax(tf_logits, dim=-1)
# library MLP for TF
library_factor_tf = self.tf_library_factor(tf_input)
mu_nb_tf = preds_tf * library_factor_tf

# 4) RNA from ATAC => library factor
topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1)
topic_peak_min, _ = torch.min(topic_peak_denoised1, dim=0, keepdim=True)
topic_peak_max, _ = torch.max(topic_peak_denoised1, dim=0, keepdim=True)
topic_peak_denoised = (topic_peak_denoised1 - topic_peak_min) / (topic_peak_max - topic_peak_min + 1e-8)
gene_peak = (self.gene_peak_factor_learnt * self.gene_peak_factor_fixed).T
batch_factor_rna = torch.mm(batch_onehot, self.rna_batch_factor)
topicxgene = torch.mm(topic_peak_denoised, gene_peak)
rna_logits = torch.mm(theta, topicxgene) + batch_factor_rna
rna_logits = self.rna_batch_norm(rna_logits)
preds_rna = F.softmax(rna_logits, dim=-1)

topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1)

# library MLP for RNA
library_factor_rna = self.rna_library_factor(rna_input)
mu_nb_rna = preds_rna * library_factor_rna

# Compute GRN variables that can be cached
if phase == "grn":
grn_atac_activator = torch.zeros(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device)
grn_atac_repressor = torch.zeros(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device)

# Calculate ATAC-based TF–gene links (activator/repressor) for each topic
for topic in range(self.num_topics):
topic_gene_peak = topic_peak_denoised1[topic].unsqueeze(1) * gene_peak
G_topic = torch.mm(self.tf_binding_matrix_activator.T, topic_gene_peak)
G_topic = G_topic / (gene_peak.sum(dim=0, keepdim=True) + 1e-7)
grn_atac_activator[topic] = G_topic

topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20)).unsqueeze(1) * gene_peak
G_topic = torch.mm(self.tf_binding_matrix_repressor.T, topic_gene_peak)
G_topic = G_topic / (gene_peak.sum(dim=0, keepdim=True) + 1e-7)
grn_atac_repressor[topic] = G_topic

# Cache computed values if in GRN phase
self._cache[cache_key] = {
'theta': theta.clone(),
'mu_theta': mu_theta.clone(),
'preds_atac': preds_atac.clone(),
'preds_tf': preds_tf.clone(),
'mu_nb_tf': mu_nb_tf.clone(),
'preds_rna': preds_rna.clone(),
'mu_nb_rna': mu_nb_rna.clone(),
'topic_peak_denoised1': topic_peak_denoised1.clone(),
'grn_atac_activator': grn_atac_activator.clone(),
'grn_atac_repressor': grn_atac_repressor.clone(),
'library_factor_tf': library_factor_tf.clone(),
'library_factor_rna': library_factor_rna.clone(),
}

# Always compute GRN-specific variables (not cached)
if phase == "grn":
grn_atac_activator = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device)
grn_atac_repressor = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device)

# Calculate ATAC-based TF–gene links (activator/repressor) for each topic
for topic in range(self.num_topics):
topic_gene_peak = topic_peak_denoised1[topic][:, None] * gene_peak
G_topic = self.tf_binding_matrix_activator.T @ topic_gene_peak
G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7)
grn_atac_activator[topic] = G_topic

topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20))[:, None] * gene_peak
G_topic = self.tf_binding_matrix_repressor.T @ topic_gene_peak
G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7)
grn_atac_repressor[topic] = G_topic

C = torch.empty(size=(self.num_topics, self.num_genes)).to(self.device)
C = torch.zeros(size=(self.num_topics, self.num_genes)).to(self.device)
tf_expression_input = topic_tf_input.to(self.device)
for topic in range(self.num_topics):
gene_atac_activator_topic = grn_atac_activator[topic] / (grn_atac_activator[topic].max() + 1e-15)
gene_atac_repressor_topic = grn_atac_repressor[topic] / (grn_atac_repressor[topic].min() + 1e-15)

G_act = gene_atac_activator_topic * torch.nn.functional.relu(self.tf_gene_topic_activator_grn[topic])
G_rep = (
gene_atac_repressor_topic * -1 * torch.nn.functional.relu(self.tf_gene_topic_repressor_grn[topic])
)
G_act = gene_atac_activator_topic * F.relu(self.tf_gene_topic_activator_grn[topic])
G_rep = gene_atac_repressor_topic * -1 * F.relu(self.tf_gene_topic_repressor_grn[topic])

C[topic] = tf_expression_input[topic] @ G_act + tf_expression_input[topic] @ G_rep
C[topic] = torch.mm(tf_expression_input[topic].unsqueeze(0), G_act).squeeze(0) + torch.mm(tf_expression_input[topic].unsqueeze(0), G_rep).squeeze(0)

batch_factor_rna_grn = torch.mm(batch_onehot, self.rna_grn_batch_factor)
preds_rna_from_grn = torch.mm(theta, C)
preds_rna_from_grn = preds_rna_from_grn + batch_factor_rna_grn
preds_rna_from_grn = self.rna_grn_batch_norm(preds_rna_from_grn)
preds_rna_from_grn = nn.Softmax(dim=1)(preds_rna_from_grn)
preds_rna_from_grn = F.softmax(preds_rna_from_grn, dim=1)
else:
preds_rna_from_grn = torch.zeros_like(preds_rna)

Expand Down
Loading