CARVIEW |
Select Language
HTTP/2 200
date: Thu, 31 Jul 2025 04:23:33 GMT
content-type: text/html; charset=utf-8
vary: X-PJAX, X-PJAX-Container, Turbo-Visit, Turbo-Frame, X-Requested-With,Accept-Encoding, Accept, X-Requested-With
etag: W/"90e5bf04b2e62e85f063f747aa2fbdca"
cache-control: max-age=0, private, must-revalidate
strict-transport-security: max-age=31536000; includeSubdomains; preload
x-frame-options: deny
x-content-type-options: nosniff
x-xss-protection: 0
referrer-policy: origin-when-cross-origin, strict-origin-when-cross-origin
content-security-policy: default-src 'none'; base-uri 'self'; child-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/; connect-src 'self' uploads.github.com www.githubstatus.com collector.github.com raw.githubusercontent.com api.github.com github-cloud.s3.amazonaws.com github-production-repository-file-5c1aeb.s3.amazonaws.com github-production-upload-manifest-file-7fdce7.s3.amazonaws.com github-production-user-asset-6210df.s3.amazonaws.com *.rel.tunnels.api.visualstudio.com wss://*.rel.tunnels.api.visualstudio.com objects-origin.githubusercontent.com copilot-proxy.githubusercontent.com proxy.individual.githubcopilot.com proxy.business.githubcopilot.com proxy.enterprise.githubcopilot.com *.actions.githubusercontent.com wss://*.actions.githubusercontent.com productionresultssa0.blob.core.windows.net/ productionresultssa1.blob.core.windows.net/ productionresultssa2.blob.core.windows.net/ productionresultssa3.blob.core.windows.net/ productionresultssa4.blob.core.windows.net/ productionresultssa5.blob.core.windows.net/ productionresultssa6.blob.core.windows.net/ productionresultssa7.blob.core.windows.net/ productionresultssa8.blob.core.windows.net/ productionresultssa9.blob.core.windows.net/ productionresultssa10.blob.core.windows.net/ productionresultssa11.blob.core.windows.net/ productionresultssa12.blob.core.windows.net/ productionresultssa13.blob.core.windows.net/ productionresultssa14.blob.core.windows.net/ productionresultssa15.blob.core.windows.net/ productionresultssa16.blob.core.windows.net/ productionresultssa17.blob.core.windows.net/ productionresultssa18.blob.core.windows.net/ productionresultssa19.blob.core.windows.net/ github-production-repository-image-32fea6.s3.amazonaws.com github-production-release-asset-2e65be.s3.amazonaws.com insights.github.com wss://alive.github.com wss://alive-staging.github.com api.githubcopilot.com api.individual.githubcopilot.com api.business.githubcopilot.com api.enterprise.githubcopilot.com; font-src github.githubassets.com; form-action 'self' github.com gist.github.com copilot-workspace.githubnext.com objects-origin.githubusercontent.com; frame-ancestors 'none'; frame-src viewscreen.githubusercontent.com notebooks.githubusercontent.com; img-src 'self' data: blob: github.githubassets.com media.githubusercontent.com camo.githubusercontent.com identicons.github.com avatars.githubusercontent.com private-avatars.githubusercontent.com github-cloud.s3.amazonaws.com objects.githubusercontent.com release-assets.githubusercontent.com secured-user-images.githubusercontent.com/ user-images.githubusercontent.com/ private-user-images.githubusercontent.com opengraph.githubassets.com copilotprodattachments.blob.core.windows.net/github-production-copilot-attachments/ github-production-user-asset-6210df.s3.amazonaws.com customer-stories-feed.github.com spotlights-feed.github.com objects-origin.githubusercontent.com *.githubusercontent.com; manifest-src 'self'; media-src github.com user-images.githubusercontent.com/ secured-user-images.githubusercontent.com/ private-user-images.githubusercontent.com github-production-user-asset-6210df.s3.amazonaws.com gist.github.com; script-src github.githubassets.com; style-src 'unsafe-inline' github.githubassets.com; upgrade-insecure-requests; worker-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/
server: github.com
content-encoding: gzip
accept-ranges: bytes
set-cookie: _gh_sess=V54vjTjlGjOSBhymN6g0C%2BTs90pO9uFPoFmjqFeFcZ%2F1MkrF1VIzi2TUD%2FO1jw3d577NkHG9y%2FVgXOBaF%2Bmx8GKt%2FKv2jEOL8%2FU2xW3j0wFRVFFn74CADVofXMrYicPq2f2x7Nrb6qDM0Dv9vcLG3b87ZMCKIZHVwouM8SY8H1Rmfzra86IcLh94vGIRHGJ9iZV0P%2FrL%2FKmPVku7YvHzNqRnLyHoiVyOPg7grc2YfqPjgZoMzkyydHDg1%2F%2BnyVPhK51uJ2eYz1YIPBefB5IotQ%3D%3D--UdNBlJPvvSSSrSzE--YicOLpTcPz1puXtMDu2umQ%3D%3D; Path=/; HttpOnly; Secure; SameSite=Lax
set-cookie: _octo=GH1.1.1033382473.1753935813; Path=/; Domain=github.com; Expires=Fri, 31 Jul 2026 04:23:33 GMT; Secure; SameSite=Lax
set-cookie: logged_in=no; Path=/; Domain=github.com; Expires=Fri, 31 Jul 2026 04:23:33 GMT; HttpOnly; Secure; SameSite=Lax
x-github-request-id: EA4C:3BDB3D:23766F:32BD69:688AEFC5
Simple implementation of Generative Adversarial Nets using chainer · GitHub
Show Gist options
Save rezoo/4e005611aaa4dad26697 to your computer and use it in GitHub Desktop.
{{ message }}
Instantly share code, notes, and snippets.
-
Star
4
(4)
You must be signed in to star a gist -
Fork
1
(1)
You must be signed in to fork a gist
-
Save rezoo/4e005611aaa4dad26697 to your computer and use it in GitHub Desktop.
Simple implementation of Generative Adversarial Nets using chainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gzip | |
import os | |
import numpy as np | |
import six | |
from six.moves.urllib import request | |
parent = 'https://yann.lecun.com/exdb/mnist' | |
train_images = 'train-images-idx3-ubyte.gz' | |
train_labels = 'train-labels-idx1-ubyte.gz' | |
test_images = 't10k-images-idx3-ubyte.gz' | |
test_labels = 't10k-labels-idx1-ubyte.gz' | |
num_train = 60000 | |
num_test = 10000 | |
dim = 784 | |
def load_mnist(images, labels, num): | |
data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) | |
target = np.zeros(num, dtype=np.uint8).reshape((num, )) | |
with gzip.open(images, 'rb') as f_images,\ | |
gzip.open(labels, 'rb') as f_labels: | |
f_images.read(16) | |
f_labels.read(8) | |
for i in six.moves.range(num): | |
target[i] = ord(f_labels.read(1)) | |
for j in six.moves.range(dim): | |
data[i, j] = ord(f_images.read(1)) | |
return data, target | |
def download_mnist_data(): | |
print('Downloading {:s}...'.format(train_images)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images) | |
print('Done') | |
print('Downloading {:s}...'.format(train_labels)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels) | |
print('Done') | |
print('Downloading {:s}...'.format(test_images)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images) | |
print('Done') | |
print('Downloading {:s}...'.format(test_labels)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels) | |
print('Done') | |
print('Converting training data...') | |
data_train, target_train = load_mnist(train_images, train_labels, | |
num_train) | |
print('Done') | |
print('Converting test data...') | |
data_test, target_test = load_mnist(test_images, test_labels, num_test) | |
mnist = {} | |
mnist['data'] = np.append(data_train, data_test, axis=0) | |
mnist['target'] = np.append(target_train, target_test, axis=0) | |
print('Done') | |
print('Save output...') | |
with open('mnist.pkl', 'wb') as output: | |
six.moves.cPickle.dump(mnist, output, -1) | |
print('Done') | |
print('Convert completed') | |
def load_mnist_data(): | |
if not os.path.exists('mnist.pkl'): | |
download_mnist_data() | |
with open('mnist.pkl', 'rb') as mnist_pickle: | |
mnist = six.moves.cPickle.load(mnist_pickle) | |
return mnist |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import logging | |
import argparse | |
import pickle | |
import numpy as np | |
from PIL import Image | |
import chainer | |
from chainer import cuda | |
import chainer.functions as F | |
import chainer.optimizers | |
import data | |
class GANModel(chainer.FunctionSet): | |
n_hidden = 100 | |
def __init__(self): | |
super(GANModel, self).__init__( | |
g_fc0=F.Linear(self.n_hidden, 500), | |
g_fc1=F.Linear(500, 500), | |
g_fc2=F.Linear(500, 784), | |
d_fc0=F.Linear(784, 240), | |
d_fc1=F.Linear(240, 240), | |
d_fc2=F.Linear(240, 1)) | |
@property | |
def generators(self): | |
return [self.g_fc0, self.g_fc1, self.g_fc2] | |
@property | |
def discriminators(self): | |
return [self.d_fc0, self.d_fc1, self.d_fc2] | |
def make_z(self, n): | |
return 0.2 * np.asarray( | |
np.random.randn(n, self.n_hidden), | |
dtype=np.float32) | |
def make_generator(self, z, train=True): | |
h = F.dropout(F.relu(self.g_fc0(z)), train=train) | |
h = F.dropout(F.relu(self.g_fc1(h)), train=train) | |
return self.g_fc2(h) | |
# return F.sigmoid(self.g_fc2(h)) | |
def make_discriminator(self, x, t, train=True): | |
h = F.relu(self.d_fc0(x)) | |
h = F.relu(self.d_fc1(h)) | |
h = self.d_fc2(h) | |
return F.sigmoid_cross_entropy(h, t) | |
def collect_generator_parameters(self): | |
parameters = ( | |
sum((f.parameters for f in self.generators), ()), | |
sum((f.gradients for f in self.generators), ()), | |
) | |
return parameters | |
def collect_discriminator_parameters(self): | |
parameters = ( | |
sum((f.parameters for f in self.discriminators), ()), | |
sum((f.gradients for f in self.discriminators), ()), | |
) | |
return parameters | |
def generate(self, z_data, train=True): | |
z = chainer.Variable(z_data) | |
x = self.make_generator(z, train=train) | |
return x.data | |
def forward_xy(self, x_data, t_data): | |
x = chainer.Variable(x_data) | |
t = chainer.Variable(t_data) | |
loss = self.make_discriminator(x, t) | |
return loss | |
def forward_zy(self, z_data, t_data): | |
z = chainer.Variable(z_data) | |
t = chainer.Variable(t_data) | |
x = self.make_generator(z) | |
loss = self.make_discriminator(x, t) | |
return loss | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--batch', type=int, default=50) | |
parser.add_argument('--epoch', type=int, default=30) | |
parser.add_argument('-g', '--gpu', type=int, default=-1) | |
parser.add_argument('--display', type=int, default=100) | |
parser.add_argument('--image', default='./') | |
parser.add_argument('--dst', default='model.pkl') | |
args = parser.parse_args() | |
logging.basicConfig( | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
if 0 <= args.gpu: | |
cuda.init(args.gpu) | |
logger.info('loading MNIST dataset...') | |
mnist = data.load_mnist_data() | |
mnist['data'] = mnist['data'].astype(np.float32) | |
mnist['data'] /= 255 | |
x_train = mnist['data'] | |
logger.info('constructing GAN model...') | |
model = GANModel() | |
if 0 <= args.gpu: | |
model = model.to_gpu(args.gpu) | |
logger.info('initializing two optimizers...') | |
g_optimizer = chainer.optimizers.Adam(alpha=-1e-5) | |
g_optimizer.setup(model.collect_generator_parameters()) | |
d_optimizer = chainer.optimizers.Adam(alpha=1e-5) | |
d_optimizer.setup(model.collect_discriminator_parameters()) | |
example_z = model.make_z(100) | |
if 0 <= args.gpu: | |
example_z = cuda.to_gpu(example_z, args.gpu) | |
iteration = 0 | |
for epoch in xrange(1, args.epoch + 1): | |
logger.info('epoch %i', epoch) | |
perm = np.random.permutation(x_train.shape[0]) | |
sum_dloss, sum_gloss = 0.0, 0.0 | |
example_x = cuda.to_cpu(model.generate(example_z, train=False)) | |
example_x = example_x.reshape(10, 10, 28, 28).transpose([0, 2, 1, 3]) | |
example_x = np.clip( | |
255 * example_x.reshape(280, 280), 0.0, 255.0).astype(np.uint8) | |
img = Image.fromarray(example_x) | |
img.save(os.path.join(args.image, "{:03}.png".format(epoch))) | |
for i in xrange(0, x_train.shape[0], args.batch): | |
iteration += 1 | |
batchsize = min(i + args.batch, x_train.shape[0]) - i | |
# update discriminator | |
x_batch = np.empty( | |
(2 * batchsize, x_train.shape[1]), dtype=np.float32) | |
t_batch = np.empty((2 * batchsize, 1), dtype=np.int32) | |
z_batch = model.make_z(batchsize) | |
if 0 <= args.gpu: | |
z_batch = cuda.to_gpu(z_batch, args.gpu) | |
x_batch[:batchsize] = cuda.to_cpu(model.generate(z_batch)) | |
x_batch[batchsize:] = x_train[perm[i:i + batchsize]] | |
t_batch[:batchsize] = 0 | |
t_batch[batchsize:] = 1 | |
if 0 <= args.gpu: | |
x_batch = cuda.to_gpu(x_batch, args.gpu) | |
t_batch = cuda.to_gpu(t_batch, args.gpu) | |
d_optimizer.zero_grads() | |
#g_optimizer.zero_grads() | |
dloss = model.forward_xy(x_batch, t_batch) | |
dloss.backward() | |
d_optimizer.update() | |
dloss_data = float(cuda.to_cpu(dloss.data)) | |
sum_dloss += dloss_data * batchsize * 2 | |
# update generator | |
z_batch = model.make_z(batchsize) | |
t_batch = np.zeros((batchsize, 1), dtype=np.int32) | |
if 0 <= args.gpu: | |
z_batch = cuda.to_gpu(z_batch, args.gpu) | |
t_batch = cuda.to_gpu(t_batch, args.gpu) | |
#d_optimizer.zero_grads() | |
g_optimizer.zero_grads() | |
gloss = model.forward_zy(z_batch, t_batch) | |
gloss.backward() | |
g_optimizer.update() | |
gloss_data = float(cuda.to_cpu(gloss.data)) | |
sum_gloss += gloss_data * batchsize | |
if iteration % args.display == 0: | |
logger.info( | |
'loss D:%.3e G:%.3e iter:%i', | |
dloss_data, gloss_data, iteration) | |
ave_dloss = sum_dloss / (2 * x_train.shape[0]) | |
ave_gloss = sum_gloss / x_train.shape[0] | |
logger.info( | |
'train mean loss D:%.3e G:%.3e epoch:%i', | |
ave_dloss, ave_gloss, epoch) | |
logger.info('done. now pickling model...') | |
with open(args.dst, 'wb') as fp: | |
pickle.dump(model, fp) | |
if __name__ == "__main__": | |
sys.exit(main()) |
@vebmaylrie Sorry for replying late.
Exactly, in order to make the discriminator missclassify the sampled data,
I set the alpha in the optimizer of the optimizer to the negative value.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can’t perform that action at this time.
Is Line 169 np.ones(*) ? This is because the generator should be trained to make the discriminator missclassify the sampled data.