3636from torchao .float8 .float8_scaling_utils import (
3737 hp_tensor_to_float8_dynamic ,
3838)
39- from torchao .float8 .float8_tensor import GemmInputRole , LinearMMConfig , ScaledMMConfig
39+ from torchao .float8 .float8_training_tensor import (
40+ GemmInputRole ,
41+ LinearMMConfig ,
42+ ScaledMMConfig ,
43+ )
4044from torchao .testing .training .test_utils import get_test_float8_linear_config
4145
4246
@@ -238,7 +242,7 @@ def forward(self, x):
238242 "CUDA with capability 9.0 or greater not available" ,
239243 )
240244 def test_float8_with_graph_break_in_the_middle (self ):
241- """Test that having Float8Tensor object at the boundary of a subgraph"""
245+ """Test that having Float8TrainingTensor object at the boundary of a subgraph"""
242246 cnts = CompileCounterWithBackend ("inductor" )
243247 mod = self .MockLinear (graph_break = True ).cuda ()
244248 compiled_mod = copy .deepcopy (mod )
@@ -254,7 +258,7 @@ def test_float8_with_graph_break_in_the_middle(self):
254258 "CUDA with float8 support not available" ,
255259 )
256260 def test_float8_graph_input (self ):
257- """Test that having Float8Tensor object as a graph input"""
261+ """Test that having Float8TrainingTensor object as a graph input"""
258262
259263 def to_float (x ):
260264 return x .to_original_precision ()
@@ -278,7 +282,7 @@ def to_float(x):
278282 "CUDA with float8 support not available" ,
279283 )
280284 def test_float8_graph_output (self ):
281- """Test that having Float8Tensor object as a graph output works"""
285+ """Test that having Float8TrainingTensor object as a graph output works"""
282286 cnts = CompileCounterWithBackend ("inductor" )
283287 mod = self .MockLinear (graph_break = False ).cuda ()
284288 compiled_mod = torch .compile (mod , backend = cnts )
@@ -290,14 +294,14 @@ def test_float8_graph_output(self):
290294 for tensor in tensors :
291295 assert not isinstance (
292296 getattr (y_compiled , tensor ), torch ._subclasses .fake_tensor .FakeTensor
293- ), "Float8Tensor should not contain any FakeTensors!"
297+ ), "Float8TrainingTensor should not contain any FakeTensors!"
294298 assert isinstance (y_compiled ._orig_dtype , torch .dtype ), (
295- "Float8Tensor ._orig_dtype should be a dtype but got {}" .format (
299+ "Float8TrainingTensor ._orig_dtype should be a dtype but got {}" .format (
296300 type (y_compiled ._orig_dtype )
297301 )
298302 )
299303 assert isinstance (y_compiled ._linear_mm_config .output .emulate , bool ), (
300- "Float8Tensor ._emulate should be a bool but got {}" .format (
304+ "Float8TrainingTensor ._emulate should be a bool but got {}" .format (
301305 type (y_compiled ._linear_mm_config .output .emulate )
302306 )
303307 )
0 commit comments