@@ -35,7 +35,7 @@ def _backend_device(backend: Union[_backend.TensorCPU, _backend.TensorGPU]) -> D
3535 if isinstance (backend , _backend .TensorCPU ):
3636 return Device ("cpu" )
3737 elif isinstance (backend , _backend .TensorGPU ):
38- return Device ("gpu" , backend .device ())
38+ return Device ("gpu" , backend .device_id ())
3939 else :
4040 raise ValueError (f"Unsupported backend type: { type (backend )} " )
4141
@@ -121,8 +121,22 @@ def __init__(
121121 self .assign (fn .cast (data , dtype , device = device ).evaluate ())
122122 elif isinstance (data , TensorSlice ):
123123 self ._slice = data
124- elif hasattr (data , "__dlpack__" ):
125- self ._backend = _backend .TensorCPU (data , layout )
124+ elif hasattr (data , "__dlpack_device__" ):
125+ dl_device_type , device_id = data .__dlpack_device__ ()
126+ if int (dl_device_type ) == 1 : # CPU
127+ self ._backend = _backend .TensorCPU (data .__dlpack__ (), layout )
128+ elif int (dl_device_type ) == 2 : # GPU
129+ # If the current context is on the same device, use the same stream.
130+ ctx = _EvalContext .get ()
131+ if ctx .device .device_id == device_id :
132+ stream = ctx .cuda_stream
133+ args = {"stream" : stream .handle }
134+ else :
135+ # TODO(michalz): Come up with better stream semantics
136+ args = {}
137+ self ._backend = _backend .TensorGPU (data .__dlpack__ (** args ), layout )
138+ else :
139+ raise ValueError (f"Unsupported device type: { dl_device_type } " )
126140 self ._wraps_external_data = True
127141 elif hasattr (data , "__array__" ):
128142 self ._backend = _backend .TensorCPU (data , layout )
@@ -159,10 +173,14 @@ def __init__(
159173 copied = True
160174 self ._wraps_external_data = False
161175
162- if device is not None :
163- device = self ._device = device if isinstance (device , Device ) else Device (device )
176+ if self ._backend is not None :
177+ self ._device = _backend_device (self ._backend )
178+ if device is None :
179+ device = self ._device
164180 else :
165- device = self ._device = Device ("cpu" )
181+ if device is None :
182+ device = Device ("cpu" )
183+ self ._device = device
166184
167185 if self ._backend is not None :
168186 self ._shape = self ._backend .shape ()
@@ -214,6 +232,8 @@ def to_device(self, device: Device, force_copy: bool = False) -> "Tensor":
214232 return fn .copy (self , device = device .device_type )
215233
216234 def assign (self , other : "Tensor" ):
235+ if other is self :
236+ return
217237 self ._device = other ._device
218238 self ._shape = other ._shape
219239 self ._dtype = other ._dtype
0 commit comments