【我的创作纪念日】使用pix2pixgan实现barts2020数据集的处理(完整版本)
使用pix2pixgan (pytorch)实现T1 -> T2的基本代码
使用 https://github.com/eriklindernoren/PyTorch-GAN/ 这里面的pix2pixgan代码进行实现。
进去之后我们需要重新处理数据集,并且源代码里面先训练的生成器,后训练鉴别器。
一般情况下,先训练判别器而后训练生成器是因为这种训练顺序在理论和实践上更加稳定和有效。我们需要改变顺序以及一些代码:
以下是一些原因:
- 判别器的任务相对简单:判别器的任务是将真实样本与生成样本区分开来。这相对于生成器而言是一个相对简单的分类任务,因为它只需要区分两种类型的样本。通过先训练判别器,我们可以确保其具有足够的能力来准确识别真实和生成的样本。
- 生成器依赖于判别器的反馈:生成器的目标是生成逼真的样本,以尽可能地欺骗判别器。通过先训练判别器,我们可以得到关于生成样本质量的反馈信息。生成器可以根据判别器的反馈进行调整,并逐渐提高生成样本的质量。
- 训练稳定性:在GAN的早期训练阶段,生成器产生的样本可能会非常不真实。如果首先训练生成器,那么判别器可能会很容易辨别这些低质量的生成样本,导致梯度更新不稳定。通过先训练判别器,我们可以使生成器更好地适应判别器的反馈,从而增加训练的稳定性。
- 避免模式崩溃:在GAN训练过程中,存在模式坍塌的问题,即生成器只学会生成少数几种样本而不是整个数据分布。通过先训练判别器,我们可以提供更多样本的多样性,帮助生成器避免陷入模式崩溃现象。
尽管先训练鉴别器再训练生成器是一种常见的做法,但并不意味着这是唯一正确的方式。根据特定的问题和数据集,有时候也可以尝试其他训练策略,例如逆向训练(先训练生成器)。选择何种顺序取决于具体情况和实验结果。
数据集使用的是BraTs2020数据集,他的介绍和处理方法在我的知识链接里面。目前使用的是个人电脑的GPU跑的。然后数据也只取了前200个训练集,并且20%分出来作为测试集。
并且我们在训练的时候,每隔一定的batch使用matplotlib将T1,生成的T1,真实的T2进行展示,并且将生成器和鉴别器的loss进行展示。
通过比较可以发现使用了逐像素的L1 LOSS可以让生成的结果更好。

训练10个epoch时的结果图:

此时的测试结果:
PSNR mean: 21.1621928375993 PSNR std: 1.1501189362634836
 NMSE mean: 0.14920212 NMSE std: 0.03501928
 SSIM mean: 0.5401535398016223 SSIM std: 0.019281408927679166
代码:
dataloader.py
# dataloader for fine-tuning
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import torch.utils.data as data
import numpy as np
from PIL import ImageEnhance, Image
import random
import osdef cv_random_flip(img, label):# left right flipflip_flag = random.randint(0, 2)if flip_flag == 1:img = np.flip(img, 0).copy()label = np.flip(label, 0).copy()if flip_flag == 2:img = np.flip(img, 1).copy()label = np.flip(label, 1).copy()return img, labeldef randomCrop(image, label):border = 30image_width = image.size[0]image_height = image.size[1]crop_win_width = np.random.randint(image_width - border, image_width)crop_win_height = np.random.randint(image_height - border, image_height)random_region = ((image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,(image_height + crop_win_height) >> 1)return image.crop(random_region), label.crop(random_region)def randomRotation(image, label):rotate = random.randint(0, 1)if rotate == 1:rotate_time = random.randint(1, 3)image = np.rot90(image, rotate_time).copy()label = np.rot90(label, rotate_time).copy()return image, labeldef colorEnhance(image):bright_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Brightness(image).enhance(bright_intensity)contrast_intensity = random.randint(4, 11) / 10.0image = ImageEnhance.Contrast(image).enhance(contrast_intensity)color_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Color(image).enhance(color_intensity)sharp_intensity = random.randint(7, 13) / 10.0image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)return imagedef randomGaussian(img, mean=0.002, sigma=0.002):def gaussianNoisy(im, mean=mean, sigma=sigma):for _i in range(len(im)):im[_i] += random.gauss(mean, sigma)return imflag = random.randint(0, 3)if flag == 1:width, height = img.shapeimg = gaussianNoisy(img[:].flatten(), mean, sigma)img = img.reshape([width, height])return imgdef randomPeper(img):flag = random.randint(0, 3)if flag == 1:noiseNum = int(0.0015 * img.shape[0] * img.shape[1])for i in range(noiseNum):randX = random.randint(0, img.shape[0] - 1)randY = random.randint(0, img.shape[1] - 1)if random.randint(0, 1) == 0:img[randX, randY] = 0else:img[randX, randY] = 1return imgclass BraTS_Train_Dataset(data.Dataset):def __init__(self, source_modal, target_modal, img_size,image_root, data_rate, sort=False, argument=False, random=False):self.source = source_modalself.target = target_modalself.modal_list = ['t1', 't2']self.image_root = image_rootself.data_rate = data_rateself.images = [self.image_root + f for f in os.listdir(self.image_root) if f.endswith('.npy')]self.images.sort(key=lambda x: int(x.split(image_root)[1].split(".npy")[0]))self.img_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(img_size)])self.gt_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(img_size, Image.NEAREST)])self.sort = sortself.argument = argumentself.random = randomself.subject_num = len(self.images) // 60if self.random == True:subject = np.arange(self.subject_num)np.random.shuffle(subject)self.LUT = []for i in subject:for j in range(60):self.LUT.append(i * 60 + j)# print('slice number:', self.__len__())def __getitem__(self, index):if self.random == True:index = self.LUT[index]npy = np.load(self.images[index])img = npy[self.modal_list.index(self.source), :, :]gt = npy[self.modal_list.index(self.target), :, :]if self.argument == True:img, gt = cv_random_flip(img, gt)img, gt = randomRotation(img, gt)img = img * 255img = Image.fromarray(img.astype(np.uint8))img = colorEnhance(img)img = img.convert('L')img = self.img_transform(img)gt = self.img_transform(gt)return img, gtdef __len__(self):return int(len(self.images) * self.data_rate)def get_loader(batchsize, shuffle, pin_memory=True, source_modal='t1', target_modal='t2',img_size=256, img_root='data/train/', data_rate=0.1, num_workers=8, sort=False, argument=False,random=False):dataset = BraTS_Train_Dataset(source_modal=source_modal, target_modal=target_modal,img_size=img_size, image_root=img_root, data_rate=data_rate, sort=sort,argument=argument, random=random)data_loader = data.DataLoader(dataset=dataset, batch_size=batchsize, shuffle=shuffle,pin_memory=pin_memory, num_workers=num_workers)return data_loader# if __name__=='__main__':
#     data_loader = get_loader(batchsize=1, shuffle=True, pin_memory=True, source_modal='t1',
#                              target_modal='t2', img_size=256, num_workers=8,
#                              img_root='data/train/', data_rate=0.1, argument=True, random=False)
#     length = len(data_loader)
#     print("data_loader的长度为:", length)
#     # 将 data_loader 转换为迭代器
#     data_iter = iter(data_loader)
#
#     # 获取第一批数据
#     batch = next(data_iter)
#
#     # 打印第一批数据的大小
#     print("第一批数据的大小:", batch[0].shape)  # 输入图像的张量
#     print("第一批数据的大小:", batch[1].shape)  # 目标图像的张量
#     print(batch.shape)models.py
import torch.nn as nn
import torch.nn.functional as F
import torchdef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)##############################
#           U-NET
##############################class UNetDown(nn.Module):def __init__(self, in_size, out_size, normalize=True, dropout=0.0):super(UNetDown, self).__init__()layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]if normalize:layers.append(nn.InstanceNorm2d(out_size))layers.append(nn.LeakyReLU(0.2))if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)class UNetUp(nn.Module):def __init__(self, in_size, out_size, dropout=0.0):super(UNetUp, self).__init__()layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),nn.InstanceNorm2d(out_size),nn.ReLU(inplace=True),]if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x, skip_input):x = self.model(x)x = torch.cat((x, skip_input), 1)return xclass GeneratorUNet(nn.Module):def __init__(self, in_channels=3, out_channels=3):super(GeneratorUNet, self).__init__()self.down1 = UNetDown(in_channels, 64, normalize=False)self.down2 = UNetDown(64, 128)self.down3 = UNetDown(128, 256)self.down4 = UNetDown(256, 512, dropout=0.5)self.down5 = UNetDown(512, 512, dropout=0.5)self.down6 = UNetDown(512, 512, dropout=0.5)self.down7 = UNetDown(512, 512, dropout=0.5)self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)self.up1 = UNetUp(512, 512, dropout=0.5)self.up2 = UNetUp(1024, 512, dropout=0.5)self.up3 = UNetUp(1024, 512, dropout=0.5)self.up4 = UNetUp(1024, 512, dropout=0.5)self.up5 = UNetUp(1024, 256)self.up6 = UNetUp(512, 128)self.up7 = UNetUp(256, 64)self.final = nn.Sequential(nn.Upsample(scale_factor=2),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(128, out_channels, 4, padding=1),nn.Tanh(),)def forward(self, x):# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x)d2 = self.down2(d1)d3 = self.down3(d2)d4 = self.down4(d3)d5 = self.down5(d4)d6 = self.down6(d5)d7 = self.down7(d6)d8 = self.down8(d7)u1 = self.up1(d8, d7)u2 = self.up2(u1, d6)u3 = self.up3(u2, d5)u4 = self.up4(u3, d4)u5 = self.up5(u4, d3)u6 = self.up6(u5, d2)u7 = self.up7(u6, d1)return self.final(u7)##############################
#        Discriminator
##############################class Discriminator(nn.Module):def __init__(self, in_channels=3):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalization:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(in_channels * 2, 64, normalization=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1, bias=False))def forward(self, img_A, img_B):# Concatenate image and condition image by channels to produce inputimg_input = torch.cat((img_A, img_B), 1)return self.model(img_input)
pix2pix.py
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variablefrom models import *
from dataloader import *import torch.nn as nn
import torch.nn.functional as F
import torch
if __name__=='__main__':parser = argparse.ArgumentParser()parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")parser.add_argument("--dataset_name", type=str, default="basta2020", help="name of the dataset")parser.add_argument("--batch_size", type=int, default=2, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--img_height", type=int, default=256, help="size of image height")parser.add_argument("--img_width", type=int, default=256, help="size of image width")parser.add_argument("--channels", type=int, default=3, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=500, help="interval between sampling of images from generators")parser.add_argument("--checkpoint_interval", type=int, default=10, help="interval between model checkpoints")opt = parser.parse_args()print(opt)os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)cuda = True if torch.cuda.is_available() else False# Loss functionscriterion_GAN = torch.nn.MSELoss()criterion_pixelwise = torch.nn.L1Loss()# Loss weight of L1 pixel-wise loss between translated image and real imagelambda_pixel = 100# Calculate output of image discriminator (PatchGAN)patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)# Initialize generator and discriminatorgenerator = GeneratorUNet(in_channels=1, out_channels=1)discriminator = Discriminator(in_channels=1)if cuda:generator = generator.cuda()discriminator = discriminator.cuda()criterion_GAN.cuda()criterion_pixelwise.cuda()if opt.epoch != 0:# Load pretrained modelsgenerator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))else:# Initialize weightsgenerator.apply(weights_init_normal)discriminator.apply(weights_init_normal)# Optimizersoptimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# Configure dataloaderstransforms_ = [transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]dataloader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',target_modal='t2', img_size=256, num_workers=8,img_root='data/train/', data_rate=0.1, argument=True, random=False)# dataloader = DataLoader(#     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),#     batch_size=opt.batch_size,#     shuffle=True,#     num_workers=opt.n_cpu,# )# val_dataloader = DataLoader(#     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode="val"),#     batch_size=10,#     shuffle=True,#     num_workers=1,# )# Tensor typeTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# def sample_images(batches_done):#     """Saves a generated sample from the validation set"""#     imgs = next(iter(val_dataloader))#     real_A = Variable(imgs["B"].type(Tensor))#     real_B = Variable(imgs["A"].type(Tensor))#     fake_B = generator(real_A)#     img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)#     save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)# ----------#  Training# ----------prev_time = time.time()# 创建空列表用于保存损失值losses_G = []losses_D = []for epoch in range(opt.epoch, opt.n_epochs):for i, batch in enumerate(dataloader):# Model inputsreal_A = Variable(batch[0].type(Tensor))real_B = Variable(batch[1].type(Tensor))# print(real_A == real_B)# Adversarial ground truthsvalid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Real losspred_real = discriminator(real_B, real_A)loss_real = criterion_GAN(pred_real, valid)# Fake lossfake_B = generator(real_A)pred_fake = discriminator(fake_B.detach(), real_A)loss_fake = criterion_GAN(pred_fake, fake)# Total lossloss_D = 0.5 * (loss_real + loss_fake)loss_D.backward()optimizer_D.step()# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()# GAN losspred_fake = discriminator(fake_B, real_A)loss_GAN = criterion_GAN(pred_fake, valid)# Pixel-wise lossloss_pixel = criterion_pixelwise(fake_B, real_B)# Total lossloss_G = loss_GAN + lambda_pixel * loss_pixel   # 希望生成的接近1loss_G.backward()optimizer_G.step()# --------------#  Log Progress# --------------# Determine approximate time leftbatches_done = epoch * len(dataloader) + ibatches_left = opt.n_epochs * len(dataloader) - batches_donetime_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))prev_time = time.time()# Print logsys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"% (epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),loss_pixel.item(),loss_GAN.item(),time_left,))mat = [real_A, fake_B, real_B]if (batches_done + 1) % 200 == 0:plt.figure(dpi=400)ax = plt.subplot(131)for i, img in enumerate(mat):ax = plt.subplot(1, 3, i + 1)  #get positionimg = img.permute([0, 2, 3, 1])  # b c h w ->b h w cif img.shape[0] != 1:   # 有多个就只取第一个img = img[1]img = img.squeeze(0)   # b h w c -> h w cif img.shape[2] == 1:img = img.repeat(1, 1, 3)  # process gray imgimg = img.cpu()ax.imshow(img.data)ax.set_xticks([])ax.set_yticks([])plt.show()if (batches_done + 1) % 20 ==0:losses_G.append(loss_G.item())losses_D.append(loss_D.item())if (batches_done + 1) % 200 == 0:  # 每20个batch添加一次损失# 保存损失值plt.figure(figsize=(10, 5))plt.plot(range(int((batches_done + 1) / 20)), losses_G, label="Generator Loss")plt.plot(range(int((batches_done + 1) / 20)), losses_D, label="Discriminator Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.title("GAN Training Loss Curve")plt.legend()plt.show()# # If at sample interval save image# if batches_done % opt.sample_interval == 0:#     sample_images(batches_done)if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:# Save model checkpointstorch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))
processing.py 数据预处理
import numpy as np
from matplotlib import pylab as plt
import nibabel as nib
import random
import glob
import os
from PIL import Image
import imageiodef normalize(image, mask=None, percentile_lower=0.2, percentile_upper=99.8):if mask is None:mask = image != image[0, 0, 0]cut_off_lower = np.percentile(image[mask != 0].ravel(), percentile_lower)cut_off_upper = np.percentile(image[mask != 0].ravel(), percentile_upper)res = np.copy(image)res[(res < cut_off_lower) & (mask != 0)] = cut_off_lowerres[(res > cut_off_upper) & (mask != 0)] = cut_off_upperres = res / res.max()  # 0-1return resdef visualize(t1_data, t2_data, flair_data, t1ce_data, gt_data):plt.figure(figsize=(8, 8))plt.subplot(231)plt.imshow(t1_data[:, :], cmap='gray')plt.title('Image t1')plt.subplot(232)plt.imshow(t2_data[:, :], cmap='gray')plt.title('Image t2')plt.subplot(233)plt.imshow(flair_data[:, :], cmap='gray')plt.title('Image flair')plt.subplot(234)plt.imshow(t1ce_data[:, :], cmap='gray')plt.title('Image t1ce')plt.subplot(235)plt.imshow(gt_data[:, :])plt.title('GT')plt.show()def visualize_to_gif(t1_data, t2_data, t1ce_data, flair_data):transversal = []coronal = []sagittal = []slice_num = t1_data.shape[2]for i in range(slice_num):sagittal_plane = np.concatenate((t1_data[:, :, i], t2_data[:, :, i],t1ce_data[:, :, i], flair_data[:, :, i]), axis=1)coronal_plane = np.concatenate((t1_data[i, :, :], t2_data[i, :, :],t1ce_data[i, :, :], flair_data[i, :, :]), axis=1)transversal_plane = np.concatenate((t1_data[:, i, :], t2_data[:, i, :],t1ce_data[:, i, :], flair_data[:, i, :]), axis=1)transversal.append(transversal_plane)coronal.append(coronal_plane)sagittal.append(sagittal_plane)imageio.mimsave("./transversal_plane.gif", transversal, duration=0.01)imageio.mimsave("./coronal_plane.gif", coronal, duration=0.01)imageio.mimsave("./sagittal_plane.gif", sagittal, duration=0.01)returnif __name__ == '__main__':t1_list = sorted(glob.glob('../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t1.*'))t2_list = sorted(glob.glob('../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t2.*'))data_len = len(t1_list)train_len = int(data_len * 0.8)test_len = data_len - train_lentrain_path = '../data/train/'test_path = '../data/test/'os.makedirs(train_path, exist_ok=True)os.makedirs(test_path, exist_ok=True)for i, (t1_path, t2_path) in enumerate(zip(t1_list, t2_list)):print('preprocessing the', i + 1, 'th subject')t1_img = nib.load(t1_path)  # (240,140,155)t2_img = nib.load(t2_path)# to numpyt1_data = t1_img.get_fdata()t2_data = t2_img.get_fdata()t1_data = normalize(t1_data)  # normalize to [0,1]t2_data = normalize(t2_data)tensor = np.stack([t1_data, t2_data])  # (2, 240, 240, 155)if i < train_len:for j in range(60):Tensor = tensor[:, 10:210, 25:225, 50 + j]np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)else:for j in range(60):Tensor = tensor[:, 10:210, 25:225, 50 + j]np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)testutil.py
#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 17:21
#@Author : Tom
#@File : testutil.py.py
#@Software : PyCharm
import argparsefrom math import log10, sqrt
import numpy as np
from skimage.metrics import structural_similarity as ssimdef psnr(res,gt):mse = np.mean((res - gt) ** 2)if(mse == 0):return 100max_pixel = 1psnr = 20 * log10(max_pixel / sqrt(mse))return psnrdef nmse(res,gt):Norm = np.linalg.norm((gt * gt),ord=2)if np.all(Norm == 0):return 0else:nmse = np.linalg.norm(((res - gt) * (res - gt)),ord=2) / Normreturn nmse
test.py
#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 16:14
#@Author : Tom
#@File : test.py.py
#@Software : PyCharmimport torch
from models import *
from dataloader import *
from testutil import *if __name__ == '__main__':images_save = "images_save/"slice_num = 4os.makedirs(images_save, exist_ok=True)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = GeneratorUNet(in_channels=1, out_channels=1)data_loader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',target_modal='t2', img_size=256, num_workers=8,img_root='data/test/', data_rate=1, argument=True, random=False)model = model.to(device)model.load_state_dict(torch.load("saved_models/basta2020/generator_0.pth", map_location=torch.device(device)), strict=False)PSNR = []NMSE = []SSIM = []for i, (img, gt) in enumerate(data_loader):batch_size = img.size()[0]img = img.to(device, dtype=torch.float)gt = gt.to(device, dtype=torch.float)with torch.no_grad():pred = model(img)for j in range(batch_size):a = pred[j]save_image([pred[j]], images_save + str(i * batch_size + j + 1) + '.png', normalize=True)print(images_save + str(i * batch_size + j + 1) + '.png')pred, gt = pred.cpu().detach().numpy().squeeze(), gt.cpu().detach().numpy().squeeze()for j in range(batch_size):PSNR.append(psnr(pred[j], gt[j]))NMSE.append(nmse(pred[j], gt[j]))SSIM.append(ssim(pred[j], gt[j]))PSNR = np.asarray(PSNR)NMSE = np.asarray(NMSE)SSIM = np.asarray(SSIM)PSNR = PSNR.reshape(-1, slice_num)NMSE = NMSE.reshape(-1, slice_num)SSIM = SSIM.reshape(-1, slice_num)PSNR = np.mean(PSNR, axis=1)print(PSNR.size)NMSE = np.mean(NMSE, axis=1)SSIM = np.mean(SSIM, axis=1)print("PSNR mean:", np.mean(PSNR), "PSNR std:", np.std(PSNR))print("NMSE mean:", np.mean(NMSE), "NMSE std:", np.std(NMSE))print("SSIM mean:", np.mean(SSIM), "SSIM std:", np.std(SSIM))
相关文章:
 
【我的创作纪念日】使用pix2pixgan实现barts2020数据集的处理(完整版本)
使用pix2pixgan (pytorch)实现T1 -> T2的基本代码 使用 https://github.com/eriklindernoren/PyTorch-GAN/ 这里面的pix2pixgan代码进行实现。 进去之后我们需要重新处理数据集,并且源代码里面先训练的生成器,后训练鉴别器。 一般情况下…...
背包算法(Knapsack problem)
背包算法(Knapsack problem)是一种常见的动态规划问题,它的基本思想是利用动态规划思想求解给定重量和价值下的最优解。具体来说,背包算法用于解决一个整数背包问题,即给定一组物品,每个物品有自己的重量和…...
 
“童”趣迎国庆 安全“童”行-柿铺梁坡社区开展迎国庆活动
“金秋十月好心境,举国欢腾迎国庆。”国庆节来临之际,为进一步加强梁坡社区未成年人爱国主义教育,丰富文化生活,营造热烈喜庆、文明和谐的节日氛围。9月24日上午,樊城区柿铺街道梁坡社区新时代文明实践站联合襄阳市和时…...
 
常用压缩解压缩命令
在Linux中常见的压缩格式有.zip、.rar、.tar.gz.、tar.bz2等压缩格式。不同的压缩格式需要用不同的压缩命令和工具。须知,在Linux系统中.tar.gz为标准格式的压缩和解压缩格式,因此本文也会着重讲解tar.gz格式压缩包的压缩和解压缩命令。须知,…...
第四十一章 持久对象和SQL - Storage
文章目录 第四十一章 持久对象和SQL - StorageStorage存储定义概览持久类使用的Globals注意 第四十一章 持久对象和SQL - Storage Storage 每个持久类定义都包含描述类属性如何映射到实际存储它们的Global的信息。类编译器为类生成此信息,并在修改和重新编译时更新…...
 
【Java接口性能优化】skywalking使用
skywalking使用 提示:微服务中-skywalking使用 文章目录 skywalking使用一、进入skywalking主页二、进入具体服务1.查看接口 一、进入skywalking主页 二、进入具体服务 可以点击列表或搜索后,点击进入具体服务 依次选择日期、小时、分钟 1.查看接口 依次…...
 
大学各个专业介绍
计算机类 五米高考-计算机类 注:此处平均薪酬为毕业五年平均薪酬,薪酬数据仅供参考 来源: 掌上高考 电气类 五米高考-电气类 机械类 五米高考-机械类 电子信息类 五米高考-电子信息类 土木类 五米高考-土木类...
linux 列出网络上所有活动的主机
列出网络上所有活动的主机 #!/bin/bash# {start..end}会由shell对其进行扩展生成一组ip地址for ip in 192.168.0.{1..255} ;do ping $ip -c 2 &> /dev/null ; # $?获取退出状态,顺利退出则为0 if [ $? -eq 0 ]; then echo $ip is alive fidone https://zh…...
 
基于vue+Element Table Popover 弹出框内置表格的封装
文章目录 项目场景:实现效果认识组件代码效果分析 封装:代码封装思路页面中使用 项目场景: 在选择数据的时候需要在已选择的数据中对比选择,具体就是点击一个按钮,弹出一个小的弹出框,但不像对话框那样还需…...
 
机器人过程自动化(RPA)入门 4. 数据处理
到目前为止,我们已经了解了RPA的基本知识,以及如何使用流程图或序列来组织工作流中的步骤。我们现在了解了UiPath组件,并对UiPath Studio有了全面的了解。我们用几个简单的例子制作了我们的第一个机器人。在我们继续之前,我们应该了解UiPath中的变量和数据操作。它与其他编…...
 
java导出word(含图片、表格)
1.pom 引入 <!--word报告生成依赖--><dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>4.1.2</version></dependency><dependency><groupId>org.apache.poi</groupI…...
MySQL数据库记录的修改与更新
数据的修改和更新是数据库管理的核心任务之一,尤其是在动态和快速变化的环境下。本文将深入探讨如何在MySQL数据库中有效地进行记录的修改和更新。特别是将通过使用《三国志》游戏数据作为例子,来具体展示这些操作如何实施。文章主要面向具有基础数据库知识的读者。 文章目录…...
 
开具数电票如何减少认证频次?
“数电票”开具需多次刷脸认证,如何减少认证频次? 法定代表人、财务负责人可以在“身份认证频次设置”功能自行设置身份认证时间间隔,方法如下: 第一步 登录电子税务局。企业法定代表人或财务负责人通过手机APP“扫一扫”&#x…...
 
【进阶C语言】动态内存分配
本章大致内容介绍: 1.malloc函数和free函数 2.calloc函数 3.realloc函数 4.常见错误案例 5.笔试题详解 6.柔性数组 一、malloc和free 1.malloc函数 (1)函数原型 函数参数:根据用户的需求需要开辟多大的字节空间ÿ…...
 
手机上记录的备忘录内容怎么分享到电脑上查看?
手机已经成为了我们生活中不可或缺的一部分,我们用它来处理琐碎事务,记录生活点滴,手机备忘录就是我们常用的工具之一。但随着工作的需要,我们往往会遇到一个问题:手机上记录的备忘录内容,如何方便地分享到…...
 
LeetCode 2251. 花期内花的数目:排序 + 二分
【LetMeFly】2251.花期内花的数目:排序 二分 力扣题目链接:https://leetcode.cn/problems/number-of-flowers-in-full-bloom/ 给你一个下标从 0 开始的二维整数数组 flowers ,其中 flowers[i] [starti, endi] 表示第 i 朵花的 花期 从 st…...
 
【3】贪心算法-最优装载问题-加勒比海盗
算法背景 在北美洲东南部,有一片神秘的海域,那里碧海蓝天、阳光 明媚,这正是传说中海盗最活跃的加勒比海(Caribbean Sea)。 有一天,海盗们截获了一艘装满各种各样古董的货船,每一 件古董都价值连…...
JavaScript 的 for 循环应该如何学习?
JS for 循环语法 JS for 循环适合在已知循环次数时使用,语法格式如下: for(initialization; condition; increment) {// 要执行的代码 }for 循环中包含三个可选的表达式 initialization、condition 和 increment,其中: initial…...
 
C++核心编程--对象篇
4.2、对象 4.2.1、对象的初始化和清理 用于对对象进行初始化设置,以及对象销毁前的清理数据的设置。 构造函数和析构函数 防止对象初始化和清理也是非常重要的安全问题 一个对象或变量没有初始化状态,对其使用后果是未知的同样使用完一个对象或变量&…...
安装php扩展XLSXWriter,解决php导入excel表格时获取日期变成浮点数的方法
安装php扩展XLSXWriter 1、下载安装包 PECL :: Package :: xlswriter #例如选择下载1.3.6版本 2、解压下载包 tar -zxvf xlswriter-1.3.6.tgz 3、进入文件夹,编译 cd xlswriter-1.3.6 phpize ./configure --with-php-config=/usr/local/php7.1/bin/php-config make&am…...
 
【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型
摘要 拍照搜题系统采用“三层管道(多模态 OCR → 语义检索 → 答案渲染)、两级检索(倒排 BM25 向量 HNSW)并以大语言模型兜底”的整体框架: 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后,分别用…...
 
Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
 
TDengine 快速体验(Docker 镜像方式)
简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...
 
【单片机期末】单片机系统设计
主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...
 
使用 SymPy 进行向量和矩阵的高级操作
在科学计算和工程领域,向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能,能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作,并通过具体…...
 
排序算法总结(C++)
目录 一、稳定性二、排序算法选择、冒泡、插入排序归并排序随机快速排序堆排序基数排序计数排序 三、总结 一、稳定性 排序算法的稳定性是指:同样大小的样本 **(同样大小的数据)**在排序之后不会改变原始的相对次序。 稳定性对基础类型对象…...
 
群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...
django blank 与 null的区别
1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...
 
破解路内监管盲区:免布线低位视频桩重塑停车管理新标准
城市路内停车管理常因行道树遮挡、高位设备盲区等问题,导致车牌识别率低、逃费率高,传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法,正成为破局关键。该设备安装于车位侧方0.5-0.7米高度,直接规避树枝遮…...
LLaMA-Factory 微调 Qwen2-VL 进行人脸情感识别(二)
在上一篇文章中,我们详细介绍了如何使用LLaMA-Factory框架对Qwen2-VL大模型进行微调,以实现人脸情感识别的功能。本篇文章将聚焦于微调完成后,如何调用这个模型进行人脸情感识别的具体代码实现,包括详细的步骤和注释。 模型调用步骤 环境准备:确保安装了必要的Python库。…...
