Skip to content

Commit d614040

Browse files
authored
Merge pull request #1515 from manuelpaeza/cai-pytorch
PyTorch ver 1.0: keras 3.+ with pytorch backend for nn_models.py and …
2 parents 065b9fe + 5245d0a commit d614040

36 files changed

+1997
-5199
lines changed

caiman/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from caiman.base.timeseries import concatenate
77
from caiman.cluster import start_server, stop_server
88
from caiman.mmapping import load_memmap, save_memmap, save_memmap_each, save_memmap_join
9+
from caiman.pytorch_model_arch import PyTorchCNN
910
from caiman.summary_images import local_correlations
1011

1112
__version__ = importlib.metadata.version('caiman')

caiman/components_evaluation.py

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
import numpy as np
77
import os
88
import peakutils
9-
import tensorflow as tf
109
import scipy
1110
from scipy.sparse import csc_matrix
1211
from scipy.stats import norm
12+
import torch
1313
from typing import Any, Union
1414
import warnings
1515

1616
import caiman
1717
from caiman.paths import caiman_datadir
18+
from caiman.pytorch_model_arch import PyTorchCNN
1819
import caiman.utils.stats
19-
import caiman.utils.utils
2020

2121
try:
2222
cv2.setNumThreads(0)
@@ -270,45 +270,22 @@ def evaluate_components_CNN(A,
270270
then this code will try not to use a GPU. Otherwise it will use one if it finds it.
271271
"""
272272
logger = logging.getLogger("caiman")
273-
274-
# TODO: Find a less ugly way to do this
275273
if not isGPU and 'CAIMAN_ALLOW_GPU' not in os.environ:
276-
print("GPU run not requested, disabling use of GPUs")
274+
logger.info("GPU run not requested, disabling use of GPUs")
277275
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
278-
try:
279-
os.environ["KERAS_BACKEND"] = "tensorflow"
280-
from tensorflow.keras.models import model_from_json
281-
use_keras = True
282-
logger.info('Using Keras')
283-
except (ModuleNotFoundError):
284-
use_keras = False
285-
logger.info('Using Tensorflow')
286276

287-
if loaded_model is None:
288-
if use_keras:
289-
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".json")):
290-
model_file = os.path.join(caiman_datadir(), model_name + ".json")
291-
model_weights = os.path.join(caiman_datadir(), model_name + ".h5")
292-
elif os.path.isfile(model_name + ".json"):
293-
model_file = model_name + ".json"
294-
model_weights = model_name + ".h5"
295-
else:
296-
raise FileNotFoundError(f"File for requested model {model_name} not found")
297-
with open(model_file, 'r') as json_file:
298-
print(f"USING MODEL (keras API): {model_file}")
299-
loaded_model_json = json_file.read()
277+
logger.info('Using Torch')
300278

301-
loaded_model = model_from_json(loaded_model_json)
302-
loaded_model.load_weights(model_name + '.h5')
279+
if loaded_model is None:
280+
if os.path.isfile(os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")):
281+
model_file = os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")
282+
elif os.path.isfile(model_name + ".pt"):
283+
model_file = model_name + ".pt"
303284
else:
304-
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".h5.pb")):
305-
model_file = os.path.join(caiman_datadir(), model_name + ".h5.pb")
306-
elif os.path.isfile(model_name + ".h5.pb"):
307-
model_file = model_name + ".h5.pb"
308-
else:
309-
raise FileNotFoundError(f"File for requested model {model_name} not found")
310-
print(f"USING MODEL (tensorflow API): {model_file}")
311-
loaded_model = caiman.utils.utils.load_graph(model_file)
285+
raise FileNotFoundError(f"File for requested model {model_name} not found")
286+
logger.info(f"Using model: {model_file}")
287+
loaded_model = PyTorchCNN()
288+
loaded_model.load_state_dict(torch.load(model_file))
312289

313290
logger.debug("Loaded model from disk")
314291

@@ -322,16 +299,16 @@ def evaluate_components_CNN(A,
322299
half_crop[1]:com[1] + half_crop[1]] for mm, com in zip(A.tocsc().T, coms)
323300
]
324301
final_crops = np.array([cv2.resize(im / np.linalg.norm(im), (patch_size, patch_size)) for im in crop_imgs])
325-
if use_keras:
326-
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
327-
else:
328-
tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_20_input:0')
329-
tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
330-
with tf.Session(graph=loaded_model) as sess:
331-
predictions = sess.run(tf_out, feed_dict={tf_in: final_crops[:, :, :, np.newaxis]})
332-
sess.close()
302+
303+
# Numpy to PyTorch and add a channel dimension using unsqueeze
304+
final_crops = torch.tensor(final_crops, dtype=torch.float32).unsqueeze(1)
305+
306+
# Pass the preprocessed image crops through the model to get predictions
307+
with torch.no_grad():
308+
predictions = loaded_model(final_crops)
333309

334-
return predictions, final_crops
310+
predictions_numpy = predictions.cpu().numpy()
311+
return predictions_numpy, final_crops
335312

336313
def evaluate_components(Y: np.ndarray,
337314
traces: np.ndarray,

caiman/pytorch_model_arch.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python
2+
"""
3+
Contains the model architecture for cnn_model.pt and cnn_model_online.pt. The files
4+
cnn_model.pt and cnn_model_online.pt contain the model weights. The weight files are
5+
used to load the weights into the model architecture.
6+
"""
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
12+
class PyTorchCNN(nn.Module):
13+
def __init__(self):
14+
super(PyTorchCNN, self).__init__()
15+
# First convolutional block
16+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
17+
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
18+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
19+
self.dropout1 = nn.Dropout(p=0.25)
20+
21+
# Second convolutional block
22+
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same')
23+
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
24+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
25+
self.dropout2 = nn.Dropout(p=0.25)
26+
27+
# Flattening and fully connected layers
28+
self.flatten = nn.Flatten()
29+
self.fc1 = nn.Linear(in_features=6400, out_features=512)
30+
self.dropout3 = nn.Dropout(p=0.5)
31+
self.fc2 = nn.Linear(in_features=512, out_features=2)
32+
33+
def forward(self, x):
34+
# Convolutional Block 1
35+
x = F.relu(self.conv1(x))
36+
x = F.relu(self.conv2(x))
37+
x = self.pool1(x)
38+
x = self.dropout1(x)
39+
40+
# Convolutional block 2
41+
x = F.relu(self.conv3(x))
42+
x = F.relu(self.conv4(x))
43+
x = self.pool2(x)
44+
x = self.dropout2(x)
45+
46+
# Flattening and in_features layers
47+
x = self.flatten(x)
48+
x = F.relu(self.fc1(x))
49+
x = self.dropout3(x)
50+
x = self.fc2(x)
51+
return F.softmax(x, dim=1)
52+

caiman/source_extraction/cnmf/online_cnmf.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from sklearn.decomposition import NMF
2828
from skimage.morphology import disk
2929
from sklearn.preprocessing import normalize
30-
import tensorflow as tf
30+
import torch
31+
from torch.utils.data import DataLoader, TensorDataset
3132
from time import time
3233

3334
import caiman
@@ -39,6 +40,7 @@
3940
high_pass_filter_space, sliding_window,
4041
register_translation_3d, apply_shifts_dft)
4142
import caiman.paths
43+
from caiman.pytorch_model_arch import PyTorchCNN
4244
from caiman.source_extraction.cnmf.cnmf import CNMF
4345
from caiman.source_extraction.cnmf.estimates import Estimates
4446
from caiman.source_extraction.cnmf.initialization import imblur, initialize_components, hals, downscale
@@ -50,14 +52,13 @@
5052
import caiman.summary_images
5153
from caiman.utils.nn_models import (fit_NL_model, create_LN_model, quantile_loss, rate_scheduler)
5254
from caiman.utils.stats import pd_solve
53-
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap, load_graph
55+
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap
5456

5557
try:
5658
cv2.setNumThreads(0)
5759
except():
5860
pass
5961

60-
#FIXME ???
6162
try:
6263
profile
6364
except:
@@ -357,34 +358,13 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
357358
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
358359
loaded_model = None
359360
self.params.set('online', {'sniper_mode': False})
360-
self.tf_in = None
361-
self.tf_out = None
362361
else:
363-
try:
364-
from tensorflow.keras.models import model_from_json
365-
logger.info('Using Keras')
366-
use_keras = True
367-
except(ModuleNotFoundError):
368-
use_keras = False
369-
logger.info('Using Tensorflow')
370-
if use_keras:
371-
path = self.params.get('online', 'path_to_model').split(".")[:-1]
372-
json_path = ".".join(path + ["json"])
373-
model_path = ".".join(path + ["h5"])
374-
json_file = open(json_path, 'r')
375-
loaded_model_json = json_file.read()
376-
json_file.close()
377-
loaded_model = model_from_json(loaded_model_json)
378-
loaded_model.load_weights(model_path)
379-
self.tf_in = None
380-
self.tf_out = None
381-
else:
382-
path = self.params.get('online', 'path_to_model').split(".")[:-1]
383-
model_path = '.'.join(path + ['h5', 'pb'])
384-
loaded_model = load_graph(model_path)
385-
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0')
386-
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
387-
loaded_model = tf.Session(graph=loaded_model)
362+
logger.info('Using Torch')
363+
path = self.params.get('online', 'path_to_model').split(".")[:-1]
364+
model_path = '.'.join(path + ['pt'])
365+
loaded_model = PyTorchCNN()
366+
loaded_model.load_state_dict(torch.load(model_path))
367+
388368
self.loaded_model = loaded_model
389369

390370
if self.is1p:
@@ -585,7 +565,6 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
585565
sniper_mode=self.params.get('online', 'sniper_mode'),
586566
use_peak_max=self.params.get('online', 'use_peak_max'),
587567
mean_buff=self.estimates.mean_buff,
588-
tf_in=self.tf_in, tf_out=self.tf_out,
589568
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
590569
b0=self.estimates.b0 if self.is1p else None,
591570
corr_img=self.estimates.corr_img if use_corr else None,
@@ -1238,7 +1217,7 @@ def fit_online(self, **kwargs):
12381217
else:
12391218
activity = 0.
12401219
# frame = frame.astype(np.float32) - activity
1241-
frame = frame - np.squeeze(model_LN.predict(np.expand_dims(np.expand_dims(frame.astype(np.float32) - activity, 0), -1)))
1220+
frame = frame - np.squeeze(model_LN.predict(np.expand_dims(np.expand_dims(frame.astype(np.float32) - activity, 0), -1), verbose=0))
12421221
frame = np.maximum(frame, 0)
12431222
frame_count += 1
12441223
t_frame_start = time()
@@ -1252,6 +1231,13 @@ def fit_online(self, **kwargs):
12521231
+ str(self.estimates.Ab.shape[-1] - self.params.get('init', 'nb')))
12531232
old_comps = self.N
12541233

1234+
if np.isnan(np.sum(frame)):
1235+
raise Exception(f'Frame {frame_count} contains NaN')
1236+
if t % 500 == 0:
1237+
logger.info(f'Epoch: {iter + 1}. {t} frames have been processed.'
1238+
f'{self.N - old_comps} new components were added. Total: {self.N}')
1239+
old_comps = self.N
1240+
12551241
# Downsample and normalize
12561242
frame_ = frame.copy().astype(np.float32)
12571243
if self.params.get('online', 'ds_factor') > 1:
@@ -2040,8 +2026,7 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
20402026
gHalf=(5, 5), sniper_mode=True, rval_thr=0.85,
20412027
patch_size=50, loaded_model=None, test_both=False,
20422028
thresh_CNN_noisy=0.5, use_peak_max=False,
2043-
thresh_std_peak_resid = 1, mean_buff=None,
2044-
tf_in=None, tf_out=None):
2029+
thresh_std_peak_resid = 1, mean_buff=None):
20452030
"""
20462031
Extract new candidate components from the residual buffer and test them
20472032
using space correlation or the CNN classifier. The function runs the CNN
@@ -2122,11 +2107,23 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
21222107
Ain2 /= np.std(Ain2,axis=1)[:,None]
21232108
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
21242109
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])
2125-
if tf_in is None:
2126-
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
2127-
else:
2128-
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]})
2129-
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
2110+
2111+
final_crops = Ain2[:, :, :, np.newaxis]
2112+
final_crops_tensor = torch.tensor(final_crops, dtype=torch.float32).permute(0, 3, 1, 2)
2113+
2114+
#Create DataLoader for batching
2115+
dataset = TensorDataset(final_crops_tensor)
2116+
loader = DataLoader(dataset, batch_size=int(min_num_trial), shuffle=False)
2117+
2118+
loaded_model.eval()
2119+
all_predictions = []
2120+
with torch.no_grad():
2121+
for batch in loader:
2122+
outputs = loaded_model(batch[0])
2123+
all_predictions.append(outputs)
2124+
2125+
predictions = torch.cat(all_predictions).cpu().numpy()
2126+
keep_cnn = list(np.where(predictions[:,0] > thresh_CNN_noisy)[0])
21302127
cnn_pos = Ain2[keep_cnn]
21312128
else:
21322129
keep_cnn = [] # list(range(len(Ain_cnn)))
@@ -2175,8 +2172,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
21752172
mean_buff=None, ssub_B=1, W=None, b0=None,
21762173
corr_img=None, first_moment=None, second_moment=None,
21772174
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
2178-
max_img=None, downscale_matrix=None, upscale_matrix=None,
2179-
tf_in=None, tf_out=None):
2175+
max_img=None, downscale_matrix=None, upscale_matrix=None):
21802176
"""
21812177
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
21822178
"""
@@ -2205,8 +2201,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
22052201
min_num_trial=min_num_trial, gSig=gSig, gHalf=gHalf,
22062202
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
22072203
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
2208-
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
2209-
tf_in=tf_in, tf_out=tf_out)
2204+
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff)
22102205

22112206
ind_new_all = ijsig_all
22122207

@@ -2596,4 +2591,4 @@ def load_OnlineCNMF(filename, dview = None):
25962591
return new_obj
25972592

25982593
def inv_mat_vec(A):
2599-
return np.linalg.solve(A[0], A[1])
2594+
return np.linalg.solve(A[0], A[1])
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1+
#!/usr/bin/env python
12

3+
from . import config
4+
from . import model
5+
from . import neurons
6+
from . import utils
7+
from . import visualize

0 commit comments

Comments
 (0)