Genearative Adverserial Network. Two-model architecture. TensorFlow.
- Tung San
- Jun 29, 2022
- 2 min read
Train an AI to generate fake human writings of numbers 0 to 9. Two-model architecture: Generator & Discriminator.
Data Prep
np.random.seed(1000)
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)
Generator Architecture
# Generator generate images
generator = Sequential()
# group1
generator.add(Dense(256, input_dim=randomDim))
generator.add(LeakyReLU(0.2))
# group2
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
# group3
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
# output
generator.add(Dense(784, activation='tanh'))
adam = Adam(learning_rate=0.0002, beta_1=0.5)
generator.compile(loss='binary_crossentropy', optimizer=adam)
Discriminator Architecture
# Discriminator detech if images are faked
discriminator = Sequential()
# group1
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
# group2
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
# group3
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
# output
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)
Combining two networks
# Generator takes random input and generated fake images
GANinput = Input(shape=(randomDim,))
fakeimgs = generator(GANinput)
# Discriminator detect if the image inputted is faked
discriminator.trainable = False
GANoutput = discriminator(fakeimgs)
# Combined network: GAN
GAN = Model(inputs=GANinput, outputs=GANoutput)
GAN.compile(loss='binary_crossentropy', optimizer=adam)
Define functions to show loss per epoch and generated images
import os
os.makedirs("./images/")
disLosses= []
genLosses= []
# Plot the loss from each batch
def plotLoss(epoch):
plt.figure(figsize=(10, 8))
plt.plot(dLosses, label='Discriminitive loss')
plt.plot(gLosses, label='Generative loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('images/gan_loss_epoch_%d.png' % epoch)
def saveGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
noise = np.random.normal(0, 1, size=[examples, randomDim])
generatedImages = generator.predict(noise)
generatedImages = generatedImages.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(generatedImages.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)
Train Begin
def train(epochs=1, batchSize=128):
batchCount = int(X_train.shape[0] / batchSize)
print ('Epochs:', epochs)
print ('Batch size:', batchSize)
print ('Batches per epoch:', batchCount)
for e in range(1, epochs+1):
print (Epoch %d' % e)
for _ in range(batchCount): # Each epoch randomly runs batchCount many times
if i%100 == 0:
print(f"{i}-th run in the {e}-th epoch")
#########################
# Part 1
#########################
# Get a random set of input noise and real images
# the generator at this time point is trainned in the previous ep
noise = np.random.normal(0, 1, size=[batchSize, randomDim])
imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]
# Generate fake MNIST images
generatedImages = generator.predict(noise)
# Both REAL images and generatedImages (fake)
realandfake = np.concatenate([imageBatch, generatedImages])
# Labels for generated and real data
labels_realfake = np.zeros(2*batchSize)
# One-sided label smoothing
# REAL images has value 0.9
labels_realfake[:batchSize] = 0.9
# Train discriminator
# discriminator is trained to distinguish if images is REAL or fake
discriminator.trainable = True
disloss = discriminator.train_on_batch(realandfake, labels_realfake)
#########################
# Part 2
#########################
# Train generator
noiseInput = np.random.normal(0, 1, size=[batchSize, randomDim])
labels_taget1 = np.ones(batchSize)
We want the images generated by the generator to be detected as real, i.e., output close to 1, by the discriminator, so we use a random vector (noise) as input to the generator;
discriminator.trainable = False
KEY HERE! discriminator not learning
genloss = GAN.train_on_batch(noiseInput, labels_taget1)
generator is adjusting the weights s.t. the loss is minimized. Smaller loss here means that the generated images are better in deceiving discriminator
discriminator should be giving close to 0 rating for generated images (fake images) but now generator is adjusting to make it closer to 1 (smaller loss).
#########################
# A run ends
#########################
# Store loss of most recent batch from this epoch
disLosses.append(disloss)
genLosses.append(genloss)
# make and save generated images at the end of an epoch
if e == 1 or e % 5 == 0:
saveGeneratedImages(e)
#########################
# An epoch ends
#########################
# Plot losses from every epoch
plotLoss(e)
Running
train(20, 128)




Eventually, the generators learnt how to deceive discriminator.

Comments