Jump to content

TensorFlow: Difference between revisions

1,647 bytes added ,  28 January 2021
Line 148: Line 148:
In TF1, you first build a computational graph by chaining commands with placeholder.   
In TF1, you first build a computational graph by chaining commands with placeholder.   
Then, you execute the graph in a tf session.
Then, you execute the graph in a tf session.
{{hidden | TF1 MNIST Example |
<syntaxhighlight lang="python">
<syntaxhighlight lang="python">
import tensorflow as tf
import tensorflow as tf
from tensorflow import keras
import numpy as np


NUM_EPOCHS = 10
BATCH_SIZE = 64


(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
rng = np.random.default_rng()


classification_model = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    keras.layers.Conv2D(16, 3, padding="SAME"),
    keras.layers.ReLU(),
    keras.layers.Conv2D(16, 3, padding="SAME"),
    keras.layers.ReLU(),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='relu'),
])
x_in = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, 28, 28, 1))
logits = classification_model(x_in)
gt_classes = tf.compat.v1.placeholder(dtype=tf.int32, shape=(None,))
loss = tf.losses.softmax_cross_entropy(tf.one_hot(gt_classes, 10), logits)
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    global_step = 0
    for epoch in range(NUM_EPOCHS):
        x_count = x_train.shape[0]
        image_ordering = rng.choice(range(x_count), x_count, replace=False)
        current_idx = 0
        while current_idx < x_count:
            my_indices = image_ordering[current_idx:min(current_idx + BATCH_SIZE, x_count)]
            x = x_train[my_indices]
            x = x[:, :, :, None] / 255
            logits_val, loss_val, _ = sess.run((logits, loss, optimizer), {
                x_in: x,
                gt_classes: y_train[my_indices]
            })
            if global_step % 100 == 0:
                print("Loss", loss_val)
            current_idx += BATCH_SIZE
            global_step += 1
</syntaxhighlight>
</syntaxhighlight>
}}


==Estimators==
==Estimators==