In a previous experiment, we have explored the behaviour and interaction between keras models, tf.functions, saved models and tf.dataset. A summary of this experiment is available here as a notebook.


As a follow up to this previous test, we compared the performance of in-graph training loop ( ) with loop in python and model with annotated __call__ function.

We performed training with 100, 200, 400, 800 and 1600 steps with the naval dataset with the following config –

It is seen that the python loop has some upfront initial cost, but over the period of time it runs faster than in-graph training loop.

The initial cost of the python loop can be attributed to the fact that multiple traces are performed for each different shapes, but once all the different shapes are encountered, there are no further delays and it runs faster than the in-graph loop.

The finding also corroborates with the issue

In order to avoid retracing for different shapes, we should provide input_signature to the model tf.function. The input_signature is known only after the partial shape information is available, so instead of statically annotating with tf.function, we need to create a tf.function with the right input signature on the fly.

Comparisons of training in python loop, avoiding the retracing –

Follow Up

As a follow up to Tf.function retracing vs specifying an input signature , training on naval dataset is run for 5000 steps, with various combinations of annotating tf functions. Following are the results

  Model Outer function Retracing happens Mean train step (discarding retracing times)
1 tf.func with signature python func no 0.28
2 tf.func without signature train_step is tf.func with signature no 0.22
3 tf func without signature compute_gradient is tf.func with signature no 0.25
4 tf func without signature compute_gradient is tf.func without signature yes 0.24
5 tf func without signature train_step is tf.func without signature yes 0.14

Refer to attached notebook for details:

We can conclude that retracing is expensive operation, but once the function has retraced all possible shapes, the compiled function is faster.

Making train_step as the tf.function is much faster than making compute_gradient as the tf function .

Pin It on Pinterest