当前位置: 首页 > article >正文

PyTorch深度学习实战 |SegNet

欢迎来到PyTorch深度学习实战的世界博客主页卿云阁欢迎关注点赞收藏⭐️留言首发时间2026年4月29日✉️希望可以和大家一起完成进阶之路作者水平很有限如果发现错误请留言轰炸哦万分感谢目录Camvid_11数据集介绍数据集下载​编辑项目结构SegNetConvBNReLU类DecoderBlock类整体代码datasetCamVidDataset(Dataset)初始化函数__init____len___find_label_path__getitem__build_datasetutils.pyfast_histSegmentationMetriccompute_class_weightsdecode_segmaptrain.pyparse_args()select_devicetrain_one_epochvalidatemain()Camvid_11数据集介绍CamVid数据集是一个用于强监督学习的精准标注图片集合包含700多张图片。这些图片被分为训练集、验证集和测试集适用于图像分割任务的训练和评估。数据集结构CamVid数据集的图片被分为以下三个部分训练集用于模型的训练。验证集用于模型的调优和验证。测试集用于最终模型的评估。类别标签CamVid数据集使用11种常用的类别进行分割精度的评估具体类别如下道路 (Road)交通标志 (Symbol)汽车 (Car)天空 (Sky)行人道 (Sidewalk)电线杆 (Pole)围墙 (Fence)行人 (Pedestrian)建筑物 (Building)自行车 (Bicyclist)树木 (Tree)注意背景类别被标记为0因此在计算类别总数时实际有12个类别包括背景。数据集下载CamVid数据集介绍:CamVid数据集介绍与图像分割任务支持 - AtomGit | GitCode因为分割标签图的像素类别取值范围是 130像素数值都很小映射到灰度显示时整体亮度极低整张标签图看起来几乎全黑肉眼无法直接区分不同类别区域虽然能读取到每个像素对应的类别数字但原图灰度视觉差异极小所以需要在每个语义区域中心标注对应类别数字搭配黑底白字样式即使背景全黑也能清晰看清每一块区域的类别编号。项目结构segnet_camvid/├── config.yaml # 唯一配置文件├── dataset.py # CamVidDataset├── model.py # SegNet├── utils.py # 指标 类别权重 上色├── train.py # 训练├── eval.py # 测试集评估├── inference.py # 单图/文件夹推理└── README.mdSegNetConvBNReLU类class ConvBNReLU(nn.Sequential): def __init__(self, in_ch, out_ch): super().__init__( nn.Conv2d(in_ch, out_ch, kernel_size3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU(inplaceTrue), )DecoderBlock类定义解码器卷积块通过多层卷积完成特征提取与通道调整最后一层可输出分割所需的分类结果。class DecoderBlock(nn.Sequential): num_convs-1 个保持通道的 ConvBNReLU 末尾 1 个改通道(若 last 则用纯 Conv 输出 logits)。 def __init__(self, in_ch, out_ch, num_convs, lastFalse): layers [] for _ in range(num_convs - 1): layers.append(ConvBNReLU(in_ch, in_ch)) if last: layers.append(nn.Conv2d(in_ch, out_ch, kernel_size3, padding1)) else: layers.append(ConvBNReLU(in_ch, out_ch)) super().__init__(*layers)整体代码SegNet 模型(VGG16-BN backbone 对称解码器)。 import torch import torch.nn as nn from torchvision import models as tv_models class ConvBNReLU(nn.Sequential): def __init__(self, in_ch, out_ch): super().__init__( nn.Conv2d(in_ch, out_ch, kernel_size3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU(inplaceTrue), ) class DecoderBlock(nn.Sequential): num_convs-1 个保持通道的 ConvBNReLU 末尾 1 个改通道(若 last 则用纯 Conv 输出 logits)。 def __init__(self, in_ch, out_ch, num_convs, lastFalse): layers [] for _ in range(num_convs - 1): layers.append(ConvBNReLU(in_ch, in_ch)) if last: layers.append(nn.Conv2d(in_ch, out_ch, kernel_size3, padding1)) else: layers.append(ConvBNReLU(in_ch, out_ch)) super().__init__(*layers) class SegNet(nn.Module): def __init__(self, num_classes11, pretrainedTrue): super().__init__() # -------- Encoder: 复用 vgg16_bn 的 conv-bn-relu -------- if hasattr(tv_models, VGG16_BN_Weights): weights tv_models.VGG16_BN_Weights.DEFAULT if pretrained else None vgg tv_models.vgg16_bn(weightsweights) else: vgg tv_models.vgg16_bn(pretrainedpretrained) feats list(vgg.features.children()) # vgg16_bn.features 中 MaxPool2d 位于索引 6, 13, 23, 33, 43 self.enc1 nn.Sequential(*feats[0:6]) # out 64 self.enc2 nn.Sequential(*feats[7:13]) # out 128 self.enc3 nn.Sequential(*feats[14:23]) # out 256 self.enc4 nn.Sequential(*feats[24:33]) # out 512 self.enc5 nn.Sequential(*feats[34:43]) # out 512 self.pool nn.MaxPool2d(kernel_size2, stride2, return_indicesTrue) self.unpool nn.MaxUnpool2d(kernel_size2, stride2) # -------- Decoder: 镜像编码器 -------- self.dec5 DecoderBlock(512, 512, num_convs3) self.dec4 DecoderBlock(512, 256, num_convs3) self.dec3 DecoderBlock(256, 128, num_convs3) self.dec2 DecoderBlock(128, 64, num_convs2) self.dec1 DecoderBlock(64, num_classes, num_convs2, lastTrue) # 解码器初始化(编码器保持 VGG 预训练) for blk in (self.dec5, self.dec4, self.dec3, self.dec2, self.dec1): for m in blk.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): x self.enc1(x); s1 x.size(); x, ind1 self.pool(x) x self.enc2(x); s2 x.size(); x, ind2 self.pool(x) x self.enc3(x); s3 x.size(); x, ind3 self.pool(x) x self.enc4(x); s4 x.size(); x, ind4 self.pool(x) x self.enc5(x); s5 x.size(); x, ind5 self.pool(x) x self.unpool(x, ind5, output_sizes5); x self.dec5(x) x self.unpool(x, ind4, output_sizes4); x self.dec4(x) x self.unpool(x, ind3, output_sizes3); x self.dec3(x) x self.unpool(x, ind2, output_sizes2); x self.dec2(x) x self.unpool(x, ind1, output_sizes1); x self.dec1(x) return x # (B, num_classes, H, W) logits if __name__ __main__: # 自测(无需联网,关掉预训练) net SegNet(num_classes11, pretrainedFalse) y net(torch.randn(2, 3, 360, 480)) print(output:, y.shape) print(params:, sum(p.numel() for p in net.parameters()) / 1e6, M)SegNet 模型的输入为 3 通道 RGB 图像尺寸格式为 (B, 3, H, W)输出为与输入尺寸相同的语义分割 logits 图尺寸格式为 (B, num_classes, H, W)第 i 个通道 模型对「第 i 类物体」在整图每个像素的置信得分logits对应你 CamVid第 0 通道背景 / 无效类 Void 得分图第 1 通道天空 Sky 得分图第 2 通道建筑 Building 得分图第 3 通道电线杆 Pole 得分图第 4 通道道路 Road 得分图第 5 通道人行道 Pavement 得分图第 6 通道树木 Tree 得分图第 7 通道交通标志 SignSymbol 得分图第 8 通道围栏 Fence 得分图第 9 通道车辆 Car 得分图第 10 通道行人 Pedestrian 得分图每个通道都是一张热力图该位置越属于这一类 → 像素数值越大 → 可视化越亮该位置不属于这一类 → 数值越小 → 越暗datasetCamVidDataset(Dataset)Dataset说明这是一个数据集类初始化函数__init__# 构造函数初始化数据集类传入所有必要参数 # img_dir: 图片所在文件夹路径 # lbl_dir: 标签所在文件夹路径 # input_size: 模型输入尺寸 (高度, 宽度) # mean: 图像归一化均值 # std: 图像归一化标准差 # augment: 是否启用数据增强默认关闭 # hflip: 是否使用随机水平翻转默认开启 # void_value_in_label: 标签中无效区域的像素值原始标签为255 # ignore_index: 训练时损失函数要忽略的类别索引11类任务设为11 def __init__(self, img_dir, lbl_dir, input_size, mean, std, augmentFalse, hflipTrue, void_value_in_label255, ignore_index11): # 1. 把外部传入的参数保存到类内部 self.img_dir img_dir # 存储图片文件夹路径 self.lbl_dir lbl_dir # 存储标签文件夹路径 self.input_size tuple(input_size) # 统一图像尺寸转成元组格式 (H, W) self.mean list(mean) # 归一化均值转成列表 self.std list(std) # 归一化标准差转成列表 self.augment augment # 是否开启数据增强训练集开启验证/测试关闭 self.hflip hflip # 数据增强是否使用水平翻转 self.void_value_in_label void_value_in_label # 标签中无效区域的原始值255 self.ignore_index ignore_index # 训练时忽略的索引值11 # 2. 扫描图片文件夹获取所有有效图片文件名 # 遍历 img_dir 下所有文件只保留 .png/.jpg/.jpeg 格式的图片并按名称排序 self.img_names sorted([ f for f in os.listdir(img_dir) if f.lower().endswith((.png, .jpg, .jpeg)) ]) # 3. 检查是否找到图片没有则直接报错 if not self.img_names: raise RuntimeError(fNo images found in {img_dir})self.img_names [ 0001TP_007200.png, 0001TP_007500.png, 0001TP_007800.png, 0006TP_000900.png, ... ]__len__def __len__(self): return len(self.img_names)PyTorch 规定所有数据集必须实现这个方法作用是告诉程序这个数据集一共有多少张图片。_find_label_path_find_label_path 根据图像名自动搜索匹配的标签文件兼容多种命名格式保证图像和标签能正确一一对应。def _find_label_path(self, img_name): 兼容多种命名:同名 / 加 _L 后缀 / 改扩展名为 .png stem, ext os.path.splitext(img_name) for cand in (img_name, stem _L ext, stem .png, stem _L.png): p os.path.join(self.lbl_dir, cand) if os.path.exists(p): return p raise FileNotFoundError(fLabel not found for {img_name} in {self.lbl_dir})__getitem__传入一个序号 idx返回第 idx 张【处理好的图片 标签】喂给模型训练。# 核心函数根据索引 idx 获取 一张图片 一个标签 def __getitem__(self, idx): # 1. 通过索引 idx 拿到 图片文件名 img_name self.img_names[idx] # 2. 拼接路径 打开图片 → 强制转成 RGB 三通道彩色图 image Image.open(os.path.join(self.img_dir, img_name)).convert(RGB) # 3. 调用上面的函数自动找到对应的标签文件 → 打开标签图 label Image.open(self._find_label_path(img_name)) # 4. 标签必须是 单通道灰度图 modeL不是的话强制转换 if label.mode ! L: label label.convert(L) # 统一尺寸 # PIL 接收 (宽, 高)所以把 input_size(H,W) 反过来用 # 图像用双线性插值平滑 image image.resize((self.input_size[1], self.input_size[0]), Image.BILINEAR) # 标签用最近邻插值保证分类数字不变不模糊 label label.resize((self.input_size[1], self.input_size[0]), Image.NEAREST) # 数据增强随机水平翻转 # 训练时才开启50% 概率翻转 if self.augment and self.hflip and random.random() 0.5: image TF.hflip(image) # 图片翻转 label TF.hflip(label) # 标签必须同步翻转 # 图像预处理 # 转成 PyTorch 张量 (C, H, W)并归一化到 [0,1] image TF.to_tensor(image) # 减均值、除方差和预训练模型保持一致 image TF.normalize(image, self.mean, self.std) # 标签预处理 # 转成 numpy 数组类型必须是 int64分类标签要求 label np.array(label, dtypenp.int64) # 把标签里的无效值 255 → 替换成 11ignore_index if self.void_value_in_label ! self.ignore_index: label[label self.void_value_in_label] self.ignore_index # 大于11的无效标签 → 全部设为11忽略 label[label self.ignore_index] self.ignore_index # 转成 PyTorch 张量 label torch.from_numpy(label) # 最终返回处理好的图像张量 处理好的标签张量 return image, labelbuild_datasetdef build_dataset(cfg, split): 根据配置文件 cfg 和数据集划分 split(train/val/test) 来构造数据集。 # 1. 安全检查确保 split 只能是 train / val / test 三个值之一否则报错 # 防止传入错误参数 assert split in (train, val, test), split # 2. 拼接 图片路径 数据集根目录 对应集的图片文件夹 # 例如data/CamVid/train img_dir os.path.join(cfg[data_root], cfg[f{split}_images]) # 3. 拼接 标签路径 数据集根目录 对应集的标签文件夹 # 例如data/CamVid/train_labels lbl_dir os.path.join(cfg[data_root], cfg[f{split}_labels]) # 4. 创建 CamVidDataset 并返回 return CamVidDataset( img_dirimg_dir, # 图片路径 lbl_dirlbl_dir, # 标签路径 input_sizecfg[input_size], # 输入尺寸 (360,480) meancfg[mean], # 归一化均值 stdcfg[std], # 归一化方差 augment(split train), # 只有训练集才做数据增强 hflipcfg.get(hflip, True), # 是否水平翻转 void_value_in_labelcfg[void_value_in_label], # 标签无效值 255 ignore_indexcfg[ignore_index] # 忽略索引 11 )utils.pyfast_hist# 功能计算 混淆矩阵 histogram # label_true: 真实标签图片上每个像素的正确类别 # label_pred: 预测标签模型输出的类别 # num_classes: 类别总数比如CamVid11 # ignore_index: 要忽略的类别比如11不参与计算 def fast_hist(label_true, label_pred, num_classes, ignore_indexNone): # --------------------- 1. 筛选有效像素只算合法标签--------------------- # 只保留 真实标签 在 0 ~ num_classes-1 之间的像素 mask (label_true 0) (label_true num_classes) # 如果设置了忽略的索引如11再把这些像素也屏蔽掉 if ignore_index is not None: mask (label_true ! ignore_index) # --------------------- 2. 用公式计算混淆矩阵核心技巧--------------------- # 公式num_classes * 真实标签 预测标签 # 作用把 (真实,预测) 对 编码成一个数字方便统计 return np.bincount( num_classes * label_true[mask].astype(np.int64) label_pred[mask].astype(np.int64), minlengthnum_classes ** 2 # 固定矩阵大小类别数×类别数 ).reshape(num_classes, num_classes) # 最后 reshape 成方阵把每个像素的「真实类别」和「预测类别」统计一遍数清楚真实是 A 类、预测成 A 类有多少真实是 A 类、预测成 B 类有多少预测值 0 1 2 ... 10 真实 0 [50, 0, 0, ...] 真实 1 [2, 600, 5, ...] 真实 2 [0, 3, 120, ...] ... 真实 10 [...]SegmentationMetric# 语义分割 评估指标计算类 # 作用累积所有图片的混淆矩阵最后统一计算 PA / mAcc / mIoU 等指标 class SegmentationMetric: 累积混淆矩阵,提供 PA / mAcc / mIoU / 每类 IoU。 # 初始化传入类别数11、忽略索引11 def __init__(self, num_classes, ignore_indexNone): self.num_classes num_classes # 类别数CamVid 11类 self.ignore_index ignore_index # 要忽略的类别11 self.reset() # 初始化混淆矩阵 # 重置把混淆矩阵清零开始新一轮评估 def reset(self): # 创建一个 num_classes × num_classes 的全0矩阵 self.hist np.zeros((self.num_classes, self.num_classes), dtypenp.int64) # 更新传入一批 预测值 和 标签累积混淆矩阵 def update(self, preds, labels): # 如果是 PyTorch 张量先转到 CPU 并转成 numpy 数组 if torch.is_tensor(preds): preds preds.detach().cpu().numpy() if torch.is_tensor(labels): labels labels.detach().cpu().numpy() # 遍历每一张图片的预测和标签 for p, l in zip(preds, labels): # 调用 fast_hist 计算单张图的混淆矩阵并累加到总 hist 里 self.hist fast_hist( l.flatten(), # 真实标签展平 p.flatten(), # 预测标签展平 self.num_classes, self.ignore_index ) # 计算最终所有指标评估完所有图片后调用 def compute(self): h self.hist.astype(np.float64) # 混淆矩阵转浮点 eps 1e-10 # 防止除0 # 1. 像素准确率 PA (Pixel Accuracy) # 所有对角线上正确的像素 / 总像素 pixel_acc np.diag(h).sum() / max(h.sum(), eps) # 2. 每类 IoU 平均 mIoU # IoU 交集 / 并集 iou np.diag(h) / np.maximum(h.sum(axis1) h.sum(axis0) - np.diag(h), eps) miou np.nanmean(iou) # 所有类的平均IoU最重要指标 # 3. 每类准确率 平均 mAcc class_acc np.diag(h) / np.maximum(h.sum(axis1), eps) # 每类正确率 mean_acc np.nanmean(class_acc) # 平均正确率 # 返回所有指标 return { pixel_acc: pixel_acc, # 像素准确率 mean_acc: mean_acc, # 平均类别准确率 mIoU: miou, # 平均交并比核心指标 iou_per_class: iou # 每一类的IoU }compute_class_weights# --------------------------- 类别权重 --------------------------- # 功能计算每个类别的权重解决样本不平衡问题 # 方法中位数频率平衡 (Median Frequency Balancing)SegNet 论文官方使用 def compute_class_weights(dataset, num_classes, ignore_indexNone): # 1. 创建一个长度为 num_classes 的数组用来统计【每个类出现了多少个像素】 counts np.zeros(num_classes, dtypenp.int64) # 2. 遍历整个数据集统计每一类的像素总数 for _, label in dataset: # 把标签展平成一维方便统计 l label.numpy().flatten() # 如果有忽略的类别如11就把这些像素去掉不参与统计 if ignore_index is not None: l l[l ! ignore_index] # 统计这张图里每类像素的数量并累加到总 counts 中 counts np.bincount(l, minlengthnum_classes)[:num_classes] # 3. 计算每类像素出现的【频率】 freq counts.astype(np.float64) / max(counts.sum(), 1) # 4. 计算所有【非零频率】的中位数 median np.median(freq[freq 0]) # 5. 核心公式权重 中位数频率 / 当前类频率 # 出现越少的类别权重越大出现越多的类别权重越小 weights median / (freq 1e-10) # 6. 转成 PyTorch 张量返回给损失函数用 return torch.from_numpy(weights).float()decode_segmap# --------------------------- 可视化 --------------------------- # 功能将类别索引图单通道 H,W转为 RGB 彩色图3通道 H,W,3 # label_mask: 模型输出的类别图每个像素是 0~10 的数字 # class_colors: 每个类别对应的颜色列表如 [(0,0,0), (128,0,0)...] # void_color: 忽略区域的颜色默认黑色 # ignore_index: 忽略的类别编号如11 def decode_segmap(label_mask, class_colors, void_color(0, 0, 0), ignore_indexNone): 把 (H, W) 类别索引图渲染成 (H, W, 3) RGB。 # 获取索引图的高和宽 h, w label_mask.shape # 创建一张和原图一样大的空白 RGB 图全黑 rgb np.zeros((h, w, 3), dtypenp.uint8) # 遍历每一个类别 c以及它对应的颜色 color for c, color in enumerate(class_colors): # 把所有“类别等于 c”的像素全部涂成对应的 color rgb[label_mask c] color # 如果有忽略区域比如11把这些像素涂成黑色或指定颜色 if ignore_index is not None: rgb[label_mask ignore_index] void_color # 返回彩色图 return rgbtrain.pyparse_args()# 解析命令行参数运行程序时输入的指令 def parse_args(): # 1. 创建一个参数解析器 p argparse.ArgumentParser(description训练/测试配置解析) # 2. 定义可以接收的参数 # --config指定配置文件默认用 config.yaml不写就用这个 p.add_argument(--config, defaultconfig.yaml, helpyaml 配置文件) # 下面这些是常用超参数可以直接在命令行覆盖配置文件里的值 p.add_argument(--epochs, typeint, defaultNone) # 训练轮数 p.add_argument(--batch-size, typeint, defaultNone, destbatch_size) # 批次大小 p.add_argument(--lr, typefloat, defaultNone) # 学习率 p.add_argument(--workers, typeint, defaultNone) # 数据加载线程数 p.add_argument(--device, defaultNone, helpcpu / 0 / cuda) # 用CPU还是GPU # 3. 解析命令行输入的所有参数并返回 return p.parse_args()apply_cli_overrides# 功能用命令行参数(args) 覆盖 配置文件(cfg) 中的参数 # 优先级命令行输入 config.yaml 配置文件 def apply_cli_overrides(cfg, args): # 如果命令行传了 --epochs就用它替换 cfg 里的 epochs if args.epochs is not None: cfg[epochs] args.epochs # 如果命令行传了 --batch-size替换 cfg 里的 batch_size if args.batch_size is not None: cfg[batch_size] args.batch_size # 如果命令行传了 --lr替换学习率 if args.lr is not None: cfg[lr] args.lr # 如果命令行传了 --workers替换线程数 if args.workers is not None: cfg[num_workers] args.workersselect_device# 功能根据输入参数智能选择 计算设备 (GPU 或 CPU) def select_device(device_arg): # 情况1用户没指定设备 或 为空 # 自动选择有GPU用GPU没GPU用CPU if device_arg is None or device_arg : return torch.device(cuda if torch.cuda.is_available() else cpu) # 情况2用户明确指定用 CPU if str(device_arg).lower() cpu: return torch.device(cpu) # 情况3用户输入数字如 0代表使用 0 号 GPU if str(device_arg).isdigit() and torch.cuda.is_available(): return torch.device(fcuda:{device_arg}) # 情况4其他情况兜底选择有GPU用GPU否则CPU return torch.device(cuda if torch.cuda.is_available() else cpu)train_one_epoch# 功能训练模型 **一轮一个 epoch** # model模型SegNet # loader数据加载器训练集 # optimizer优化器更新参数 # criterion损失函数算误差 # deviceGPU / CPU def train_one_epoch(model, loader, optimizer, criterion, device): # 1. 把模型设为 **训练模式** (启用Dropout/BatchNorm等训练特性) model.train() # 2. 初始化总损失 total总样本数 n total_loss 0.0 num_samples 0 # 3. 遍历训练集中的 **每一批数据** for images, labels in loader: # 把数据搬到 GPU/CPU 上 images images.to(device, non_blockingTrue) labels labels.to(device, non_blockingTrue) # 4. 清空上一步的梯度必须做 optimizer.zero_grad() # 5. 前向传播喂入图片得到预测结果计算损失误差 outputs model(images) # 模型预测 loss criterion(outputs, labels)# 计算预测和标签的误差 # 6. 反向传播 更新模型参数 loss.backward() # 反向传播算梯度 optimizer.step() # 优化器根据梯度修改模型权重 # 7. 累计损失和样本数量用于最后算平均损失 batch_size images.size(0) # 这一批有多少张图 total_loss loss.item() * batch_size # 累加这一批的总损失 num_samples batch_size # 累加总图片数 # 8. 返回 **平均每张图片的损失** return total_loss / max(1, num_samples)validate# 功能模型在验证集 / 测试集上进行一轮评估 # 作用计算验证损失、mIoU、像素准确率等指标不更新模型参数 # 输入 # model训练的语义分割模型 # loader验证集数据加载器 # criterion损失函数 # device运行设备GPU/CPU # num_classes类别数量 # ignore_index需要忽略的类别索引不参与指标计算 # 返回包含loss、mIoU、像素准确率等指标的字典 def validate(model, loader, criterion, device, num_classes, ignore_index): # 将模型设置为【评估模式】 # 关闭 Dropout、BatchNorm 等训练时才用的随机操作保证预测稳定 model.eval() # 创建语义分割指标计算器用于累计混淆矩阵并计算 mIoU、PA 等指标 metric SegmentationMetric(num_classes, ignore_index) # 初始化变量total 累计总损失n 累计样本数量 total, n 0.0, 0 # 遍历验证集中的每一批数据图片x标签y for x, y in loader: # 将图片数据搬到指定设备GPU/CPU x x.to(device, non_blockingTrue) # 将标签数据搬到指定设备GPU/CPU y y.to(device, non_blockingTrue) # 模型前向推理得到输出 logits (B, num_classes, H, W) out model(x) # 计算当前批次的损失误差 loss criterion(out, y) # 获取当前批次的图片数量 batch size bs x.size(0) # 累计总损失loss.item() 是单张平均损失乘以 bs 得到本批次总损失 total loss.item() * bs # 累计总样本数 n bs # 将模型预测结果与真实标签传入指标计算器 # out.argmax(1)在通道维度取最大值得到类别索引图 (B, H, W) metric.update(out.argmax(1), y) # 所有图片计算完成计算最终指标pixel_acc、mean_acc、mIoU、每类IoU res metric.compute() # 把平均验证损失也加入结果字典 res[loss] total / max(1, n) # 返回所有评估结果 return resmain()def main(): # 1. 解析命令行参数--config / --epochs / --batch-size 等 args parse_args() # 2. 读取 yaml 配置文件数据集路径、超参、模型配置 with open(args.config, r, encodingutf-8) as f: cfg yaml.safe_load(f) # 3. 用命令行参数覆盖配置文件优先级命令行 yaml apply_cli_overrides(cfg, args) # 4. 创建保存文件夹模型权重目录 日志目录 os.makedirs(cfg[checkpoint_dir], exist_okTrue) os.makedirs(cfg[log_dir], exist_okTrue) # 5. 自动选择设备GPU 优先没有则用 CPU device select_device(args.device) # 数据集加载 # 构建训练集 验证集 train_set build_dataset(cfg, train) val_set build_dataset(cfg, val) # 训练集加载器打乱顺序、多线程、丢最后一批 train_loader DataLoader(train_set, batch_sizecfg[batch_size], shuffleTrue, num_workerscfg[num_workers], pin_memoryTrue, drop_lastTrue) # 验证集加载器不打乱、多线程 val_loader DataLoader(val_set, batch_sizecfg[batch_size], shuffleFalse, num_workerscfg[num_workers], pin_memoryTrue) # 模型构建 # 构建 SegNet 模型搬到 GPU/CPU model SegNet(num_classescfg[num_classes], pretrainedcfg[pretrained]).to(device) # 损失函数 # 是否使用类别权重解决样本不平衡SegNet 官方方法 if cfg.get(use_class_weights, True): cw compute_class_weights(train_set, cfg[num_classes], ignore_indexcfg[ignore_index]).to(device) else: cw None # 交叉熵损失语义分割标准损失支持类别权重 忽略索引 criterion nn.CrossEntropyLoss(weightcw, ignore_indexcfg[ignore_index]) # 优化器 学习率策略 # SGD 优化器带动量 权重衰减 optimizer SGD(model.parameters(), lrcfg[lr], momentumcfg[momentum], weight_decaycfg[weight_decay]) epochs cfg[epochs] power cfg.get(poly_power, 0.9) # Poly 学习率衰减学习率逐渐下降 (1 - e/epochs)^power scheduler torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambdalambda e: (1 - e / epochs) ** power ) # 日志初始化 log_path os.path.join(cfg[log_dir], train_log.csv) with open(log_path, w) as f: f.write(epoch,train_loss,val_loss,pixel_acc,mean_acc,mIoU,lr,time\n) # 训练主循环 best_miou 0.0 # 记录最好的 mIoU for epoch in range(1, epochs 1): t0 time.time() # 记录每轮开始时间 # 训练一轮更新模型权重 train_loss train_one_epoch(model, train_loader, optimizer, criterion, device) # 验证一轮计算 mIoU / 准确率 / 损失 val_res validate(model, val_loader, criterion, device, cfg[num_classes], cfg[ignore_index]) # 更新学习率 scheduler.step() # 计算本轮耗时 dt time.time() - t0 # 当前学习率 cur_lr optimizer.param_groups[0][lr] # 打印日志 print(fEpoch {epoch:3d}/{epochs} | fLoss: {train_loss:.4f}/{val_res[loss]:.4f} | fPA: {val_res[pixel_acc]:.4f} | fmAcc: {val_res[mean_acc]:.4f} | fmIoU: {val_res[mIoU]:.4f} | fLR: {cur_lr:.2e} | f{dt:.1f}s) # 把结果写入 CSV 文件 with open(log_path, a) as f: f.write(f{epoch},{train_loss:.6f},{val_res[loss]:.6f}, f{val_res[pixel_acc]:.6f},{val_res[mean_acc]:.6f}, f{val_res[mIoU]:.6f},{cur_lr:.6e},{dt:.2f}\n) # 保存模型 # 保存最新一轮模型 ckpt {epoch: epoch, state_dict: model.state_dict(), mIoU: float(val_res[mIoU]), config: cfg} torch.save(ckpt, os.path.join(cfg[checkpoint_dir], segnet_last.pth)) # 如果当前 mIoU 更高保存为最佳模型 if val_res[mIoU] best_miou: best_miou val_res[mIoU] torch.save(ckpt, os.path.join(cfg[checkpoint_dir], segnet_best.pth))

相关文章:

PyTorch深度学习实战 |SegNet

🌞欢迎来到PyTorch深度学习实战的世界 🌈博客主页:卿云阁 💌欢迎关注🎉点赞👍收藏⭐️留言📝 📆首发时间:🌹2026年4月29日🌹 ✉️希望可以和大家…...

Flowable 流程审计与排查:如何通过历史任务查询快速定位线上问题

Flowable 流程审计与排查:如何通过历史任务查询快速定位线上问题 当生产环境的审批流程突然停滞,或是某个关键业务环节出现异常时,运维团队往往面临巨大压力。上周我们遇到一个典型案例:某金融产品的开户流程在夜间批量处理时&…...

AI图像生成技术与提示词工程实战指南

1. AI图像生成技术概述AI图像生成技术是近年来计算机视觉领域最具突破性的进展之一。这项技术能够将自然语言描述转化为高质量的视觉内容,其核心在于深度学习模型对文本和图像之间复杂映射关系的理解与重建。目前主流的图像生成模型主要基于两种架构:生成…...

HiClaw 1.1.0:企业级 Agent 开发的基建升级

我最近在做一个企业 AI 培训项目,帮客户部署智能体平台。说实话,技术能力早就不是问题,真正的挑战是怎么让它在各种奇葩环境里稳稳当当跑起来。 上周刚交付一个项目,用的是 1.0.9 版本。客户验收那天说"挺稳的"&#x…...

新联合众香港展会圆满落幕,AI融合硬件矩阵获全球瞩目

2026年4月15日,中国北京​ – 随着香港环球资源消费电子展的帷幕缓缓落下,新联合众(北京)科技有限公司在此次行业盛会上圆满收官。为期四天的展会中,新联合众以“AI硬件融合”战略、一系列新品及完整的智能办公解决方案…...

收藏必备!小白程序员轻松掌握RAG大模型,让你的AI秒懂公司文档!

RAG 是什么:一句话类比 RAG(Retrieval-Augmented Generation) 先检索,再生成。 类比:RAG 就像开卷考试。模型本身是那个能写文章的学生,知识库是那一堆参考书。考试时不靠死记硬背,而是先翻书找…...

大数据开发场景下,总结并翻译 Oracle 中常见的错误(补充其他错误码:适合初学者)

Oracle大数据开发常见错误在Oracle大数据开发(如ETL、Hadoop抽取)中,常见错误分为五类:字段/表错误:如ORA-00904(无效列名)、ORA-00942(表不存在);数据类型/转…...

C++实现简单计算器

本文实例为大家分享了C实现简单计算器的具体代码,供大家参考,具体内容如下工具stackmap步骤初始化读取字符串去空格负号处理判断为空检查格式计算示例代码1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950…...

Unity游戏实时翻译终极指南:XUnity.AutoTranslator深度技术解析

Unity游戏实时翻译终极指南:XUnity.AutoTranslator深度技术解析 【免费下载链接】XUnity.AutoTranslator 项目地址: https://gitcode.com/gh_mirrors/xu/XUnity.AutoTranslator 在全球化游戏市场日益繁荣的今天,语言障碍成为玩家体验外语游戏的最…...

[Al+」数智升级,品牌种草营销新范式

AI给各行各业带来的革新有目共睹。在营销工作中,这个命题亦尤为迫切。AI如何嵌入具体场景、解决日常问题?过去一年,千瓜持续投入「AI」产品战略升级,现已覆盖“达人、内容、品牌”三大维度,实现从选人选号、内容创作到…...

脑矿奴隶起义:软件测试从业者的觉醒与革命

在当今数字化浪潮中,软件测试从业者常被戏称为“脑矿奴隶”——一群在代码矿山中日夜劳作的隐形工人,承受着高强度脑力压榨与价值低估。这场“脑矿奴隶起义”,不是历史上的血腥抗争,而是测试工程师们通过专业工具、自动化策略和集…...

Qwen3模型网络故障诊断辅助:图解常见错误与解决方案

Qwen3模型网络故障诊断辅助:图解常见错误与解决方案 网络一断,业务瘫痪。对于运维工程师来说,这可能是最让人心跳加速的时刻。面对屏幕上跳出的错误代码,从海量的日志和复杂的拓扑图中快速定位问题根源,无异于大海捞针…...

2026年小程序商城哪个平台最好?

2026年小程序商城哪个平台最好?小程序商城没有"最好的平台",只有"最匹配业务需求的平台"。选择平台的核心依据是功能匹配度、成本可控性和运营支持能力三者的平衡。从趋势来看,2023-2025年SaaS平台方案占比从约45%增长到…...

2026 AI存储行业迎来关键时刻:英伟达“补课”,华为存储“解题”

文 | 智能相对论作者 | 陈泊丞数十亿建成的万卡GPU集群,实际利用率不足40%。这不是某个智算中心的个例。在过去两年里,中国涌现了大大小小几十个智算中心项目,GPU买了一批又一批,但真正跑满的时候不多。问题不在芯片本身——而在数…...

Swoole+LLM长连接崩了?5个致命错误代码片段+4步热修复流程,现在不看明天宕机

更多请点击: https://intelliparadigm.com 第一章:SwooleLLM长连接崩了?5个致命错误代码片段4步热修复流程,现在不看明天宕机 当 Swoole 的 WebSocket Server 与 LLM 推理服务深度耦合后,长连接看似稳定,实…...

VS Code Copilot Next 工作流配置已进入“智能编排”时代:如何用3个JSON Schema + 1个DSL描述符接管全部重复性编码任务?

更多请点击: https://intelliparadigm.com 第一章:VS Code Copilot Next 工作流配置已进入“智能编排”时代 VS Code Copilot Next 不再仅是代码补全工具,而是演变为可感知上下文、理解任务意图、并自动串联多步骤开发动作的智能工作流引擎…...

git提交代码时,将大写文件改成小写,提交不上去了

主要原因:git add . 没成功把文件加入暂存区文件被 .gitignore 规则忽略了以后永久解决大小写问题git config core.ignorecase false...

环境一致性崩塌预警!Dev Containers 生产部署前必须验证的7项黄金检查项(含自动化校验脚本)

更多请点击: https://intelliparadigm.com 第一章:环境一致性崩塌预警!Dev Containers 生产部署前必须验证的7项黄金检查项(含自动化校验脚本) 当 Dev Containers 从本地开发跃迁至 CI/CD 流水线或预发环境时&#xf…...

构建高效测试反馈循环:从CI/CD到自动化测试的工程实践

1. 项目概述:一个关于测试与循环的探索最近在GitHub上看到一个名为suhuandds/test-pilot-loop的项目,这个标题本身就很有意思。test-pilot-loop,直译过来是“测试-飞行员-循环”,听起来像是一个航空领域的术语,但在软件…...

国产替代之2SK3704与VBMB1615参数对比报告

N沟道功率MOSFET参数对比分析报告一、产品概述2SK3704:三洋(SANYO)N沟道硅MOSFET,耐压60V,导通电阻低,开关速度快(超高速开关),采用4V驱动设计。封装:TO-220M…...

VS Code 远程容器开发环境崩溃实录(附完整日志解码手册):从 Dockerfile 语法错误到 OCI runtime error 的全链路排障指南

更多请点击: https://intelliparadigm.com 第一章:VS Code 远程容器开发环境崩溃现象全景速览 VS Code 的 Remote-Containers 扩展在现代云原生开发中广受青睐,但其稳定性在特定场景下存在显著挑战。开发者常遭遇容器意外退出、Dev Containe…...

BiliTools完整指南:如何轻松下载B站视频与弹幕

BiliTools完整指南:如何轻松下载B站视频与弹幕 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools 还在为下…...

MinIO 国产平替,RustFS 发布 Beta 版本啦

历经 2850 次 Git 提交,99 个 alpha 版本,我们正式发布 RustFS Beta 版。 自从 2025 年 7 月正式开源以来,RustFS 累计获得 26.5k star,1.1k fork,全球贡献者数量超 130 位,DockerHub 镜像拉取次数更是超过…...

保姆级教程:用UE5的Cable组件和PhysicsConstraint做个会晃的吊灯(蓝图版)

用UE5打造逼真物理吊灯:Cable组件与PhysicsConstraint深度实战 在虚幻引擎5的虚拟世界中,物理交互是营造沉浸感的关键要素之一。想象一下中世纪城堡大厅里摇曳的烛光,或是现代loft空间中极具设计感的悬挂灯具——这些场景的核心,往…...

前端性能优化:可访问性优化详解

前端性能优化:可访问性优化详解 为什么可访问性优化如此重要? 在现代Web应用中,可访问性是一个常常被忽视的重要因素。合理的可访问性优化可以确保所有用户(包括残障人士)都能正常使用网站,同时也能提高搜…...

2025届学术党必备的五大AI论文方案解析与推荐

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 当下,主流的AI论文辅助工具,各自有着不同的特点,GPT呢&am…...

WS2812点阵驱动时序调不好?保姆级示波器抓波形与FPGA调试心得分享

WS2812点阵驱动时序调不好?保姆级示波器抓波形与FPGA调试心得分享 第一次接触WS2812点阵时,看着数据手册上那些以纳秒为单位的时间参数,我整个人都是懵的。1180ns、1280ns、300us——这些数字在示波器上看起来就像是在玩一场高精度的电子游戏…...

前端性能优化:构建工具优化详解

前端性能优化:构建工具优化详解 为什么构建工具优化如此重要? 在现代Web开发中,构建工具是前端开发流程的重要组成部分。合理使用构建工具可以显著提高开发效率,优化代码质量,提升页面性能。因此,构建工具优…...

数据库迁移中的索引管理:Blue/Green部署策略

在现代软件开发中,数据库迁移和部署策略对于保证系统的稳定性和可用性至关重要。Blue/Green部署是一种常见的无停机更新方式,它通过在两个独立的环境中分别运行旧版本(Blue)和新版本(Green)应用来实现。今天我们来探讨在这种部署策略下,如何在两个PostgreSQL数据库实例间…...

深入理解NumPy数组切片

引言 在科学计算和数据分析领域,NumPy库无疑是Python中最强大的工具之一。NumPy提供了多维数组对象和大量用于处理数组的函数,其中数组切片(slicing)是经常使用到的功能之一。今天我们将探讨如何在NumPy中对一维数组进行切片操作,并解决一些常见的困惑。 数组切片简介 …...