diff --git a/data/Reaching-Mackenzie-2018-08-30/dlc-models/iteration-0/ReachingAug30-trainset95shuffle1/train/pose_cfg.yaml b/data/Reaching-Mackenzie-2018-08-30/dlc-models/iteration-0/ReachingAug30-trainset95shuffle1/train/pose_cfg.yaml index f07fd25..b42d26b 100755 --- a/data/Reaching-Mackenzie-2018-08-30/dlc-models/iteration-0/ReachingAug30-trainset95shuffle1/train/pose_cfg.yaml +++ b/data/Reaching-Mackenzie-2018-08-30/dlc-models/iteration-0/ReachingAug30-trainset95shuffle1/train/pose_cfg.yaml @@ -33,19 +33,19 @@ minsize: 100 mirror: false multi_step: - - 0.005 - - 10000 + - 5000 - - 0.02 - - 430000 + - 5000 - - 0.002 - - 730000 + - 5000 - - 0.001 - - 1030000 + - 5000 net_type: resnet_50 num_joints: 5 pos_dist_thresh: 17 project_path: data/Reaching-Mackenzie-2018-08-30 rightwidth: 400 -save_iters: 50000 +save_iters: 1000 scale_jitter_lo: 0.5 scale_jitter_up: 1.25 topheight: 400 diff --git a/demo/run_dgp_demo.py b/demo/run_dgp_demo.py index 875a6bd..de95d2e 100644 --- a/demo/run_dgp_demo.py +++ b/demo/run_dgp_demo.py @@ -107,8 +107,7 @@ def get_model_cfg_path(base_path, dtype): def get_init_weights_path(base_path): return join( - base_path, 'src', 'DeepLabCut', 'deeplabcut', 'pose_estimation_tensorflow', - 'models', 'pretrained', 'resnet_v1_50.ckpt') + base_path, 'resnet_v1_50.ckpt') if __name__ == '__main__': @@ -133,7 +132,7 @@ def get_init_weights_path(base_path): parser.add_argument( "--batch_size", type=int, - default=10, + default=1, help="size of the batch, if there are memory issues, decrease it value") parser.add_argument("--test", action='store_true', default=False) @@ -143,7 +142,7 @@ def get_init_weights_path(base_path): dlcpath = input_params.dlcpath shuffle = input_params.shuffle dlcsnapshot = input_params.dlcsnapshot - batch_size = input_params.batch_size + batch_size = input_params.batch_size test = input_params.test update_configs = False @@ -155,6 +154,10 @@ def get_init_weights_path(base_path): # ------------------------------------------------------------------------------------ # Train models # ------------------------------------------------------------------------------------ + import tensorflow as tf + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + sess = tf.compat.v1.Session(config=config) try: @@ -177,7 +180,7 @@ def get_init_weights_path(base_path): displayiters=1) else: fit_dlc(snapshot, dlcpath, shuffle=shuffle, step=0) - snapshot = 'snapshot-step0-final--0' # snapshot for step 1 + snapshot = 'snapshot-step0-final-0' # snapshot for step 1 else: # use the specified DLC snapshot to initialize DGP, and skip step 0 snapshot = dlcsnapshot # snapshot for step 1 @@ -200,7 +203,7 @@ def get_init_weights_path(base_path): dlcpath, shuffle=shuffle, step=1, - maxiters=2, + maxiters=1000, displayiters=1) else: fit_dgp_labeledonly(snapshot, @@ -208,7 +211,7 @@ def get_init_weights_path(base_path): shuffle=shuffle, step=1) - snapshot = 'snapshot-step1-final--0' + snapshot = 'snapshot-step1-final-0' # %% step 2 DGP print( ''' @@ -244,7 +247,7 @@ def get_init_weights_path(base_path): gm2=gm2, gm3=gm3) - snapshot = 'snapshot-step{}-final--0'.format(step) + snapshot = 'snapshot-step{}-final-0'.format(step) # -------------------------------------------------------------------------------- # Test DGP model diff --git a/demo/run_video_pred.py b/demo/run_video_pred.py new file mode 100644 index 0000000..5a9fec1 --- /dev/null +++ b/demo/run_video_pred.py @@ -0,0 +1,273 @@ +# If you have collected labels using DLC's GUI you can run DGP with the following +"""Main fitting function for DGP. + step 0: run DLC + step 1: run DGP with labeled frames only + step 2: run DGP with spatial clique + step 3: do prediction on all videos +""" +import argparse +import os +from os import listdir +from os.path import isfile, join, split +from pathlib import Path +import sys +import yaml +import cv2 + +import pandas as pd +from deeplabcut.utils.video_processor import ( + VideoProcessorCV as vp, +) # used to CreateVideo +from deeplabcut.utils import auxiliaryfunctions, CreateVideo, visualization + + +if sys.platform == 'darwin': + import wx + if int(wx.__version__[0]) > 3: + wx.Thread_IsMain = wx.IsMainThread + +os.environ["DLClight"] = "True" +os.environ["Colab"] = "True" +from deeplabcut.utils import auxiliaryfunctions + +from deepgraphpose.models.fitdgp import fit_dlc, fit_dgp, fit_dgp_labeledonly +from deepgraphpose.models.fitdgp_util import get_snapshot_path +from deepgraphpose.models.eval import plot_dgp + + +def update_config_files(dlcpath): + base_path = os.getcwd() + + # project config + proj_cfg_path = join(base_path, dlcpath, 'config.yaml') + with open(proj_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['project_path'] = join(base_path, dlcpath) + video_loc = join(base_path, dlcpath, 'videos', 'reachingvideo1.avi') + try: + yaml_cfg['video_sets'][video_loc] = yaml_cfg['video_sets'].pop(join('videos','reachingvideo1.avi')) + except: + yaml_cfg['video_sets'][video_loc] = yaml_cfg['video_sets'].pop(video_loc) + with open(proj_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + # train model config + model_cfg_path = get_model_cfg_path(base_path, 'train') + with open(model_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['init_weights'] = get_init_weights_path(base_path) + yaml_cfg['project_path'] = join(base_path, dlcpath) + with open(model_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + # download resnet weights if necessary + if not os.path.exists(yaml_cfg['init_weights']): + raise FileNotFoundError('Must download resnet-50 weights; see README for instructions') + + # test model config + model_cfg_path = get_model_cfg_path(base_path, 'test') + with open(model_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['init_weights'] = get_init_weights_path(base_path) + with open(model_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + return join(base_path, dlcpath) + + +def return_configs(): + base_path = os.getcwd() + dlcpath = join('data','Reaching-Mackenzie-2018-08-30') + + # project config + proj_cfg_path = join(base_path, dlcpath, 'config.yaml') + with open(proj_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['project_path'] = dlcpath + video_loc = join(base_path, dlcpath, 'videos', 'reachingvideo1.avi') + yaml_cfg['video_sets'][join('videos','reachingvideo1.avi')] = yaml_cfg['video_sets'].pop(video_loc) + with open(proj_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + # train model config + model_cfg_path = get_model_cfg_path(base_path, 'train') + with open(model_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['init_weights'] = 'resnet_v1_50.ckpt' + yaml_cfg['project_path'] = dlcpath + with open(model_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + # test model config + model_cfg_path = get_model_cfg_path(base_path, 'test') + with open(model_cfg_path, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.SafeLoader) + yaml_cfg['init_weights'] = 'resnet_v1_50.ckpt' + with open(model_cfg_path, 'w') as f: + yaml.dump(yaml_cfg, f) + + +def get_model_cfg_path(base_path, dtype): + return join( + base_path, dlcpath, 'dlc-models', 'iteration-0', 'ReachingAug30-trainset95shuffle1', + dtype, 'pose_cfg.yaml') + + +def get_init_weights_path(base_path): + return join( + base_path, 'resnet_v1_50.ckpt') + + +if __name__ == '__main__': + + # %% set up dlcpath for DLC project and hyperparameters + parser = argparse.ArgumentParser() + parser.add_argument( + "--dlcpath", + type=str, + default=None, + help="the absolute path of the DLC project", + ) + + parser.add_argument( + "--dlcsnapshot", + type=str, + default=None, + help="use the DLC snapshot to initialize DGP", + ) + + parser.add_argument( + "--snapshot", + type=str, + default=None, + help="use the DGP snapshot", + ) + + parser.add_argument( + "--video-path", + type=str, + default=None, + help="path to video", + ) + + parser.add_argument( + "--video-path-out", + type=str, + default=None, + help="path to output video", + ) + + + parser.add_argument("--shuffle", type=int, default=1, help="Project shuffle") + + input_params = parser.parse_known_args()[0] + print(input_params) + + dlcpath = input_params.dlcpath + shuffle = input_params.shuffle + snapshot = input_params.snapshot + video_path = input_params.video_path + video_path_out = input_params.video_path_out + + + print(dlcpath) + + + cfg_yaml = dlcpath + '/config.yaml' + + + print(cfg_yaml) + + + + # ------------------------------------------------------------------------------------ + # Train models + # ------------------------------------------------------------------------------------ + import tensorflow as tf + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + sess = tf.compat.v1.Session(config=config) + + + # -------------------------------------------------------------------------------- + # Test DGP model + # -------------------------------------------------------------------------------- + # %% step 3 predict on all videos in videos_dgp folder + print( + ''' + ========================== + | | + | | + | Predict with DGP | + | | + | | + ========================== + ''' + , flush=True) + cfg = auxiliaryfunctions.read_config(cfg_yaml) + bodyparts2connect = cfg["skeleton"] + skeleton_color = cfg["skeleton_color"] + draw_skeleton = True + color_by = 'bodypart' + displaycropped = False + bodyparts = auxiliaryfunctions.IntersectionofBodyPartsandOnesGivenbyUser( + cfg, "all" + ) + cropping = False + x1, x2, y1, y2 = 0,0,0,0 + trailpoints = 0 + if not (os.path.exists(video_path)): + print(video_path + " does not exist!") + video_sets = list(cfg['video_sets']) + else: + video_sets = [ + join(video_path, f) for f in listdir(video_path) + if isfile(join(video_path, f)) and ( + f.find('avi') > 0 or f.find('mp4') > 0 or f.find('mov') > 0 or f.find( + 'mkv') > 0) + ] + video_pred_path = video_path_out + if not os.path.exists(video_pred_path): + os.makedirs(video_pred_path) + print('video_sets', video_sets, flush=True) + for video_file in video_sets: + plot_dgp(str(video_file), + str(video_pred_path), + proj_cfg_file=str(cfg_yaml), + dgp_model_file=str(snapshot), + shuffle=shuffle) + + + filename = video_pred_path + "/" + os.path.basename(video_file) + videooutname = filename.split(".")[0] + "_dgp_labeled.mp4" + print("VIDEO OUT NAME") + print(videooutname) + clip = vp(fname=video_file,sname=videooutname,codec="mp4v") + filepath = filename.split(".")[0] + "_labeled.h5" + df = pd.read_hdf(filepath) + + labeled_bpts = [ + bp + for bp in df.columns.get_level_values("bodyparts").unique() + if bp in bodyparts + ] + + CreateVideo( + clip, + df, + cfg["pcutoff"], + cfg["dotsize"], + cfg["colormap"], + labeled_bpts, + trailpoints, + cropping, + x1, + x2, + y1, + y2, + bodyparts2connect, + skeleton_color, + draw_skeleton, + displaycropped, + color_by, + ) \ No newline at end of file diff --git a/src/DeepLabCut/setup.py b/src/DeepLabCut/setup.py index 4c6084e..f2e5ba6 100644 --- a/src/DeepLabCut/setup.py +++ b/src/DeepLabCut/setup.py @@ -30,7 +30,7 @@ 'matplotlib==3.0.3','moviepy','numpy>=1.16.4','opencv-python~=3.4', 'pandas','patsy','python-dateutil','pyyaml>=5.1','requests', 'ruamel.yaml~=0.15','setuptools','scikit-image','scikit-learn', - 'scipy','six','statsmodels==0.10.1','tables==3.4.3', + 'scipy','six','statsmodels==0.10.1','tables', 'tensorpack>=0.9.7.1', 'tqdm','wheel'], scripts=['deeplabcut/pose_estimation_tensorflow/models/pretrained/download.sh'], diff --git a/src/deepgraphpose/dataset.py b/src/deepgraphpose/dataset.py index eea8c1c..53b0604 100644 --- a/src/deepgraphpose/dataset.py +++ b/src/deepgraphpose/dataset.py @@ -9,7 +9,7 @@ import numpy as np import scipy.io as sio import tensorflow as tf -import tensorflow.contrib.slim as slim +import tf_slim as slim import yaml from moviepy.editor import VideoFileClip from skimage.util import img_as_ubyte @@ -320,7 +320,7 @@ def __init__(self, video_path, dlc_config, paths): # 1.0) # TO DO: del self.n_frames = calculate_num_frames(self.video_clip) - self.nj = self.dlc_config.num_joints + self.nj = self.dlc_config['num_joints'] # to fill upon creating batches self.ny_in, self.nx_in = self.video_clip.size self.nx_out, self.ny_out = self._compute_pred_dims() # x, y dims of model output @@ -348,23 +348,23 @@ def __str__(self): def _compute_pred_dims(self): """Compute output dims of dgp prediction layer by pushing fake data through network.""" from deepgraphpose.models.fitdgp_util import dgp_prediction_layer - from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net + from deeplabcut.pose_estimation_tensorflow.nnets import PoseNetFactory - TF.reset_default_graph() + TF.compat.v1.reset_default_graph() nc = 3 - inputs = TF.placeholder(TF.float32, shape=[None, self.nx_in, self.ny_in, nc]) + inputs = TF.compat.v1.placeholder(TF.float32, shape=[None, self.nx_in, self.ny_in, nc]) - pn = pose_net(self.dlc_config) + pn = PoseNetFactory.create(self.dlc_config) conv_inputs, end_points = pn.extract_features(inputs) x = dgp_prediction_layer( None, None, self.dlc_config, conv_inputs, 'confidencemap', self.nj, 0, nc, 1) - sess = TF.Session(config=TF.ConfigProto()) - sess.run(TF.global_variables_initializer()) - sess.run(TF.local_variables_initializer()) + sess = TF.compat.v1.Session(config=TF.compat.v1.ConfigProto()) + sess.run(TF.compat.v1.global_variables_initializer()) + sess.run(TF.compat.v1.local_variables_initializer()) feed_dict = {inputs: np.zeros([1, self.nx_in, self.ny_in, nc])} x_np = sess.run(x, feed_dict) @@ -586,14 +586,13 @@ def _add_labels_to_batches(self): def _compute_targets(self): - from deeplabcut.pose_estimation_tensorflow.dataset.factory import \ - create as create_dataset + from deeplabcut.pose_estimation_tensorflow.datasets import PoseDatasetFactory dlc_config = copy.deepcopy(self.dlc_config) dlc_config['deterministic'] = True # switch to default dataset_type to produce expected batch output - dlc_config['dataset_type'] = 'default' - dataset = create_dataset(dlc_config) + dlc_config['dataset_type'] = 'deterministic' + dataset = PoseDatasetFactory.create(dlc_config) nt = len(self.idxs['vis']['train']) # number of training frames # assert nt >= 1 nj = max([dat_.joints[0].shape[0] for dat_ in dataset.data]) @@ -617,7 +616,6 @@ def extract_frame_num(img_path): # pairwise_targets = 5 # pairwise_mask = 6 # data_item = 7 - im_path = data[data_keys[5]].im_path # skip if frame belongs to another video im_path_split = os.path.normpath(im_path).split(os.sep) @@ -837,11 +835,12 @@ def __init__(self, config_yaml, video_sets=None, shuffle=None, S0=None): if video_sets is None: self.proj_config['video_path'] = self.proj_config['video_sets'] # backwards compat else: + print(self.proj_config['video_sets']) video_set_keys = self.proj_config['video_sets'].keys() video_set_keys = [split(v)[-1] for v in video_set_keys] - #print('video_set_keys: ', video_set_keys) + print('video_set_keys: ', video_set_keys) video_set_input = [split(v)[-1] for v in video_sets] - #print('video_set_input: ', video_set_input) + print('video_set_input: ', video_set_input) if set(video_set_keys)==set(video_set_input): self.proj_config['video_path'] = self.proj_config['video_sets'] else: @@ -853,19 +852,24 @@ def __init__(self, config_yaml, video_sets=None, shuffle=None, S0=None): join(self.proj_config['project_path'], key): val for key, val in self.proj_config['video_sets'].items() } + print(self.proj_config['video_sets']) + print("project path") + print(self.proj_config['project_path']) self.dlc_config = get_train_config(self.proj_config, shuffle) # save path info - self.paths['project'] = Path(self.dlc_config.project_path) - self.paths['dlc_model'] = Path(self.dlc_config.snapshot_prefix).parent + self.paths['project'] = Path(self.dlc_config['project_path']) + self.paths['dlc_model'] = Path(self.dlc_config['snapshot_prefix']).parent self.paths['batched_data'] = '' # create a dataset for each video self.video_files = self.proj_config['video_sets'].keys() + print(self.proj_config['video_sets']) assert len(self.video_files) > 0 self.batch_ratios = [] for video_file in self.video_files: + print(video_file) self.datasets.append(Dataset(video_file, self.dlc_config, self.paths)) self.batch_ratios.append(len(self.datasets[-1].idxs['vis']['train'])) self.batch_ratios = np.array(self.batch_ratios) / np.sum(self.batch_ratios) diff --git a/src/deepgraphpose/models/eval.py b/src/deepgraphpose/models/eval.py index 545add6..f0a8f72 100644 --- a/src/deepgraphpose/models/eval.py +++ b/src/deepgraphpose/models/eval.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np import tensorflow as tf -import tensorflow.contrib.slim as slim +import tf_slim as slim import yaml from moviepy.editor import VideoFileClip from skimage.draw import circle @@ -18,7 +18,7 @@ TF = tf.compat.v1 else: TF = tf -config = tf.ConfigProto() +config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -112,7 +112,7 @@ def add_marker(get_frame, t): return frame - clip_marked = clip.fl(add_marker) + clip_marked = clip.transform(add_marker) clip_marked.write_videofile(str(filename), codec="mpeg4", fps=fps, bitrate="1000k") clip_marked.close() @@ -169,42 +169,41 @@ def setup_dgp_eval_graph(dlc_cfg, dgp_model_file, loc_ref=False, gauss_len=1, ga inputs (tf.Tensor) """ - from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net + from deeplabcut.pose_estimation_tensorflow.nnets import PoseNetFactory from deepgraphpose.models.fitdgp_util import argmax_2d_from_cm, dgp_prediction_layer # ------------------- # define model # ------------------- - TF.reset_default_graph() - inputs = TF.placeholder(tf.float32, shape=[1, None, None, 3]) - pn = pose_net(dlc_cfg) + TF.compat.v1.reset_default_graph() + inputs = TF.compat.v1.placeholder(tf.float32, shape=[1, None, None, 3]) + pn = PoseNetFactory.create(dlc_cfg) # extract resnet outputs net, end_points = pn.extract_features(inputs) - with tf.variable_scope('pose', reuse=None): + with tf.compat.v1.variable_scope('pose', reuse=None): scmap = dgp_prediction_layer(None, None, dlc_cfg, net, name='part_pred', - num_outputs=dlc_cfg.num_joints, init_flag=False, nc=None, train_flag=True, - stride=dlc_cfg.deconvolutionstride) + num_outputs=dlc_cfg['num_joints'], init_flag=False, nc=None, train_flag=True) if loc_ref: locref = dgp_prediction_layer(None, None, dlc_cfg, net, name='locref_pred', - num_outputs=dlc_cfg.num_joints * 2, init_flag=False, nc=None, - train_flag=True, stride=dlc_cfg.deconvolutionstride) + num_outputs=dlc_cfg['num_joints'] * 2, init_flag=False, nc=None, + train_flag=True) else: locref = None variables_to_restore = slim.get_variables_to_restore() - restorer = TF.train.Saver(variables_to_restore) + restorer = TF.compat.v1.train.Saver(variables_to_restore) weights_location = str(dgp_model_file) - mu_n, softmax_tensor = argmax_2d_from_cm(scmap, dlc_cfg.num_joints, gamma, gauss_len) + mu_n, softmax_tensor = argmax_2d_from_cm(scmap, dlc_cfg['num_joints'], gamma, gauss_len) # initialize tf session - config_TF = TF.ConfigProto() + config_TF = TF.compat.v1.ConfigProto() config_TF.gpu_options.allow_growth = True - sess = TF.Session(config=config_TF) + sess = TF.compat.v1.Session(config=config_TF) # initialize weights - sess.run(TF.global_variables_initializer()) - sess.run(TF.local_variables_initializer()) + sess.run(TF.compat.v1.global_variables_initializer()) + sess.run(TF.compat.v1.local_variables_initializer()) # restore resnet from dlc trained weights print('loading resnet model weights from %s...' % weights_location, end='') @@ -269,10 +268,10 @@ def estimate_pose(proj_cfg_file, dgp_model_file, video_file, output_dir, shuffle # extract pose # ------------------- try: - dlc_cfg.net_type = 'resnet_50' + dlc_cfg['net_type'] = 'resnet_50' sess, mu_n, _, scmap, _, inputs = setup_dgp_eval_graph(dlc_cfg, dgp_model_file) except: - dlc_cfg.net_type = 'resnet_101' + dlc_cfg['net_type'] = 'resnet_101' sess, mu_n, _, scmap, _, inputs = setup_dgp_eval_graph(dlc_cfg, dgp_model_file) print('\n') @@ -290,13 +289,13 @@ def estimate_pose(proj_cfg_file, dgp_model_file, video_file, output_dir, shuffle pbar.update(1) """ # %% - nj = dlc_cfg.num_joints + nj = dlc_cfg['num_joints'] nx, ny = video_clip.size - nx_out, ny_out = int((nx - dlc_cfg.stride / 2) / dlc_cfg.stride + 1) + 5, int( - (ny - dlc_cfg.stride / 2) / dlc_cfg.stride + 1) + 5 + nx_out, ny_out = int((nx - dlc_cfg['stride'] / 2) / dlc_cfg['stride'] + 1) + 5, int( + (ny - dlc_cfg['stride'] / 2) / dlc_cfg['stride'] + 1) + 5 # %% - markers = np.zeros((n_frames, dlc_cfg.num_joints, 2)) + markers = np.zeros((n_frames, dlc_cfg['num_joints'], 2)) mu_likelihoods = np.zeros((n_frames, nj, 2)).astype('int') likelihoods = np.zeros((n_frames, nj)) @@ -349,8 +348,8 @@ def estimate_pose(proj_cfg_file, dgp_model_file, video_file, output_dir, shuffle video_clip.close() # %% - xr = markers[:, :, 1] * dlc_cfg.stride + 0.5 * dlc_cfg.stride # T x nj - yr = markers[:, :, 0] * dlc_cfg.stride + 0.5 * dlc_cfg.stride + xr = markers[:, :, 1] * dlc_cfg['stride'] + 0.5 * dlc_cfg['stride'] # T x nj + yr = markers[:, :, 0] * dlc_cfg['stride'] + 0.5 * dlc_cfg['stride'] # %% # true xr xr *= scale_x @@ -368,7 +367,7 @@ def estimate_pose(proj_cfg_file, dgp_model_file, video_file, output_dir, shuffle if not Path(save_file).parent.exists(): os.makedirs(os.path.dirname(save_file)) export_pose_like_dlc(labels, os.path.basename(dgp_model_file), - dlc_cfg.all_joints_names, save_file) + dlc_cfg['all_joints_names'], save_file) return labels @@ -424,10 +423,10 @@ def estimate_pose_obsolete(proj_cfg_file, dgp_model_file, video_file, output_dir # extract pose # ------------------- try: - dlc_cfg.net_type = 'resnet_50' + dlc_cfg['net_type'] = 'resnet_50' sess, mu_n, _, scmap, _, inputs = setup_dgp_eval_graph(dlc_cfg, dgp_model_file) except: - dlc_cfg.net_type = 'resnet_101' + dlc_cfg['net_type'] = 'resnet_101' sess, mu_n, _, scmap, _, inputs = setup_dgp_eval_graph(dlc_cfg, dgp_model_file) print('\n') diff --git a/src/deepgraphpose/models/fitdgp.py b/src/deepgraphpose/models/fitdgp.py index 2eed3b8..0a76a23 100644 --- a/src/deepgraphpose/models/fitdgp.py +++ b/src/deepgraphpose/models/fitdgp.py @@ -21,19 +21,21 @@ import numpy as np import tensorflow as tf -import tensorflow.contrib.slim as slim +import tf_slim as slim import deeplabcut from deeplabcut.pose_estimation_tensorflow.config import load_config -from deeplabcut.pose_estimation_tensorflow.dataset.factory import ( - create as create_dataset, ) -from deeplabcut.pose_estimation_tensorflow.dataset.pose_defaultdataset import PoseDataset -from deeplabcut.pose_estimation_tensorflow.nnet.net_factory import pose_net -from deeplabcut.pose_estimation_tensorflow.train import LearningRate, get_batch_spec, \ +from deeplabcut.pose_estimation_tensorflow.datasets import ( + PoseDatasetFactory, + ImgaugPoseDataset ) +from deeplabcut.pose_estimation_tensorflow.nnets import ( + PoseNetFactory, + PoseResnet +) +from deeplabcut.pose_estimation_tensorflow.core.train import LearningRate, get_batch_spec, \ setup_preloading, start_preloading, get_optimizer from deeplabcut.utils import auxiliaryfunctions -from deeplabcut.pose_estimation_tensorflow.nnet.pose_net import PoseNet, losses, \ - prediction_layer +from deeplabcut.pose_estimation_tensorflow.nnets.layers import prediction_layer from deepgraphpose.dataset import MultiDataset, coord2map from deepgraphpose.models.fitdgp_util import gen_batch, argmax_2d_from_cm, combine_all_marker, build_aug, data_aug, learn_wt @@ -51,7 +53,7 @@ # %% def fit_dlc( - snapshot, dlcpath, shuffle=1, step=0, saveiters=1000, displayiters=100, maxiters=200000, + snapshot, dlcpath, shuffle=1, step=0, saveiters=1000, displayiters=100, maxiters=50000, trainingsetindex=0): """Run the original DLC code. Parameters @@ -90,103 +92,105 @@ def fit_dlc( # Change dlc_cfg as we want, here we set the default values # TODO: it would be better to set the default values when making config.yaml and pose_cfg.yaml, double check the default setting for these two yamls. dlc_cfg = load_config(pose_config_yaml) - dlc_cfg.crop = True - dlc_cfg.cropratio = 0.4 - dlc_cfg.global_scale = 0.8 - dlc_cfg.multi_step = [[0.001, 10000], [0.005, 430000], [0.002, 730000], + print(pose_config_yaml) + print(dlc_cfg) + dlc_cfg['crop'] = False + dlc_cfg['cropratio'] = 0.4 + dlc_cfg['global_scale'] = 0.8 + dlc_cfg['multi_step'] = [[0.001, 10000], [0.005, 430000], [0.002, 730000], [0.001, 1030000]] if "snapshot" in snapshot: train_path = dlc_base_path / modelfoldername / 'train' init_weights = str(train_path / snapshot) else: parent_path = Path(os.path.dirname(deeplabcut.__file__)) - snapshot = dlc_cfg.net_type.split('_')[0] + '_v1_' + dlc_cfg.net_type.split('_')[1] + '.ckpt' + snapshot = dlc_cfg['net_type'].split('_')[0] + '_v1_' + dlc_cfg['net_type'].split('_')[1] + '.ckpt' init_weights = str( parent_path / join('pose_estimation_tensorflow', 'models', 'pretrained', snapshot)) - dlc_cfg.init_weights = init_weights - dlc_cfg.pos_dist_thresh = 8 - dlc_cfg.output_stride = 16 + dlc_cfg['init_weights'] = init_weights + dlc_cfg['pos_dist_thresh'] = 8 + dlc_cfg['output_stride'] = 16 # skip this DLC step if it's already done. - model_name = dlc_cfg.snapshot_prefix + '-step0-final--0.index' + model_name = dlc_cfg['snapshot_prefix'] + '-step0-final--0.index' if os.path.isfile(model_name): print(model_name, ' exists! The original DLC has already been run.', flush=True) return None # Build loss function - TF.reset_default_graph() - - dataset = create_dataset(dlc_cfg) + TF.compat.v1.reset_default_graph() + dlc_cfg['dataset_type'] = 'deterministic' + dataset = PoseDatasetFactory.create(dlc_cfg) batch_spec = get_batch_spec(dlc_cfg) batch, enqueue_op, placeholders = setup_preloading(batch_spec) - losses = pose_net(dlc_cfg).train(batch) + losses = PoseNetFactory.create(dlc_cfg).train(batch) total_loss = losses["total_loss"] - + print(losses.items()) for k, t in losses.items(): - TF.summary.scalar(k, t) - merged_summaries = TF.summary.merge_all() + TF.compat.v1.summary.scalar(k, t) + merged_summaries = TF.compat.v1.summary.merge_all() - if "snapshot" in Path(dlc_cfg.init_weights).stem: + if "snapshot" in Path(dlc_cfg['init_weights']).stem: print("Loading already trained DLC with backbone:", - dlc_cfg.net_type, + dlc_cfg['net_type'], flush=True) variables_to_restore = slim.get_variables_to_restore() else: - print("Loading ImageNet-pretrained", dlc_cfg.net_type, flush=True) + print("Loading ImageNet-pretrained", dlc_cfg['net_type'], flush=True) # loading backbone from ResNet, MobileNet etc. - if "resnet" in dlc_cfg.net_type: + if "resnet" in dlc_cfg['net_type']: variables_to_restore = slim.get_variables_to_restore( include=["resnet_v1"]) - elif "mobilenet" in dlc_cfg.net_type: + elif "mobilenet" in dlc_cfg['net_type']: variables_to_restore = slim.get_variables_to_restore( include=["MobilenetV2"]) else: print("Wait for DLC 2.3.") - restorer = TF.train.Saver(variables_to_restore) - saver = TF.train.Saver( + restorer = TF.compat.v1.train.Saver(variables_to_restore) + saver = TF.compat.v1.train.Saver( max_to_keep=5 ) # selects how many snapshots are stored, # see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835 allow_growth = True if allow_growth: - config = TF.ConfigProto() + config = TF.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - sess = TF.Session(config=config) + sess = TF.compat.v1.Session(config=config) else: - sess = TF.Session() + sess = TF.compate.v1.Session() coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders) - train_writer = TF.summary.FileWriter(dlc_cfg.log_dir, sess.graph) - learning_rate, train_op = get_optimizer(total_loss, dlc_cfg) + train_writer = TF.compat.v1.summary.FileWriter(dlc_cfg['log_dir'], sess.graph) + learning_rate, train_op, tstep = get_optimizer(total_loss, dlc_cfg) - sess.run(TF.global_variables_initializer()) - sess.run(TF.local_variables_initializer()) + sess.run(TF.compat.v1.global_variables_initializer()) + sess.run(TF.compat.v1.local_variables_initializer()) # Restore variables from disk. - restorer.restore(sess, dlc_cfg.init_weights) + restorer.restore(sess, dlc_cfg['init_weights']) # Run iterations if displayiters is None: - display_iters = max(1, int(dlc_cfg.display_iters)) + display_iters = max(1, int(dlc_cfg['display_iters'])) else: display_iters = max(1, int(displayiters)) print("Display_iters overwritten as", display_iters, flush=True) if saveiters is None: - save_iters = max(1, int(dlc_cfg.save_iters)) + save_iters = max(1, int(dlc_cfg['save_iters'])) else: save_iters = max(1, int(saveiters)) print("Save_iters overwritten as", save_iters, flush=True) if maxiters is None: - max_iter = int(dlc_cfg.multi_step[-1][1]) + max_iter = int(dlc_cfg['multi_step'][-1][1]) else: - max_iter = min(int(dlc_cfg.multi_step[-1][1]), int(maxiters)) + max_iter = min(int(dlc_cfg['multi_step'][-1][1]), int(maxiters)) print("Max_iters overwritten as", max_iter, flush=True) lr_gen = LearningRate(dlc_cfg) # learning rate @@ -206,7 +210,7 @@ def fit_dlc( # collect loss partloss += alllosses["part_loss"] # scoremap loss - if dlc_cfg.location_refinement: + if dlc_cfg['location_refinement']: locrefloss += alllosses["locref_loss"] cumloss += loss_val train_writer.add_summary(summary, it) @@ -236,11 +240,11 @@ def fit_dlc( # Save snapshot if (it % save_iters == 0 and it != 0) or it == max_iter: - model_name = dlc_cfg.snapshot_prefix + '-step' + str( - step) + '-' + model_name = dlc_cfg['snapshot_prefix'] + '-step' + str( + step) saver.save(sess, model_name, global_step=it) if it == max_iter: - model_name = dlc_cfg.snapshot_prefix + '-step' + str( + model_name = dlc_cfg['snapshot_prefix'] + '-step' + str( step) + '-final-' saver.save(sess, model_name, global_step=0) @@ -340,26 +344,26 @@ def fit_dgp_labeledonly( shuffle=shuffle, S0=S0) dgp_cfg = data_batcher.dlc_config - dgp_cfg.ws = 0 # the spatial clique parameter - dgp_cfg.ws_max = 1.2 # the multiplier for the upper bound of spatial distance - dgp_cfg.wt = 0 # the temporal clique parameter - dgp_cfg.wt_max = 0 # the upper bound of temporal distance - dgp_cfg.wn_visible = 1 # the network clique parameter for visible frames - dgp_cfg.wn_hidden = 0 # the network clique parameter for hidden frames - dgp_cfg.gamma = 1 # the multiplier for the softmax confidence map - dgp_cfg.gauss_len = 1 # the length scale for the Gaussian kernel convolving the softmax confidence map - dgp_cfg.lengthscale = 1 # the length scale for the Gaussian target map - dgp_cfg.max_to_keep = 5 # max number of snapshots to keep - dgp_cfg.batch_size = 1 # batch size - dgp_cfg.n_times_all_frames = 100 # the number of times each selected frames is iterated over - dgp_cfg.lr = 0.005 # learning rate - # dgp_cfg.net_type = 'resnet_50' - dgp_cfg.gm2 = 0 # scale target by confidence level - dgp_cfg.gm3 = 0 # scale hidden loss by confidence level - dgp_cfg.aug = aug # data augmentation + dgp_cfg['ws'] = 0 # the spatial clique parameter + dgp_cfg['ws_max'] = 1.2 # the multiplier for the upper bound of spatial distance + dgp_cfg['wt'] = 0 # the temporal clique parameter + dgp_cfg['wt_max'] = 0 # the upper bound of temporal distance + dgp_cfg['wn_visible'] = 1 # the network clique parameter for visible frames + dgp_cfg['wn_hidden'] = 0 # the network clique parameter for hidden frames + dgp_cfg['gamma'] = 1 # the multiplier for the softmax confidence map + dgp_cfg['gauss_len'] = 1 # the length scale for the Gaussian kernel convolving the softmax confidence map + dgp_cfg['lengthscale'] = 1 # the length scale for the Gaussian target map + dgp_cfg['max_to_keep'] = 5 # max number of snapshots to keep + dgp_cfg['batch_size'] = 1 # batch size + dgp_cfg['n_times_all_frames'] = 100 # the number of times each selected frames is iterated over + dgp_cfg['lr'] = 0.005 # learning rate + # dgp_cfg['net_type'] = 'resnet_50' + dgp_cfg['gm2'] = 0 # scale target by confidence level + dgp_cfg['gm3'] = 0 # scale hidden loss by confidence level + dgp_cfg['aug'] = aug # data augmentation # skip this DGP with labeled frames only step if it's already done. - model_name = dgp_cfg.snapshot_prefix + '-step1-final--0.index' + model_name = dgp_cfg['snapshot_prefix'] + '-step1-final-0.index' if os.path.isfile(model_name): print(model_name, ' exists! DGP with labeled frames has already been run.', flush=True) return None @@ -386,9 +390,9 @@ def fit_dgp_labeledonly( # ------------------------------------------------------------------------------------ # Build model # ------------------------------------------------------------------------------------ - TF.reset_default_graph() + TF.compat.v1.reset_default_graph() loss, total_loss, total_loss_visible, placeholders = dgp_loss(data_batcher, dgp_cfg) - learning_rate = TF.placeholder(tf.float32, shape=[]) + learning_rate = TF.compat.v1.compat.v1.placeholder(tf.float32, shape=[]) # Restore network parameters for RESNET and COVNET variables_to_restore0 = slim.get_variables_to_restore( @@ -396,29 +400,29 @@ def fit_dgp_labeledonly( variables_to_restore1 = slim.get_variables_to_restore( include=['pose/locref_pred']) variables_to_restore2 = slim.get_variables_to_restore(include=['resnet']) - restorer = TF.train.Saver(variables_to_restore0 + variables_to_restore1 + + restorer = TF.compat.v1.train.Saver(variables_to_restore0 + variables_to_restore1 + variables_to_restore2) - saver = TF.train.Saver(max_to_keep=dgp_cfg.max_to_keep) + saver = TF.compat.v1.train.Saver(max_to_keep=dgp_cfg['max_to_keep']) # Set up session allow_growth = True if allow_growth: - config = TF.ConfigProto() + config = TF.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - sess = TF.Session(config=config) + sess = TF.compat.v1.Session(config=config) else: - sess = TF.Session() + sess = TF.compat.v1.Session() # Set up optimizer - all_train_vars = TF.trainable_variables() - optimizer = TF.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) + all_train_vars = TF.compat.v1.trainable_variables() + optimizer = TF.compat.v1.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) gradients, variables = zip( *optimizer.compute_gradients(total_loss_visible, var_list=all_train_vars)) gradients, _ = TF.clip_by_global_norm(gradients, 10.0) train_op = optimizer.apply_gradients(zip(gradients, variables)) - sess.run(TF.global_variables_initializer()) - sess.run(TF.local_variables_initializer()) + sess.run(TF.compat.v1.global_variables_initializer()) + sess.run(TF.compat.v1.local_variables_initializer()) # Restore RESNET var print('restoring resnet weights from %s' % init_weights) @@ -428,7 +432,7 @@ def fit_dgp_labeledonly( # ------------------------------------------------------------------------------------ # Begin training # ------------------------------------------------------------------------------------ - nepoch = np.min([int(n_visible_frames_total * dgp_cfg.n_times_all_frames), maxiters]) + nepoch = np.min([int(n_visible_frames_total * dgp_cfg['n_times_all_frames']), maxiters]) visible_frame_total_dict = [] for i, v in enumerate(visible_frame_total): for vv in v: @@ -438,18 +442,19 @@ def fit_dgp_labeledonly( save_iters = saveiters maxiters = batch_ind_all.shape[0] - pdata = PoseDataset(dgp_cfg) + dgp_cfg['dataset_type'] = 'deterministic' + pdata = PoseDatasetFactory.create(dgp_cfg) data_batcher.reset() # %% print('Begin Training for {} iterations'.format(maxiters)) - if dgp_cfg.aug: + if dgp_cfg['aug']: pipeline = build_aug(apply_prob=0.8) time_start = time.time() for it in range(maxiters): - current_lr = dgp_cfg.lr + current_lr = dgp_cfg['lr'] # get batch index visible_batch_ind = visible_frame_total_dict[batch_ind_all[it]] @@ -471,14 +476,14 @@ def fit_dgp_labeledonly( visible_frame_within_batch = [np.where(all_frame == i)[0][0] for i in visible_frame] # batch data for placeholders - if dgp_cfg.wt > 0: + if dgp_cfg['wt'] > 0: vector_field = learn_wt(all_data_batch) # vector field from optical flow else: vector_field = np.zeros((1,1,1)) - wt_batch = np.ones(nt_batch - 1, ) * dgp_cfg.wt + wt_batch = np.ones(nt_batch - 1, ) * dgp_cfg['wt'] # data augmentation for visible frames - if dgp_cfg.aug and dgp_cfg.wt == 0 and len(visible_frame_within_batch) > 0: + if dgp_cfg['aug'] and dgp_cfg['wt'] == 0 and len(visible_frame_within_batch) > 0: all_data_batch, joint_loc = data_aug(all_data_batch, visible_frame_within_batch, joint_loc, pipeline, dgp_cfg) locref_targets_batch, locref_mask_batch = coord2map(pdata, joint_loc, nx_out, ny_out, nj) @@ -532,11 +537,11 @@ def fit_dgp_labeledonly( # Save snapshot if (it % save_iters == 0) or (it + 1) == maxiters: - model_name = dgp_cfg.snapshot_prefix + '-step' + str(step) + '-' + model_name = dgp_cfg['snapshot_prefix'] + '-step' + str(step) saver.save(sess, model_name, global_step=it) saver.save(sess, model_name, global_step=0) if (it + 1) == maxiters: - model_name = dgp_cfg.snapshot_prefix + '-step' + str(step) + '-final-' + model_name = dgp_cfg['snapshot_prefix'] + '-step' + str(step) + '-final' saver.save(sess, model_name, global_step=0) time_end = time.time() @@ -547,8 +552,8 @@ def fit_dgp_labeledonly( def fit_dgp( - snapshot, dlcpath, batch_size=10, shuffle=1, step=2, saveiters=1000, displayiters=5, - maxiters=200000, ns=10, nc=2048, n_max_frames=2000, gm2=0, gm3=0, nepoch=100, wt=0, aug=True, + snapshot, dlcpath, batch_size=1, shuffle=1, step=2, saveiters=500, displayiters=10, + maxiters=30000, ns=10, nc=2048, n_max_frames=2000, gm2=0, gm3=0, nepoch=100, wt=0, aug=True, debug='', trainingsetindex=0): """Run DGP. Parameters @@ -635,26 +640,26 @@ def fit_dgp( shuffle=shuffle, S0=S0) dgp_cfg = data_batcher.dlc_config - dgp_cfg.ws = 1000 # the spatial clique parameter - dgp_cfg.ws_max = 1.2 # the multiplier for the upper bound of spatial distance - dgp_cfg.wt = wt # the temporal clique parameter - dgp_cfg.wt_max = 0 # the upper bound of temporal distance - dgp_cfg.wn_visible = 5 # the network clique parameter for visible frames - dgp_cfg.wn_hidden = 3 # the network clique parameter for hidden frames - dgp_cfg.gamma = 1 # the multiplier for the softmax confidence map - dgp_cfg.gauss_len = 1 # the length scale for the Gaussian kernel convolving the softmax confidence map - dgp_cfg.lengthscale = 1 # the length scale for the Gaussian target map - dgp_cfg.max_to_keep = 5 # max number of snapshots to keep - dgp_cfg.batch_size = batch_size # batch size - dgp_cfg.n_times_all_frames = nepoch # the number of times each selected frames is iterated over - dgp_cfg.lr = 0.005 # learning rate - # dgp_cfg.net_type = 'resnet_50' - dgp_cfg.gm2 = gm2 # scale target by confidence level - dgp_cfg.gm3 = gm3 # scale hidden loss by confidence level - dgp_cfg.aug = aug # data augmentation + dgp_cfg['ws'] = 1000 # the spatial clique parameter + dgp_cfg['ws_max'] = 1.2 # the multiplier for the upper bound of spatial distance + dgp_cfg['wt'] = wt # the temporal clique parameter + dgp_cfg['wt_max'] = 0 # the upper bound of temporal distance + dgp_cfg['wn_visible'] = 5 # the network clique parameter for visible frames + dgp_cfg['wn_hidden'] = 3 # the network clique parameter for hidden frames + dgp_cfg['gamma'] = 1 # the multiplier for the softmax confidence map + dgp_cfg['gauss_len'] = 1 # the length scale for the Gaussian kernel convolving the softmax confidence map + dgp_cfg['lengthscale'] = 1 # the length scale for the Gaussian target map + dgp_cfg['max_to_keep'] = 5 # max number of snapshots to keep + dgp_cfg['batch_size'] = batch_size # batch size + dgp_cfg['n_times_all_frames'] = nepoch # the number of times each selected frames is iterated over + dgp_cfg['lr'] = 0.005 # learning rate + # dgp_cfg['net_type'] = 'resnet_50' + dgp_cfg['gm2'] = gm2 # scale target by confidence level + dgp_cfg['gm3'] = gm3 # scale hidden loss by confidence level + dgp_cfg['aug'] = aug # data augmentation # skip this DGP step if it's already done. - model_name = dgp_cfg.snapshot_prefix + '-step{}{}-final--0.index'.format(step, debug) + model_name = dgp_cfg['snapshot_prefix'] + '-step{}-final-0.index'.format(step) if os.path.isfile(model_name): print(model_name, ' exists! DGP has already been run.', flush=True) return None @@ -681,9 +686,9 @@ def fit_dgp( # ------------------------------------------------------------------------------------ # Build model # ------------------------------------------------------------------------------------ - TF.reset_default_graph() + TF.compat.v1.reset_default_graph() loss, total_loss, total_loss_visible, placeholders = dgp_loss(data_batcher, dgp_cfg) - learning_rate = TF.placeholder(tf.float32, shape=[]) + learning_rate = TF.compat.v1.placeholder(tf.float32, shape=[]) # Restore network parameters for RESNET and COVNET variables_to_restore0 = slim.get_variables_to_restore( @@ -691,29 +696,29 @@ def fit_dgp( variables_to_restore1 = slim.get_variables_to_restore( include=['pose/locref_pred']) variables_to_restore2 = slim.get_variables_to_restore(include=['resnet']) - restorer = TF.train.Saver(variables_to_restore0 + variables_to_restore1 + + restorer = TF.compat.v1.train.Saver(variables_to_restore0 + variables_to_restore1 + variables_to_restore2) - saver = TF.train.Saver(max_to_keep=dgp_cfg.max_to_keep) + saver = TF.compat.v1.train.Saver(max_to_keep=dgp_cfg['max_to_keep']) # Set up session allow_growth = True if allow_growth: - config = TF.ConfigProto() + config = TF.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - sess = TF.Session(config=config) + sess = TF.compat.v1.Session(config=config) else: - sess = TF.Session() + sess = TF.compat.v1.Session() # Set up optimizer - all_train_vars = TF.trainable_variables() - optimizer = TF.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) + all_train_vars = TF.compat.v1.trainable_variables() + optimizer = TF.compat.v1.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) gradients, variables = zip( *optimizer.compute_gradients(total_loss, var_list=all_train_vars)) gradients, _ = TF.clip_by_global_norm(gradients, 10.0) train_op = optimizer.apply_gradients(zip(gradients, variables)) - sess.run(TF.global_variables_initializer()) - sess.run(TF.local_variables_initializer()) + sess.run(TF.compat.v1.global_variables_initializer()) + sess.run(TF.compat.v1.local_variables_initializer()) # Restore RESNET var print('restoring resnet weights from %s' % init_weights) @@ -724,21 +729,23 @@ def fit_dgp( # Begin training # ------------------------------------------------------------------------------------ batch_ind_all = gen_batch(visible_frame_total, hidden_frame_total, all_frame_total, dgp_cfg, maxiters) - save_iters = np.int(saveiters / dgp_cfg.batch_size) + save_iters = np.int(saveiters / dgp_cfg['batch_size']) maxiters = len(batch_ind_all) - pdata = PoseDataset(dgp_cfg) + + dgp_cfg['dataset_type'] = 'deterministic' + pdata = PoseDatasetFactory.create(dgp_cfg) data_batcher.reset() # %% print('Begin Training for {} iterations'.format(maxiters)) - if dgp_cfg.aug: + if dgp_cfg['aug']: pipeline = build_aug(apply_prob=0.8) time_start = time.time() for it in range(maxiters): - current_lr = dgp_cfg.lr + current_lr = dgp_cfg['lr'] # get batch index batch_ind = batch_ind_all[it] @@ -768,14 +775,14 @@ def fit_dgp( visible_frame_within_batch = [np.where(all_frame == i)[0][0] for i in visible_frame] # batch data for placeholders - if dgp_cfg.wt > 0: + if dgp_cfg['wt'] > 0: vector_field = learn_wt(all_data_batch) # vector field from optical flow else: vector_field = np.zeros((1,1,1)) - wt_batch = np.ones(nt_batch - 1, ) * dgp_cfg.wt + wt_batch = np.ones(nt_batch - 1, ) * dgp_cfg['wt'] # data augmentation for visible frames - if dgp_cfg.aug and dgp_cfg.wt == 0 and len(visible_frame_within_batch) > 0: + if dgp_cfg['aug'] and dgp_cfg['wt'] == 0 and len(visible_frame_within_batch) > 0: all_data_batch, joint_loc = data_aug(all_data_batch, visible_frame_within_batch, joint_loc, pipeline, dgp_cfg) locref_targets_batch, locref_mask_batch = coord2map(pdata, joint_loc, nx_out, ny_out, nj) @@ -829,12 +836,12 @@ def fit_dgp( # Save snapshot if (it % save_iters == 0) or (it + 1) == maxiters: - model_name = dgp_cfg.snapshot_prefix + '-step' + str(step) + '{}'.format(debug)+ '-' + model_name = dgp_cfg['snapshot_prefix'] + '-step' + str(step) #print('Storing model {}'.format(model_name)) saver.save(sess, model_name, global_step=it) saver.save(sess, model_name, global_step=0) if (it + 1) == maxiters: - model_name = dgp_cfg.snapshot_prefix + '-step' + str(step) + '{}'.format(debug) +'-final-' + model_name = dgp_cfg['snapshot_prefix'] + '-step' + str(step) +'-final' #print('Storing model {}'.format(model_name)) saver.save(sess, model_name, global_step=0) @@ -884,69 +891,69 @@ def dgp_loss(data_batcher, dgp_cfg): limb_full[np.abs(limb_full) > 1e5] = 0 limb_full = np.reshape(limb_full, [joint_loc_full.shape[0], 2, -1]) limb_full = np.sqrt(np.sum(np.square(limb_full), 1)) - limb_full = limb_full.T * dgp_cfg.stride + dgp_cfg.stride / 2 - ws_max = np.max(np.nan_to_num(limb_full), 1) * dgp_cfg.ws_max + limb_full = limb_full.T * dgp_cfg['stride'] + dgp_cfg['stride'] / 2 + ws_max = np.max(np.nan_to_num(limb_full), 1) * dgp_cfg['ws_max'] limb_full = np.true_divide(limb_full.sum(1), (limb_full != 0).sum(1)) ws = 1 / (np.nan_to_num( - limb_full) + 1e-20) * dgp_cfg.ws # spatial clique parameter based on the limb length and dlc_cfg.ws + limb_full) + 1e-20) * dgp_cfg['ws'] # spatial clique parameter based on the limb length and dlc_cfg.ws # Define placeholders # input and output - inputs = TF.placeholder(TF.float32, shape=[None, None, None, 3]) - targets = TF.placeholder(TF.float32, shape=[None, nj, 2]) - targets_nonan = TF.where(TF.is_nan(targets), TF.ones_like(targets) * 0, targets) # set nan to be 0 in targets + inputs = TF.compat.v1.placeholder(TF.float32, shape=[None, None, None, 3]) + targets = TF.compat.v1.placeholder(TF.float32, shape=[None, nj, 2]) + targets_nonan = TF.where(TF.compat.v1.is_nan(targets), TF.ones_like(targets) * 0, targets) # set nan to be 0 in targets # local refinement - locref_map = TF.placeholder(TF.float32, shape=[None, None, None, nj * 2]) - locref_mask = TF.placeholder(TF.float32, shape=[None, None, None, nj * 2]) + locref_map = TF.compat.v1.placeholder(TF.float32, shape=[None, None, None, nj * 2]) + locref_mask = TF.compat.v1.placeholder(TF.float32, shape=[None, None, None, nj * 2]) # placeholders for parameters - visible_marker_pl = TF.placeholder(TF.int32, shape=[ + visible_marker_pl = TF.compat.v1.placeholder(TF.int32, shape=[ None, ]) # placeholder for visible marker index in the batch - hidden_marker_pl = TF.placeholder(TF.int32, shape=[ + hidden_marker_pl = TF.compat.v1.placeholder(TF.int32, shape=[ None, ]) # placeholder for hidden marker index in the batch - visible_marker_in_targets_pl = TF.placeholder(TF.int32, shape=[ + visible_marker_in_targets_pl = TF.compat.v1.placeholder(TF.int32, shape=[ None, ]) # placeholder for visible marker index in targets/visible frames - nt_batch_pl = TF.placeholder(TF.int32, shape=[]) # placeholder for the total number of frames in the batch + nt_batch_pl = TF.compat.v1.placeholder(TF.int32, shape=[]) # placeholder for the total number of frames in the batch - wt_batch_pl = TF.placeholder(TF.float32, shape=[ + wt_batch_pl = TF.compat.v1.placeholder(TF.float32, shape=[ None, ]) # placeholder for the temporal clique wt; it's a vector which can contain different clique values for different frames - wt_batch_mask_pl = TF.placeholder(TF.float32, shape=[ + wt_batch_mask_pl = TF.compat.v1.placeholder(TF.float32, shape=[ None, ]) # placeholder for the batch mask for wt, 1 means wt is in the batch; 0 means wt is not in the batch wt_batch_tf = TF.multiply(wt_batch_pl, wt_batch_mask_pl) # wt vector for the batch - wt_max_tf = TF.constant(dgp_cfg.wt_max, TF.float32) # placeholder for the upper bounds for the temporal clique wt + wt_max_tf = TF.constant(dgp_cfg['wt_max'], TF.float32) # placeholder for the upper bounds for the temporal clique wt - wn_visible_tf = TF.constant(dgp_cfg.wn_visible, + wn_visible_tf = TF.constant(dgp_cfg['wn_visible'], TF.float32) # placeholder for the upper bounds for the spatial clique ws; it varies across joints - wn_hidden_tf = TF.constant(dgp_cfg.wn_hidden, + wn_hidden_tf = TF.constant(dgp_cfg['wn_hidden'], TF.float32) # placeholder for the upper bounds for the spatial clique ws; it varies across joints ws_tf = TF.constant(ws, TF.float32) # placeholder for the spatial clique ws; it varies across joints ws_max_tf = TF.constant(ws_max, TF.float32) # placeholder for the upper bounds for the spatial clique ws; it varies across joints - vector_field_tf = TF.placeholder(TF.float32, shape=[None, None, None]) # placeholder for the vector fields + vector_field_tf = TF.compat.v1.placeholder(TF.float32, shape=[None, None, None]) # placeholder for the vector fields # Build the network - pn = PoseNet(dgp_cfg) + pn = PoseResnet(dgp_cfg) net, end_points = pn.extract_features(inputs) scope = "pose" reuse = None heads = {} # two convnets, one is the prediction network, the other is the local refinement network. - with TF.variable_scope(scope, reuse=reuse): + with TF.compat.v1.variable_scope(scope, reuse=reuse): heads["part_pred"] = prediction_layer(dgp_cfg, net, "part_pred", nj) heads["locref"] = prediction_layer(dgp_cfg, net, "locref_pred", nj * 2) # Read the 2D targets from pred pred = heads['part_pred'] nx_out, ny_out = tf.shape(pred)[1], tf.shape(pred)[2] - targets_pred, confidencemap_softmax = argmax_2d_from_cm(pred, nj, dgp_cfg.gamma, dgp_cfg.gauss_len) + targets_pred, confidencemap_softmax = argmax_2d_from_cm(pred, nj, dgp_cfg['gamma'], dgp_cfg['gauss_len']) targets_pred_marker = TF.reshape(targets_pred, [-1, 2]) # 2d locations for all markers targets_pred_hidden_marker = TF.gather(targets_pred_marker, hidden_marker_pl) # 2d locations for hidden markers, predicted targets from the network @@ -964,12 +971,12 @@ def dgp_loss(data_batcher, dgp_cfg): target_expand = TF.expand_dims(TF.expand_dims(targets_all_marker, 2), 3) # (nt*nj) x 2 x 1 x 1 # 2d grid of the output - alpha_tf = TF.placeholder(tf.float32, shape=[2, None, None], name="2dgrid") + alpha_tf = TF.compat.v1.placeholder(tf.float32, shape=[2, None, None], name="2dgrid") alpha_expand = TF.expand_dims(alpha_tf, 0) # 1 x 2 x nx_out x ny_out # normalize the Gaussian bump for the target so that the peak is 1, nt * nx_out * ny_out * nj targets_gauss = TF.exp(-TF.reduce_sum(TF.square(alpha_expand - target_expand), axis=1) / - (2 * (dgp_cfg.lengthscale ** 2))) + (2 * (dgp_cfg['lengthscale'] ** 2))) gauss_max = TF.reduce_max(TF.reduce_max(targets_gauss, [1]), [1]) + TF.constant(1e-5, TF.float32) gauss_max = TF.expand_dims(TF.expand_dims(gauss_max, [1]), [2]) targets_gauss = targets_gauss / gauss_max @@ -991,7 +998,7 @@ def dgp_loss(data_batcher, dgp_cfg): pred_v = TF.gather(pred, visible_marker_pl) # output pred for visible markers pred_h = TF.gather(pred, hidden_marker_pl) # output pred for hidden markers - if dgp_cfg.gm2 == 1: + if dgp_cfg['gm2'] == 1: # scale crossentropy loss terms by confidence pred_h_sigmoid = tf.sigmoid(pred_h) pgm_h1 = tf.reduce_max(tf.reduce_max(pred_h_sigmoid, [1]), [1]) # + EPSILON_tf @@ -1000,9 +1007,9 @@ def dgp_loss(data_batcher, dgp_cfg): targets_gauss_h = targets_gauss_h * pgm_h2 # scale network pred_h_scaled = pred_h_sigmoid * pgm_h2 - pred_h_scaled1 = -tf.log(1 - pred_h_scaled + 1e-20) + tf.log( + pred_h_scaled1 = -tf.compat.v1.log(1 - pred_h_scaled + 1e-20) + tf.compat.v1.log( pred_h_scaled + 1e-20) - elif dgp_cfg.gm2 == 2: + elif dgp_cfg['gm2'] == 2: # scale crossentropy loss input by confidence pred_h_sigmoid = tf.sigmoid(pred_h) pgm_h1 = tf.reduce_max(tf.reduce_max(pred_h_sigmoid, [1]), [1]) # + EPSILON_tf @@ -1013,24 +1020,24 @@ def dgp_loss(data_batcher, dgp_cfg): targets_gauss_h = targets_gauss_h # *pgm_h2 # scaled the network output pred_h_scaled = pred_h_sigmoid * pgm_h2 - pred_h_scaled1 = -tf.log(1 - pred_h_scaled + 1e-20) + tf.log( + pred_h_scaled1 = -tf.compat.v1.log(1 - pred_h_scaled + 1e-20) + tf.compat.v1.log( pred_h_scaled + 1e-20) - elif dgp_cfg.gm2 == 0: + elif dgp_cfg['gm2'] == 0: pass else: raise Exception('Not implemented') # %% loss = {} - loss["visible_loss_pred"] = TF.losses.sigmoid_cross_entropy(targets_gauss_v, pred_v, 1.0) - if dgp_cfg.gm3 == 3: - loss["hidden_loss_pred"] = TF.losses.sigmoid_cross_entropy(targets_gauss_h, pred_h_scaled1, + loss["visible_loss_pred"] = TF.compat.v1.losses.sigmoid_cross_entropy(targets_gauss_v, pred_v, 1.0) + if dgp_cfg['gm3'] == 3: + loss["hidden_loss_pred"] = TF.compat.v1.losses.sigmoid_cross_entropy(targets_gauss_h, pred_h_scaled1, weights=(1 - pgm_h2)) * \ n_visible_frames_total / n_hidden_frames_total * \ n_hidden_frames_batch / n_visible_frames_batch * wn_hidden_tf / wn_visible_tf - elif dgp_cfg.gm3 == 0: - loss["hidden_loss_pred"] = TF.losses.sigmoid_cross_entropy(targets_gauss_h, pred_h, 1.0) * \ + elif dgp_cfg['gm3'] == 0: + loss["hidden_loss_pred"] = TF.compat.v1.losses.sigmoid_cross_entropy(targets_gauss_h, pred_h, 1.0) * \ n_visible_frames_total / n_hidden_frames_total * \ n_hidden_frames_batch / n_visible_frames_batch * wn_hidden_tf / wn_visible_tf else: @@ -1049,9 +1056,9 @@ def dgp_loss(data_batcher, dgp_cfg): locref_map_v = TF.gather(locref_map_reshape, visible_marker_pl) locref_mask_reshape = TF.reshape(TF.transpose(locref_mask, [0, 3, 1, 2]), [-1, 2, nx_out, ny_out]) locref_mask_v = TF.gather(locref_mask_reshape, visible_marker_pl) - - loss_func = losses.huber_loss if dgp_cfg.locref_huber_loss else TF.losses.mean_squared_error - loss['visible_loss_locref'] = dgp_cfg.locref_loss_weight * loss_func(locref_map_v, locref_pred_v, locref_mask_v) + + loss_func = tf.compat.v1.losses.huber_loss if dgp_cfg['locref_huber_loss'] else TF.compat.v1.losses.mean_squared_error + loss['visible_loss_locref'] = dgp_cfg['locref_loss_weight'] * loss_func(locref_map_v, locref_pred_v, locref_mask_v) total_loss = total_loss + loss['visible_loss_locref'] # ------------------------------------------------------------------------------------ @@ -1063,7 +1070,7 @@ def dgp_loss(data_batcher, dgp_cfg): if nl > 0: S = TF.constant(S0, dtype=TF.float32) targets_all_marker_spatial = TF.reshape(TF.transpose(targets_all_marker_3c, [1, 2, 0]), - [nj, -1]) * dgp_cfg.stride + 0.5 * dgp_cfg.stride + [nj, -1]) * dgp_cfg['stride'] + 0.5 * dgp_cfg['stride'] dist_targets = TF.sqrt( TF.reduce_sum( TF.square(TF.reshape(TF.matmul(S, targets_all_marker_spatial), [nl, 2, -1])), [1])) @@ -1076,7 +1083,7 @@ def dgp_loss(data_batcher, dgp_cfg): total_loss += loss['ws_loss'] # Temporal clique - if dgp_cfg.wt > 0: + if dgp_cfg['wt'] > 0: targets_all_marker_temporal = targets_all_marker_3c * dgp_cfg.stride + 0.5 * dgp_cfg.stride targets_all_marker_temporal0 = targets_all_marker_temporal[:-1, :, :] targets_all_marker_temporal1 = targets_all_marker_temporal[1:, :, :] diff --git a/src/deepgraphpose/models/fitdgp_util.py b/src/deepgraphpose/models/fitdgp_util.py index 07ba235..8421754 100644 --- a/src/deepgraphpose/models/fitdgp_util.py +++ b/src/deepgraphpose/models/fitdgp_util.py @@ -1,7 +1,7 @@ import tensorflow as tf from pathlib import Path import numpy as np -import tensorflow.contrib.slim as slim +import tf_slim as slim from deeplabcut.utils import auxiliaryfunctions import imgaug.augmenters as iaa from imgaug.augmentables import Keypoint, KeypointsOnImage @@ -52,8 +52,8 @@ def dgp_prediction_layer(weight_dlc, activation_fn=None, normalizer_fn=None, weights_regularizer=slim.l2_regularizer( - dlc_cfg.weight_decay)): - with TF.variable_scope(name, reuse=tf.AUTO_REUSE): + dlc_cfg['weight_decay'])): + with TF.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): if init_flag: pred = slim.conv2d_transpose( inputs, @@ -160,10 +160,10 @@ def gen_batch(visible_frame_total, hidden_frame_total, all_frame_total, dgp_cfg, """ - batch_size = dgp_cfg.batch_size + batch_size = dgp_cfg['batch_size'] n_frames_total = np.sum([len(v) for v in all_frame_total]) n_datasets = len(all_frame_total) - nepoch = np.min([int(n_frames_total * dgp_cfg.n_times_all_frames / + nepoch = np.min([int(n_frames_total * dgp_cfg['n_times_all_frames'] / batch_size), maxiters]) print('nepoch: ', nepoch) @@ -175,7 +175,7 @@ def gen_batch(visible_frame_total, hidden_frame_total, all_frame_total, dgp_cfg, index_vh_i = list(all_frame_total[i]) + list(hidden_frame_total[i]) index_all_i = np.unique(list(index_v_i) + list(index_vh_i)) - batch_size = dgp_cfg.batch_size + batch_size = dgp_cfg['batch_size'] batchsize_i = max([1, int(nepoch / n_frames_total * len(index_all_i))]) if len(index_all_i) < batch_size: @@ -279,7 +279,7 @@ def rank(tensor): # Make Gaussian kernel following SciPy logic def make_gaussian_2d_kernel(sigma, truncate=1.0, dtype=TF.float32): - radius = TF.to_int32(sigma * truncate) + radius = TF.compat.v1.to_int32(sigma * truncate) x = TF.cast(TF.range(-radius, radius + 1), dtype=dtype) k = TF.exp(-0.5 * TF.square(x / sigma)) k = k / TF.reduce_sum(k) @@ -396,7 +396,7 @@ def argmax_2d_from_cm(tensor, nj, gamma=1, gauss_len=2, th=None): # Multiply (with broadcasting) and reduce over image dimensions to get the result # of shape [N, C, 2] - spatial_soft_argmax = TF.reduce_sum(softmax_tensor * image_coords, reduction_indices=[1, 2]) + spatial_soft_argmax = TF.compat.v1.reduce_sum(softmax_tensor * image_coords, reduction_indices=[1, 2]) # stack and return 2D coordinates return spatial_soft_argmax, softmax_tensor0 @@ -440,10 +440,10 @@ def data_aug(all_data_batch, visible_frame_within_batch, joint_loc, pipeline, dg visible_data = all_data_batch[visible_frame_within_batch, :, :, :].astype(np.uint8) # array to list - joint_loc_list = array2list(np.flip(joint_loc, 2) * dgp_cfg.stride + dgp_cfg.stride / 2, direct=1) + joint_loc_list = array2list(np.flip(joint_loc, 2) * dgp_cfg['stride'] + dgp_cfg['stride'] / 2, direct=1) batch_images, batch_joints = pipeline(images=visible_data, keypoints=joint_loc_list) # list to array - joint_loc_aug = np.flip(array2list(batch_joints, direct=-1) / dgp_cfg.stride - 0.5, 2) + joint_loc_aug = np.flip(array2list(batch_joints, direct=-1) / dgp_cfg['stride'] - 0.5, 2) all_data_batch_aug = np.copy(all_data_batch) all_data_batch_aug[visible_frame_within_batch, :, :, :] = batch_images diff --git a/src/deepgraphpose/preprocess/get_morig_labeled_data.py b/src/deepgraphpose/preprocess/get_morig_labeled_data.py index 65f5b95..ae2e67c 100644 --- a/src/deepgraphpose/preprocess/get_morig_labeled_data.py +++ b/src/deepgraphpose/preprocess/get_morig_labeled_data.py @@ -111,7 +111,7 @@ def create_labels(task, date, overwrite_flag=False, check_labels=False, verbose= frames2pick = np.sort( np.random.choice(frames_index_keep, cfg["numframes2pick"], replace=False) ) - local_extract_frames(path_config_file, frames2pick) + local_extract_frames(path_config_file, frames2pick, crop=True) frames = os.listdir(frames_dir) #%% diff --git a/src/deepgraphpose/utils_data.py b/src/deepgraphpose/utils_data.py index 8703106..86a16eb 100644 --- a/src/deepgraphpose/utils_data.py +++ b/src/deepgraphpose/utils_data.py @@ -95,7 +95,7 @@ def local_extract_frames_md( config_path, frames2pick, video, - crop=False, + crop=True, opencv=True, full_path=False, ): diff --git a/src/deepgraphpose/utils_model.py b/src/deepgraphpose/utils_model.py index a9f9b45..0310582 100644 --- a/src/deepgraphpose/utils_model.py +++ b/src/deepgraphpose/utils_model.py @@ -105,8 +105,8 @@ def get_train_config(cfg, shuffle=0): "It seems the model for shuffle %s and trainFraction %s does not exist." % (shuffle, TrainingFraction)) # from get_model_config - dlc_cfg.video_path = cfg['video_path'] - dlc_cfg.project_path = cfg['project_path'] + dlc_cfg['video_path'] = cfg['video_path'] + dlc_cfg['project_path'] = cfg['project_path'] return dlc_cfg