diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 00acf910910712628a3b6f2936b8952c0c80cda4..04c7454422bb7bab585306595d7d6ba5331841b1 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -186,6 +186,7 @@ class Model(Logger, Configurable, LightningModule, ABC): index_column: str = "event_no", gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", + placeholder: str = "untitled_column", ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. @@ -211,7 +212,6 @@ class Model(Logger, Configurable, LightningModule, ABC): "doesn't resample batches; or do not request " "`additional_attributes`." ) - self.info(f"Column names for predictions are: \n {prediction_columns}") predictions_torch = self.predict( dataloader=dataloader, gpus=gpus, @@ -220,10 +220,12 @@ class Model(Logger, Configurable, LightningModule, ABC): predictions = ( torch.cat(predictions_torch, dim=1).detach().cpu().numpy() ) - assert len(prediction_columns) == predictions.shape[1], ( - f"Number of provided column names ({len(prediction_columns)}) and " - f"number of output columns ({predictions.shape[1]}) don't match." + prediction_columns = self._dataframe_safeguard( + prediction_columns, + out_shape=predictions.shape[1], + placeholder=placeholder, ) + self.info(f"Column names for predictions are: \n {prediction_columns}") # Get additional attributes attributes: Dict[str, List[np.ndarray]] = OrderedDict( @@ -317,3 +319,44 @@ class Model(Logger, Configurable, LightningModule, ABC): ), f"Argument `source` of type ({type(source)}) is not a `ModelConfig" return source._construct_model(trust, load_modules) + + def _dataframe_safeguard( + self, + prediction_columns: List[str], + out_shape: int, + placeholder: str = "untitled_column", + ) -> List[str]: + """Make prediction_columns have the correct length to create dataframe. + + Arguments: + prediction_columns: Prediction_columns from _predict_as_dataframe. + out_shape: Number of output predictions at final layer. + placeholder: Name of additional columns used to fill up prediction_columns. + + Returns: A list of column names of length out_shape. + """ + prediction_columns_deficit = len(prediction_columns) - out_shape + common_error_message = f""" + Dimensional mismatch between number prediction columns + ({len(prediction_columns)}) and the output shape ({out_shape}) + """ + if prediction_columns_deficit == 0: + return prediction_columns + corrected_prediction_columns: List[str] = prediction_columns + if prediction_columns_deficit > 0: + additional_error_message = f""" + Appending {prediction_columns_deficit} columns titled + {placeholder}_idx + """ + for idx in range(prediction_columns_deficit): + corrected_prediction_columns.append(f"{placeholder}{idx}") + else: + additional_error_message = f""" + Only using the first {out_shape} columns + {prediction_columns[:out_shape]} + """ + corrected_prediction_columns = corrected_prediction_columns[ + :out_shape + ] + self.warning(f"{common_error_message}\n{additional_error_message}") + return corrected_prediction_columns