File tree Expand file tree Collapse file tree 4 files changed +15
-0
lines changed Expand file tree Collapse file tree 4 files changed +15
-0
lines changed Original file line number Diff line number Diff line change 26
26
from torchao .quantization .quant_api import quantize_
27
27
from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
28
28
29
+ if common_utils .SEED is None :
30
+ common_utils .SEED = 1234
31
+
29
32
try :
30
33
import gemlite # noqa: F401
31
34
Original file line number Diff line number Diff line change 20
20
apply_activation_checkpointing ,
21
21
)
22
22
from torch .distributed .fsdp .wrap import ModuleWrapPolicy
23
+ from torch .testing ._internal import common_utils
23
24
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
24
25
from torch .testing ._internal .common_fsdp import FSDPTest
25
26
from torch .testing ._internal .common_utils import (
29
30
run_tests ,
30
31
)
31
32
33
+ if common_utils .SEED is None :
34
+ common_utils .SEED = 1234
35
+
32
36
import torchao
33
37
from packaging import version
34
38
from torchao .dtypes ._nf4tensor_api import nf4_weight_only
Original file line number Diff line number Diff line change 15
15
import torch
16
16
import torch .distributed as dist
17
17
import torch .nn .functional as F
18
+ import torch .testing ._internal .common_utils as common_utils
18
19
from torch import nn
19
20
from torch .distributed ._composable .fsdp import MixedPrecisionPolicy , fully_shard
20
21
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
40
41
)
41
42
from torchao .quantization .quant_api import quantize_
42
43
44
+ if common_utils .SEED is None :
45
+ common_utils .SEED = 1234
46
+
43
47
_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
44
48
45
49
Original file line number Diff line number Diff line change 16
16
OffloadPolicy ,
17
17
fully_shard ,
18
18
)
19
+ from torch .testing ._internal import common_utils
19
20
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
20
21
from torch .testing ._internal .common_fsdp import FSDPTest
21
22
from torch .testing ._internal .common_utils import (
25
26
run_tests ,
26
27
)
27
28
29
+ if common_utils .SEED is None :
30
+ common_utils .SEED = 1234
31
+
28
32
from packaging .version import Version
29
33
from torchao import optim
30
34
from torchao .optim .quant_utils import (
You can’t perform that action at this time.
0 commit comments