TensorFlow: Difference between revisions

1,584 bytes added ,  15 October 2019
Line 8: Line 8:
===Training Loop===
===Training Loop===
[https://www.tensorflow.org/guide/keras/train_and_evaluate#part_ii_writing_your_own_training_evaluation_loops_from_scratch Reference]<br>
[https://www.tensorflow.org/guide/keras/train_and_evaluate#part_ii_writing_your_own_training_evaluation_loops_from_scratch Reference]<br>
You can write your own training loop.
While you can train using <code>model.compile</code> and <code>model.fit</code>, using your own custom training loop is much more flexable and easier to understand.
You can write your own training loop by doing the following:
<syntaxhighlight code="python>
 
my_model= keras.Sequential([
    keras.layers.Dense(400, input_shape=400, activation='relu'),
    keras.layers.Dense(400, activation='relu'),
    keras.layers.Dense(400, activation='relu'),
    keras.layers.Dense(400, activation='relu'),
    keras.layers.Dense(400, activation='relu'),
    keras.layers.Dense(2)
])
 
training_loss = []
validation_loss = []
for epoch in range(100):
    print('Start of epoch %d' % (epoch,))
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            guess = my_model(x_batch_train)
            loss_value = my_custom_loss(y_batch_train, guess)
 
        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, my_model.trainable_weights)
 
        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, my_model.trainable_weights))
 
        # Log every 200 batches.
        if step % 200 == 0:
            print('Training loss at step %s: %s' % (step, float(loss_value)))
        training_loss.append(loss_value)
        guess_validation = model(x_validation)
        validation_loss.append(my_custom_loss(y_validation, guess_validation))
</syntaxhighlight>


===Save and Load Models===
===Save and Load Models===