Enable inverse transform of predictions at inference-time
Created by: asogaard
Regressing targets that have highly non-Gaussian distributions and/or have numerical scales that, in their natural units, are very far from O(1), is a common problem (e.g. energy and interaction time). This can be solved in (at least) two ways:
- By having predicting the targeting in its natural units and then — within the loss function and only for the purposes of numerically stable training — transforming both into a space with characteristic scales of O(1). This is what the
transform_prediction_and_targetcallable argument does here. - By transforming the target into a space with characteristic scales of O(1); having the model predict this quantity; and then perform the inverse transform of the predictions at inference-time to obtain the actual physical quantity of interest.
Currently, the transform_target callable argument here does the first half of (2.) above, but in order to enable this second option fully we need: