diff --git a/romatch/models/matcher.py b/romatch/models/matcher.py index 8108a92..1fac0aa 100644 --- a/romatch/models/matcher.py +++ b/romatch/models/matcher.py @@ -1,18 +1,19 @@ -import os import math +import os +import warnings +from warnings import warn + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -import warnings -from warnings import warn from PIL import Image - from romatch.utils import get_tuple_transform_ops +from romatch.utils.kde import kde from romatch.utils.local_correlation import local_correlation from romatch.utils.utils import cls_to_flow_refine, get_autocast_params -from romatch.utils.kde import kde + class ConvRefiner(nn.Module): def __init__( @@ -654,7 +655,12 @@ def match( test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True ) - im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB'))) + if isinstance(im_A_path, (str, os.PathLike)): + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + else: + im_A, im_B = im_A_path, im_B_path + im_A, im_B = test_transform((im_A, im_B)) + im_A, im_B = im_A[None].to(device), im_B[None].to(device) scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} diff --git a/romatch/models/model_zoo/__init__.py b/romatch/models/model_zoo/__init__.py index d0470ca..c22026f 100644 --- a/romatch/models/model_zoo/__init__.py +++ b/romatch/models/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Tuple import torch from .roma_models import roma_model, tiny_roma_v1_model @@ -27,7 +27,7 @@ def tiny_roma_v1_outdoor(device, weights = None, xfeat = None): return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device) -def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): +def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,Tuple[int,int]] = 560, upsample_res: Union[int,Tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): if isinstance(coarse_res, int): coarse_res = (coarse_res, coarse_res) if isinstance(upsample_res, int): @@ -51,7 +51,7 @@ def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[in print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}") return model -def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): +def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,Tuple[int,int]] = 560, upsample_res: Union[int,Tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): if isinstance(coarse_res, int): coarse_res = (coarse_res, coarse_res) if isinstance(upsample_res, int):