Skip to content

Commit fd1d878

Browse files
author
dmoi
committed
fix saving & reload in learning scripts
1 parent ee2a67b commit fd1d878

File tree

7 files changed

+2131
-1747
lines changed

7 files changed

+2131
-1747
lines changed

config_notebook_1k_epochs.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ run_name: "notebook_replication_1k_epochs"
99

1010
# Training hyperparameters (from notebook)
1111
epochs: 1000
12-
batch_size: 20
12+
batch_size: 10
1313
gradient_accumulation_steps: 1
1414
seed: 0
1515

1616
# Model architecture (from notebook)
17-
hidden_size: 100
17+
hidden_size: 200
1818
num_embeddings: 40
1919
embedding_dim: 128
2020

@@ -71,6 +71,8 @@ se3_transformer: false
7171
output_fft: false
7272
output_rt: false
7373

74+
gpus: 2
75+
7476
# Notes:
7577
# - The notebook uses loss weight schedulers (currently commented out in training loop)
7678
# - To enable loss weight scheduling in the future, you would need to add that feature

foldtree2/learn_lightning.py

Lines changed: 406 additions & 445 deletions
Large diffs are not rendered by default.

foldtree2/learn_monodecoder.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@
152152
parser.add_argument('--ss-weight', type=float, default=0.25,
153153
help='Weight for secondary structure loss (default: 0.25)')
154154

155+
# Tensor Core precision
156+
parser.add_argument('--tensor-core-precision', type=str, default='high',
157+
choices=['highest', 'high', 'medium'],
158+
help='Float32 matrix multiplication precision for Tensor Cores (default: high)')
159+
155160
# Print an overview of the arguments and example command if no arguments provided
156161
if len(sys.argv) == 1:
157162
print('No arguments provided. Use -h for help.')
@@ -200,6 +205,11 @@
200205
torch.backends.cudnn.deterministic = True
201206
torch.backends.cudnn.benchmark = False
202207

208+
# Set tensor core precision for better performance on modern GPUs
209+
if torch.cuda.is_available():
210+
torch.set_float32_matmul_precision(args.tensor_core_precision)
211+
print(f"Tensor Core precision set to: {args.tensor_core_precision}")
212+
203213
if args.EMA:
204214
print("Using Exponential Moving Average for encoder codebook")
205215
else:
@@ -278,8 +288,10 @@
278288

279289
if os.path.exists(args.output_dir) and args.overwrite:
280290
#remove existing model
281-
if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best.pkl')):
282-
os.remove(os.path.join(args.output_dir, args.model_name + '_best.pkl'))
291+
if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best_encoder.pt')):
292+
os.remove(os.path.join(args.output_dir, args.model_name + '_best_encoder.pt'))
293+
if os.path.exists(os.path.join(args.output_dir, args.model_name + '_best_decoder.pt')):
294+
os.remove(os.path.join(args.output_dir, args.model_name + '_best_decoder.pt'))
283295

284296
# Data setup
285297
datadir = '../../datasets/foldtree2/'
@@ -334,15 +346,17 @@
334346

335347

336348
# Initialize or load model
337-
if os.path.exists(os.path.join(modeldir, modelname + '_best.pkl')) and args.overwrite == False:
338-
print(f"Loading existing model from {os.path.join(modeldir, modelname + '_best.pkl')}")
349+
encoder_path = os.path.join(modeldir, modelname + '_best_encoder.pt')
350+
decoder_path = os.path.join(modeldir, modelname + '_best_decoder.pt')
351+
if os.path.exists(encoder_path) and os.path.exists(decoder_path) and args.overwrite == False:
352+
print(f"Loading existing model from {encoder_path} and {decoder_path}")
339353
if os.path.exists(os.path.join(modeldir, modelname + '_info.txt')):
340354
with open(os.path.join(modeldir, modelname + '_info.txt'), 'r') as f:
341355
model_info = f.read()
342356
print("Model info:", model_info)
343357
# Load encoder and decoder from saved model
344-
with open(os.path.join(modeldir, modelname + '_best.pkl'), 'rb') as f:
345-
encoder, decoder = pickle.load(f)
358+
encoder = torch.load(encoder_path, map_location=device, weights_only=False)
359+
decoder = torch.load(decoder_path, map_location=device, weights_only=False)
346360
else:
347361
print("Creating new model...")
348362
# Model setup
@@ -934,6 +948,6 @@ def analyze_gradient_norms(model, top_k=3):
934948
# Close TensorBoard writer
935949
writer.close()
936950

937-
print(f"Training complete! Final model saved to {os.path.join(modeldir, modelname + '.pkl')}")
951+
print(f"Training complete! Final model saved to {os.path.join(modeldir, modelname + '_encoder_final.pt')} and {os.path.join(modeldir, modelname + '_decoder_final.pt')}")
938952
print(f"TensorBoard logs saved to: {tensorboard_log_dir}")
939953
print(f"View with: tensorboard --logdir={args.tensorboard_dir}")

0 commit comments

Comments
 (0)