Comparision of timings on naval dataset for different tf function strategies

We compared the timings for train_step function ( forward and backward pass + update optimizer)

Model Outer function Retracing happens Mean train step (discarding retracing times)
1 tf.func with signature python func no 0.28
2 tf.func without signature train_step is tf.func with signature no 0.22
3 tf func without signature compute_gradient is tf.func with signature no 0.25
4 tf func without signature compute_gradient is tf.func without signature yes 0.24
5 tf func without signature train_step is tf.func without signature yes 0.14

1. Model is tf.func with signature

Only the model call function is annotated with a generic shape signature, so that retracing is avoided. The train_step function is a python function.

In [30]:
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
In [31]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_dynamic_signature/timings_step")
plt.plot(timings_step)
Out[31]:
[<matplotlib.lines.Line2D at 0x7f2a2c462b70>]

Only the first step is expensive, tracing is performed once. Average excluding the first step -

In [32]:
np.mean(timings_step[1:])
Out[32]:
0.2794752272636229
In [33]:
timings_eval = np.load("/home/saswata.chakravarty/trainings/naval_model_dynamic_signature/timings_eval")
plt.plot(timings_eval)
Out[33]:
[<matplotlib.lines.Line2D at 0x7f2a2c44ca20>]
In [34]:
np.mean(timings_eval)
Out[34]:
52.39069647789002

2. Train_step is tf.func with a signature

The train_step function is a tf.function annotated with generic signature to avoid retracing. Model call function is tf function without signature.

In [35]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_train_step_signature/timings_step")
plt.plot(timings_step)
Out[35]:
[<matplotlib.lines.Line2D at 0x7f2a2c3b2dd8>]

Similar to 1, only the first step is expensive. Average exclusing first step -

In [36]:
np.mean(timings_step[1:])
Out[36]:
0.22285934719330647
In [37]:
timings_eval = np.load("/home/saswata.chakravarty/trainings/naval_model_train_step_signature/timings_eval")
plt.plot(timings_eval)
Out[37]:
[<matplotlib.lines.Line2D at 0x7f2a2c321198>]
In [38]:
np.mean(timings_eval)
Out[38]:
45.985404329299925

Annotating the top level train_step is better.

3. compute_gradient is tf.func with a signature

The compute_gradient function is a tf.function annotated with generic signature to avoid retracing. Model call function is tf function without signature.

In [39]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_compute_grad_signature/timings_step")
plt.plot(timings_step)
Out[39]:
[<matplotlib.lines.Line2D at 0x7f2a2c2f68d0>]
In [40]:
np.mean(timings_step)
Out[40]:
0.25844784026145934
In [41]:
timings_eval = np.load("/home/saswata.chakravarty/trainings/naval_compute_grad_signature/timings_eval")
plt.plot(timings_eval)
np.mean(timings_eval)
Out[41]:
45.57868378639221

4. compute_gradient is tf.func without signature

The compute_gradient function is a tf.function without signature. Model call function is tf function without signature.

In [42]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_master/timings_step")
plt.plot(timings_step)
Out[42]:
[<matplotlib.lines.Line2D at 0x7f2a2c235da0>]

There are multiple expensive steps, each will be for new shapes encountered. Discarding the expensive steps -

In [43]:
np.mean(timings_step[2000:])
Out[43]:
0.24018074917793275
In [44]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_master/timings_eval")
plt.plot(timings_eval)
np.mean(timings_eval)
Out[44]:
45.57868378639221

5. train_step is tf.func without signature

The train_step function is a tf.function without signature. Model call function is tf function without signature.

In [45]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_train_step_generic/timings_step")
plt.plot(timings_step)
Out[45]:
[<matplotlib.lines.Line2D at 0x7f2a2c179ef0>]

Mean discarding the expensive steps -

In [46]:
np.mean(timings_step[2000:])
Out[46]:
0.14269281848271687
In [47]:
timings_step = np.load("/home/saswata.chakravarty/trainings/naval_model_train_step_generic/timings_eval")
plt.plot(timings_eval)
np.mean(timings_eval)
Out[47]:
45.57868378639221