diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ed43ff8..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "XMem_utilities"] - path = XMem_utilities - url = git@github.com:max810/Xmem_utility_scripts.git diff --git a/create_mattes_from_colored_mattes.py b/create_mattes_from_colored_mattes.py new file mode 100644 index 0000000..fdb8f8a --- /dev/null +++ b/create_mattes_from_colored_mattes.py @@ -0,0 +1,247 @@ +import cv2 +import numpy as np +import re +import os +import sys +import csv +import ast +import argparse + + +BGR_MATTE_COLOR = (0, 0, 128) + + +def get_files_in_path(colormatte_path): + """Return the list of files in the provided path.""" + file_names = os.listdir(colormatte_path) + + return file_names + + +def get_frame_number(filename): + """Get the frame number from the path to the matte.""" + + # regex may be more concise but less readable. + # Find the position of the last dot + last_dot_index = filename.rfind('.') + + # Find the position of the second-to-last dot + second_last_dot_index = filename.rfind('.', 0, last_dot_index) + + # Extract the substring between the last two dots + if last_dot_index != -1 and second_last_dot_index != -1: + frame = filename[second_last_dot_index + 1 : last_dot_index] + return frame + else: + print("Could not extract frame number.") + + +def get_frame_range(folder_path): + """Get the min and max frame numbers for the files in the given path.""" + + frame_numbers = [] + + # looks for a sequence of digits (\d+) followed by ".png" and ensures that it is a whole word + # using word boundaries (\b) + pattern = re.compile(r'\b(\d+)\.png\b') + + for filename in os.listdir(folder_path): + match = pattern.search(filename) + if match: + frame_numbers.append(int(match.group(1))) + + return min(frame_numbers), max(frame_numbers) + + +def create_black_matte_filename(existing_file, frame_number): + """Create a filename (no extension) for a black matte.""" + prefix, frame, suffix = existing_file.split('.') + frame = frame_number + new_file = '.'.join([prefix, str(frame).zfill(8)]) + new_file = new_file.replace("colormatte", "matte") + return new_file + + +def get_image_width_height(image): + """Given an image, obtain it's width and height.""" + height, width, channels = image.shape + return width, height + + +def get_matte_path(colormatte_path): + """Get the path for b/w mattes. Input path must contain 'colormatte'.""" + if 'colormatte' in colormatte_path: + matte_path = colormatte_path.replace("colormatte", "mattes") + return matte_path + else: + print("The provided colormattes must have a 'colormatte' folder in their path.") + sys.exit(1) + + +def get_denoise_path(colormatte_path): + """Get the path for denoise frames. Input path must contain 'colormatte'.""" + if 'colormatte' in colormatte_path: + denoise_path = colormatte_path.replace("colormatte", "denoise") + + # Always use v001 for the denoise path. The number of frames will be the same in all versions. + denoise_index = denoise_path.find('denoise\\') + prefix = denoise_path[:denoise_index + len('denoise\\')] + + # Construct the new path with 'v001' + denoise_path = os.path.join(prefix, 'v001') + return denoise_path + else: + print("The provided colormattes must have a 'colormatte' folder in their path.") + sys.exit(1) + + +def get_matte_filename(image_path): + """Get the matte filename from the input path""" + filename = os.path.basename(image_path) + + prefix, frame, suffix = filename.split('.') + filename_without_extension = '.'.join([prefix, str(frame).zfill(8)]).replace("colormatte", "matte") + + return filename_without_extension + + +def get_matte_colors(image): + """Get the list of colors that were used for mattes in the given image.""" + flattened_image = image.reshape((-1, 3)) + + unique_colors = np.unique(flattened_image, axis=0) + unique_colors_list = [tuple(color) for color in unique_colors] + black = (0, 0, 0) + unique_colors_list.remove(black) + return unique_colors_list + + +def get_element_bw_matte_from_colormatte(image, target_color): + """Create a black and white matte from the area matted by the target color.""" + # Convert the target color to a NumPy array + target_color_np = np.array(target_color, dtype=np.uint8) + + # Create a b/w matte for the exact color using an equality check + matte = np.all(image == target_color_np, axis=-1).astype(np.uint8) * 255 + + return matte + + +def create_black_matte(width, height): + """Create a solid black matte. Used for missing frames or missing element in a frame.""" + return np.zeros((height, width, 3), dtype=np.uint8) + + +def save_bw_matte(matte, element_matte_path, matte_filename): + """Save the black and white matte""" + os.makedirs(element_matte_path, exist_ok=True) + + try: + cv2.imwrite(rf'{element_matte_path}\{matte_filename}.png', matte) + print(rf'Saved image {element_matte_path}\{matte_filename}.png') + except Exception as e: + print(f"Could not save file: {e}") + + +def create_frame_dict(files): + """Create a lookup by frame number for a set of files.""" + frame_dict = {} + for file in files: + frame_number = get_frame_number(file) + frame_dict[int(frame_number)] = file + return frame_dict + + +parser = argparse.ArgumentParser(description='Create black and white matte for element in color matte.') + +# Required parameters +parser.add_argument('colormatte_path', + help='Path to the colormatte files from which to extract black and white mattes. The last folder in the path must be the element name.') +# parser.add_argument('element', type=str, help='Element that needs a black and white matte') + +# Optional parameters +parser.add_argument('-sf', '--start_frame', type=int, help='Start processing at this frame') +parser.add_argument('-ef', '--end_frame', type=int, help='End processing at this frame') +parser.add_argument('-incr', '--increment', type=int, default=1, help='Frame increment') + +args = parser.parse_args() + +element = os.path.basename(args.colormatte_path) + + +# Get the default frame range from the denoise path +denoise_path = get_denoise_path(args.colormatte_path) + +if args.start_frame is None or args.end_frame is None: + denoise_start_frame, denoise_end_frame = get_frame_range(denoise_path) + print('denoise_start_frame', denoise_start_frame) + print('denoise_end_frame', denoise_end_frame) + +if args.start_frame is None: + matte_start_frame = denoise_start_frame +else: + matte_start_frame = args.start_frame + + +if args.end_frame is None: + matte_end_frame = denoise_end_frame +else: + matte_end_frame = args.end_frame + + +print('colormatte_path', args.colormatte_path) +print('element', element) +print('start_frame', matte_start_frame) +print('end_frame', matte_end_frame) +print('increment', args.increment) + +matte_path = get_matte_path(args.colormatte_path) + +colormatte_files = get_files_in_path(args.colormatte_path) + +image_path = os.path.join(args.colormatte_path, colormatte_files[0]) +image = cv2.imread(image_path) + +IMAGE_WIDTH, IMAGE_HEIGHT = get_image_width_height(image) +print('IMAGE_WIDTH', IMAGE_WIDTH) +print('IMAGE_HEIGHT', IMAGE_HEIGHT) + +colormatte_lookup = create_frame_dict(colormatte_files) + +for frame in range(matte_start_frame, matte_end_frame + 1, args.increment): + # if colormatte_lookup[frame]: + if frame in colormatte_lookup: + # image_path = os.path.join(args.colormatte_path, file) + image_path = os.path.join(args.colormatte_path, colormatte_lookup[frame]) + matte_filename = get_matte_filename(image_path) + + # Load the colormatte image + image = cv2.imread(image_path) + + matte_colors = get_matte_colors(image) + + # element is part of matte_path + element_matte_path = rf'{matte_path}' + # if that color is in the colormatte for that frame, create a b/w matte matte for the element + if BGR_MATTE_COLOR in matte_colors: + matte = get_element_bw_matte_from_colormatte(image, BGR_MATTE_COLOR) + else: + matte = create_black_matte(IMAGE_WIDTH, IMAGE_HEIGHT) + + save_bw_matte(matte, element_matte_path, matte_filename) + + else: + # frame is missing - create black matte for element + + # create an output file path for that frame, based on the input file paths + black_matte_filename = create_black_matte_filename(colormatte_files[0], frame) + # create a black matte using the image width and height of the input images + black_matte = create_black_matte(IMAGE_WIDTH, IMAGE_HEIGHT) + + # element is part of matte_path + element_matte_path = rf'{matte_path}' + + save_bw_matte(black_matte, element_matte_path, black_matte_filename) + +print(f"Generation of mattes for frames {matte_start_frame} through {matte_end_frame} with increment {args.increment}" + f" is complete.") diff --git a/inference/interact/fbrs/inference/predictors/brs_functors.py b/inference/interact/fbrs/inference/predictors/brs_functors.py index 92a5d99..222b61c 100644 --- a/inference/interact/fbrs/inference/predictors/brs_functors.py +++ b/inference/interact/fbrs/inference/predictors/brs_functors.py @@ -72,7 +72,7 @@ def __call__(self, x): self._last_mask = current_mask loss.backward() - f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.cfloat) return [f_val, f_grad] diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 6d17ac4..5e80c97 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -109,12 +109,29 @@ def __init__(self, net: XMem, self.spacebar = QShortcut(QKeySequence(Qt.Key_Space), self) self.spacebar.activated.connect(self.pause_propagation) + + # Have two text boxes, the first can be updated to choose the frame + self.current_frame = QTextEdit() + self.current_frame.setReadOnly(False) + self.current_frame.setMaximumHeight(28) + self.current_frame.setFixedWidth(60) + self.current_frame.setText('{: 4d}'.format(0)) + self.current_frame.installEventFilter(self) + # self.current_frame.returnPressed.connect(self.show_chosen_frame) + + self.last_frame = QTextEdit() + self.last_frame.setReadOnly(True) + self.last_frame.setMaximumHeight(28) + self.last_frame.setFixedWidth(60) + self.last_frame.setText('{: 4d}'.format(self.num_frames-1)) + + # LCD - self.lcd = QTextEdit() - self.lcd.setReadOnly(True) - self.lcd.setMaximumHeight(28) - self.lcd.setFixedWidth(120) - self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) + # self.lcd = QTextEdit() + # self.lcd.setReadOnly(True) + # self.lcd.setMaximumHeight(28) + # self.lcd.setFixedWidth(120) + # self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) # timeline slider self.tl_slider = QSlider(Qt.Horizontal) @@ -247,7 +264,9 @@ def __init__(self, net: XMem, # navigator navi = QHBoxLayout() - navi.addWidget(self.lcd) + # navi.addWidget(self.lcd) + navi.addWidget(self.current_frame) + navi.addWidget(self.last_frame) navi.addWidget(self.play_button) interact_subbox = QVBoxLayout() @@ -580,7 +599,10 @@ def show_current_frame(self, fast=False): self.update_interact_vis() self.update_minimap() - self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) + + # self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) + self.current_frame.setText('{: 3d}'.format(self.cursur)) + self.last_frame.setText('{: 3d}'.format(self.num_frames-1)) self.tl_slider.setValue(self.cursur) if self.cursur in self.reference_ids: @@ -661,6 +683,24 @@ def save_current_mask(self): # save mask to hard disk self.res_man.save_mask(self.cursur, self.current_mask) + def eventFilter(self, obj, event): + if obj == self.current_frame and event.type() == event.KeyPress: + if event.key() == Qt.Key_Return or event.key() == Qt.Key_Enter: + self.show_chosen_frame() + return True # Event handled + + return super().eventFilter(obj, event) + + def show_chosen_frame(self): + self.console_push_text('Change to the selected frame.') + # frame_num = int(self.current_frame.toPlainText()) + frame_txt = self.current_frame.toPlainText() + self.console_push_text(frame_txt) + frame_int = int(frame_txt) + self.cursur = frame_int + self.load_current_image_mask() + self.show_current_frame() + def tl_slide(self): # if we are propagating, the on_run function will take care of everything # don't do duplicate work here @@ -904,7 +944,7 @@ def on_play_video(self): self.timer.stop() self.play_button.setText('Play Video') else: - self.timer.start(1000 / 30) + self.timer.start(33) self.play_button.setText('Stop Video') def on_reset_mask(self): @@ -932,7 +972,9 @@ def set_navi_enable(self, boolean): self.run_button.setEnabled(boolean) self.tl_slider.setEnabled(boolean) self.play_button.setEnabled(boolean) - self.lcd.setEnabled(boolean) + # self.lcd.setEnabled(boolean) + self.current_frame.setEnabled(boolean) + self.last_frame.setEnables(boolean) def hit_number_key(self, number): if number == self.current_object: diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 80680fc..93c3c75 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -113,6 +113,7 @@ def __init__(self, config): # read all frame names self.names = sorted(os.listdir(self.image_dir)) + self.img_extension = self.names[0][-4:] if self.names else None self.names = [f[:-4] for f in self.names] # remove extensions self.length = len(self.names) @@ -249,7 +250,7 @@ def _get_image_unbuffered(self, ti): # returns H*W*3 uint8 array assert 0 <= ti < self.length - image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg')) + image = Image.open(path.join(self.image_dir, self.names[ti]+self.img_extension)) image = np.array(image) return image diff --git a/interactive_demo.py b/interactive_demo.py index 7f4e82a..e8cfb9a 100644 --- a/interactive_demo.py +++ b/interactive_demo.py @@ -3,6 +3,7 @@ """ import os + # fix for Windows if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ: os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = '' @@ -50,14 +51,14 @@ # Long-memory options # Defaults. Some can be changed in the GUI. - parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) - parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) + parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=2) + parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=1) parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', - type=int, default=10000) - parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) + type=int, default=1000) + parser.add_argument('--num_prototypes', help='P in paper', type=int, default=32) parser.add_argument('--top_k', type=int, default=30) - parser.add_argument('--mem_every', type=int, default=10) + parser.add_argument('--mem_every', type=int, default=1) parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) parser.add_argument('--no_amp', help='Turn off AMP', action='store_true') parser.add_argument('--size', default=480, type=int,