In a previous experiment, we have explored the behaviour and interaction between keras models, tf.functions, saved models and tf.dataset. A summary of this experiment is available here as a notebook.

As a follow up to this previous test, we compared the performance of in-graph training loop ( https://www.tensorflow.org/guide/function#advanced_example_an_in-graph_training_loop ) with loop in python and model with annotated __call__ function.

We performed training with 100, 200, 400, 800 and 1600 steps with the naval dataset with the following config –

It is seen that the python loop has some upfront initial cost, but over the period of time it runs faster than in-graph training loop.

The initial cost of the python loop can be attributed to the fact that multiple traces are performed for each different shapes, but once all the different shapes are encountered, there are no further delays and it runs faster than the in-graph loop.

The finding also corroborates with the issue https://github.com/tensorflow/tensorflow/issues/35165

In order to avoid retracing for different shapes, we should provide input_signature to the model tf.function. The input_signature is known only after the partial shape information is available, so instead of statically annotating with tf.function, we need to create a tf.function with the right input signature on the fly.

Comparisons of training in python loop, avoiding the retracing –

Follow Up

As a follow up to Tf.function retracing vs specifying an input signature , training on naval dataset is run for 5000 steps, with various combinations of annotating tf functions. Following are the results

ModelOuter functionRetracing happensMean train step (discarding retracing times)
1tf.func with signaturepython funcno0.28
2tf.func without signaturetrain_step is tf.func with signatureno0.22
3tf func without signaturecompute_gradient is tf.func with signatureno0.25
4tf func without signaturecompute_gradient is tf.func without signatureyes0.24
5tf func without signaturetrain_step is tf.func without signatureyes0.14

Refer to attached notebook for details:

We can conclude that retracing is expensive operation, but once the function has retraced all possible shapes, the compiled function is faster.

Making train_step as the tf.function is much faster than making compute_gradient as the tf function .

Pin It on Pinterest