|
152 | 152 | parser.add_argument('--ss-weight', type=float, default=0.25, |
153 | 153 | help='Weight for secondary structure loss (default: 0.25)') |
154 | 154 |
|
| 155 | +# Tensor Core precision |
| 156 | +parser.add_argument('--tensor-core-precision', type=str, default='high', |
| 157 | + choices=['highest', 'high', 'medium'], |
| 158 | + help='Float32 matrix multiplication precision for Tensor Cores (default: high)') |
| 159 | + |
155 | 160 | # Print an overview of the arguments and example command if no arguments provided |
156 | 161 | if len(sys.argv) == 1: |
157 | 162 | print('No arguments provided. Use -h for help.') |
|
200 | 205 | torch.backends.cudnn.deterministic = True |
201 | 206 | torch.backends.cudnn.benchmark = False |
202 | 207 |
|
| 208 | +# Set tensor core precision for better performance on modern GPUs |
| 209 | +if torch.cuda.is_available(): |
| 210 | + torch.set_float32_matmul_precision(args.tensor_core_precision) |
| 211 | + print(f"Tensor Core precision set to: {args.tensor_core_precision}") |
| 212 | + |
203 | 213 | if args.EMA: |
204 | 214 | print("Using Exponential Moving Average for encoder codebook") |
205 | 215 | else: |
|
278 | 288 |
|
279 | 289 | if os.path.exists(args.output_dir) and args.overwrite: |
280 | 290 | #remove existing model |
281 | | - if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best.pkl')): |
282 | | - os.remove(os.path.join(args.output_dir, args.model_name + '_best.pkl')) |
| 291 | + if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best_encoder.pt')): |
| 292 | + os.remove(os.path.join(args.output_dir, args.model_name + '_best_encoder.pt')) |
| 293 | + if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best_decoder.pt')): |
| 294 | + os.remove(os.path.join(args.output_dir, args.model_name + '_best_decoder.pt')) |
283 | 295 |
|
284 | 296 | # Data setup |
285 | 297 | datadir = '../../datasets/foldtree2/' |
|
334 | 346 |
|
335 | 347 |
|
336 | 348 | # Initialize or load model |
337 | | -if os.path.exists(os.path.join(modeldir, modelname + '_best.pkl')) and args.overwrite == False: |
338 | | - print(f"Loading existing model from {os.path.join(modeldir, modelname + '_best.pkl')}") |
| 349 | +encoder_path = os.path.join(modeldir, modelname + '_best_encoder.pt') |
| 350 | +decoder_path = os.path.join(modeldir, modelname + '_best_decoder.pt') |
| 351 | +if os.path.exists(encoder_path) and os.path.exists(decoder_path) and args.overwrite == False: |
| 352 | + print(f"Loading existing model from {encoder_path} and {decoder_path}") |
339 | 353 | if os.path.exists(os.path.join(modeldir, modelname + '_info.txt')): |
340 | 354 | with open(os.path.join(modeldir, modelname + '_info.txt'), 'r') as f: |
341 | 355 | model_info = f.read() |
342 | 356 | print("Model info:", model_info) |
343 | 357 | # Load encoder and decoder from saved model |
344 | | - with open(os.path.join(modeldir, modelname + '_best.pkl'), 'rb') as f: |
345 | | - encoder, decoder = pickle.load(f) |
| 358 | + encoder = torch.load(encoder_path, map_location=device, weights_only=False) |
| 359 | + decoder = torch.load(decoder_path, map_location=device, weights_only=False) |
346 | 360 | else: |
347 | 361 | print("Creating new model...") |
348 | 362 | # Model setup |
@@ -934,6 +948,6 @@ def analyze_gradient_norms(model, top_k=3): |
934 | 948 | # Close TensorBoard writer |
935 | 949 | writer.close() |
936 | 950 |
|
937 | | -print(f"Training complete! Final model saved to {os.path.join(modeldir, modelname + '.pkl')}") |
| 951 | +print(f"Training complete! Final model saved to {os.path.join(modeldir, modelname + '_encoder_final.pt')} and {os.path.join(modeldir, modelname + '_decoder_final.pt')}") |
938 | 952 | print(f"TensorBoard logs saved to: {tensorboard_log_dir}") |
939 | 953 | print(f"View with: tensorboard --logdir={args.tensorboard_dir}") |
0 commit comments