Debugging ML Models: Difference between revisions

no edit summary
No edit summary
No edit summary
Line 29: Line 29:
** PyTorch has many augmentations in [https://pytorch.org/docs/stable/torchvision/transforms.html torchvision.transforms].
** PyTorch has many augmentations in [https://pytorch.org/docs/stable/torchvision/transforms.html torchvision.transforms].
** TF/Keras has augmentations as well in tf.image. See [https://www.tensorflow.org/tutorials/images/data_augmentation Tutorials: Data Augmentation].
** TF/Keras has augmentations as well in tf.image. See [https://www.tensorflow.org/tutorials/images/data_augmentation Tutorials: Data Augmentation].
==NaNs and Infs==
You can get NaNs and Infs in either the forward pass or the backward pass. 
To determine where, check the outputs of each layer and check the gradients.
In PyTorch, you can do:
<syntaxhighlight lang="python">
assert torch.isfinite(my_tensor).all(), "my_tensor has NaNs or Infs"
</syntaxhighlight>
In Tensorflow, you can do:
<syntaxhighlight lang="python">
def all_finite(tensor):
  is_finite = tf.math.is_finite(tensor)
  all_finite = tf.math.reduce_all(is_finite)
  return all_finite.numpy()
assert all_finite(my_tensor), "my_tensor has NaNs or Infs"
</syntaxhighlight>
Typically, you can Infs and NaNs when there is an division by 0 in the forward or backward pass. 
However it is also possible that the learning rate is too high or your model is broken. 
I typically debug by:
* Dropping the learning rate to something super small <code>1e-10</code>
* Determining whether the NaN/Inf is in the forward or backward pass.
* Checking that the training data has no NaNs or Infs.
* Checking that there are no divides anywhere in the code or that all divides are safe.
* Checking the gradients of trig functions in the code.
\(
\DeclareMathOperator{\arcsec}{arcsec}
\DeclareMathOperator{\arccot}{arccot}
\DeclareMathOperator{\arccsc}{arccsc}
\)
The following functions have unstable (e.g. Inf) gradients at certain points:
* Derivative of \(\log(x)\) is \(1/x\) which is Inf near 0
* Derivative of \(\arcsin(x)\) and \(\arccos(x)\) are \(\frac{\pm 1}{\sqrt{1-x^2}}\) which are Inf near 1.
* The derivatives of \(\arccsc(x)\) and \(\arcsec(x)\) are Inf near 0 and 1.
If you must use one of these functions near the unstable points, I suggest performing a linear approximation near 0 and 1.
Other ways to mitigate NaNs and Infs are:
* Decrease the learning rate and increase the batch size.
* Gradient clipping.
** For PyTorch see [https://pytorch.org/docs/master/generated/torch.nn.utils.clip_grad_norm_.html torch.nn.utils.clip_grad_norm_] and [https://pytorch.org/docs/master/generated/torch.nn.utils.clip_grad_value_.html torch.nn.utils.clip_grad_value_].
** For Tensorflow see [https://www.tensorflow.org/api_docs/python/tf/clip_by_norm tf.clip_by_norm] and [https://www.tensorflow.org/api_docs/python/tf/clip_by_value tf.clip_by_value].
* Using a safe divide which forces the denominator to have values with abs > EPS.