diff --git a/pix2pix/pix2pix.py b/pix2pix/pix2pix.py index 2e8652a341..257f056c08 100644 --- a/pix2pix/pix2pix.py +++ b/pix2pix/pix2pix.py @@ -29,7 +29,6 @@ def __init__(self): self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_rows, self.img_cols)) - # Calculate output shape of D (PatchGAN) patch = int(self.img_rows / 2**4) self.disc_patch = (patch, patch, 1) @@ -56,18 +55,17 @@ def __init__(self): # Input images and their conditioning images img_A = Input(shape=self.img_shape) - img_B = Input(shape=self.img_shape) # By conditioning on B generate a fake version of A - fake_A = self.generator(img_B) + fake_A = self.generator(img_A) # For the combined model we will only train the generator self.discriminator.trainable = False # Discriminators determines validity of translated images / condition pairs - valid = self.discriminator([fake_A, img_B]) + valid = self.discriminator([fake_A, img_A]) - self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A]) + self.combined = Model(inputs=img_A, outputs=[valid, fake_A]) self.combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=optimizer) @@ -171,7 +169,7 @@ def train(self, epochs, batch_size=1, sample_interval=50): # ----------------- # Train the generators - g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A]) + g_loss = self.combined.train_on_batch(imgs_B, [valid, imgs_A]) elapsed_time = datetime.datetime.now() - start_time # Plot the progress