Skip to content

Make prediction column names self-contained

Created by: RasmusOrsoe

Hi

When we want to predict, the repo leans towards the approach:

predictions = get_predictions(trainer, model, validation_dataloader, [target + '_pred'])

This means that the user of the trained model is required to in advance know the column names of the output of the trained model.

I'd like to suggest that we instead move to a self-contained approach that would allow for a syntax on the form:

predictions = get_predictions(trainer, model, validation_dataloader, model.output_column_names)

In addition I think it would be good for portability to include syntax for trained models such as:

model_predicts_this_target = model.target
model_was_trained_on_this_pulsemap_and_therefore_requires_it_for_predictions = model.pulsemap

Rasmus