Make training scripts simpler
Created by: RasmusOrsoe
In /examples/train_model.py the MWE of a training script is 150 lines of code. If we introduce a default configuration for each model (model.config), a Session class we could simplify this greatly. Here is a bit of conceptual code
from graphnet.models.training.utils import TrainingSession, InferenceSession, make_dataloaders
from graphnet.models.gnn import DynEdge
from graphnet.components.loss_functions import LogCoshLoss
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.models.detector.icecube import IceCubeDeepCore
# Main function definition
def main(model_config = None):
archive = "/groups/icecube/asogaard/gnn/results/"
run_name = "dynedge_{}_example".format(config["target"])
selection = None
model = DynEdge()
# Configuration
if model_config is None:
model.config['loss_function'] = LogCoshLoss
model.config['task'] = EnergyReconstruction
model.config['data_path'] = "/groups/icecube/asogaard/data/sqlite/dev_lvl7_robustness_muon_neutrino_0000/data/dev_lvl7_robustness_muon_neutrino_0000.db"
model.config['pulsemap'] = 'SRTTWOfflinePulsesDC'
model.config['detector'] = IeCubeDeepCore
model.config['selection'] = selection
else:
model.config = model_config
# Make DataLoaders
train_dataloader, valid_dataloader, test_dataloader = make_dataloaders(model)
# Setup Sessions
training_session = TrainingSession(model, archive, run_name)
inference_session = InferenceSession(model, archive, run_name)
# training session
trained_model = training_session.start(train_dataloader = train_dataloader, validation_dataloader = valid_dataloader)
# inference session
if trained_model.converged:
validation_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'validation')
test_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'test')
# save results to csv
save_results(validation_results, trained_model, run_name, archive)
save_results(test_results, trained_model, run_name, archive)
# Main function call
if __name__ == "__main__":
main()