当前位置: 首页 > news >正文

Lightning基础训练尝试实例

一、训练任务概述

动机:由于后续的课题中会用到类似图像去噪的算法,考虑先用U-Net,这里做一个前置的尝试。

训练任务:分割出图像中的细胞。

数据集:可私

数据集结构:

二、具体实现

U-Net的网络实现是现成的,只需要在网上找一个比较漂亮的实现(一般都是模块化,写的很漂亮)copy就可以了,需要特别注意的是最后整合的模型

2.1 基础模型模块实现

双卷积模块

class DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)

上采样模块

class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)return self.conv(x)

下采样模块

class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)

输出层

class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

2.2 整合模块->模型

class UNet(L.LightningModule):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits# 对应的层设置检查点,节省显存m,可用可不用def use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)# 定义优化器def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(),lr=0.001)return optimizer# 定义train的单步流程def training_step(self,train_batch,batch_index):image,label = train_batchimage_hat = self.forward(image)# U-Net的lossloss = nn.functional.mse_loss(image_hat,label)return loss# 定义val的单步流程def validation_step(self, val_batch,batch_index):image,label = val_batchimage_hat = self.forward(image)# U-Net的lossloss = nn.functional.mse_loss(image_hat,label)self.log('val_loss',loss)return loss

注意:模块可以不需要继承自L.LightningModule,只要最后整合的时候继承自L.LightningModule就可以了。

2.3 数据划分

重定义Dataset类,供数据集划分函数调用,二者要相互配合

class UDataset(Dataset):def __init__(self,image_dir,mask_dir,transform=None):self.image_dir = image_dirself.mask_dir = mask_dirif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):image = Image.open(self.image_dir[index]).convert('RGB')label = Image.open(self.mask_dir[index]).convert('RGB')if self.transform is not None:image = self.transform(image)label = self.transform(label)return image,labeldef __len__(self):return len(self.image_dir)

 定义数据集划分函数(包括"找出文件列表"、"定义数据预处理方式"、“定义批量大小”)

train_image_dir = "./data/train/image/*.png"
train_label_dir = "./data/train/label/*.png"
val_image_dir = "./data/val/image/*.png"
val_label_dir = "./data/val/label/*.png"  def data_process(train_image_dir,train_label_dir,val_image_dir,val_label_dir):# 查找路径下的所有文件,返回文件路径列表train_image_list = glob.glob(train_image_dir)train_label_list = glob.glob(train_label_dir)val_image_list = glob.glob(val_image_dir)val_label_list = glob.glob(val_label_dir)# 数据处理train_data_transform = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])val_data_transform = transforms.Compose([  transforms.Resize((256,256)),transforms.ToTensor()])train_dataloader = data.DataLoader(UDataset(train_image_list,train_label_list,train_data_transform),batch_size=5,shuffle=True)val_dataloader = data.DataLoader(UDataset(val_image_list,val_label_list,val_data_transform),batch_size=5,shuffle=False)return train_dataloader,val_dataloader

2.4 模型验证

在训练之前,要看一下模型的结构有没有错误,用summary打印出网络的结构

    # 模型验证device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = UNet(n_channels=3,n_classes=1).to(device)print(summary(model,(3,512,512)))

也可以用其他的方法查看网络结构

2.5 模型训练

加入TensorBoardLogger是为了可视化训练Loss

训练的流程遵循前文的基本流程

    # 创建 TensorBoardLoggerlogger = TensorBoardLogger("tb_logs", name="unet")# 创建 Trainertrainer = L.Trainer(max_epochs=20, logger=logger)# 划分数据集train_dataloader,val_dataloader = data_process(train_image_dir,train_label_dir,val_image_dir,val_label_dir)# 创建模型model = UNet(n_channels=3,n_classes=1)# 启动模型训练过程trainer.fit(model,train_dataloader,val_dataloader)# 保存模型权重torch.save(model.state_dict(),'./model.pth')

相关文章:

Lightning基础训练尝试实例

一、训练任务概述 动机:由于后续的课题中会用到类似图像去噪的算法,考虑先用U-Net,这里做一个前置的尝试。 训练任务:分割出图像中的细胞。 数据集:可私 数据集结构: 二、具体实现 U-Net的网络实现是现…...

osgearth视点坐标及鼠标交点坐标的信息显示(七)

核心函数如下: void COSGObject::addViewPointLabel() {//mRoot->addChild(osgEarth::Util::Controls::ControlCanvas::get(mViewer));//放开这句,球就卡住了。 为什么,shitosgEarth::Util::Controls::ControlCanvas* canvas = osgEarth::Util::Controls::ControlCanvas…...

动态规划 之 背包问题

文章目录 0-1背包问题2915.和为目标值的最长子序列的长度494.目标和 完全背包问题322.零钱兑换518.零钱兑换II 多重背包2585.获得分数的方法数 分组背包1155.掷骰子等于目标和的方法数 背包问题是动态规划一个很重要的一类题目,主要分为0-1背包问题以及完全背包问题…...

【Azure 架构师学习笔记】- Azure Databricks (11) -- UC搭建

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (10) – UC 使用 前言 由于ADB 的更新速度很快,在几个月之后重新搭建ADB 时发现UC 已经更新了很多,为了后续做ADB 的功…...

RTMP(Real-Time Messaging Protocol)

RTMP(Real-Time Messaging Protocol)是一种用于实时音视频和数据传输的协议,常见于直播和流媒体应用。 一 RTSP 协商消息 一、消息类型(Message Types) RTMP消息分为多种类型,通过Message Type ID标识&a…...

docker容器部署jar应用导入文件时候报缺少字体错误解决

如题,在导入文件时候报错如下: Handler dispatch failed; nested exception is java.lang.NoClassDefFoundError: Could not initialize class sun.awt.X11FontManager 经查是缺少对应字体,解决办法有两张: 第一种:…...

贪吃蛇解析

目录 文章结尾有代码可自取 Win32API 光标的隐藏 获取按键信息 控制光标位置 游戏开始前的准备 游戏准备及介绍 加载和欢迎界面 打印游戏指南 运行游戏 打印墙体和说明 设置蛇的各个信息 初始化及打印蛇 创造食物 运行游戏 1)打印得分情况 2&#…...

vue非组件的初学笔记

1.创建Vue实例,初始化渲染的核心 准备容器引包创建Vue实例new Vue() el用来指定控制的盒子data提供数据 2.插值表达式 作用利用表达式插值,将数据渲染到页面中 格式{{表达式}} 注意点 表达式的数据要在data中存在表达式是可计算结果的语句插值表达式…...

LeetCode 热题 100_单词搜索(60_79_中等_C++)(深度优先搜索(回溯))(初始化二维vector的大小)

LeetCode 热题 100_单词搜索(60_79) 题目描述:输入输出样例:题解:解题思路:思路一(深度优先搜索(回溯)): 代码实现代码实现(思路一&am…...

js闭包,跨域

js闭包,跨域 闭包 想象一下,你家有个大仓库(函数),仓库里放着各种东西(变量)。一般情况下,你从仓库外面是看不到也拿不到仓库里的东西的。但是,闭包就像是你在仓库里留…...

算法练习(力扣-BFS)——102. 二叉树的层序遍历

题目描述(简要概括) 题目链接:102. 二叉树的层序遍历 - 力扣(LeetCode) 题目要求对给定的二叉树进行层序遍历(从上到下,从左到右),并返回遍历的结果。层序遍历是一种基…...

Jetson Agx Orin平台preferred_stride调试记录--1924x720图像异常

1.问题描述 硬件: AGX Orin 在Jetpack 5.0.1和Jetpack 5.0.2上测试验证 图像分辨率在1920x720和1024x1920下图像采集正常 但是当采集图像分辨率为1924x720视频时,图像输出异常 像素格式:yuv_uyvy16 gstreamer命令如下 gst-launch-1.0 v4l2src device=/dev/video0 ! …...

nlp|微调大语言模型初探索(2),训练自己的聊天机器人

前言 上篇文章记录了具体的微调语言大模型步骤,以及在微调过程中可能遇见的各种报错,美中不足的是只是基于开源数据集的微调,今天来记录一下怎么基于自己的数据集去微调大语言模型,训练自己的智能机器人!!&…...

win11安装wsl报错:无法解析服务器的名称或地址(启用wsl2)

1. 启用wsl报错如下 # 查看可安装的 wsl --install wsl --list --online此原因是因为没有开启DNS的原因,所以需要我们手动开启DNS。 2. 按照如下配置即可 Google的DNS(8.8.8.8和8.8.4.4) 全国通用DNS地址 (114.114.114.114) 3. 运行以下命令来重启 WSL…...

Gentleman:优雅的Go语言HTTP客户端工具包

gentlemen介绍,特点等 插件驱动架构:Gentleman的核心特点是其插件系统,允许用户注册和重用各种自定义插件,如重试策略或动态服务器发现,以增强HTTP客户端的功能。 中间件层:项目内置了一个上下文感知的层次…...

解锁豆瓣高清海报(三)从深度爬虫到URL构造,实现极速下载

脚本地址: 项目地址: Gazer PosterBandit_v2.py 前瞻 之前的 PosterBandit.py 是按照深度爬虫的思路一步步进入海报界面来爬取, 是个值得学习的思路, 但缺点是它爬取慢, 仍然容易碰到豆瓣的 418 错误, 本文也会指出彻底解决旧版 418 错误的方法并提高爬取速度. 现在我将介绍…...

IDEA单元测试插件 SquareTest 延长试用期权限

SquareTest是一款强大的IDEA单元测试生成插件工具,具体使用方法就不过多介绍了,这里主要介绍变更试用期,方便大家使用 配置信息 我的电脑安装前提配置条件 IntelliJ IDEA 2023.2windows 系统 软件安装 IntelliJ IDEA 直接安装插件Squar…...

PLC的五个学习步骤

五个学习步骤详解: 1. 夯实电气基础 (第一步) 核心思想: PLC控制技术是建立在传统电气控制技术之上的,因此扎实的电气基础至关重要。学习内容: 电气元件原理: 深入理解继电器、接触器、按钮、三相异步电机等常用电气元件的工作原理。这是理解电气控制回…...

深度学习05 ResNet残差网络

目录 传统卷积神经网络存在的问题 如何解决 批量归一化BatchNormalization, BN 残差连接方式 ​残差结构 ResNet网络 ResNet 网络是在 2015年 由微软实验室中的何凯明等几位大神提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得CO…...

卷积神经网络CNN

目录 一、CNN概述 二、图像基础知识 三、卷积层 3.1 卷积的计算 3.2 Padding 3.3 Stride 3.4 多通道卷积计算 3.5 多卷积核卷积计算 3.6 特征图大小计算 3.7 Pytorch 卷积层API 四、池化层 4.1 池化计算 4.2 Stride 4.3 Padding 4.4 多通道池化计算 4.5 Pytorc…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误

HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...

UE5 学习系列(三)创建和移动物体

这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...

多种风格导航菜单 HTML 实现(附源码)

下面我将为您展示 6 种不同风格的导航菜单实现&#xff0c;每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现企业微信功能

1. 开发环境准备 ​​安装DevEco Studio 3.1​​&#xff1a; 从华为开发者官网下载最新版DevEco Studio安装HarmonyOS 5.0 SDK ​​项目配置​​&#xff1a; // module.json5 {"module": {"requestPermissions": [{"name": "ohos.permis…...

DeepSeek源码深度解析 × 华为仓颉语言编程精粹——从MoE架构到全场景开发生态

前言 在人工智能技术飞速发展的今天&#xff0c;深度学习与大模型技术已成为推动行业变革的核心驱动力&#xff0c;而高效、灵活的开发工具与编程语言则为技术创新提供了重要支撑。本书以两大前沿技术领域为核心&#xff0c;系统性地呈现了两部深度技术著作的精华&#xff1a;…...

Python实现简单音频数据压缩与解压算法

Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中&#xff0c;压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言&#xff0c;提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...