Skip to content

Enable inverse transform of predictions at inference-time

Created by: asogaard

Regressing targets that have highly non-Gaussian distributions and/or have numerical scales that, in their natural units, are very far from O(1), is a common problem (e.g. energy and interaction time). This can be solved in (at least) two ways:

  1. By having predicting the targeting in its natural units and then — within the loss function and only for the purposes of numerically stable training — transforming both into a space with characteristic scales of O(1). This is what the transform_prediction_and_target callable argument does here.
  2. By transforming the target into a space with characteristic scales of O(1); having the model predict this quantity; and then perform the inverse transform of the predictions at inference-time to obtain the actual physical quantity of interest.

Currently, the transform_target callable argument here does the first half of (2.) above, but in order to enable this second option fully we need:

  1. an argument like inference_transform to Model here that, if specified is called on the model outputs when an inference flag is set.
  2. such an inference flag, e.g. using an .inference() method to Model, analogous to (and interplaying with) the .train() and .eval() methods of torch.nn.Module, cf. here.