代码拉取完成,页面将自动刷新
import os, time, multiprocessing
import numpy as np
import tensorflow as tf
import tensorlayer as tl
from glob import glob
from data import get_celebA, flags
from model import get_generator, get_discriminator
num_tiles = int(np.sqrt(flags.sample_size))
def train():
images, images_path = get_celebA(flags.output_size, flags.n_epoch, flags.batch_size)
G = get_generator([None, flags.z_dim])
D = get_discriminator([None, flags.output_size, flags.output_size, flags.c_dim])
G.train()
D.train()
d_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
g_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
n_step_epoch = int(len(images_path) // flags.batch_size)
# Z = tf.distributions.Normal(0., 1.)
for epoch in range(flags.n_epoch):
for step, batch_images in enumerate(images):
if batch_images.shape[0] != flags.batch_size: # if the remaining data in this epoch < batch_size
break
step_time = time.time()
with tf.GradientTape(persistent=True) as tape:
# z = Z.sample([flags.batch_size, flags.z_dim])
z = np.random.normal(loc=0.0, scale=1.0, size=[flags.batch_size, flags.z_dim]).astype(np.float32)
d_logits = D(G(z))
d2_logits = D(batch_images)
# discriminator: real images are labelled as 1
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
# discriminator: images from generator (fake) are labelled as 0
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
# combined loss for updating discriminator
d_loss = d_loss_real + d_loss_fake
# generator: try to fool discriminator to output 1
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
grad = tape.gradient(g_loss, G.trainable_weights)
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
grad = tape.gradient(d_loss, D.trainable_weights)
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
del tape
print("Epoch: [{}/{}] [{}/{}] took: {:.3f}, d_loss: {:.5f}, g_loss: {:.5f}".format(epoch, \
flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))
if np.mod(epoch, flags.save_every_epoch) == 0:
G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz')
D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz')
G.eval()
result = G(z)
G.train()
tl.visualize.save_images(result.numpy(), [num_tiles, num_tiles], '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))
if __name__ == '__main__':
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。