Skip to content

Make training scripts simpler

Created by: RasmusOrsoe

In /examples/train_model.py the MWE of a training script is 150 lines of code. If we introduce a default configuration for each model (model.config), a Session class we could simplify this greatly. Here is a bit of conceptual code

from graphnet.models.training.utils import TrainingSession, InferenceSession, make_dataloaders
from graphnet.models.gnn import DynEdge
from graphnet.components.loss_functions import LogCoshLoss
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.models.detector.icecube import IceCubeDeepCore

# Main function definition
def main(model_config = None):
    
    archive = "/groups/icecube/asogaard/gnn/results/"
    run_name = "dynedge_{}_example".format(config["target"])
    selection = None
    
    model = DynEdge()
    # Configuration
    if model_config is None:
        model.config['loss_function'] = LogCoshLoss
        model.config['task'] = EnergyReconstruction
        model.config['data_path'] = "/groups/icecube/asogaard/data/sqlite/dev_lvl7_robustness_muon_neutrino_0000/data/dev_lvl7_robustness_muon_neutrino_0000.db"
        model.config['pulsemap'] = 'SRTTWOfflinePulsesDC'
        model.config['detector'] = IeCubeDeepCore
        model.config['selection'] = selection
    else:
        model.config = model_config
    
    # Make DataLoaders
    train_dataloader, valid_dataloader, test_dataloader = make_dataloaders(model)

    # Setup Sessions
    training_session = TrainingSession(model, archive, run_name)
    inference_session = InferenceSession(model, archive, run_name)
        
    # training session
    trained_model = training_session.start(train_dataloader = train_dataloader, validation_dataloader = valid_dataloader)
    
    # inference session
    if trained_model.converged:
        validation_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'validation')
        test_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'test')


    # save results to csv
    save_results(validation_results, trained_model, run_name, archive)
    save_results(test_results, trained_model, run_name, archive)


# Main function call
if __name__ == "__main__":
    main()