Debugging ML Models: Difference between revisions
No edit summary |
|||
Line 8: | Line 8: | ||
** Check that your loss is implemented correctly and taken against the correct ground truth image. | ** Check that your loss is implemented correctly and taken against the correct ground truth image. | ||
* Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | * Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | ||
* Make sure there is no activation on the final layer. | * Make sure there is no activation on the final layer. | ||
Line 18: | Line 16: | ||
* Increase or decrease the learning rate one magnitude. | * Increase or decrease the learning rate one magnitude. | ||
* Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations. | * Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations. | ||
* Try disabling any tricks you have like dropout. | |||
==Overfitting== | |||
Overfitting occurs when your training loss is below your validation loss. | |||
Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting. | |||
Recently though, overfitting has become less of a concern with larger ML models. | |||
* Increase the capacity and depth of the model. | |||
* Add more training data if you can. | |||
* Depending on your task, you can also try data augmentation (e.g. random crops, random rotations, random hue). | |||
** PyTorch has many augmentations in [https://pytorch.org/docs/stable/torchvision/transforms.html torchvision.transforms]. | |||
** TF/Keras has augmentations as well in tf.image. See [https://www.tensorflow.org/tutorials/images/data_augmentation Tutorials: Data Augmentation]. |