百度飞浆ResNet50大模型微调实现十二种猫图像分类
12种猫分类比赛传送门
要求很简单,给train和test集,训练模型实现图像分类。
这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。
训练十几个迭代,每个批次60左右,准确率达到90%以上
一、导入库,解压文件
import os import zipfile import random import json import cv2 import numpy as np from PIL import Imageimport matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import paddle import paddle.nn as nn from paddle.io import Dataset,DataLoader from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms from paddle.vision.models import resnet50 from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224], # 输入图片的shape"class_dim": 12, # 分类数"src_path":"data/data10954/cat_12_train.zip", # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip", # 原始数据集路径"target_path":"/home/aistudio/data/dataset", # 要解压的路径 "train_list_path": "./train.txt", # train_data.txt路径"eval_list_path": "./eval.txt", # eval_data.txt路径"label_dict":{}, # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6, # 训练轮数"train_batch_size": 16, # 批次的大小"learning_strategy": { # 优化函数相关的配置"lr": 0.0005 # 超参数学习率} }scr_path=train_parameters['src_path'] target_path=train_parameters['target_path'] src_test_path=train_parameters["src_test_path"] z = zipfile.ZipFile(scr_path, 'r') z.extractall(path=target_path) z = zipfile.ZipFile(src_test_path, 'r') z.extractall(path=target_path) z.close() for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)
解压后将所有图像变为RGB图像
二、加载训练集,进行预处理、数据增强、格式变换
transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签 contents=[] with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test') for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)
重点是transforms变换的预处理
三、划分训练集和测试集
x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32') y_train=paddle.to_tensor(y_train,dtype='int64') x_test=paddle.to_tensor(x_test,dtype='float32') y_test=paddle.to_tensor(y_test,dtype='int64') x_eval=paddle.to_tensor(x_eval,dtype='float32')
这是必要的,可以随时利用测试集查看准确率
四、加载预训练模型,选择损失函数和优化器
learning_rate=0.001 epochs =5 # 迭代轮数 batch_size = 50 # 批次大小 weight_decay=1e-5 num_class=12cnn=resnet50(pretrained=True) checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False cnn.fc = nn.Linear(2048, num_class) cnn.set_dict(checkpoint['cnn_state_dict']) criterion=nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)
第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数
五、模型训练
if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train]) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数 paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')
使用批处理,这个很重要,不然平台分分钟炸了
六、测试集准确率
num_class=12 batch_size=64 cnn=resnet50(pretrained=True) checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False cnn.fc = nn.Linear(2048, num_class) cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2)) dataset = paddle.io.TensorDataset([x_test, y_test]) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))
设置eval模式,使用批处理测试准确率
相关文章:
百度飞浆ResNet50大模型微调实现十二种猫图像分类
12种猫分类比赛传送门 要求很简单,给train和test集,训练模型实现图像分类。 这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。 训练十几个迭代,每个批次60左右,准确率达到90%以上…...
多服务器云探针源码(服务器云监控)/多服务器多节点_云监控程序python源码
源码简介: 多服务器云探针源码(服务器云监控),支持python多服务器多节点,云监控程序源码。它是一款很实用的云探针和服务器云监控程序源码。使用它可以帮助管理员能够快速监控和管理各种服务器和节点,实用性强。 源码链接: 网盘…...
ESP8266 WiFi物联网智能插座—下位机软件实现
目录 1、软件架构 2、开发环境 3、软件功能 4、程序设计 4.1、初始化 4.2、主循环状态机 4.3、初始化模式 4.4、配置模式 4.5、运行模式 4.6、重启模式 4.7、升级模式 5、程序功能特点 5.1、日志管理 5.2、数据缓存队列 本篇博文开始讲解下位机插座节点的MCU软件…...
微信小程序--下拉选择框组件封装,可CV直接使用
一、起因 接到的项目需求,查看ui设计图后,由于微信小程序官方设计的下拉选择框不符合需求,而且常用的第三方库也没有封装类似的,所以选择自己自定义组件。在此记录一下,方便日后复用。 ui设计图如下: 微信官方提供的选择框 对比发现并不能实现我们想要的功能。 二、自定义组件…...
代码随想录算法训练营第五十九天 |647. 回文子串、516.最长回文子序列、动态规划总结篇
一、647. 回文子串 题目链接/文章讲解:代码随想录 思考: 1.确定dp数组(dp table)以及下标的含义 如果本题定义dp[i] 为 下标i结尾的字符串有 dp[i]个回文串的话: 会发现很难找到递归关系,dp[i] 和 dp[i-1]…...
互联网性能和可用性优化CDN和DNS
当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…...
使用 ErrorStack 在出现报错 ORA-14402 时产生的日志量
0、测试结论: 测试结果:设置 ErrorStack 级别为 1 时产生 Trace 的日志量最小,大小为 308K,同时在 alert 日志中也存在记录。 1、准备测试数据: sqlplus / as sysdba show pdbs alter session set containerpdb; …...
详解Spring-ApplicationContext
加载器目前有两种选择:ContextLoaderListener和ContextLoaderServlet。 这两者在功能上完全等同,只是一个是基于Servlet2.3版本中新引入的Listener接口实现,而另一个基于Servlet接口实现。开发中可根据目标Web容器的实际情况进行选择。 配…...
关键字extern、static与const
关键字extern、static与const extern关键字与include的区别 extern:于声明某个函数或变量是外部的(其他源文件中)include:用于批量引入 项目中可以根据需要引入的函数或变量数量决定使用extern还是include static关键字 static关键字用于限制函数和全局变量的作用域仅在当…...
虹科方案|国庆出游季,古建筑振动监测让历史古迹不再受损
全文导读: 国庆长假即将到来,各位小伙伴是不是都做好了出游计划呢?今年中秋、国庆“双节”连休八天,多地预计游客接待量将创下新高,而各地的名胜古迹更是人流爆满。迎接游客的同时,如何保障历史古迹不因巨大…...
Python学习笔记-使用哈希算法Hash,Hashlib进行数据加密
文章目录 一、概述1.1 哈希算法1.2 常见算法分类1.2.1 SHA算法1.2.2 MD4算法1.2.3 MD5算法 1.3 Hash算法的特性1.4 Hash算法的应用场景1.4.1 数据校验1.4.2 安全加密1.4.3 数字签名 二、Hash算法使用2.1 使用hash函数直接获取hash值2.2 使用hashlib库进行hash计算2.2.1 基本使用…...
跨境电商能否成为黄河流域产业带的新引擎?
近年来,随着全球贸易格局的不断演变和中国经济的快速崛起,跨境电商已经成为中国外贸的一大亮点。而在中国国内,黄河流域产业带一直以其丰富的资源和悠久的历史而闻名,但也面临着转型升级的挑战。那么,跨境电商是否有潜…...
从数据到决策:企业投资信息查询API的关键作用
前言 在现代商业环境中,数据是一项无价的资产。企业不仅需要访问大量数据,还需要将这些数据转化为有用的见解,以支持战略决策。对于企业投资而言,准确的信息和实时的市场数据至关重要。在这个信息时代,企业投资信息查…...
NSIC2050JBT3G 车规级120V 50mA ±15% 用于LED照明的线性恒流调节器(CCR) 增强汽车安全
随着汽车行业的巨大变革,高品质的汽车氛围灯效、仪表盘等LED指示灯效已成为汽车内饰设计中不可或缺的元素。深力科安森美LED驱动芯片系列赋能智能座舱灯效充满艺术感和科技感——NSIC2050JBT3G LED驱动芯片,实现对每路LED亮度和颜色进行细腻控制…...
LuatOS-SOC接口文档(air780E)-- ftp - ftp 客户端
ftp.login(adapter,ip_addr,port,username,password)# FTP客户端 参数 传入值类型 解释 int 适配器序号, 只能是socket.ETH0, socket.STA, socket.AP,如果不填,会选择平台自带的方式,然后是最后一个注册的适配器 string ip_addr 地址 string port 端口,默认21 string…...
第二证券:市净率高好还是低好?
市净率是一个衡量公司股票投资价值的指标,通过比较公司股票价格和公司每股净资产的比值来评估公司股票的估值水平。市净率高好还是低好这个问题并没有一个简单的答案,取决于具体的市场环境和投资者的需求。本文将从多个角度分析市净率高好还是低好。 首…...
HTTP协议是什么
HTTP (全称为 “超文本传输协议”) 是一种应用非常广泛的 应用层协议,是一种网络通信协议。 超文本:所谓 “超文本” 的含义, 就是传输的内容不仅仅是文本(比如 html, css 这个就是文本), 还可以是一些其他的资源, 比如图片, 视频, 音频等二进制的数据。…...
微服务09-Sentinel的入门
文章目录 微服务中的雪崩现象解决办法:1. 超时处理2. 舱壁模式3. 熔断降级4.流量控制 Sentinel1.介绍2.使用操作3.限流规则4.实战:流量监控5.高级选项功能的使用1.关联模式2.链路模式3.总结 流控效果1.预热模式2.排队等待模式3.总结4.热点参数限流5.实战…...
2023-2024-1 高级语言程序设计实验一: 选择结构
7-1 古时年龄称谓知多少? 输入一个人的年龄(岁),判断出他属于哪个年龄段 ? 0-9 :垂髫之年; 10-19: 志学之年; 20-29 :弱冠之年; 30-39 &#…...
js事件循环详解
事件循环简介 JavaScript的事件循环是一种处理异步事件和回调函数的机制,它是在浏览器或Node.js环境中运行,用于管理任务队列和调用栈,以及在适当的时候执行回调函数。 事件循环的基本原理是,JavaScript引擎在空闲时等待事件的到…...
地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配
AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...
微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...
IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...
企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
Springboot社区养老保险系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,社区养老保险系统小程序被用户普遍使用,为方…...
Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?
在大数据处理领域,Hive 作为 Hadoop 生态中重要的数据仓库工具,其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式,很多开发者常常陷入选择困境。本文将从底…...
【Java学习笔记】BigInteger 和 BigDecimal 类
BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点:传参类型必须是类对象 一、BigInteger 1. 作用:适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...
