22from tqdm import tqdm
33import time
44
5- from KD_Lib .common import BaseClass
5+ from KD_Lib .KD . common import BaseClass
66
77
8- class Pipeline () :
8+ class Pipeline :
99 """
1010 Pipeline of knowledge distillation, pruning and quantization methods
1111 supported by KD_Lib. Sequentially applies a list of methods on the student model.
12-
12+
1313 All the elements in list must implement either train_student, prune or quantize
1414 methods.
1515
1616 :param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization
1717 :param: epochs (int) number of iterations through whole batch for each method in
18- list
18+ list
1919 :param: plot_losses (bool) Plot a graph of losses during training
2020 :param: save_model (bool) Save model after performing the list methods
2121 :param: save_model_pth (str) Path where model is saved if save_model is True
2222 :param: verbose (int) Verbose
2323 """
24+
2425 def __init__ (
25- self ,
26- steps ,
27- epochs = 5 ,
28- plot_losses = True ,
29- save_model = True ,
30- save_model_pth = "./models/student.pt" ,
31- verbose = 0 ):
26+ self ,
27+ steps ,
28+ epochs = 5 ,
29+ plot_losses = True ,
30+ save_model = True ,
31+ save_model_pth = "./models/student.pt" ,
32+ verbose = 0 ,
33+ ):
3234 self .steps = steps
3335 self .device = device
3436 self .verbose = verbose
@@ -43,10 +45,12 @@ def _validate_steps(self):
4345 name , process = zip (* self .steps )
4446
4547 for t in process :
46- if (not hasattr (t , ('train_student' , 'prune' , 'quantize' ))):
47- raise TypeError ("All the steps must support at least one of "
48- "train_student, prune or quantize method, {} is not"
49- " supported yet" .format (str (t )))
48+ if not hasattr (t , ("train_student" , "prune" , "quantize" )):
49+ raise TypeError (
50+ "All the steps must support at least one of "
51+ "train_student, prune or quantize method, {} is not"
52+ " supported yet" .format (str (t ))
53+ )
5054
5155 def get_steps (self ):
5256 return self .steps
@@ -65,38 +69,49 @@ def _fit(self):
6569 for idx , name , process in self ._iter ():
6670 print ("Starting {}" .format (name ))
6771 if idx != 0 :
68- if hasattr (process , ' train_student' ):
69- if hasattr (self .steps [idx - 1 ], ' train_student' ):
70- process .student_model = self .steps [idx - 1 ].student_model
72+ if hasattr (process , " train_student" ):
73+ if hasattr (self .steps [idx - 1 ], " train_student" ):
74+ process .student_model = self .steps [idx - 1 ].student_model
7175 else :
72- process .student_model = self .steps [idx - 1 ].model
76+ process .student_model = self .steps [idx - 1 ].model
7377 t1 = time .time ()
74- if hasattr (process , 'train_student' ):
75- process .train_student (self .epochs , self .plot_losses , self .save_model , self .save_model_path )
76- elif hasattr (proces , 'prune' ):
78+ if hasattr (process , "train_student" ):
79+ process .train_student (
80+ self .epochs , self .plot_losses , self .save_model , self .save_model_path
81+ )
82+ elif hasattr (proces , "prune" ):
7783 process .prune ()
78- elif hasattr (process , ' quantize' ):
84+ elif hasattr (process , " quantize" ):
7985 process .quantize ()
8086 else :
81- raise TypeError ("{} is not supported by the pipeline yet."
82- .format (process ))
87+ raise TypeError (
88+ "{} is not supported by the pipeline yet." .format (process )
89+ )
8390
8491 t2 = time .time () - t1
85- print ("{} completed in {}hr {}min {}s" .format (name , t2 // (60 * 60 ), t2 // 60 , t2 % 60 )
86-
92+ print (
93+ "{} completed in {}hr {}min {}s" .format (
94+ name , t2 // (60 * 60 ), t2 // 60 , t2 % 60
95+ )
96+ )
97+
8798 if self .verbose :
8899 pbar .update (1 )
89-
100+
90101 if self .verbose :
91102 pbar .close ()
92103
93104 def train (self ):
94105 """
95- Train the (student) model sequentially through the list.
106+ Train the (student) model sequentially through the list.
96107 """
97108 self ._validate_steps ()
98109
99110 t1 = time .time ()
100111 self ._fit ()
101112 t2 = time .time () - t1
102- print ("Pipeline execution completed in {}hr {}min {}s" .format (t2 // (60 * 60 ), t2 // 60 , t2 % 60 )
113+ print (
114+ "Pipeline execution completed in {}hr {}min {}s" .format (
115+ t2 // (60 * 60 ), t2 // 60 , t2 % 60
116+ )
117+ )
0 commit comments