Adds an improved implementation of DynEdge
Created by: AMHermansen
This PR adds a slightly improved implementation of DynEdge. The models uses a slightly different weight set, than the current implementation, but the models are mathematically equivalent.
The optimization is based on how the first post processing step is done. The current implementation first appends all "skip-connection" graphs to a list, and the concatenates the list to a single graph with all features. And then applies a linear layer. This implementation avoids the append and concatenate step, by applying linear layer "piece-wise".
Mathematically this corresponds to:
(Let W be the weights first linear layer in the post processing step, b the biases, and g the concatenated graph.)
For this implementation each dynedge convolution has its own linear layer, without biases and the input graphs has a linear layer with biases.
I've also included profiling logs of a run with this new implementation and a run with the current implementation and I found a 3% speedup, which admittedly isn't a lot, but still something
Merge request reports
Activity
requested review from @jprado
Created by: RasmusOrsoe
Review: Commented
Hi @AMHermansen!
I think it's great that you're thinking along these lines. Could you elaborate on how you come up with the speed improvement number?
When I consult the logs you've attached here, I note that the mean duration of the function calls that should appear faster, if this change actually makes the forward pass faster, is slower. I.e.
run_training_batch
- presumably the time it takes to train on a single batch (forward + backwards + optimizer step) appears with a mean duration of 0.023597 and 0.024962 for the original DynEdge model and your modified version, respectively. Given this is a mean over 120k function calls, I'd think it's significant statistically. I realize the mean epoch duration is lower for your new implementation by around 20 seconds, but that could be due to other circumstances that has nothing to do with your change. I think a more straight forward evaluation of the impact of this change would be to instantiate the two models and callmodel(mock_batch)
, wheremock_batch
is a test batch in memory, 200k times and measure the average call times. If this appear faster, which these logs suggests it shouldn't, then I think we can call this a speed-up.I'm also not entirely convinced that this is mathematically equivalent to the original implementation. In a fully connected layer,
which torch.nn.Linear
is, each neuron is evaluated on the entire set of columns. The equation you share above is the math for a single neuron. As far as I can tell, you construct a fully connected layer for the input graph and each state graph separately, which means each neural network reasons over a portion of the data and not its entirety. I think that could lead to a difference in performance. The number of learnable parameters should betensor_columns*(hidden_dim + 1)
. Have you counted the entries in the state dict and seen they agree 1:1?