5
5
6
6
import torch
7
7
import torch .distributed as dist
8
+ import torch .distributed ._symmetric_memory as symm_mem
8
9
from torch ._C ._autograd import DeviceType
9
10
from torch ._C ._distributed_c10d import _SymmetricMemory
10
11
from torch ._inductor .utils import fresh_inductor_cache , run_and_get_triton_code
@@ -81,9 +82,25 @@ def _init_process(self):
81
82
rank = self .rank ,
82
83
store = store ,
83
84
)
84
- enable_symm_mem_for_group (dist .group .WORLD .group_name )
85
85
torch .manual_seed (42 + self .rank )
86
86
87
+ @skipIfRocm
88
+ @skip_if_lt_x_gpu (2 )
89
+ def test_cuda_nvlink_connectivity_detection (self ) -> None :
90
+ from torch ._C ._distributed_c10d import _detect_dma_connectivity
91
+
92
+ connectivity = _detect_dma_connectivity (DeviceType .CUDA , "nvlink" )
93
+ self .assertEqual (connectivity .device_type , DeviceType .CUDA )
94
+ self .assertEqual (connectivity .connection_type , "nvlink" )
95
+ self .assertEqual (len (connectivity .matrix ), torch .cuda .device_count ())
96
+ for row in connectivity .matrix :
97
+ self .assertEqual (len (row ), torch .cuda .device_count ())
98
+
99
+ @skipIfRocm
100
+ def test_large_alloc (self ) -> None :
101
+ t = symm_mem .empty (2 * 1024 ** 3 , dtype = torch .uint8 , device = "cuda" )
102
+ self .assertEqual (t .numel () * t .element_size (), 2 * 1024 ** 3 )
103
+
87
104
def _get_test_alloc_args (self ):
88
105
shape = (64 , 64 )
89
106
stride = (64 , 1 )
@@ -92,64 +109,56 @@ def _get_test_alloc_args(self):
92
109
group_name = "0"
93
110
return (shape , stride , dtype , device , group_name )
94
111
95
- def _verify_symmetric_memory (self , symm_mem ):
96
- self .assertEqual (symm_mem .world_size , 2 )
112
+ def _verify_symmetric_memory (self , symm_mem_hdl ):
113
+ self .assertEqual (symm_mem_hdl .world_size , 2 )
97
114
98
- buf = symm_mem .get_buffer (0 , (symm_mem .buffer_size // 4 ,), torch .float32 )
115
+ buf = symm_mem_hdl .get_buffer (
116
+ 0 , (symm_mem_hdl .buffer_size // 4 ,), torch .float32
117
+ )
99
118
self .assertEqual (buf .storage_offset (), 0 )
100
- self .assertEqual (buf .untyped_storage ().size (), symm_mem .buffer_size )
119
+ self .assertEqual (buf .untyped_storage ().size (), symm_mem_hdl .buffer_size )
101
120
102
- if symm_mem .rank == 0 :
103
- symm_mem .wait_signal (src_rank = 1 )
121
+ if symm_mem_hdl .rank == 0 :
122
+ symm_mem_hdl .wait_signal (src_rank = 1 )
104
123
self .assertTrue (buf .eq (42 ).all ())
105
124
else :
106
125
buf .fill_ (42 )
107
- symm_mem .put_signal (dst_rank = 0 )
126
+ symm_mem_hdl .put_signal (dst_rank = 0 )
108
127
109
- symm_mem .barrier ()
128
+ symm_mem_hdl .barrier ()
110
129
111
- if symm_mem .rank == 0 :
112
- symm_mem .barrier ()
130
+ if symm_mem_hdl .rank == 0 :
131
+ symm_mem_hdl .barrier ()
113
132
self .assertTrue (buf .eq (43 ).all ())
114
133
else :
115
134
buf .fill_ (43 )
116
- symm_mem .barrier ()
135
+ symm_mem_hdl .barrier ()
117
136
118
- symm_mem .barrier ()
119
-
120
- @skipIfRocm
121
- @skip_if_lt_x_gpu (2 )
122
- def test_cuda_nvlink_connectivity_detection (self ) -> None :
123
- from torch ._C ._distributed_c10d import _detect_dma_connectivity
124
-
125
- connectivity = _detect_dma_connectivity (DeviceType .CUDA , "nvlink" )
126
- self .assertEqual (connectivity .device_type , DeviceType .CUDA )
127
- self .assertEqual (connectivity .connection_type , "nvlink" )
128
- self .assertEqual (len (connectivity .matrix ), torch .cuda .device_count ())
129
- for row in connectivity .matrix :
130
- self .assertEqual (len (row ), torch .cuda .device_count ())
137
+ symm_mem_hdl .barrier ()
131
138
132
139
@skipIfRocm
133
140
@skip_if_lt_x_gpu (2 )
134
141
def test_empty_strided_p2p (self ) -> None :
135
142
self ._init_process ()
143
+ enable_symm_mem_for_group (dist .group .WORLD .group_name )
136
144
137
145
alloc_args = self ._get_test_alloc_args ()
138
146
139
147
t = torch .empty ((64 , 64 ), device = self .device )
140
148
self .assertIsNone (_SymmetricMemory .rendezvous (t ))
141
149
142
150
t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
143
- symm_mem = _SymmetricMemory .rendezvous (t )
151
+ symm_mem_hdl = _SymmetricMemory .rendezvous (t )
144
152
145
153
del t
146
- self ._verify_symmetric_memory (symm_mem )
154
+ self ._verify_symmetric_memory (symm_mem_hdl )
147
155
dist .destroy_process_group ()
148
156
149
157
@skipIfRocm
150
158
@skip_if_lt_x_gpu (2 )
151
159
def test_empty_strided_p2p_persistent (self ) -> None :
152
160
self ._init_process ()
161
+ enable_symm_mem_for_group (dist .group .WORLD .group_name )
153
162
154
163
alloc_args = self ._get_test_alloc_args ()
155
164
@@ -168,51 +177,47 @@ def test_empty_strided_p2p_persistent(self) -> None:
168
177
t = _SymmetricMemory .empty_strided_p2p (* alloc_args , alloc_id = 42 )
169
178
self .assertEqual (t .data_ptr (), data_ptr )
170
179
171
- symm_mem = _SymmetricMemory .rendezvous (t )
172
- self ._verify_symmetric_memory (symm_mem )
180
+ symm_mem_hdl = _SymmetricMemory .rendezvous (t )
181
+ self ._verify_symmetric_memory (symm_mem_hdl )
173
182
dist .destroy_process_group ()
174
183
175
184
@skipIfRocm
176
185
@skip_if_lt_x_gpu (2 )
177
186
def test_get_signal_pad (self ) -> None :
178
187
self ._init_process ()
179
188
180
- t = _SymmetricMemory . empty_strided_p2p ( * self . _get_test_alloc_args () )
181
- symm_mem = _SymmetricMemory .rendezvous (t )
189
+ t = symm_mem . empty ( 1 , device = "cuda" )
190
+ symm_mem_hdl = symm_mem .rendezvous (t , group = dist . group . WORLD )
182
191
peer_rank = (self .rank + 1 ) % self .world_size
183
192
184
- signal_pad = symm_mem .get_signal_pad (self .rank )
185
- self .assertEqual (signal_pad .data_ptr (), symm_mem .signal_pad_ptrs [symm_mem .rank ])
193
+ signal_pad = symm_mem_hdl .get_signal_pad (self .rank )
194
+ self .assertEqual (
195
+ signal_pad .data_ptr (), symm_mem_hdl .signal_pad_ptrs [symm_mem_hdl .rank ]
196
+ )
186
197
187
- signal_pad = symm_mem .get_signal_pad (peer_rank )
198
+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank )
188
199
self .assertEqual (signal_pad .dtype , torch .uint32 )
189
- self .assertEqual (signal_pad .numel (), symm_mem .signal_pad_size // 4 )
200
+ self .assertEqual (signal_pad .numel (), symm_mem_hdl .signal_pad_size // 4 )
190
201
191
202
# Only specify sizes
192
- signal_pad = symm_mem .get_signal_pad (peer_rank , (8 , 8 ))
203
+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , (8 , 8 ))
193
204
self .assertEqual (signal_pad .dtype , torch .uint32 )
194
205
self .assertEqual (signal_pad .numel (), 64 )
195
206
196
207
# Only specify dtype
197
- signal_pad = symm_mem .get_signal_pad (peer_rank , dtype = torch .uint64 )
208
+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , dtype = torch .uint64 )
198
209
self .assertEqual (signal_pad .dtype , torch .uint64 )
199
- self .assertEqual (signal_pad .numel (), symm_mem .signal_pad_size // 8 )
210
+ self .assertEqual (signal_pad .numel (), symm_mem_hdl .signal_pad_size // 8 )
200
211
201
212
# Specify both sizes and dtype
202
- signal_pad = symm_mem .get_signal_pad (peer_rank , (8 , 8 ), dtype = torch .uint64 )
213
+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , (8 , 8 ), dtype = torch .uint64 )
203
214
self .assertEqual (signal_pad .dtype , torch .uint64 )
204
215
self .assertEqual (signal_pad .numel (), 64 )
205
216
206
217
# Sanity check that writes to buffer doesn't corrupt signal_pad
207
- t = _SymmetricMemory .empty_strided_p2p (
208
- (0 ,),
209
- (0 ,),
210
- torch .float32 ,
211
- self .device ,
212
- dist .group .WORLD .group_name ,
213
- )
214
- symm_mem = _SymmetricMemory .rendezvous (t )
215
- signal_pad = symm_mem .get_signal_pad (self .rank )
218
+ t = symm_mem .empty (0 , device = "cuda" )
219
+ symm_mem_hdl = symm_mem .rendezvous (t )
220
+ signal_pad = symm_mem_hdl .get_signal_pad (self .rank )
216
221
signal_pad .fill_ (42 )
217
222
t .fill_ (0 )
218
223
self .assertTrue (signal_pad .eq (42 ).all ())
@@ -224,14 +229,12 @@ def test_get_signal_pad(self) -> None:
224
229
def test_barrier_timeout (self ) -> None :
225
230
self ._init_process ()
226
231
227
- alloc_args = self ._get_test_alloc_args ()
228
-
229
- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
230
- symm_mem = _SymmetricMemory .rendezvous (t )
232
+ t = symm_mem .empty (1 , device = "cuda" )
233
+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
231
234
232
235
if self .rank == 0 :
233
236
with self .assertRaises (RuntimeError ):
234
- symm_mem .barrier (timeout_ms = 1000 )
237
+ symm_mem_hdl .barrier (timeout_ms = 1000 )
235
238
torch .cuda .synchronize ()
236
239
else :
237
240
torch .cuda .synchronize ()
@@ -247,17 +250,15 @@ def test_barrier_timeout(self) -> None:
247
250
def test_put_signal_timeout (self ) -> None :
248
251
self ._init_process ()
249
252
250
- alloc_args = self ._get_test_alloc_args ()
251
-
252
- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
253
- symm_mem = _SymmetricMemory .rendezvous (t )
253
+ t = symm_mem .empty (1 , device = "cuda" )
254
+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
254
255
255
256
if self .rank == 0 :
256
257
with self .assertRaises (RuntimeError ):
257
258
# First, put a signal into rank 1's signal pad. Since rank 1
258
259
# doesn't wait on this signal, the subsequent put will timeout.
259
- symm_mem .put_signal (dst_rank = 1 )
260
- symm_mem .put_signal (dst_rank = 1 , timeout_ms = 1000 )
260
+ symm_mem_hdl .put_signal (dst_rank = 1 )
261
+ symm_mem_hdl .put_signal (dst_rank = 1 , timeout_ms = 1000 )
261
262
torch .cuda .synchronize ()
262
263
else :
263
264
torch .cuda .synchronize ()
@@ -273,14 +274,12 @@ def test_put_signal_timeout(self) -> None:
273
274
def test_wait_signal_timeout (self ) -> None :
274
275
self ._init_process ()
275
276
276
- alloc_args = self ._get_test_alloc_args ()
277
-
278
- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
279
- symm_mem = _SymmetricMemory .rendezvous (t )
277
+ t = symm_mem .empty (1 , device = "cuda" )
278
+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
280
279
281
280
if self .rank == 0 :
282
281
with self .assertRaises (RuntimeError ):
283
- symm_mem .wait_signal (src_rank = 1 , timeout_ms = 1000 )
282
+ symm_mem_hdl .wait_signal (src_rank = 1 , timeout_ms = 1000 )
284
283
torch .cuda .synchronize ()
285
284
else :
286
285
torch .cuda .synchronize ()
@@ -685,7 +684,6 @@ def _init_process(self):
685
684
rank = self .rank ,
686
685
store = store ,
687
686
)
688
- enable_symm_mem_for_group (dist .group .WORLD .group_name )
689
687
torch .manual_seed (42 + self .rank )
690
688
691
689
@skipIfRocm
@@ -699,18 +697,10 @@ def test_subgroup(self) -> None:
699
697
700
698
world = dist .group .WORLD
701
699
subgroup = subgroup_0 if world .rank () < world .size () // 2 else subgroup_1
702
- enable_symm_mem_for_group (subgroup .group_name )
703
700
704
- t = _SymmetricMemory .empty_strided_p2p (
705
- size = (64 ,),
706
- stride = (1 ,),
707
- dtype = torch .float32 ,
708
- device = self .device ,
709
- )
710
- symm_mem_world = _SymmetricMemory .rendezvous (t , group_name = world .group_name )
711
- symm_mem_subgroup = _SymmetricMemory .rendezvous (
712
- t , group_name = subgroup .group_name
713
- )
701
+ t = symm_mem .empty (64 , device = "cuda" )
702
+ symm_mem_world = symm_mem .rendezvous (t , group = world )
703
+ symm_mem_subgroup = symm_mem .rendezvous (t , group = subgroup )
714
704
715
705
self .assertEqual (symm_mem_world .world_size , world .size ())
716
706
self .assertEqual (symm_mem_world .rank , world .rank ())
0 commit comments