diff --git a/src/scdori/_core/models.py b/src/scdori/_core/models.py index 8f03410..f2e9be9 100644 --- a/src/scdori/_core/models.py +++ b/src/scdori/_core/models.py @@ -109,6 +109,7 @@ def __init__( self.dim_encoder1 = dim_encoder1 self.dim_encoder2 = dim_encoder2 self.batch_norm = batch_norm + self._cache = {} # Cache for storing computed values # ENCODER for RNA self.encoder_rna = nn.Sequential( @@ -228,6 +229,7 @@ def forward( num_cells, batch_onehot, phase="warmup_1", + use_cache=True, ): """ Forward pass through scDoRI, producing predictions for ATAC, TF, and RNA reconstructions (Phase 1), as well as GRN-based RNA predictions in GRN phase (Phase 2). @@ -254,95 +256,122 @@ def forward( phase : str, optional Which training phase: "warmup_1", "warmup_2", or "grn". If phase=="grn", the GRN-based RNA predictions are included. + use_cache : bool, optional + Whether to use cached values for computations. Only applies when phase=="grn". Returns ------- dict - A dictionary with the following keys: - - "theta": (B, num_topics), the softmaxed topic distribution. - - "mu_theta": (B, num_topics), raw topic logits. - - "preds_atac": (B, num_peaks), predicted peak accessibility. - - "preds_tf": (B, num_tfs), predicted TF expression. - - "mu_nb_tf": (B, num_tfs), TF negative binomial mean = preds_tf * TF library factor. - - "preds_rna": (B, num_genes), predicted RNA expression. - - "mu_nb_rna": (B, num_genes), RNA negative binomial mean = preds_rna * RNA library factor. - - "preds_rna_from_grn": (B, num_genes), optional GRN-based RNA predictions. - - "mu_nb_rna_grn": (B, num_genes), negative binomial mean of GRN-based RNA predictions. - - "library_factor_tf": (B, 1), predicted library factor for TF. - - "library_factor_rna": (B, 1), predicted library factor for RNA. - + A dictionary containing model predictions and intermediate values. """ - # 1) ENCODE => topic distribution - theta, mu_theta = self.encode(rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells) - - # 2) ATAC decoding - batch_factor_atac = torch.mm(batch_onehot, self.atac_batch_factor) - preds_atac = torch.mm(theta, self.topic_peak_decoder) + batch_factor_atac - preds_atac = self.atac_batch_norm(preds_atac) - preds_atac = F.softmax(preds_atac, dim=-1) - - # 3) TF decoding => library factor - batch_factor_tf = torch.mm(batch_onehot, self.tf_batch_factor) - tf_logits = torch.mm(theta, self.topic_tf_decoder) + batch_factor_tf - tf_logits = self.tf_batch_norm(tf_logits) - preds_tf = F.softmax(tf_logits, dim=-1) - # library MLP for TF - library_factor_tf = self.tf_library_factor(tf_input) - mu_nb_tf = preds_tf * library_factor_tf - - # 4) RNA from ATAC => library factor - topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1) - topic_peak_min, _ = torch.min(topic_peak_denoised1, dim=0, keepdim=True) - topic_peak_max, _ = torch.max(topic_peak_denoised1, dim=0, keepdim=True) - topic_peak_denoised = (topic_peak_denoised1 - topic_peak_min) / (topic_peak_max - topic_peak_min + 1e-8) - gene_peak = (self.gene_peak_factor_learnt * self.gene_peak_factor_fixed).T - batch_factor_rna = torch.mm(batch_onehot, self.rna_batch_factor) - topicxgene = torch.mm(topic_peak_denoised, gene_peak) - rna_logits = torch.mm(theta, topicxgene) + batch_factor_rna - rna_logits = self.rna_batch_norm(rna_logits) - preds_rna = F.softmax(rna_logits, dim=-1) - - topic_peak_denoised1 = nn.Softmax(dim=1)(self.topic_peak_decoder) - - # library MLP for RNA - library_factor_rna = self.rna_library_factor(rna_input) - mu_nb_rna = preds_rna * library_factor_rna - - # 5) GRN => preds_rna_from_grn if phase=="grn" + skip_computations = use_cache & (phase == "grn") + cache_key = f"{rna_input.shape[0]}_{atac_input.shape[0]}" # Use batch sizes as cache key + + if skip_computations and cache_key in self._cache: + # Retrieve cached values + # Load detached tensors to avoid unnecessary computation + # and to save memory + cached = self._cache[cache_key] + theta = cached['theta'].detach() + mu_theta = cached['mu_theta'].detach() + preds_atac = cached['preds_atac'].detach() + preds_tf = cached['preds_tf'].detach() + mu_nb_tf = cached['mu_nb_tf'].detach() + preds_rna = cached['preds_rna'].detach() + mu_nb_rna = cached['mu_nb_rna'].detach() + topic_peak_denoised1 = cached['topic_peak_denoised1'].detach() + grn_atac_activator = cached['grn_atac_activator'].detach() + grn_atac_repressor = cached['grn_atac_repressor'].detach() + library_factor_tf = cached['library_factor_tf'].detach() + library_factor_rna = cached['library_factor_rna'].detach() + else: + # Compute all values + theta, mu_theta = self.encode(rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells) + + # 2) ATAC decoding + batch_factor_atac = torch.mm(batch_onehot, self.atac_batch_factor) + preds_atac = torch.mm(theta, self.topic_peak_decoder) + batch_factor_atac + preds_atac = self.atac_batch_norm(preds_atac) + preds_atac = F.softmax(preds_atac, dim=-1) + + # 3) TF decoding => library factor + batch_factor_tf = torch.mm(batch_onehot, self.tf_batch_factor) + tf_logits = torch.mm(theta, self.topic_tf_decoder) + batch_factor_tf + tf_logits = self.tf_batch_norm(tf_logits) + preds_tf = F.softmax(tf_logits, dim=-1) + # library MLP for TF + library_factor_tf = self.tf_library_factor(tf_input) + mu_nb_tf = preds_tf * library_factor_tf + + # 4) RNA from ATAC => library factor + topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1) + topic_peak_min, _ = torch.min(topic_peak_denoised1, dim=0, keepdim=True) + topic_peak_max, _ = torch.max(topic_peak_denoised1, dim=0, keepdim=True) + topic_peak_denoised = (topic_peak_denoised1 - topic_peak_min) / (topic_peak_max - topic_peak_min + 1e-8) + gene_peak = (self.gene_peak_factor_learnt * self.gene_peak_factor_fixed).T + batch_factor_rna = torch.mm(batch_onehot, self.rna_batch_factor) + topicxgene = torch.mm(topic_peak_denoised, gene_peak) + rna_logits = torch.mm(theta, topicxgene) + batch_factor_rna + rna_logits = self.rna_batch_norm(rna_logits) + preds_rna = F.softmax(rna_logits, dim=-1) + + topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1) + + # library MLP for RNA + library_factor_rna = self.rna_library_factor(rna_input) + mu_nb_rna = preds_rna * library_factor_rna + + # Compute GRN variables that can be cached + if phase == "grn": + grn_atac_activator = torch.zeros(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) + grn_atac_repressor = torch.zeros(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) + + # Calculate ATAC-based TF–gene links (activator/repressor) for each topic + for topic in range(self.num_topics): + topic_gene_peak = topic_peak_denoised1[topic].unsqueeze(1) * gene_peak + G_topic = torch.mm(self.tf_binding_matrix_activator.T, topic_gene_peak) + G_topic = G_topic / (gene_peak.sum(dim=0, keepdim=True) + 1e-7) + grn_atac_activator[topic] = G_topic + + topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20)).unsqueeze(1) * gene_peak + G_topic = torch.mm(self.tf_binding_matrix_repressor.T, topic_gene_peak) + G_topic = G_topic / (gene_peak.sum(dim=0, keepdim=True) + 1e-7) + grn_atac_repressor[topic] = G_topic + + # Cache computed values if in GRN phase + self._cache[cache_key] = { + 'theta': theta.clone(), + 'mu_theta': mu_theta.clone(), + 'preds_atac': preds_atac.clone(), + 'preds_tf': preds_tf.clone(), + 'mu_nb_tf': mu_nb_tf.clone(), + 'preds_rna': preds_rna.clone(), + 'mu_nb_rna': mu_nb_rna.clone(), + 'topic_peak_denoised1': topic_peak_denoised1.clone(), + 'grn_atac_activator': grn_atac_activator.clone(), + 'grn_atac_repressor': grn_atac_repressor.clone(), + 'library_factor_tf': library_factor_tf.clone(), + 'library_factor_rna': library_factor_rna.clone(), + } + + # Always compute GRN-specific variables (not cached) if phase == "grn": - grn_atac_activator = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) - grn_atac_repressor = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) - - # Calculate ATAC-based TF–gene links (activator/repressor) for each topic - for topic in range(self.num_topics): - topic_gene_peak = topic_peak_denoised1[topic][:, None] * gene_peak - G_topic = self.tf_binding_matrix_activator.T @ topic_gene_peak - G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7) - grn_atac_activator[topic] = G_topic - - topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20))[:, None] * gene_peak - G_topic = self.tf_binding_matrix_repressor.T @ topic_gene_peak - G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7) - grn_atac_repressor[topic] = G_topic - - C = torch.empty(size=(self.num_topics, self.num_genes)).to(self.device) + C = torch.zeros(size=(self.num_topics, self.num_genes)).to(self.device) tf_expression_input = topic_tf_input.to(self.device) for topic in range(self.num_topics): gene_atac_activator_topic = grn_atac_activator[topic] / (grn_atac_activator[topic].max() + 1e-15) gene_atac_repressor_topic = grn_atac_repressor[topic] / (grn_atac_repressor[topic].min() + 1e-15) - G_act = gene_atac_activator_topic * torch.nn.functional.relu(self.tf_gene_topic_activator_grn[topic]) - G_rep = ( - gene_atac_repressor_topic * -1 * torch.nn.functional.relu(self.tf_gene_topic_repressor_grn[topic]) - ) + G_act = gene_atac_activator_topic * F.relu(self.tf_gene_topic_activator_grn[topic]) + G_rep = gene_atac_repressor_topic * -1 * F.relu(self.tf_gene_topic_repressor_grn[topic]) - C[topic] = tf_expression_input[topic] @ G_act + tf_expression_input[topic] @ G_rep + C[topic] = torch.mm(tf_expression_input[topic].unsqueeze(0), G_act).squeeze(0) + torch.mm(tf_expression_input[topic].unsqueeze(0), G_rep).squeeze(0) batch_factor_rna_grn = torch.mm(batch_onehot, self.rna_grn_batch_factor) preds_rna_from_grn = torch.mm(theta, C) preds_rna_from_grn = preds_rna_from_grn + batch_factor_rna_grn preds_rna_from_grn = self.rna_grn_batch_norm(preds_rna_from_grn) - preds_rna_from_grn = nn.Softmax(dim=1)(preds_rna_from_grn) + preds_rna_from_grn = F.softmax(preds_rna_from_grn, dim=1) else: preds_rna_from_grn = torch.zeros_like(preds_rna) diff --git a/src/scdori/_core/train_grn.py b/src/scdori/_core/train_grn.py index 5ca57e2..4f6fe56 100644 --- a/src/scdori/_core/train_grn.py +++ b/src/scdori/_core/train_grn.py @@ -96,14 +96,14 @@ def get_tf_expression( """ Compute TF expression per topic. - If `tf_expression_mode` is "True", this function computes the mean TF expression + If `tf_expression_mode` is "actual", this function computes the mean TF expression for the top-k cells in each topic. Otherwise, it uses a normalized topic-TF decoder matrix from the model. Parameters ---------- tf_expression_mode : str - Mode for TF expression. "True" calculates per-topic TF expression from top-k cells, + Mode for TF expression. "actual" calculates per-topic TF expression from top-k cells, "latent" uses the topic-TF decoder matrix. model : torch.nn.Module The scDoRI model containing encoder and decoder modules. @@ -129,7 +129,7 @@ def get_tf_expression( torch.Tensor A (num_topics x num_tfs) tensor of TF expression values for each topic. """ - if tf_expression_mode == "True": + if tf_expression_mode == "actual": latent_all_torch = get_latent_topics( model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) @@ -144,25 +144,23 @@ def get_tf_expression( topic_tf = np.array([rna_tf_vals[top_k_indices[:, t], :].mean(axis=0) for t in range(model.num_topics)]) topic_tf = torch.from_numpy(topic_tf) - preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) - preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) - topic_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) - topic_tf[topic_tf < config_file.tf_expression_clamp] = 0 - topic_tf = topic_tf.to(device) - return topic_tf + preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0] + preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0] + normalized_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) + topic_tf = normalized_tf.clamp(min=config_file.tf_expression_clamp) + return topic_tf.to(device) else: import torch.nn as nn # Ensure this import is available if using nn.Softmax topic_tf = nn.Softmax(dim=1)(model.decoder.topic_tf_decoder.detach().cpu()) - preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) - preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) + preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0] + preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0] tf_normalised = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) - tf_normalised[tf_normalised < config_file.tf_expression_clamp] = 0 - topic_tf = tf_normalised.to(device) - return topic_tf - + tf_normalised = tf_normalised.clamp(min=config_file.tf_expression_clamp) + return tf_normalised.to(device) +@torch.no_grad() def compute_eval_loss_grn( model, device, @@ -231,85 +229,85 @@ def compute_eval_loss_grn( config_file, ) - with torch.no_grad(): - for batch_data in eval_loader: - cell_indices = batch_data[0].to(device) - B = cell_indices.shape[0] - - input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( - device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot - ) - rna_input = input_matrix[:, : model.num_genes] - atac_input = input_matrix[:, model.num_genes :] - log_lib_rna = library_size_value[:, 0].reshape(-1, 1) - log_lib_atac = library_size_value[:, 1].reshape(-1, 1) - - out = model( - rna_input, - atac_input, - tf_exp, - topic_tf_input, - log_lib_rna, - log_lib_atac, - num_cells_value, - input_batch, - phase="grn", - ) - preds_atac = out["preds_atac"] - mu_nb_tf = out["mu_nb_tf"] - mu_nb_rna = out["mu_nb_rna"] - mu_nb_rna_grn = out["mu_nb_rna_grn"] - - criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum") - library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) - preds_poisson = preds_atac * library_factor_peak - loss_atac = criterion_poisson(preds_poisson, atac_input) - - alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) - nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() - loss_tf = -nb_tf_ll + criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum") - alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) - nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() - loss_rna = -nb_rna_ll - - nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() - loss_rna_grn = -nb_rna_grn_ll - - l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) - l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) - l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) - l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) - l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) - l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) - l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) - l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) + for batch_data in eval_loader: + cell_indices = batch_data[0].to(device) + B = cell_indices.shape[0] - loss_norm = ( - config_file.l1_penalty_topic_tf * l1_norm_tf - + config_file.l2_penalty_topic_tf * l2_norm_tf - + config_file.l1_penalty_topic_peak * l1_norm_peak - + config_file.l2_penalty_topic_peak * l2_norm_peak - + config_file.l1_penalty_gene_peak * l1_norm_gene_peak - + config_file.l2_penalty_gene_peak * l2_norm_gene_peak - + config_file.l1_penalty_grn_activator * l1_norm_grn_activator - + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor - ) + input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( + device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot + ) + rna_input = input_matrix[:, : model.num_genes] + atac_input = input_matrix[:, model.num_genes :] + log_lib_rna = library_size_value[:, 0].reshape(-1, 1) + log_lib_atac = library_size_value[:, 1].reshape(-1, 1) + + out = model( + rna_input, + atac_input, + tf_exp, + topic_tf_input, + log_lib_rna, + log_lib_atac, + num_cells_value, + input_batch, + phase="grn", + ) + preds_atac = out["preds_atac"] + mu_nb_tf = out["mu_nb_tf"] + mu_nb_rna = out["mu_nb_rna"] + mu_nb_rna_grn = out["mu_nb_rna_grn"] + + library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) + preds_poisson = preds_atac * library_factor_peak + loss_atac = criterion_poisson(preds_poisson, atac_input) + + alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) + nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() + loss_tf = -nb_tf_ll + + alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) + nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() + loss_rna = -nb_rna_ll + + nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() + loss_rna_grn = -nb_rna_grn_ll + + l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) + l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) + l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) + l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) + l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) + l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) + l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) + l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) + + loss_norm = ( + config_file.l1_penalty_topic_tf * l1_norm_tf + + config_file.l2_penalty_topic_tf * l2_norm_tf + + config_file.l1_penalty_topic_peak * l1_norm_peak + + config_file.l2_penalty_topic_peak * l2_norm_peak + + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + + config_file.l2_penalty_gene_peak * l2_norm_gene_peak + + config_file.l1_penalty_grn_activator * l1_norm_grn_activator + + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor + ) - total_loss = ( - config_file.weight_atac_grn * loss_atac - + config_file.weight_tf_grn * loss_tf - + config_file.weight_rna_grn * loss_rna - + config_file.weight_rna_from_grn * loss_rna_grn - + loss_norm - ) + total_loss = ( + config_file.weight_atac_grn * loss_atac + + config_file.weight_tf_grn * loss_tf + + config_file.weight_rna_grn * loss_rna + + config_file.weight_rna_from_grn * loss_rna_grn + + loss_norm + ) - running_loss += total_loss.item() - running_loss_atac += loss_atac.item() - running_loss_tf += loss_tf.item() - running_loss_rna += loss_rna.item() - running_loss_rna_grn += loss_rna_grn.item() - nbatch += 1 + running_loss += total_loss.item() + running_loss_atac += loss_atac.item() + running_loss_tf += loss_tf.item() + running_loss_rna += loss_rna.item() + running_loss_rna_grn += loss_rna_grn.item() + nbatch += 1 eval_loss = running_loss / max(1, nbatch) eval_loss_atac = running_loss_atac / max(1, nbatch) @@ -366,25 +364,10 @@ def train_model_grn( torch.nn.Module The trained model after the GRN phase completes or early stopping occurs. """ - if not config_file.update_encoder_in_grn: - set_encoder_frozen(model, freeze=True) - else: - set_encoder_frozen(model, freeze=False) - - if not config_file.update_peak_gene_in_grn: - set_peak_gene_frozen(model, freeze=True) - else: - set_peak_gene_frozen(model, freeze=False) - - if not config_file.update_topic_peak_in_grn: - set_topic_peak_frozen(model, freeze=True) - else: - set_topic_peak_frozen(model, freeze=False) - - if not config_file.update_topic_tf_in_grn: - set_topic_tf_frozen(model, freeze=True) - else: - set_topic_tf_frozen(model, freeze=False) + set_encoder_frozen(model, freeze=not config_file.update_encoder_in_grn) + set_peak_gene_frozen(model, freeze=not config_file.update_peak_gene_in_grn) + set_topic_peak_frozen(model, freeze=not config_file.update_topic_peak_in_grn) + set_topic_tf_frozen(model, freeze=not config_file.update_topic_tf_in_grn) optimizer_grn = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=config_file.learning_rate_grn @@ -395,7 +378,7 @@ def train_model_grn( max_val_patience = config_file.grn_val_patience topic_tf_input = None - if config_file.tf_expression_mode == "True": + if config_file.tf_expression_mode == "actual": topic_tf_input = get_tf_expression( config_file.tf_expression_mode, model, @@ -443,10 +426,8 @@ def train_model_grn( ) rna_input = input_matrix[:, : model.num_genes] atac_input = input_matrix[:, model.num_genes :] - tf_input = tf_exp log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) - batch_onehot = input_batch if config_file.tf_expression_mode == "latent": topic_tf_input = get_tf_expression( @@ -465,12 +446,12 @@ def train_model_grn( out = model( rna_input, atac_input, - tf_input, + tf_exp, topic_tf_input, log_lib_rna, log_lib_atac, num_cells_value, - batch_onehot, + input_batch, phase="grn", ) preds_atac = out["preds_atac"] @@ -484,7 +465,7 @@ def train_model_grn( loss_atac = criterion_poisson(preds_poisson, atac_input) alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) - nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean() + nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) @@ -523,7 +504,7 @@ def train_model_grn( ) optimizer_grn.zero_grad() - total_loss.backward() + total_loss.backward(retain_graph=True) optimizer_grn.step() running_loss += total_loss.item() diff --git a/src/scdori/_core/train_scdori.py b/src/scdori/_core/train_scdori.py index 82f9ad3..a5308dc 100644 --- a/src/scdori/_core/train_scdori.py +++ b/src/scdori/_core/train_scdori.py @@ -339,7 +339,7 @@ def train_scdori_phases( total_loss = weights["atac"] * loss_atac + weights["tf"] * loss_tf + weights["rna"] * loss_rna + loss_norm optimizer.zero_grad() - total_loss.backward() + total_loss.backward(retain_graph=True) optimizer.step() running_loss += total_loss.item()