11from typing import Union
22import torch
33from .roma_models import roma_model , tiny_roma_v1_model
4- from loguru import logger
54
65
76weight_urls = {
@@ -37,20 +36,9 @@ def roma_outdoor(
3736 upsample_res : Union [int , tuple [int , int ]] = 864 ,
3837 amp_dtype : torch .dtype = torch .float16 ,
3938 symmetric = True ,
40- use_custom_corr = False ,
39+ use_custom_corr = True ,
4140 upsample_preds = True ,
4241):
43- if isinstance (coarse_res , int ):
44- coarse_res = (coarse_res , coarse_res )
45- if isinstance (upsample_res , int ):
46- upsample_res = (upsample_res , upsample_res )
47-
48- if str (device ) == "cpu" :
49- amp_dtype = torch .float32
50-
51- assert coarse_res [0 ] % 14 == 0 , "Needs to be multiple of 14 for backbone"
52- assert coarse_res [1 ] % 14 == 0 , "Needs to be multiple of 14 for backbone"
53-
5442 if weights is None :
5543 weights = torch .hub .load_state_dict_from_url (
5644 weight_urls ["romatch" ]["outdoor" ], map_location = device
@@ -68,10 +56,7 @@ def roma_outdoor(
6856 amp_dtype = amp_dtype ,
6957 symmetric = symmetric ,
7058 use_custom_corr = use_custom_corr ,
71- )
72- model .upsample_res = upsample_res
73- logger .info (
74- f"Using coarse resolution { coarse_res } , and upsample res { model .upsample_res } "
59+ upsample_res = upsample_res ,
7560 )
7661 return model
7762
@@ -83,15 +68,10 @@ def roma_indoor(
8368 coarse_res : Union [int , tuple [int , int ]] = 560 ,
8469 upsample_res : Union [int , tuple [int , int ]] = 864 ,
8570 amp_dtype : torch .dtype = torch .float16 ,
71+ symmetric = True ,
72+ use_custom_corr = True ,
73+ upsample_preds = True ,
8674):
87- if isinstance (coarse_res , int ):
88- coarse_res = (coarse_res , coarse_res )
89- if isinstance (upsample_res , int ):
90- upsample_res = (upsample_res , upsample_res )
91-
92- assert coarse_res [0 ] % 14 == 0 , "Needs to be multiple of 14 for backbone"
93- assert coarse_res [1 ] % 14 == 0 , "Needs to be multiple of 14 for backbone"
94-
9575 if weights is None :
9676 weights = torch .hub .load_state_dict_from_url (
9777 weight_urls ["romatch" ]["indoor" ], map_location = device
@@ -102,14 +82,13 @@ def roma_indoor(
10282 )
10383 model = roma_model (
10484 resolution = coarse_res ,
105- upsample_preds = True ,
85+ upsample_preds = upsample_preds ,
10686 weights = weights ,
10787 dinov2_weights = dinov2_weights ,
10888 device = device ,
10989 amp_dtype = amp_dtype ,
90+ symmetric = symmetric ,
91+ use_custom_corr = use_custom_corr ,
92+ upsample_res = upsample_res ,
11093 )
111- model .upsample_res = upsample_res
112- logger .info (
113- f"Using coarse resolution { coarse_res } , and upsample res { model .upsample_res } "
114- )
115- return model
94+ return model
0 commit comments