@@ -210,6 +210,17 @@ def decode_jpeg(
210210 raise ValueError ("All elements of the input list must be tensors." )
211211 if not all (t .device .type == "cpu" for t in input ):
212212 raise ValueError ("Input list must contain tensors on CPU." )
213+ custom_privateuse1_name = torch ._C ._get_privateuse1_backend_name ()
214+ if device .type == custom_privateuse1_name or device .type == "privateuseone" :
215+ # When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone.
216+ # This operator needs to be pre-registered by the user through torch.library.define/impl.
217+ decoder = getattr (torch .ops .image , "decode_jpegs_privateuseone" , None )
218+ if decoder is None :
219+ raise RuntimeError (
220+ "decode_jpeg tensors on PrivateUse1 device require registering "
221+ "torch.ops.image.decode_jpegs_privateuseone."
222+ )
223+ return decoder (input , mode .value , apply_exif_orientation )
213224 if device .type == "cuda" :
214225 return torch .ops .image .decode_jpegs_cuda (input , mode .value , device )
215226 else :
@@ -218,6 +229,15 @@ def decode_jpeg(
218229 else : # input is tensor
219230 if input .device .type != "cpu" :
220231 raise ValueError ("Input tensor must be a CPU tensor" )
232+ if device .type == custom_privateuse1_name or device .type == "privateuseone" :
233+ custom_privateuse1_name = torch ._C ._get_privateuse1_backend_name ()
234+ decoder = getattr (torch .ops .image , "decode_jpegs_privateuseone" , None )
235+ if decoder is None :
236+ raise RuntimeError (
237+ "decode_jpeg tensor on PrivateUse1 device require registering "
238+ "torch.ops.image.decode_jpegs_privateuseone."
239+ )
240+ return decoder ([input ], mode .value , apply_exif_orientation )[0 ]
221241 if device .type == "cuda" :
222242 return torch .ops .image .decode_jpegs_cuda ([input ], mode .value , device )[0 ]
223243 else :
@@ -246,16 +266,36 @@ def encode_jpeg(
246266 _log_api_usage_once (encode_jpeg )
247267 if quality < 1 or quality > 100 :
248268 raise ValueError ("Image quality should be a positive number between 1 and 100" )
269+ custom_privateuse1_name = torch ._C ._get_privateuse1_backend_name ()
270+
249271 if isinstance (input , list ):
250272 if not input :
251273 raise ValueError ("encode_jpeg requires at least one input tensor when a list is passed" )
252- if input [0 ].device .type == "cuda" :
274+ device_type = input [0 ].device .type
275+ if device_type == custom_privateuse1_name or device_type == "privateuseone" :
276+ encoder = getattr (torch .ops .image , "encode_jpegs_privateuseone" , None )
277+ if encoder is None :
278+ raise RuntimeError (
279+ "encode_jpeg tensors on PrivateUse1 device require registering "
280+ "torch.ops.image.encode_jpegs_privateuseone."
281+ )
282+ return encoder (input , quality )
283+ if device_type == "cuda" :
253284 return torch .ops .image .encode_jpegs_cuda (input , quality )
254285 else :
255286 return [torch .ops .image .encode_jpeg (image , quality ) for image in input ]
256287 else : # single input tensor
257- if input .device .type == "cuda" :
288+ device_type = input .device .type
289+ if device_type == "cuda" :
258290 return torch .ops .image .encode_jpegs_cuda ([input ], quality )[0 ]
291+ elif device_type == custom_privateuse1_name or device_type == "privateuseone" :
292+ encoder = getattr (torch .ops .image , "encode_jpegs_privateuseone" , None )
293+ if encoder is None :
294+ raise RuntimeError (
295+ "encode_jpeg tensor on PrivateUse1 device require registering "
296+ "torch.ops.image.encode_jpegs_privateuseone."
297+ )
298+ return encoder ([input ], quality )[0 ]
259299 else :
260300 return torch .ops .image .encode_jpeg (input , quality )
261301
0 commit comments