27
27
from sklearn .decomposition import NMF
28
28
from skimage .morphology import disk
29
29
from sklearn .preprocessing import normalize
30
- import tensorflow as tf
30
+ import torch
31
+ from torch .utils .data import DataLoader , TensorDataset
31
32
from time import time
32
33
33
34
import caiman
39
40
high_pass_filter_space , sliding_window ,
40
41
register_translation_3d , apply_shifts_dft )
41
42
import caiman .paths
43
+ from caiman .pytorch_model_arch import PyTorchCNN
42
44
from caiman .source_extraction .cnmf .cnmf import CNMF
43
45
from caiman .source_extraction .cnmf .estimates import Estimates
44
46
from caiman .source_extraction .cnmf .initialization import imblur , initialize_components , hals , downscale
50
52
import caiman .summary_images
51
53
from caiman .utils .nn_models import (fit_NL_model , create_LN_model , quantile_loss , rate_scheduler )
52
54
from caiman .utils .stats import pd_solve
53
- from caiman .utils .utils import save_dict_to_hdf5 , load_dict_from_hdf5 , parmap , load_graph
55
+ from caiman .utils .utils import save_dict_to_hdf5 , load_dict_from_hdf5 , parmap
54
56
55
57
try :
56
58
cv2 .setNumThreads (0 )
57
59
except ():
58
60
pass
59
61
60
- #FIXME ???
61
62
try :
62
63
profile
63
64
except :
@@ -357,34 +358,13 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
357
358
if self .params .get ('online' , 'path_to_model' ) is None or self .params .get ('online' , 'sniper_mode' ) is False :
358
359
loaded_model = None
359
360
self .params .set ('online' , {'sniper_mode' : False })
360
- self .tf_in = None
361
- self .tf_out = None
362
361
else :
363
- try :
364
- from tensorflow .keras .models import model_from_json
365
- logger .info ('Using Keras' )
366
- use_keras = True
367
- except (ModuleNotFoundError ):
368
- use_keras = False
369
- logger .info ('Using Tensorflow' )
370
- if use_keras :
371
- path = self .params .get ('online' , 'path_to_model' ).split ("." )[:- 1 ]
372
- json_path = "." .join (path + ["json" ])
373
- model_path = "." .join (path + ["h5" ])
374
- json_file = open (json_path , 'r' )
375
- loaded_model_json = json_file .read ()
376
- json_file .close ()
377
- loaded_model = model_from_json (loaded_model_json )
378
- loaded_model .load_weights (model_path )
379
- self .tf_in = None
380
- self .tf_out = None
381
- else :
382
- path = self .params .get ('online' , 'path_to_model' ).split ("." )[:- 1 ]
383
- model_path = '.' .join (path + ['h5' , 'pb' ])
384
- loaded_model = load_graph (model_path )
385
- self .tf_in = loaded_model .get_tensor_by_name ('prefix/conv2d_1_input:0' )
386
- self .tf_out = loaded_model .get_tensor_by_name ('prefix/output_node0:0' )
387
- loaded_model = tf .Session (graph = loaded_model )
362
+ logger .info ('Using Torch' )
363
+ path = self .params .get ('online' , 'path_to_model' ).split ("." )[:- 1 ]
364
+ model_path = '.' .join (path + ['pt' ])
365
+ loaded_model = PyTorchCNN ()
366
+ loaded_model .load_state_dict (torch .load (model_path ))
367
+
388
368
self .loaded_model = loaded_model
389
369
390
370
if self .is1p :
@@ -585,7 +565,6 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
585
565
sniper_mode = self .params .get ('online' , 'sniper_mode' ),
586
566
use_peak_max = self .params .get ('online' , 'use_peak_max' ),
587
567
mean_buff = self .estimates .mean_buff ,
588
- tf_in = self .tf_in , tf_out = self .tf_out ,
589
568
ssub_B = ssub_B , W = self .estimates .W if self .is1p else None ,
590
569
b0 = self .estimates .b0 if self .is1p else None ,
591
570
corr_img = self .estimates .corr_img if use_corr else None ,
@@ -1238,7 +1217,7 @@ def fit_online(self, **kwargs):
1238
1217
else :
1239
1218
activity = 0.
1240
1219
# frame = frame.astype(np.float32) - activity
1241
- frame = frame - np .squeeze (model_LN .predict (np .expand_dims (np .expand_dims (frame .astype (np .float32 ) - activity , 0 ), - 1 )))
1220
+ frame = frame - np .squeeze (model_LN .predict (np .expand_dims (np .expand_dims (frame .astype (np .float32 ) - activity , 0 ), - 1 ), verbose = 0 ))
1242
1221
frame = np .maximum (frame , 0 )
1243
1222
frame_count += 1
1244
1223
t_frame_start = time ()
@@ -1252,6 +1231,13 @@ def fit_online(self, **kwargs):
1252
1231
+ str (self .estimates .Ab .shape [- 1 ] - self .params .get ('init' , 'nb' )))
1253
1232
old_comps = self .N
1254
1233
1234
+ if np .isnan (np .sum (frame )):
1235
+ raise Exception (f'Frame { frame_count } contains NaN' )
1236
+ if t % 500 == 0 :
1237
+ logger .info (f'Epoch: { iter + 1 } . { t } frames have been processed.'
1238
+ f'{ self .N - old_comps } new components were added. Total: { self .N } ' )
1239
+ old_comps = self .N
1240
+
1255
1241
# Downsample and normalize
1256
1242
frame_ = frame .copy ().astype (np .float32 )
1257
1243
if self .params .get ('online' , 'ds_factor' ) > 1 :
@@ -2040,8 +2026,7 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
2040
2026
gHalf = (5 , 5 ), sniper_mode = True , rval_thr = 0.85 ,
2041
2027
patch_size = 50 , loaded_model = None , test_both = False ,
2042
2028
thresh_CNN_noisy = 0.5 , use_peak_max = False ,
2043
- thresh_std_peak_resid = 1 , mean_buff = None ,
2044
- tf_in = None , tf_out = None ):
2029
+ thresh_std_peak_resid = 1 , mean_buff = None ):
2045
2030
"""
2046
2031
Extract new candidate components from the residual buffer and test them
2047
2032
using space correlation or the CNN classifier. The function runs the CNN
@@ -2122,11 +2107,23 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
2122
2107
Ain2 /= np .std (Ain2 ,axis = 1 )[:,None ]
2123
2108
Ain2 = np .reshape (Ain2 ,(- 1 ,) + tuple (np .diff (ijSig_cnn ).squeeze ()),order = 'F' )
2124
2109
Ain2 = np .stack ([cv2 .resize (ain ,(patch_size ,patch_size )) for ain in Ain2 ])
2125
- if tf_in is None :
2126
- predictions = loaded_model .predict (Ain2 [:,:,:,np .newaxis ], batch_size = min_num_trial , verbose = 0 )
2127
- else :
2128
- predictions = loaded_model .run (tf_out , feed_dict = {tf_in : Ain2 [:, :, :, np .newaxis ]})
2129
- keep_cnn = list (np .where (predictions [:, 0 ] > thresh_CNN_noisy )[0 ])
2110
+
2111
+ final_crops = Ain2 [:, :, :, np .newaxis ]
2112
+ final_crops_tensor = torch .tensor (final_crops , dtype = torch .float32 ).permute (0 , 3 , 1 , 2 )
2113
+
2114
+ #Create DataLoader for batching
2115
+ dataset = TensorDataset (final_crops_tensor )
2116
+ loader = DataLoader (dataset , batch_size = int (min_num_trial ), shuffle = False )
2117
+
2118
+ loaded_model .eval ()
2119
+ all_predictions = []
2120
+ with torch .no_grad ():
2121
+ for batch in loader :
2122
+ outputs = loaded_model (batch [0 ])
2123
+ all_predictions .append (outputs )
2124
+
2125
+ predictions = torch .cat (all_predictions ).cpu ().numpy ()
2126
+ keep_cnn = list (np .where (predictions [:,0 ] > thresh_CNN_noisy )[0 ])
2130
2127
cnn_pos = Ain2 [keep_cnn ]
2131
2128
else :
2132
2129
keep_cnn = [] # list(range(len(Ain_cnn)))
@@ -2175,8 +2172,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
2175
2172
mean_buff = None , ssub_B = 1 , W = None , b0 = None ,
2176
2173
corr_img = None , first_moment = None , second_moment = None ,
2177
2174
crosscorr = None , col_ind = None , row_ind = None , corr_img_mode = None ,
2178
- max_img = None , downscale_matrix = None , upscale_matrix = None ,
2179
- tf_in = None , tf_out = None ):
2175
+ max_img = None , downscale_matrix = None , upscale_matrix = None ):
2180
2176
"""
2181
2177
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
2182
2178
"""
@@ -2205,8 +2201,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
2205
2201
min_num_trial = min_num_trial , gSig = gSig , gHalf = gHalf ,
2206
2202
sniper_mode = sniper_mode , rval_thr = rval_thr , patch_size = 50 ,
2207
2203
loaded_model = loaded_model , thresh_CNN_noisy = thresh_CNN_noisy ,
2208
- use_peak_max = use_peak_max , test_both = test_both , mean_buff = mean_buff ,
2209
- tf_in = tf_in , tf_out = tf_out )
2204
+ use_peak_max = use_peak_max , test_both = test_both , mean_buff = mean_buff )
2210
2205
2211
2206
ind_new_all = ijsig_all
2212
2207
@@ -2596,4 +2591,4 @@ def load_OnlineCNMF(filename, dview = None):
2596
2591
return new_obj
2597
2592
2598
2593
def inv_mat_vec (A ):
2599
- return np .linalg .solve (A [0 ], A [1 ])
2594
+ return np .linalg .solve (A [0 ], A [1 ])
0 commit comments