利用生成對抗網(wǎng)絡生成海洋塑料合成圖像
問題陳述
過去十年來,海洋塑料污染一直是氣候問題的首要問題。海洋中的塑料不僅能夠通過勒死或饑餓殺死海洋生物,而且也是通過捕獲二氧化碳使海洋變暖的一個主要因素。
近年來,非營利組織海洋清潔組織(Ocean Cleanup)多次嘗試清潔環(huán)繞我們海洋的塑料。很多清理過程的問題是,它需要人力,而且成本效益不高。
通過使用計算機視覺和深度學習檢測海洋碎片,利用ROV和AUV進行清理,已經(jīng)有很多研究將這一過程自動化。
這種方法的主要問題是關于訓練計算機視覺模型的數(shù)據(jù)集的可用性。JAMSTEC-JEDI數(shù)據(jù)集收集了日本沿海海底的海洋廢棄物。
但是,除了這個數(shù)據(jù)集,數(shù)據(jù)集的可用性存在巨大差異。因此,我利用了生成對抗網(wǎng)絡的幫助。
DCGAN尤其致力于合成數(shù)據(jù)集,理論上,隨著時間的推移,這些數(shù)據(jù)集可能與真實數(shù)據(jù)集非常接近。
GAN和DCGAN
2014年,伊恩·古德費羅等人提出了GANs或生成對抗網(wǎng)絡。GANs由兩個簡單的組件組成,分別稱為生成器和鑒別器。
該過程如下:生成器角色用于生成新數(shù)據(jù),而鑒別器角色用于區(qū)分生成的數(shù)據(jù)和實際數(shù)據(jù)。在理想情況下,鑒別器無法區(qū)分生成的數(shù)據(jù)和真實數(shù)據(jù),從而產生理想的合成數(shù)據(jù)點。
DCGAN是上述GAN結構的直接擴展,只是它在鑒別器和發(fā)生器中分別使用了深卷積層。Radford等人在論文中首次描述了深度卷積生成對抗網(wǎng)絡的無監(jiān)督表征學習。鑒別器由跨步卷積層組成,而生成器由卷積轉置層組成。
PyTorch實現(xiàn)
在這種方法中,將在DeepTrash數(shù)據(jù)集。如果你不熟悉DeepTrash數(shù)據(jù)集,請考慮閱讀論文。
DeepTrash是海洋表層和深海表層塑料圖像的集合,旨在利用計算機視覺進行海洋塑料檢測。
讓我們開始編碼吧!
代碼
安裝
我們首先安裝構建GAN模型的所有基本庫,比如Matplotlib和Numpy。
我們還將利用PyTorch的所有工具(如神經(jīng)網(wǎng)絡、轉換)。
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seem for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
初始化超參數(shù)
這一步相當簡單。我們將設置我們想要用來訓練神經(jīng)網(wǎng)絡的超參數(shù)。這些超參數(shù)直接來自于論文和PyTorch的訓練教程。
# Root directory for dataset
# NOTE you don't have to create this. It will be created for you in the next block!
dataroot = "/content/pgan"
# Number of workers for dataloader
workers = 4
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 300
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
生成器和鑒別器
現(xiàn)在,我們定義生成器和鑒別器的體系結構。
# Generator
class Generator(nn.Module):
def __init__(self, ngpu)
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
# Discriminator
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
定義訓練函數(shù)
在定義生成器和鑒別器類之后,我們繼續(xù)定義訓練函數(shù)。
訓練函數(shù)采用生成器、鑒別器、優(yōu)化函數(shù)和epoch數(shù)作為參數(shù)。我們通過遞歸調用train函數(shù)來訓練生成器和鑒別器,直到達到所需的epoch數(shù)。
我們通過迭代數(shù)據(jù)加載器,用生成器中的新圖像更新鑒別器,并計算和更新?lián)p失函數(shù)來實現(xiàn)這一點。
def train(args, gen, disc, device, dataloader, optimizerG, optimizerD, criterion, epoch, iters):
gen.train()
disc.train()
img_list = []
fixed_noise = torch.randn(64, config.nz, 1, 1, device=device)
# Establish convention for real and fake labels during training (with label smoothing)
real_label = 0.9
fake_label = 0.1
for i, data in enumerate(dataloader, 0):
#*****
# Update Discriminator
#*****
## Train with all-real batch
disc.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = disc(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, config.nz, 1, 1, device=device)
# Generate fake image batch with G
fake = gen(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = disc(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
#*****
# Update Generator
#*****
gen.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = disc(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, args.epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
wandb.log({
"Gen Loss": errG.item(),
"Disc Loss": errD.item()})
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == args.epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = gen(fixed_noise).detach().cpu()
img_list.append(wandb.Image(vutils.make_grid(fake, padding=2, normalize=True)))
wandb.log({
"Generated Images": img_list})
iters += 1
監(jiān)督和訓練DCGAN
在我們建立了生成器、鑒別器和訓練函數(shù)之后,最后一步就是簡單地調用我們定義的eoich數(shù)的訓練函數(shù)。我還使用了Wandb,它允許我們監(jiān)控我們的訓練。
#hide-collapse
wandb.watch_called = False
# WandB – Config is a variable that holds and saves
hyperparameters and inputs
config = wandb.config # Initialize config
config.batch_size = batch_size
config.epochs = num_epochs
config.lr = lr
config.beta1 = beta1
config.nz = nz
config.no_cuda = False
config.seed = manualSeed # random seed (default: 42)
config.log_interval = 10 # how many batches to wait before logging training status
def main():
use_cuda = not config.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# Set random seeds and deterministic pytorch for reproducibility
random.seed(config.seed) # python random seed
torch.manual_seed(config.seed) # pytorch random seed
np.random.seed(config.seed) # numpy random seed
torch.backends.cudnn.deterministic = True
# Load the dataset
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
shuffle=True, num_workers=workers)
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr
config.lr, betas=(config.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
# WandB – wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.
# Using log="all" log histograms of parameter values in addition to gradients
wandb.watch(netG, log="all")
wandb.watch(netD, log="all")
iters = 0
for epoch in range(1, config.epochs + 1):
train(config, netG, netD, device, trainloader, optimizerG, optimizerD, criterion, epoch, iters)
# WandB – Save the model checkpoint. This automatically saves a file to the cloud and associates it with the current run.
torch.save(netG.state_dict(), "model.h5")
wandb.save('model.h5')
if __name__ == '__main__':
main()
結果
我們繪制了生成器和鑒別器在訓練期間的損失。
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
我們還可以查看生成器生成的圖像,以查看真實圖像和虛假圖像之間的差異。
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
看起來像這樣:
結論
在本文中,我們討論了使用深度卷積生成對抗網(wǎng)絡生成海洋塑料的合成圖像,研究人員可以使用這些圖像來擴展他們當前的海洋塑料數(shù)據(jù)集。這有助于讓研究人員能夠通過混合真實和合成圖像來擴展他們的數(shù)據(jù)集。
從結果中可以看出,GAN仍然需要大量的工作。海洋是一個復雜的環(huán)境,光照、渾濁度、模糊度等各不相同。
原文標題 : 利用生成對抗網(wǎng)絡生成海洋塑料合成圖像

請輸入評論內容...
請輸入評論/評論長度6~500個字
最新活動更多
推薦專題
- 1 UALink規(guī)范發(fā)布:挑戰(zhàn)英偉達AI統(tǒng)治的開始
- 2 北電數(shù)智主辦酒仙橋論壇,探索AI產業(yè)發(fā)展新路徑
- 3 降薪、加班、裁員三重暴擊,“AI四小龍”已折戟兩家
- 4 “AI寒武紀”爆發(fā)至今,五類新物種登上歷史舞臺
- 5 國產智駕迎戰(zhàn)特斯拉FSD,AI含量差幾何?
- 6 光計算迎來商業(yè)化突破,但落地仍需時間
- 7 東陽光:2024年扭虧、一季度凈利大增,液冷疊加具身智能打開成長空間
- 8 地平線自動駕駛方案解讀
- 9 封殺AI“照騙”,“淘寶們”終于不忍了?
- 10 優(yōu)必選:營收大增主靠小件,虧損繼續(xù)又逢關稅,能否乘機器人東風翻身?