Image Super-Resolution (ISR) involves improving the quality of images by increasing their resolution, creating superior images from lower resolution versions. ESRGAN, an advanced model for super-resolution tasks, is renowned for producing lifelike high-resolution images and maintaining crucial details.
This article offers a detailed tutorial on how to utilize ESRGAN with PyTorch. We will discuss the structure, main parts, and the process of implementing it.
What is ESRGAN?
ESRGAN is an upgraded iteration of the SRGAN, incorporating GANs into super-resolution processes. ESRGAN tackles certain drawbacks of SRGAN, enhancing the visual quality and incorporating more intricate elements into the produced images. Notable enhancements in ESRGAN comprise:
- Residual-in-Residual Dense Block (RRDB) enhances feature learning and improves training stability by replacing conventional residual blocks.
- ESRGAN uses an enhanced perceptual loss function to create more realistic images.
- Relativistic GAN: ESRGAN chooses to utilize Relativistic GAN loss instead of traditional GAN losses, aiming for the discriminator to perceive real images as relatively more authentic compared to fake images.
Components of ESRGAN Architecture
ESRGAN's structure comprises a generator and a discriminator as its main components.
- Generator: The generator converts images with low resolution (LR) into images with high resolution (HR). RRDB blocks are utilized for feature extraction and image upscaling in the process. This is the location of the main improvement.
- Discriminator: The discriminator assesses if an image is authentic (HR) or produced (SR, super-resolved). It has been taught to distinguish between authentic HR images and the ones that are created.
- Perceptual Loss: One of the key innovations in ESRGAN is the use of perceptual loss, which measures how perceptually similar the generated image is to the ground truth. This is done by comparing feature maps extracted from a pre-trained VGG network, ensuring that the generated images look visually pleasing to the human eye.
Loss Functions Used in ESRGAN
- Adversarial Loss: Adversarial loss ensures the generated image looks as realistic as possible. It's computed by comparing the generated image to real images through the discriminator network.
- Content Loss: Content loss measures the difference between the high-resolution ground truth and the generated high-resolution image. This is typically done at the pixel level using Mean Squared Error (MSE).
- Perceptual Loss: Perceptual loss, as mentioned earlier, compares feature maps between the generated image and the ground truth using a pre-trained model like VGG.
Implementing the ESRGAN Model in PyTorch
In this part, we will carry out the implementation of ESRGAN in PyTorch, by executing the following steps:
1. Setup the Environment
Make sure that PyTorch is properly installed.
pip install torch torchvision
Additionally, matplotlib, PIL, and numpy are necessary dependencies for loading and displaying images.
pip install matplotlib pillow numpy
2. Data Preparation
Get the dataset ready by including both low-resolution and high-resolution images. You have the option to utilize pre-existing datasets such as DIV2K or generate a customized dataset by resizing high-resolution images to produce their low-resolution counterparts.
Python
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = os.listdir(root_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image
3. ESRGAN Model Architecture
Next, we will create the ESRGAN model. Begin by incorporating the Residual-in-Residual Dense Block (RRDB) into the code.
Python
import torch
import torch.nn as nn
class RRDB(nn.Module):
def __init__(self, in_channels):
super(RRDB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return x + out # Residual connection
class Generator(nn.Module):
def __init__(self, in_channels=3, num_rrdb=23):
super(Generator, self).__init__()
self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
self.rrdb_blocks = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb)])
self.final_conv = nn.Conv2d(64, in_channels, kernel_size=3, padding=1)
def forward(self, x):
initial_feature = self.initial_conv(x)
out = self.rrdb_blocks(initial_feature)
out = self.final_conv(out)
return out
4. Discriminator
Following that, we establish the Discriminator. This network's job will be to differentiate between the created super-resolved images and the actual high-resolution ones.
Python
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.BatchNorm2d(out_feat))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(in_channels, 64, normalize=False),
*block(64, 128),
*block(128, 256),
*block(256, 512),
nn.Conv2d(512, 1, 3, stride=1, padding=1)
)
def forward(self, img):
return self.model(img)
5. Loss Functions
We establish the loss functions utilized for ESRGAN training, which consist of content loss, perceptual loss, and adversarial loss.
Python
import torch.nn.functional as F
class ContentLoss(nn.Module):
def __init__(self):
super(ContentLoss, self).__init__()
def forward(self, sr, hr):
return F.mse_loss(sr, hr)
class PerceptualLoss(nn.Module):
def __init__(self, vgg_model):
super(PerceptualLoss, self).__init__()
self.vgg = vgg_model.features[:36] # Use pre-trained VGG features
self.vgg.eval()
def forward(self, sr, hr):
sr_features = self.vgg(sr)
hr_features = self.vgg(hr)
return F.mse_loss(sr_features, hr_features)
6. Training Loop
Ultimately, establish the training loop in which both the generator and discriminator are trained concurrently through the adversarial framework.
Python
def train(generator, discriminator, dataloader, num_epochs, optimizer_G, optimizer_D, criterion_content, criterion_perceptual):
for epoch in range(num_epochs):
for i, img in enumerate(dataloader):
# Train Generator
optimizer_G.zero_grad()
sr_image = generator(img)
content_loss = criterion_content(sr_image, img)
perceptual_loss = criterion_perceptual(sr_image, img)
g_loss = content_loss + perceptual_loss
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_output = discriminator(img)
fake_output = discriminator(sr_image.detach())
d_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output)) + \
F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
d_loss.backward()
optimizer_D.step()
if i % 10 == 0:
print(f"Epoch {epoch}/{num_epochs}, Step {i}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
7. Making deductions
After training the model, you have the capability to enhance low-resolution images.
Python
def upscale_image(generator, lr_image):
generator.eval()
with torch.no_grad():
sr_image = generator(lr_image)
return sr_image
The code contains essential procedures for installing dependencies, loading data, specifying the ESRGAN architecture, training the model, and ultimately testing it using a sample image.
For simplicity, we will utilize a pre-existing VGG model for perceptual loss and a smaller dataset for showcasing.
Import Required Libraries
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
from torchvision.utils import save_image
Data Preparation
The BSD300 dataset will be utilized for training.
!wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
!tar -xvzf BSDS300-images.tgz
Python
# Define dataset class for loading and preprocessing images
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = os.listdir(root_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name).convert("RGB")
if self.transform:
image = self.transform(image)
return image
# Define transforms (resize images for simplicity)
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
# Create Dataset and DataLoader
dataset = ImageDataset(root_dir="BSDS300/images/train", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Define ESRGAN Model Architecture
Python
class RRDB(nn.Module):
def __init__(self, in_channels):
super(RRDB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return x + out # Residual connection
class Generator(nn.Module):
def __init__(self, in_channels=3, num_rrdb=23):
super(Generator, self).__init__()
self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
self.rrdb_blocks = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb)])
self.final_conv = nn.Conv2d(64, in_channels, kernel_size=3, padding=1)
def forward(self, x):
initial_feature = self.initial_conv(x)
out = self.rrdb_blocks(initial_feature)
out = self.final_conv(out)
return out
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.BatchNorm2d(out_feat))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(in_channels, 64, normalize=False),
*block(64, 128),
*block(128, 256),
*block(256, 512),
nn.Conv2d(512, 1, 3, stride=1, padding=1)
)
def forward(self, img):
return self.model(img)
Loss Functions (Content and Perceptual Loss)
We utilize a pre-existing VGG19 model for the perceptual loss function.
Python
class ContentLoss(nn.Module):
def __init__(self):
super(ContentLoss, self).__init__()
def forward(self, sr, hr):
return F.mse_loss(sr, hr)
class PerceptualLoss(nn.Module):
def __init__(self, vgg_model):
super(PerceptualLoss, self).__init__()
self.vgg = vgg_model.features[:36] # Use pre-trained VGG features
self.vgg.eval()
def forward(self, sr, hr):
sr_features = self.vgg(sr)
hr_features = self.vgg(hr)
return F.mse_loss(sr_features, hr_features)
Training Loop
Python
def train(generator, discriminator, dataloader, num_epochs, optimizer_G, optimizer_D, criterion_content, criterion_perceptual, device):
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):
for i, img in enumerate(dataloader):
img = img.to(device)
# Generate super-resolved image
sr_image = generator(img)
# Train Generator
optimizer_G.zero_grad()
content_loss = criterion_content(sr_image, img)
perceptual_loss = criterion_perceptual(sr_image, img)
g_loss = content_loss + perceptual_loss
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_output = discriminator(img)
fake_output = discriminator(sr_image.detach())
d_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output)) + \
F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
d_loss.backward()
optimizer_D.step()
if i % 10 == 0:
print(f"Epoch {epoch}/{num_epochs}, Step {i}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
Initialize Model and Start Training
Python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize Generator, Discriminator, and Optimizers
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
# Load pre-trained VGG model for Perceptual Loss
vgg = models.vgg19(pretrained=True).to(device)
criterion_content = ContentLoss()
criterion_perceptual = PerceptualLoss(vgg)
# Train ESRGAN
train(generator, discriminator, dataloader, num_epochs=2, optimizer_G=optimizer_G, optimizer_D=optimizer_D,
criterion_content=criterion_content, criterion_perceptual=criterion_perceptual, device=device)
Output:
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 72.1MB/s]
Epoch 0/2, Step 0, G Loss: 0.6165198087692261, D Loss: 1.3769948482513428
Epoch 0/2, Step 10, G Loss: 0.21817927062511444, D Loss: 1.336503267288208
Epoch 0/2, Step 20, G Loss: 0.12613624334335327, D Loss: 1.1241037845611572
Epoch 0/2, Step 30, G Loss: 0.18434345722198486, D Loss: 0.7252156138420105
Epoch 0/2, Step 40, G Loss: 0.05957853049039841, D Loss: 0.7124162912368774
Epoch 1/2, Step 0, G Loss: 0.03649333491921425, D Loss: 0.7204165458679199
Epoch 1/2, Step 10, G Loss: 0.04035758227109909, D Loss: 1.0853936672210693
Epoch 1/2, Step 20, G Loss: 0.02555007115006447, D Loss: 0.5146785974502563
Epoch 1/2, Step 30, G Loss: 0.034172821789979935, D Loss: 0.27036136388778687
Epoch 1/2, Step 40, G Loss: 0.024695610627532005, D Loss: 0.35952311754226685
Testing with a Sample Image
Python
# Load a test image
test_image = Image.open("BSDS300/images/test/3096.jpg").convert("RGB")
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
test_image = transform(test_image).unsqueeze(0).to(device)
# Generate super-resolved image
generator.eval()
with torch.no_grad():
sr_image = generator(test_image)
# Save and Display Results
save_image(sr_image, "sr_image.png")
save_image(test_image, "lr_image.png")
# Show images
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title("Low-Resolution Image")
plt.imshow(np.transpose(test_image.squeeze().cpu().numpy(), (1, 2, 0)))
plt.subplot(1, 2, 2)
plt.title("Super-Resolved Image")
plt.imshow(np.transpose(sr_image.squeeze().cpu().numpy(), (1, 2, 0)))
plt.show()
Output:
ESRGAN Model in PyTorchConclusion
ESRGAN offers a strong technique for enhancing the quality of enlarged images, boosting super-resolution tasks effectively. Developing ESRGAN in PyTorch requires creating the RRDB-based generator, the discriminator, and training them with adversarial loss, content loss, and perceptual loss. This model can be utilized to improve the quality of images for different purposes, including image restoration, medical imaging, and video processing.