Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 21 additions & 29 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pdb
import csv

#Constants
# Constants
DATE_FORMAT_STR = "%Y-%m-%d:%H-%M-%S"

if __name__ == '__main__':
Expand All @@ -28,23 +28,23 @@
repo = git.Repo(search_parent_directories=True)
commit = repo.head.object
args.commit = commit.hexsha
print("OncoNet main running from commit: \n\n{}\n{}author: {}, date: {}".format(
commit.hexsha, commit.message, commit.author, commit.committed_date))
print(f"OncoNet main running from commit: {commit.hexsha}\n{commit.message}\nauthor: {commit.author}, date: {commit.committed_date}")

if args.get_dataset_stats:
print("\nComputing image mean and std...")
args.img_mean, args.img_std = get_dataset_stats(args)
print('Mean: {}'.format(args.img_mean))
print('Std: {}'.format(args.img_std))
print(f'Mean: {args.img_mean}\nStd: {args.img_std}')

print("\nLoading data-augmentation scheme...")
transformers = transformer_factory.get_transformers(
args.image_transformers, args.tensor_transformers, args)
test_transformers = transformer_factory.get_transformers(
args.test_image_transformers, args.test_tensor_transformers, args)

# Load dataset and add dataset specific information to args
print("\nLoading data...")
train_data, dev_data, test_data = dataset_factory.get_dataset(args, transformers, test_transformers)

# Load model and add model specific information to args
if args.snapshot is None:
model = model_factory.get_model(args)
Expand All @@ -55,11 +55,11 @@
model._model.pool = non_trained_model._model.pool
model._model.args = non_trained_model._model.args


print(model)

# Load run parameters if resuming that run.
args.model_path = state.get_model_path(args)
print('Trained model will be saved to [%s]' % args.model_path)
print(f'Trained model will be saved to [{args.model_path}]')
if args.resume:
try:
state_keeper = state.StateKeeper(args)
Expand All @@ -68,20 +68,15 @@
args.current_epoch = epoch
args.lr = lr
args.epoch_stats = epoch_stats
except:
args.optimizer_state = None
args.current_epoch = None
args.lr = None
args.epoch_stats = None
print("\n Error loading previous state. \n Starting run from scratch.")
except Exception as e:
print(f"\nError loading previous state: {e}\nStarting run from scratch.")
else:
print("\n Restarting run from scratch.")

print("\nRestarting run from scratch.")

print("\nParameters:")
for attr, value in sorted(args.__dict__.items()):
if attr not in ['optimizer_state', 'patient_to_partition_dict', 'path_to_hidden_dict', 'exam_to_year_dict', 'exam_to_device_dict']:
print("\t{}={}".format(attr.upper(), value))
print(f"\t{attr.upper()}={value}")

save_path = args.results_path
print()
Expand All @@ -91,54 +86,51 @@

if args.plot_losses:
visualize.viz_utils.plot_losses(epoch_stats)
print("Save train/dev results to {}".format(save_path))
print(f"Save train/dev results to {save_path}")
args_dict = vars(args)
pickle.dump(args_dict, open(save_path, 'wb'))

print()
if args.dev:
print("-------------\nDev")
args.dev_stats = train.compute_threshold_and_dev_stats(dev_data, model, args)
print("Save dev results to {}".format(save_path))
print(f"Save dev results to {save_path}")
args_dict = vars(args)
pickle.dump(args_dict, open(save_path, 'wb'))

if args.test:

print("-------------\nTest")
args.test_stats = train.eval_model(test_data, model, args)
print("Save test results to {}".format(save_path))
print(f"Save test results to {save_path}")
args_dict = vars(args)
pickle.dump(args_dict, open(save_path, 'wb'))

if (args.dev or args.test) and args.prediction_save_path is not None:
exams, probs = [], []
if args.dev:
exams.extend( args.dev_stats['exams'])
probs.extend( args.dev_stats['probs'])
exams.extend(args.dev_stats['exams'])
probs.extend(args.dev_stats['probs'])
if args.test:
exams.extend( args.test_stats['exams'])
probs.extend( args.test_stats['probs'])
exams.extend(args.test_stats['exams'])
probs.extend(args.test_stats['probs'])
legend = ['patient_exam_id']
if args.callibrator_snapshot is not None:
callibrator = pickle.load(open(args.callibrator_snapshot,'rb'))
for i in range(args.max_followup):
legend.append("{}_year_risk".format(i+1))
legend.append(f"{i+1}_year_risk")
export = {}
with open(args.prediction_save_path,'w') as out_file:
writer = csv.DictWriter(out_file, fieldnames=legend)
writer.writeheader()
for exam, arr in zip(exams, probs):
export['patient_exam_id'] = exam
for i in range(args.max_followup):
key = "{}_year_risk".format(i+1)
key = f"{i+1}_year_risk"
raw_val = arr[i]
if args.callibrator_snapshot is not None:
val = callibrator[i].predict_proba([[raw_val]])[0,1]
else:
val = raw_val
export[key] = val
writer.writerow(export)
print("Exported predictions to {}".format(args.prediction_save_path))


print(f"Exported predictions to {args.prediction_save_path}")