Keras: Is there a way to save and load GAN models without losing the optimizer state?

Created on 30 Jul 2018  路  10Comments  路  Source: keras-team/keras

Since the discriminators are included in the GAN and they also need to be used separately during training - how do you save and load GANs? Now, I save the generators and discriminators separately and recompile the GAN for each training episode, but I lose the optimizer state this way. Before you could extract the optimizer state, but it was removed a few releases ago.

Most helpful comment

I think I managed to finally solve this issue after much frustration and eventually switching to tensorflow.keras. I'll summarize.

keras doesn't seem to respect model.trainable when re-loading a model. So if you have a model with an inner submodel with submodel.trainable = False, when you attempt to reload model at a later point and call model.summary() you will notice that all layers are trainable and then you get that optimizer state warning when loading model.

What's interesting is that this isn't the case with tensorflow.keras. In that library if you set submodel.trainable = False and reload model latter, you'll notice that model.summary does in fact have quite a large number of un-trainable parameters.

Another thing to keep in mind is that submodel.trainable behaves differently when training versus saving the model. For training, whatever trainable is set to prior to calling model.compile is what is respected in training. However, when calling model.save() all that matters is what trainable is set to prior to calling save (it doesn't care about what trainable was when compiled).

So in the context of GANs, one needs to ensure the following:

  1. Dump keras and switch to tensorflow.keras.
  2. Assuming gan is the combined model, generator and discriminator are the submodels, then one can carry out constructing the models as follows:
def create_generator():
    generator = Sequential()
    ...

    return generator


def create_discriminator():
    discriminator = Sequential()
    ...
    return discriminator


def create_gan(generator, discriminator):
    discriminator.trainable = False

    gan_input = Input(shape=(INPUT_SIZE,))
    generator_output = generator(gan_input)
    gan_output = discriminator(generator_output)

    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

    gan.summary()
    discriminator.trainable = True
    return gan

One can carry out saving and loading as follows:

def save(gan, generator, discriminator):
    discriminator.trainable = False
    save_model(gan, 'gan')
    discriminator.trainable = True
    save_model(generator, 'generator')
    save_model(discriminator, 'discriminator')


def load():
    discriminator = load_model('discriminator')
    generator = load_model('generator')
    gan = load_model('gan')
    gan.summary()
    discriminator.summary()
    generator.summary()

    return gan, generator, discriminator

You'll notice that its necessary to switch the trainable state of discriminator prior to saving the gan in order to ensure the discriminator part of it isn't trainable, and hence properly re-loaded later. It doesn't matter whether or not it was set during compile.

It's also important to note that this example will only work on tensorflow.keras and not keras. Let me know if this helps or if anyone finds any issues.

All 10 comments

If you are defining custom optimizer then you need to import the definition. Or you can define it again when you load_model again, and compile with that optimizer. The model won't lose it's weights.

Thanks, I was thinking of the state of the optimizer, e.g. if you are using learning decay or momentum. If you recompile it, the weights are fine, but I lose the optimizer state.

Further, this script demonstrates that for CGANs (I'm assuming it also applies to the non-conditional variants), saving and loading does not preserve the discriminators state.

import numpy as np
from keras import Input, Model, losses, optimizers
from keras.engine.saving import load_model
from keras.layers import Dense, concatenate

# Arbitrary constants
N_LATENT = 100
N_FEATURES = 100
N_FLAGS = 7
N_ROWS = 100

print("Building inputs")
noise = Input(shape=(N_LATENT,), name="noise")
flags = Input(shape=(N_FLAGS,), name="flags")
features = Input(shape=(N_FEATURES,), name="features")

print("Discriminator")
d = concatenate([features, flags])
d = Dense(52, activation='relu')(d)
d = Dense(52, activation='relu')(d)
d_out = Dense(1, name='d_out')(d)
D = Model([features, flags], d_out, name="D")
D.compile(
    loss=losses.binary_crossentropy,
    optimizer=optimizers.Adadelta(),
)
D.summary()

print("Generator")
g = concatenate([features, noise])
g = Dense(52, activation='relu')(g)
g = Dense(52, activation='relu')(g)
g_out = Dense(7, activation='sigmoid', name='g_out')(g)
G = Model([features, noise], g_out, name="G")
G.summary()

print("GAN")
for l in D.layers:
    l.trainable = False
gan_out = D([G([features, noise]), features])
GAN = Model([features, noise], gan_out)
GAN.compile(
    loss=losses.binary_crossentropy,
    optimizer=optimizers.Adadelta(),
)
GAN.summary()

features = np.random.normal(0, 1, (N_ROWS, 100))
noise = np.random.normal(0, 1, (N_ROWS, N_LATENT))
flags = np.random.uniform(0, 1, (N_ROWS, 7))
ones = np.ones((N_ROWS, 1))

# Save
D.save('./D')
G.save('./G')
GAN.save('./GAN')

del D
del G
del GAN

print("D")
D = load_model('./D')
D.summary()

print("G")
G = load_model('./G')
G.summary()

print("GAN")
GAN = load_model('./GAN')
GAN.summary()

After running that script, examine the output of D.summary(). All parameters are marked as non-trainable, and along with that, the optimizer has been refreshed.

Are there any fixes for this?

running into same issue also found similar issue here https://github.com/keras-team/keras/issues/9589

solved it by upgrading Keras to 2.2.4 and using pickle. i was having issue with pix2pix model

@ismailsimsek Do you have a quick example I can test?

Also, I created a related question on StackOverflow: https://stackoverflow.com/questions/52463551/how-do-you-train-gans-using-multiple-gpus-with-keras

I'm also struggling to resolve this error. Is there a way to easily save and load a GAN without losing information?

I think I managed to finally solve this issue after much frustration and eventually switching to tensorflow.keras. I'll summarize.

keras doesn't seem to respect model.trainable when re-loading a model. So if you have a model with an inner submodel with submodel.trainable = False, when you attempt to reload model at a later point and call model.summary() you will notice that all layers are trainable and then you get that optimizer state warning when loading model.

What's interesting is that this isn't the case with tensorflow.keras. In that library if you set submodel.trainable = False and reload model latter, you'll notice that model.summary does in fact have quite a large number of un-trainable parameters.

Another thing to keep in mind is that submodel.trainable behaves differently when training versus saving the model. For training, whatever trainable is set to prior to calling model.compile is what is respected in training. However, when calling model.save() all that matters is what trainable is set to prior to calling save (it doesn't care about what trainable was when compiled).

So in the context of GANs, one needs to ensure the following:

  1. Dump keras and switch to tensorflow.keras.
  2. Assuming gan is the combined model, generator and discriminator are the submodels, then one can carry out constructing the models as follows:
def create_generator():
    generator = Sequential()
    ...

    return generator


def create_discriminator():
    discriminator = Sequential()
    ...
    return discriminator


def create_gan(generator, discriminator):
    discriminator.trainable = False

    gan_input = Input(shape=(INPUT_SIZE,))
    generator_output = generator(gan_input)
    gan_output = discriminator(generator_output)

    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

    gan.summary()
    discriminator.trainable = True
    return gan

One can carry out saving and loading as follows:

def save(gan, generator, discriminator):
    discriminator.trainable = False
    save_model(gan, 'gan')
    discriminator.trainable = True
    save_model(generator, 'generator')
    save_model(discriminator, 'discriminator')


def load():
    discriminator = load_model('discriminator')
    generator = load_model('generator')
    gan = load_model('gan')
    gan.summary()
    discriminator.summary()
    generator.summary()

    return gan, generator, discriminator

You'll notice that its necessary to switch the trainable state of discriminator prior to saving the gan in order to ensure the discriminator part of it isn't trainable, and hence properly re-loaded later. It doesn't matter whether or not it was set during compile.

It's also important to note that this example will only work on tensorflow.keras and not keras. Let me know if this helps or if anyone finds any issues.

@ bradsheppard can you confirm that your suggested method works?

@bradsheppard 's solution worked for me (currently using TensorFlow 2.1.0)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

nryant picture nryant  路  3Comments

yil8 picture yil8  路  3Comments

amityaffliction picture amityaffliction  路  3Comments

snakeztc picture snakeztc  路  3Comments

anjishnu picture anjishnu  路  3Comments