top of page

Genearative Adverserial Network. Two-model architecture. TensorFlow.

  • Writer: Tung San
    Tung San
  • Jun 29, 2022
  • 2 min read
Train an AI to generate fake human writings of numbers 0 to 9. Two-model architecture: Generator & Discriminator.

Data Prep

np.random.seed(1000)

(X_train, _), (_, _) = mnist.load_data()

X_train = (X_train.astype(np.float32) - 127.5)/127.5

X_train = X_train.reshape(60000, 784)



Generator Architecture

# Generator generate images

generator = Sequential()


# group1

generator.add(Dense(256, input_dim=randomDim))

generator.add(LeakyReLU(0.2))


# group2

generator.add(Dense(512))

generator.add(LeakyReLU(0.2))


# group3

generator.add(Dense(1024))

generator.add(LeakyReLU(0.2))


# output

generator.add(Dense(784, activation='tanh'))


adam = Adam(learning_rate=0.0002, beta_1=0.5)

generator.compile(loss='binary_crossentropy', optimizer=adam)



Discriminator Architecture

# Discriminator detech if images are faked

discriminator = Sequential()


# group1

discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))


# group2

discriminator.add(Dense(512))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))


# group3

discriminator.add(Dense(256))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))


# output

discriminator.add(Dense(1, activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer=adam)



Combining two networks

# Generator takes random input and generated fake images

GANinput = Input(shape=(randomDim,))

fakeimgs = generator(GANinput)


# Discriminator detect if the image inputted is faked

discriminator.trainable = False

GANoutput = discriminator(fakeimgs)


# Combined network: GAN

GAN = Model(inputs=GANinput, outputs=GANoutput)

GAN.compile(loss='binary_crossentropy', optimizer=adam)



Define functions to show loss per epoch and generated images

import os

os.makedirs("./images/")

disLosses= []

genLosses= []


# Plot the loss from each batch

def plotLoss(epoch):

plt.figure(figsize=(10, 8))

plt.plot(dLosses, label='Discriminitive loss')

plt.plot(gLosses, label='Generative loss')

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.legend()

plt.savefig('images/gan_loss_epoch_%d.png' % epoch)


def saveGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):

noise = np.random.normal(0, 1, size=[examples, randomDim])

generatedImages = generator.predict(noise)

generatedImages = generatedImages.reshape(examples, 28, 28)


plt.figure(figsize=figsize)

for i in range(generatedImages.shape[0]):

plt.subplot(dim[0], dim[1], i+1)

plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')

plt.axis('off')

plt.tight_layout()

plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)




Train Begin

def train(epochs=1, batchSize=128):

batchCount = int(X_train.shape[0] / batchSize)

print ('Epochs:', epochs)

print ('Batch size:', batchSize)

print ('Batches per epoch:', batchCount)


for e in range(1, epochs+1):

print (Epoch %d' % e)

for _ in range(batchCount): # Each epoch randomly runs batchCount many times

if i%100 == 0:

print(f"{i}-th run in the {e}-th epoch")

#########################

# Part 1

#########################

# Get a random set of input noise and real images

# the generator at this time point is trainned in the previous ep

noise = np.random.normal(0, 1, size=[batchSize, randomDim])

imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]


# Generate fake MNIST images

generatedImages = generator.predict(noise)

# Both REAL images and generatedImages (fake)

realandfake = np.concatenate([imageBatch, generatedImages])

# Labels for generated and real data

labels_realfake = np.zeros(2*batchSize)

# One-sided label smoothing

# REAL images has value 0.9

labels_realfake[:batchSize] = 0.9

# Train discriminator

# discriminator is trained to distinguish if images is REAL or fake

discriminator.trainable = True

disloss = discriminator.train_on_batch(realandfake, labels_realfake)

#########################

# Part 2

#########################

# Train generator

noiseInput = np.random.normal(0, 1, size=[batchSize, randomDim])

labels_taget1 = np.ones(batchSize)

We want the images generated by the generator to be detected as real, i.e., output close to 1, by the discriminator, so we use a random vector (noise) as input to the generator;

discriminator.trainable = False

KEY HERE! discriminator not learning

genloss = GAN.train_on_batch(noiseInput, labels_taget1)

generator is adjusting the weights s.t. the loss is minimized. Smaller loss here means that the generated images are better in deceiving discriminator
discriminator should be giving close to 0 rating for generated images (fake images) but now generator is adjusting to make it closer to 1 (smaller loss).

#########################

# A run ends

#########################


# Store loss of most recent batch from this epoch

disLosses.append(disloss)

genLosses.append(genloss)

# make and save generated images at the end of an epoch

if e == 1 or e % 5 == 0:

saveGeneratedImages(e)

#########################

# An epoch ends

#########################

# Plot losses from every epoch

plotLoss(e)



Running

train(20, 128)











Eventually, the generators learnt how to deceive discriminator.


















Comments


Post: Blog2 Post
bottom of page