@@ -213,7 +213,7 @@ def test_scriptability_lru(self) -> None:
213213 torch .jit .script (mcc_ec )
214214
215215 @unittest .skipIf (
216- torch .cuda .device_count () < 1 ,
216+ torch .cuda .device_count () < 2 ,
217217 "Not enough GPUs, this test requires at least one GPUs" ,
218218 )
219219 # pyre-ignore [56]
@@ -292,7 +292,7 @@ def test_zch_hash_train_to_inf_block_bucketize(
292292 )
293293
294294 @unittest .skipIf (
295- torch .cuda .device_count () < 1 ,
295+ torch .cuda .device_count () < 2 ,
296296 "Not enough GPUs, this test requires at least one GPUs" ,
297297 )
298298 # pyre-ignore [56]
@@ -404,13 +404,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
404404 )
405405
406406 @unittest .skipIf (
407- torch .cuda .device_count () < 1 ,
407+ torch .cuda .device_count () < 2 ,
408408 "Not enough GPUs, this test requires at least one GPUs" ,
409409 )
410410 # pyre-ignore [56]
411411 @given (hash_size = st .sampled_from ([0 , 80 ]))
412412 @settings (max_examples = 5 , deadline = None )
413- def test_zch_hash_train_rescales_four (self , hash_size : int ) -> None :
413+ def test_zch_hash_train_rescales_one (self , hash_size : int ) -> None :
414414 keep_original_indices = True
415415 kjt = KeyedJaggedTensor (
416416 keys = ["f" ],
@@ -446,23 +446,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
446446 ),
447447 )
448448
449- # start with world_size = 4
450- world_size = 4
449+ # start with world_size = 2
450+ world_size = 2
451451 block_sizes = torch .tensor (
452452 [(size + world_size - 1 ) // world_size for size in [hash_size ]],
453453 dtype = torch .int64 ,
454454 device = "cuda" ,
455455 )
456456
457- m1_1 = m0 .rebuild_with_output_id_range ((0 , 10 ))
458- m2_1 = m0 .rebuild_with_output_id_range ((10 , 20 ))
459- m3_1 = m0 .rebuild_with_output_id_range ((20 , 30 ))
460- m4_1 = m0 .rebuild_with_output_id_range ((30 , 40 ))
457+ m1_1 = m0 .rebuild_with_output_id_range ((0 , 20 ))
458+ m2_1 = m0 .rebuild_with_output_id_range ((20 , 40 ))
461459
462- # shard, now world size 2!
463- # start with world_size = 4
460+ # shard, now world size 1!
464461 if hash_size > 0 :
465- world_size = 2
462+ world_size = 1
466463 block_sizes = torch .tensor (
467464 [(size + world_size - 1 ) // world_size for size in [hash_size ]],
468465 dtype = torch .int64 ,
@@ -476,7 +473,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
476473 keep_original_indices = keep_original_indices ,
477474 output_permute = True ,
478475 )
479- in1_2 , in2_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )
476+ in1_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )[ 0 ]
480477 else :
481478 bucketized_kjt , permute = bucketize_kjt_before_all2all (
482479 kjt ,
@@ -492,14 +489,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
492489 values = torch .cat ([kjts [0 ].values (), kjts [1 ].values ()], dim = 0 ),
493490 lengths = torch .cat ([kjts [0 ].lengths (), kjts [1 ].lengths ()], dim = 0 ),
494491 )
495- in2_2 = KeyedJaggedTensor (
496- keys = kjts [2 ].keys (),
497- values = torch .cat ([kjts [2 ].values (), kjts [3 ].values ()], dim = 0 ),
498- lengths = torch .cat ([kjts [2 ].lengths (), kjts [3 ].lengths ()], dim = 0 ),
499- )
500492
501- m1_2 = m0 .rebuild_with_output_id_range ((0 , 20 ))
502- m2_2 = m0 .rebuild_with_output_id_range ((20 , 40 ))
493+ m1_2 = m0 .rebuild_with_output_id_range ((0 , 40 ))
503494 m1_zch_identities = torch .cat (
504495 [
505496 m1_1 .state_dict ()["_hash_zch_identities" ],
@@ -516,53 +507,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
516507 state_dict ["_hash_zch_identities" ] = m1_zch_identities
517508 state_dict ["_hash_zch_metadata" ] = m1_zch_metadata
518509 m1_2 .load_state_dict (state_dict )
519-
520- m2_zch_identities = torch .cat (
521- [
522- m3_1 .state_dict ()["_hash_zch_identities" ],
523- m4_1 .state_dict ()["_hash_zch_identities" ],
524- ]
525- )
526- m2_zch_metadata = torch .cat (
527- [
528- m3_1 .state_dict ()["_hash_zch_metadata" ],
529- m4_1 .state_dict ()["_hash_zch_metadata" ],
530- ]
531- )
532- state_dict = m2_2 .state_dict ()
533- state_dict ["_hash_zch_identities" ] = m2_zch_identities
534- state_dict ["_hash_zch_metadata" ] = m2_zch_metadata
535- m2_2 .load_state_dict (state_dict )
536-
537510 _ = m1_2 (in1_2 .to_dict ())
538- _ = m2_2 (in2_2 .to_dict ())
539511
540512 m0 .reset_inference_mode () # just clears out training state
541513 full_zch_identities = torch .cat (
542514 [
543515 m1_2 .state_dict ()["_hash_zch_identities" ],
544- m2_2 .state_dict ()["_hash_zch_identities" ],
545516 ]
546517 )
547518 state_dict = m0 .state_dict ()
548519 state_dict ["_hash_zch_identities" ] = full_zch_identities
549520 m0 .load_state_dict (state_dict )
550521
551- # now set all models to eval, and run kjt
552522 m1_2 .eval ()
553- m2_2 .eval ()
554523 assert m0 .training is False
555524
556525 inf_input = kjt .to_dict ()
557- inf_output = m0 (inf_input )
558526
527+ inf_output = m0 (inf_input )
559528 o1_2 = m1_2 (in1_2 .to_dict ())
560- o2_2 = m2_2 (in2_2 .to_dict ())
561529 self .assertTrue (
562530 torch .allclose (
563531 inf_output ["f" ].values (),
564532 torch .index_select (
565- torch . cat ([ x [ "f" ].values () for x in [ o1_2 , o2_2 ]] ),
533+ o1_2 [ "f" ].values (),
566534 dim = 0 ,
567535 index = cast (torch .Tensor , permute ),
568536 ),
0 commit comments