Lightning基础训练尝试实例
一、训练任务概述
动机:由于后续的课题中会用到类似图像去噪的算法,考虑先用U-Net,这里做一个前置的尝试。
训练任务:分割出图像中的细胞。
数据集:可私
数据集结构:
二、具体实现
U-Net的网络实现是现成的,只需要在网上找一个比较漂亮的实现(一般都是模块化,写的很漂亮)copy就可以了,需要特别注意的是最后整合的模型:
2.1 基础模型模块实现
双卷积模块
class DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)
上采样模块
class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)return self.conv(x)
下采样模块
class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)
输出层
class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)
2.2 整合模块->模型
class UNet(L.LightningModule):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits# 对应的层设置检查点,节省显存m,可用可不用def use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)# 定义优化器def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(),lr=0.001)return optimizer# 定义train的单步流程def training_step(self,train_batch,batch_index):image,label = train_batchimage_hat = self.forward(image)# U-Net的lossloss = nn.functional.mse_loss(image_hat,label)return loss# 定义val的单步流程def validation_step(self, val_batch,batch_index):image,label = val_batchimage_hat = self.forward(image)# U-Net的lossloss = nn.functional.mse_loss(image_hat,label)self.log('val_loss',loss)return loss
注意:模块可以不需要继承自L.LightningModule,只要最后整合的时候继承自L.LightningModule就可以了。
2.3 数据划分
重定义Dataset类,供数据集划分函数调用,二者要相互配合
class UDataset(Dataset):def __init__(self,image_dir,mask_dir,transform=None):self.image_dir = image_dirself.mask_dir = mask_dirif transform is not None:self.transform = transformelse:self.transform = Nonedef __getitem__(self, index):image = Image.open(self.image_dir[index]).convert('RGB')label = Image.open(self.mask_dir[index]).convert('RGB')if self.transform is not None:image = self.transform(image)label = self.transform(label)return image,labeldef __len__(self):return len(self.image_dir)
定义数据集划分函数(包括"找出文件列表"、"定义数据预处理方式"、“定义批量大小”)
train_image_dir = "./data/train/image/*.png"
train_label_dir = "./data/train/label/*.png"
val_image_dir = "./data/val/image/*.png"
val_label_dir = "./data/val/label/*.png" def data_process(train_image_dir,train_label_dir,val_image_dir,val_label_dir):# 查找路径下的所有文件,返回文件路径列表train_image_list = glob.glob(train_image_dir)train_label_list = glob.glob(train_label_dir)val_image_list = glob.glob(val_image_dir)val_label_list = glob.glob(val_label_dir)# 数据处理train_data_transform = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])val_data_transform = transforms.Compose([ transforms.Resize((256,256)),transforms.ToTensor()])train_dataloader = data.DataLoader(UDataset(train_image_list,train_label_list,train_data_transform),batch_size=5,shuffle=True)val_dataloader = data.DataLoader(UDataset(val_image_list,val_label_list,val_data_transform),batch_size=5,shuffle=False)return train_dataloader,val_dataloader
2.4 模型验证
在训练之前,要看一下模型的结构有没有错误,用summary打印出网络的结构
# 模型验证device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = UNet(n_channels=3,n_classes=1).to(device)print(summary(model,(3,512,512)))
也可以用其他的方法查看网络结构
2.5 模型训练
加入TensorBoardLogger是为了可视化训练Loss
训练的流程遵循前文的基本流程
# 创建 TensorBoardLoggerlogger = TensorBoardLogger("tb_logs", name="unet")# 创建 Trainertrainer = L.Trainer(max_epochs=20, logger=logger)# 划分数据集train_dataloader,val_dataloader = data_process(train_image_dir,train_label_dir,val_image_dir,val_label_dir)# 创建模型model = UNet(n_channels=3,n_classes=1)# 启动模型训练过程trainer.fit(model,train_dataloader,val_dataloader)# 保存模型权重torch.save(model.state_dict(),'./model.pth')
相关文章:

Lightning基础训练尝试实例
一、训练任务概述 动机:由于后续的课题中会用到类似图像去噪的算法,考虑先用U-Net,这里做一个前置的尝试。 训练任务:分割出图像中的细胞。 数据集:可私 数据集结构: 二、具体实现 U-Net的网络实现是现…...
osgearth视点坐标及鼠标交点坐标的信息显示(七)
核心函数如下: void COSGObject::addViewPointLabel() {//mRoot->addChild(osgEarth::Util::Controls::ControlCanvas::get(mViewer));//放开这句,球就卡住了。 为什么,shitosgEarth::Util::Controls::ControlCanvas* canvas = osgEarth::Util::Controls::ControlCanvas…...

动态规划 之 背包问题
文章目录 0-1背包问题2915.和为目标值的最长子序列的长度494.目标和 完全背包问题322.零钱兑换518.零钱兑换II 多重背包2585.获得分数的方法数 分组背包1155.掷骰子等于目标和的方法数 背包问题是动态规划一个很重要的一类题目,主要分为0-1背包问题以及完全背包问题…...

【Azure 架构师学习笔记】- Azure Databricks (11) -- UC搭建
本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (10) – UC 使用 前言 由于ADB 的更新速度很快,在几个月之后重新搭建ADB 时发现UC 已经更新了很多,为了后续做ADB 的功…...
RTMP(Real-Time Messaging Protocol)
RTMP(Real-Time Messaging Protocol)是一种用于实时音视频和数据传输的协议,常见于直播和流媒体应用。 一 RTSP 协商消息 一、消息类型(Message Types) RTMP消息分为多种类型,通过Message Type ID标识&a…...

docker容器部署jar应用导入文件时候报缺少字体错误解决
如题,在导入文件时候报错如下: Handler dispatch failed; nested exception is java.lang.NoClassDefFoundError: Could not initialize class sun.awt.X11FontManager 经查是缺少对应字体,解决办法有两张: 第一种:…...

贪吃蛇解析
目录 文章结尾有代码可自取 Win32API 光标的隐藏 获取按键信息 控制光标位置 游戏开始前的准备 游戏准备及介绍 加载和欢迎界面 打印游戏指南 运行游戏 打印墙体和说明 设置蛇的各个信息 初始化及打印蛇 创造食物 运行游戏 1)打印得分情况 2&#…...

vue非组件的初学笔记
1.创建Vue实例,初始化渲染的核心 准备容器引包创建Vue实例new Vue() el用来指定控制的盒子data提供数据 2.插值表达式 作用利用表达式插值,将数据渲染到页面中 格式{{表达式}} 注意点 表达式的数据要在data中存在表达式是可计算结果的语句插值表达式…...

LeetCode 热题 100_单词搜索(60_79_中等_C++)(深度优先搜索(回溯))(初始化二维vector的大小)
LeetCode 热题 100_单词搜索(60_79) 题目描述:输入输出样例:题解:解题思路:思路一(深度优先搜索(回溯)): 代码实现代码实现(思路一&am…...
js闭包,跨域
js闭包,跨域 闭包 想象一下,你家有个大仓库(函数),仓库里放着各种东西(变量)。一般情况下,你从仓库外面是看不到也拿不到仓库里的东西的。但是,闭包就像是你在仓库里留…...
算法练习(力扣-BFS)——102. 二叉树的层序遍历
题目描述(简要概括) 题目链接:102. 二叉树的层序遍历 - 力扣(LeetCode) 题目要求对给定的二叉树进行层序遍历(从上到下,从左到右),并返回遍历的结果。层序遍历是一种基…...

Jetson Agx Orin平台preferred_stride调试记录--1924x720图像异常
1.问题描述 硬件: AGX Orin 在Jetpack 5.0.1和Jetpack 5.0.2上测试验证 图像分辨率在1920x720和1024x1920下图像采集正常 但是当采集图像分辨率为1924x720视频时,图像输出异常 像素格式:yuv_uyvy16 gstreamer命令如下 gst-launch-1.0 v4l2src device=/dev/video0 ! …...

nlp|微调大语言模型初探索(2),训练自己的聊天机器人
前言 上篇文章记录了具体的微调语言大模型步骤,以及在微调过程中可能遇见的各种报错,美中不足的是只是基于开源数据集的微调,今天来记录一下怎么基于自己的数据集去微调大语言模型,训练自己的智能机器人!!&…...

win11安装wsl报错:无法解析服务器的名称或地址(启用wsl2)
1. 启用wsl报错如下 # 查看可安装的 wsl --install wsl --list --online此原因是因为没有开启DNS的原因,所以需要我们手动开启DNS。 2. 按照如下配置即可 Google的DNS(8.8.8.8和8.8.4.4) 全国通用DNS地址 (114.114.114.114) 3. 运行以下命令来重启 WSL…...
Gentleman:优雅的Go语言HTTP客户端工具包
gentlemen介绍,特点等 插件驱动架构:Gentleman的核心特点是其插件系统,允许用户注册和重用各种自定义插件,如重试策略或动态服务器发现,以增强HTTP客户端的功能。 中间件层:项目内置了一个上下文感知的层次…...

解锁豆瓣高清海报(三)从深度爬虫到URL构造,实现极速下载
脚本地址: 项目地址: Gazer PosterBandit_v2.py 前瞻 之前的 PosterBandit.py 是按照深度爬虫的思路一步步进入海报界面来爬取, 是个值得学习的思路, 但缺点是它爬取慢, 仍然容易碰到豆瓣的 418 错误, 本文也会指出彻底解决旧版 418 错误的方法并提高爬取速度. 现在我将介绍…...

IDEA单元测试插件 SquareTest 延长试用期权限
SquareTest是一款强大的IDEA单元测试生成插件工具,具体使用方法就不过多介绍了,这里主要介绍变更试用期,方便大家使用 配置信息 我的电脑安装前提配置条件 IntelliJ IDEA 2023.2windows 系统 软件安装 IntelliJ IDEA 直接安装插件Squar…...
PLC的五个学习步骤
五个学习步骤详解: 1. 夯实电气基础 (第一步) 核心思想: PLC控制技术是建立在传统电气控制技术之上的,因此扎实的电气基础至关重要。学习内容: 电气元件原理: 深入理解继电器、接触器、按钮、三相异步电机等常用电气元件的工作原理。这是理解电气控制回…...

深度学习05 ResNet残差网络
目录 传统卷积神经网络存在的问题 如何解决 批量归一化BatchNormalization, BN 残差连接方式 残差结构 ResNet网络 ResNet 网络是在 2015年 由微软实验室中的何凯明等几位大神提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得CO…...

卷积神经网络CNN
目录 一、CNN概述 二、图像基础知识 三、卷积层 3.1 卷积的计算 3.2 Padding 3.3 Stride 3.4 多通道卷积计算 3.5 多卷积核卷积计算 3.6 特征图大小计算 3.7 Pytorch 卷积层API 四、池化层 4.1 池化计算 4.2 Stride 4.3 Padding 4.4 多通道池化计算 4.5 Pytorc…...

SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
Java 语言特性(面试系列2)
一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

从WWDC看苹果产品发展的规律
WWDC 是苹果公司一年一度面向全球开发者的盛会,其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具,对过去十年 WWDC 主题演讲内容进行了系统化分析,形成了这份…...

最新SpringBoot+SpringCloud+Nacos微服务框架分享
文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...
uniapp中使用aixos 报错
问题: 在uniapp中使用aixos,运行后报如下错误: AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
【Go语言基础【13】】函数、闭包、方法
文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...
【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)
LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 题目描述解题思路Java代码 题目描述 题目链接:LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...

GraphRAG优化新思路-开源的ROGRAG框架
目前的如微软开源的GraphRAG的工作流程都较为复杂,难以孤立地评估各个组件的贡献,传统的检索方法在处理复杂推理任务时可能不够有效,特别是在需要理解实体间关系或多跳知识的情况下。先说结论,看完后感觉这个框架性能上不会比Grap…...
Java多线程实现之Runnable接口深度解析
Java多线程实现之Runnable接口深度解析 一、Runnable接口概述1.1 接口定义1.2 与Thread类的关系1.3 使用Runnable接口的优势 二、Runnable接口的基本实现方式2.1 传统方式实现Runnable接口2.2 使用匿名内部类实现Runnable接口2.3 使用Lambda表达式实现Runnable接口 三、Runnabl…...