Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7633453
continual learning
Jan 22, 2023
a6b34d7
Delete trvae_model.py
MohammedZidane Jan 26, 2023
cfbf601
Add files via upload
MohammedZidane Jan 26, 2023
3b5e928
Update _base.py
MohammedZidane Jan 26, 2023
3b008d3
Update trainer.py
MohammedZidane Jan 26, 2023
207e389
Update trainer.py
MohammedZidane Jan 31, 2023
8db77f7
Delete EWC_pancreas.ipynb
MohammedZidane Jan 31, 2023
4327e6e
Delete LR_EWC_Pancreas.ipynb
MohammedZidane Jan 31, 2023
104aaef
Delete latent_replay_pancreas.ipynb
MohammedZidane Jan 31, 2023
370d5ad
Delete rehearsal_pancreas.ipynb
MohammedZidane Jan 31, 2023
c6b2d2e
Delete standared.ipynb
MohammedZidane Jan 31, 2023
a0ef00b
Add files via upload
MohammedZidane Jan 31, 2023
3f93067
Delete rehearsal_PBMC.ipynb
MohammedZidane Feb 1, 2023
0ad1900
Add files via upload
MohammedZidane Feb 1, 2023
218915c
Add files via upload
MohammedZidane Feb 1, 2023
1283f0e
Delete surgery_pancreas.ipynb
MohammedZidane Feb 1, 2023
a9cdbcf
Add files via upload
MohammedZidane Feb 1, 2023
0d0dd49
Add files via upload
MohammedZidane Feb 9, 2023
f26da54
Add files via upload
MohammedZidane Feb 22, 2023
7388b5b
Add files via upload
MohammedZidane Feb 22, 2023
e0f7d31
Update _base.py
MohammedZidane Feb 23, 2023
a813fcf
Delete implementations_compared_pancreas.ipynb
MohammedZidane Feb 28, 2023
514293c
Delete EWC_PBMC.ipynb
MohammedZidane Feb 28, 2023
a74747e
Delete LR_EWC_PBMC.ipynb
MohammedZidane Feb 28, 2023
a41ddc3
Delete latent_replay_PBMC.ipynb
MohammedZidane Feb 28, 2023
548e02d
Add files via upload
MohammedZidane Feb 28, 2023
b777d4d
Add files via upload
MohammedZidane Mar 3, 2023
91f8486
Update trainer.py
MohammedZidane Mar 3, 2023
052f8fa
Update _base.py
MohammedZidane Mar 13, 2023
a4dde2d
Create Readme
MohammedZidane Mar 21, 2023
b4d0f1e
Update modules.py
MohammedZidane Mar 27, 2023
23bd0b0
Update unsupervised.py
MohammedZidane May 10, 2023
2619a66
Update trainer.py
MohammedZidane Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 75 additions & 12 deletions scarches/models/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from ._utils import _validate_var_names

from ...trainers.trvae.trainer import Trainer


class BaseMixin:
""" Adapted from
Expand Down Expand Up @@ -185,9 +187,12 @@ class SurgeryMixin:
def load_query_data(
cls,
adata: AnnData,
reference_model: Union[str, 'Model'],
freeze: bool = True,
freeze_expression: bool = True,
# reference_model: Union[str, 'Model'] = None,
learning_approach: str = None,
model: str = None,
ID: int = 0,
# freeze: bool = None,
# freeze_expression: bool = None,
remove_dropout: bool = True,
**kwargs
):
Expand All @@ -197,7 +202,7 @@ def load_query_data(
----------
adata
Query anndata object.
reference_model
model
A model to expand or a path to a model folder.
freeze: Boolean
If 'True' freezes every part of the network except the first layers of encoder/decoder.
Expand All @@ -213,14 +218,72 @@ def load_query_data(
new_model
New model to train on query data.
"""
if isinstance(reference_model, str):
attr_dict, model_state_dict, var_names = cls._load_params(reference_model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = reference_model._get_public_attributes()
model_state_dict = reference_model.model.state_dict()
adata = _validate_var_names(adata, reference_model.adata.var_names)

if learning_approach == None:
freeze = False
freeze_expression = False
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict =model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)
elif learning_approach == 'Surgery':
freeze = True
freeze_expression = True
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict =model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)

elif learning_approach == 'ewc':
freeze = False
freeze_expression = False
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)

elif learning_approach == 'latent replay':
freeze = False
freeze_expression = False
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)

elif learning_approach == 'LR+EWC':
freeze = False
freeze_expression = False
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)

elif learning_approach == 'rehearsal':
freeze = False
freeze_expression = False
if isinstance(model, str):
attr_dict, model_state_dict, var_names = cls._load_params(model)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = model._get_public_attributes()
model_state_dict = model.model.state_dict()
adata = _validate_var_names(adata, model.adata.var_names)

init_params = deepcopy(cls._get_init_params_from_dict(attr_dict))

conditions = init_params['conditions']
Expand All @@ -233,6 +296,7 @@ def load_query_data(
if item not in conditions:
new_conditions.append(item)


# Add new conditions to overall conditions
for condition in new_conditions:
conditions.append(condition)
Expand All @@ -244,7 +308,6 @@ def load_query_data(

new_model = cls(adata, **init_params)
new_model._load_expand_params_from_dict(model_state_dict)

if freeze:
new_model.model.freeze = True
for name, p in new_model.model.named_parameters():
Expand Down
137 changes: 127 additions & 10 deletions scarches/models/trvae/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from ._utils import one_hot_encoder

data_list = []
next_data_list = []
old_data_initial_list = []
old_data_final_list = []

class CondLayers(nn.Module):
def __init__(
Expand All @@ -21,8 +25,9 @@ def __init__(
def forward(self, x: torch.Tensor):
if self.n_cond == 0:
out = self.expr_L(x)

else:
expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1)
expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1)#tensor.shape[1] gives you no. cols
out = self.expr_L(expr) + self.cond_L(cond)
return out

Expand Down Expand Up @@ -63,9 +68,11 @@ def __init__(self,
self.n_classes = num_classes
self.FC = None
if len(layer_sizes) > 1:
print("Encoder Architecture:")
self.FC = nn.Sequential()
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): #v = [1000, 128, 128]
#print(v[:-1], v[1:])
#zipped = zip(v[:-1], v[1:])
#list(zipped) --> [(1000, 128), (128, 128)]
if i == 0:
print("\tInput Layer in, out and cond:", in_size, out_size, self.n_classes)
self.FC.add_module(name="L{:d}".format(i), module=CondLayers(in_size,
Expand All @@ -81,22 +88,133 @@ def __init__(self,
self.FC.add_module("N{:d}".format(i), module=nn.LayerNorm(out_size, elementwise_affine=False))
self.FC.add_module(name="A{:d}".format(i), module=nn.ReLU())
if use_dr:
latent_layer=False
self.FC.add_module(name="D{:d}".format(i), module=nn.Dropout(p=dr_rate))
print("\tMean/Var Layer in/out:", layer_sizes[-1], latent_dim)
self.mean_encoder = nn.Linear(layer_sizes[-1], latent_dim)
self.log_var_encoder = nn.Linear(layer_sizes[-1], latent_dim)

def forward(self, x, batch=None):
def forward(self, x, batch=None, external_memory=0, dataset_counter=0, first_epoch=0, replay_layer=0):
torch.autograd.set_detect_anomaly(True)
import math
import random
import numpy as np
if batch is not None:
batch = one_hot_encoder(batch, n_cls=self.n_classes)
x = torch.cat((x, batch), dim=-1)
if self.FC is not None:
x = self.FC(x)
x = torch.cat((x, batch), dim=-1)

if self.FC is not None:
x = self.FC(x)
per=1 #percentage of task i you take to concatenate later with the task i+1
if external_memory ==1:# ONLY save the training data
if first_epoch == 1 and replay_layer ==1: #save the training data ONLY on the first epoch
x=x.detach()
data_list.append(x) #save task i
if dataset_counter!=0:
next_data_list.append(x) # save task i+1

else:
if dataset_counter!=0: #if task i+1, start splitting the whole data into new and old, new from task i+1, old from task i
new_data = x
new_batches = len(data_list)-len(next_data_list)

if new_batches==len(data_list):
old_data = data_list
else:
old_data = data_list[:new_batches]

mb_size = x.size(0)
cur_sz = int(mb_size*0.2) # how much of the mini batch size will be taken by the new data 'task i+1'
n2inject = mb_size - cur_sz # how much of the mini batch size will be taken by the old data 'task i'

old_data_percent=random.sample(old_data,int(len(old_data)*per)) #Randomly get a percentage from old data 'task i'
cuda0 = torch.device('cuda:0')

#Loop over old data to make the study labels of the same length
for idxe, item in enumerate(old_data_percent):
for idxe2, item2 in enumerate(item):
if x.size(-1)==item2.size(-1):
element = item2
else:
item2 = torch.unsqueeze(item2,0)
element = torch.cat((item2,
torch.zeros(item2.size(0),x.size(-1)-item2.size(-1),device=cuda0)),dim=-1)
element = torch.squeeze(element,0)
old_data_initial_list.append(element)
catted = torch.stack(old_data_initial_list,0)

old_data_final_list.append(catted)

old_tensor=torch.cat(old_data_final_list,0)
del old_data_initial_list[:] #delet the list because when it is filled again, the new study labels will be of #different length
del old_data_final_list[:] #delet the list because when it is filled again, the new study labels will be of #different length

# checking if padding data is needed to fix the batch dimensions in the old_tensor
n_missing = old_tensor.shape[0] % mb_size
if n_missing > 0:
surplus = 1
else:
surplus = 0
# computing iters over old_tensor
old_loop = old_tensor.shape[0] // mb_size + surplus #it's like 5//2+1=2.5 and cuz //, then floor, 2+1 =3 end result
# padding data to fix batch dimensions
if n_missing > 0:
n_to_add = mb_size - n_missing
old_tensor = torch.cat((old_tensor[:n_to_add], old_tensor))


new_tensor=x
# checking if padding data is needed to fix the batch dimensions in the new_tensor
n_missing = new_tensor.shape[0] % mb_size
if n_missing > 0:
surplus = 1
else:
surplus = 0
# computing iters over new_tensor
new_loop = new_tensor.shape[0] // mb_size + surplus #it's like 5//2+1=2.5 and cuz //, then floor, 2+1 =3 end result
# checking if padding data is needed to fix the batch dimensions in the new_tensor
if n_missing > 0:
n_to_add = mb_size - n_missing
new_tensor = torch.cat((new_tensor[:n_to_add], new_tensor))

# loop over the old_tensor and new_tensor then concatenate
if new_loop>old_loop:
for it in range(new_loop):
start_new = it * (cur_sz)
end_new = (it + 1) * (cur_sz)

start_previous = it * (n2inject)
end_previous = (it + 1) * (n2inject)
x= torch.cat((new_tensor[start_new:end_new],old_tensor[start_previous:end_previous]),0)

it+=1
if it>old_loop:
it = random.randrange(old_loop)
start_previous = it * (n2inject)
end_previous = (it + 1) * (n2inject)
x= torch.cat((new_tensor[start_new:end_new],old_tensor[start_previous:end_previous]),0)

elif new_loop<old_loop:
for it in range(old_loop):
start_new = it * (cur_sz)
end_new = (it + 1) * (cur_sz)

start_previous = it * (n2inject)
end_previous = (it + 1) * (n2inject)
x= torch.cat((new_tensor[start_new:end_new],old_tensor[start_previous:end_previous]),0)

it+=1
if it>new_loop:
it = random.randrange(new_loop)
start_new = it * (cur_sz)
end_new = (it + 1) * (cur_sz)
x= torch.cat((new_tensor[start_new:end_new],old_tensor[start_previous:end_previous]),0)

means = self.mean_encoder(x)
log_vars = self.log_var_encoder(x)
return means, log_vars

return means, log_vars


class Decoder(nn.Module):
"""ScArches Decoder class. Constructs the decoder sub-network of TRVAE or CVAE networks. It will transform the
constructed latent space to the previous space of data with n_dimensions = x_dimension.
Expand Down Expand Up @@ -194,7 +312,6 @@ def forward(self, z, batch=None):
x = self.HiddenL(dec_latent)
else:
x = dec_latent

# Compute Decoder Output
if self.recon_loss == "mse":
recon_x = self.recon_decoder(x)
Expand Down
12 changes: 5 additions & 7 deletions scarches/models/trvae/trvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ._utils import one_hot_encoder
from ..base._base import CVAELatentsModelMixin


class trVAE(nn.Module, CVAELatentsModelMixin):
"""ScArches model class. This class contains the implementation of Conditional Variational Auto-encoder.

Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(self,

self.hidden_layer_sizes = hidden_layer_sizes
encoder_layer_sizes = self.hidden_layer_sizes.copy()
encoder_layer_sizes.insert(0, self.input_dim)
encoder_layer_sizes.insert(0, self.input_dim) #before insert it was [128,128], after inserting 1000 in the zero index, it became #[1000,128,128]
decoder_layer_sizes = self.hidden_layer_sizes.copy()
decoder_layer_sizes.reverse()
decoder_layer_sizes.append(self.input_dim)
Expand All @@ -112,13 +111,13 @@ def __init__(self,
self.use_dr,
self.dr_rate,
self.n_conditions)

def forward(self, x=None, batch=None, sizefactor=None, labeled=None):
def forward(self, x=None, batch=None, sizefactor=None, labeled=None,external_memory=0,dataset_counter=0,first_epoch=0,replay_layer=0):
x_log = torch.log(1 + x)
if self.recon_loss == 'mse':
x_log = x

z1_mean, z1_log_var = self.encoder(x_log, batch)
z1_mean, z1_log_var = self.encoder(x_log, batch,external_memory,dataset_counter=dataset_counter,
first_epoch=first_epoch, replay_layer=replay_layer)
z1 = self.sampling(z1_mean, z1_log_var)
outputs = self.decoder(z1, batch)

Expand Down Expand Up @@ -153,5 +152,4 @@ def forward(self, x=None, batch=None, sizefactor=None, labeled=None):
mmd_loss = mmd(z1, batch,self.n_conditions, self.beta, self.mmd_boundary)
else:
mmd_loss = mmd(y1, batch,self.n_conditions, self.beta, self.mmd_boundary)

return recon_loss, kl_div, mmd_loss
6 changes: 5 additions & 1 deletion scarches/models/trvae/trvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ def train(
n_epochs: int = 400,
lr: float = 1e-3,
eps: float = 0.01,
ID: int = None,
learning_approach: str = None,
**kwargs
):
if learning_approach is not None:
print(learning_approach, 'is happening!')
"""Train the model.

Parameters
Expand All @@ -134,7 +138,7 @@ def train(
self.adata,
condition_key=self.condition_key_,
**kwargs)
self.trainer.train(n_epochs, lr, eps)
self.trainer.train(n_epochs, lr, eps, ID=ID, learning_approach=learning_approach)
self.is_trained_ = True

@classmethod
Expand Down
Loading