1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| np.random.seed(42) tf.random.set_seed(42)
codings_size = 30
generator = keras.models.Sequential([ keras.layers.Dense(100, activation="selu", input_shape=[codings_size]), keras.layers.Dense(150, activation="selu"), keras.layers.Dense(28 * 28, activation="sigmoid"), keras.layers.Reshape([28, 28]) ]) discriminator = keras.models.Sequential([ keras.layers.Flatten(input_shape=[28, 28]), keras.layers.Dense(150, activation="selu"), keras.layers.Dense(100, activation="selu"), keras.layers.Dense(1, activation="sigmoid") ]) gan = keras.models.Sequential([generator, discriminator])
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop") discriminator.trainable = False gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
batch_size = 32 dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000) dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50): generator, discriminator = gan.layers for epoch in range(n_epochs): print("Epoch {}/{}".format(epoch + 1, n_epochs)) for X_batch in dataset: noise = tf.random.normal(shape=[batch_size, codings_size]) generated_images = generator(noise) X_fake_and_real = tf.concat([generated_images, X_batch], axis=0) y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size) discriminator.trainable = True discriminator.train_on_batch(X_fake_and_real, y1) noise = tf.random.normal(shape=[batch_size, codings_size]) y2 = tf.constant([[1.]] * batch_size) discriminator.trainable = False gan.train_on_batch(noise, y2) plot_multiple_images(generated_images, 8) plt.show() train_gan(gan, dataset, batch_size, codings_size, n_epochs=1)
|