当前位置: 首页 > news >正文

百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Imageimport matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224],                     # 输入图片的shape"class_dim": 12,                                 # 分类数"src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径"target_path":"/home/aistudio/data/dataset",     # 要解压的路径 "train_list_path": "./train.txt",                # train_data.txt路径"eval_list_path": "./eval.txt",                  # eval_data.txt路径"label_dict":{},                                 # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6,                                 # 训练轮数"train_batch_size": 16,                          # 批次的大小"learning_strategy": {                           # 优化函数相关的配置"lr": 0.0005                                  # 超参数学习率} 
}scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)

 解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

 这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练 

if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')

 使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))

设置eval模式,使用批处理测试准确率 

相关文章:

百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门 要求很简单,给train和test集,训练模型实现图像分类。 这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。 训练十几个迭代,每个批次60左右,准确率达到90%以上…...

多服务器云探针源码(服务器云监控)/多服务器多节点_云监控程序python源码

源码简介: 多服务器云探针源码(服务器云监控),支持python多服务器多节点,云监控程序源码。它是一款很实用的云探针和服务器云监控程序源码。使用它可以帮助管理员能够快速监控和管理各种服务器和节点,实用性强。 源码链接: 网盘…...

ESP8266 WiFi物联网智能插座—下位机软件实现

目录 1、软件架构 2、开发环境 3、软件功能 4、程序设计 4.1、初始化 4.2、主循环状态机 4.3、初始化模式 4.4、配置模式 4.5、运行模式 4.6、重启模式 4.7、升级模式 5、程序功能特点 5.1、日志管理 5.2、数据缓存队列 本篇博文开始讲解下位机插座节点的MCU软件…...

微信小程序--下拉选择框组件封装,可CV直接使用

一、起因 接到的项目需求,查看ui设计图后,由于微信小程序官方设计的下拉选择框不符合需求,而且常用的第三方库也没有封装类似的,所以选择自己自定义组件。在此记录一下,方便日后复用。 ui设计图如下: 微信官方提供的选择框 对比发现并不能实现我们想要的功能。 二、自定义组件…...

代码随想录算法训练营第五十九天 |647. 回文子串、516.最长回文子序列、动态规划总结篇

一、647. 回文子串 题目链接/文章讲解:代码随想录 思考: 1.确定dp数组(dp table)以及下标的含义 如果本题定义dp[i] 为 下标i结尾的字符串有 dp[i]个回文串的话: 会发现很难找到递归关系,dp[i] 和 dp[i-1]…...

互联网性能和可用性优化CDN和DNS

当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…...

使用 ErrorStack 在出现报错 ORA-14402 时产生的日志量

0、测试结论: 测试结果:设置 ErrorStack 级别为 1 时产生 Trace 的日志量最小,大小为 308K,同时在 alert 日志中也存在记录。 1、准备测试数据: sqlplus / as sysdba show pdbs alter session set containerpdb; …...

详解Spring-ApplicationContext

加载器目前有两种选择:ContextLoaderListener和ContextLoaderServlet。 这两者在功能上完全等同,只是一个是基于Servlet2.3版本中新引入的Listener接口实现,而另一个基于Servlet接口实现。开发中可根据目标Web容器的实际情况进行选择。 配…...

关键字extern、static与const

关键字extern、static与const extern关键字与include的区别 extern:于声明某个函数或变量是外部的(其他源文件中)include:用于批量引入 项目中可以根据需要引入的函数或变量数量决定使用extern还是include static关键字 static关键字用于限制函数和全局变量的作用域仅在当…...

虹科方案|国庆出游季,古建筑振动监测让历史古迹不再受损

全文导读: 国庆长假即将到来,各位小伙伴是不是都做好了出游计划呢?今年中秋、国庆“双节”连休八天,多地预计游客接待量将创下新高,而各地的名胜古迹更是人流爆满。迎接游客的同时,如何保障历史古迹不因巨大…...

Python学习笔记-使用哈希算法Hash,Hashlib进行数据加密

文章目录 一、概述1.1 哈希算法1.2 常见算法分类1.2.1 SHA算法1.2.2 MD4算法1.2.3 MD5算法 1.3 Hash算法的特性1.4 Hash算法的应用场景1.4.1 数据校验1.4.2 安全加密1.4.3 数字签名 二、Hash算法使用2.1 使用hash函数直接获取hash值2.2 使用hashlib库进行hash计算2.2.1 基本使用…...

跨境电商能否成为黄河流域产业带的新引擎?

近年来,随着全球贸易格局的不断演变和中国经济的快速崛起,跨境电商已经成为中国外贸的一大亮点。而在中国国内,黄河流域产业带一直以其丰富的资源和悠久的历史而闻名,但也面临着转型升级的挑战。那么,跨境电商是否有潜…...

从数据到决策:企业投资信息查询API的关键作用

前言 在现代商业环境中,数据是一项无价的资产。企业不仅需要访问大量数据,还需要将这些数据转化为有用的见解,以支持战略决策。对于企业投资而言,准确的信息和实时的市场数据至关重要。在这个信息时代,企业投资信息查…...

NSIC2050JBT3G 车规级120V 50mA ±15% 用于LED照明的线性恒流调节器(CCR) 增强汽车安全

随着汽车行业的巨大变革,高品质的汽车氛围灯效、仪表盘等LED指示灯效已成为汽车内饰设计中不可或缺的元素。深力科安森美LED驱动芯片系列赋能智能座舱灯效充满艺术感和科技感——NSIC2050JBT3G LED驱动芯片,实现对每路LED亮度和颜色进行细腻控制&#xf…...

LuatOS-SOC接口文档(air780E)-- ftp - ftp 客户端

ftp.login(adapter,ip_addr,port,username,password)# FTP客户端 参数 传入值类型 解释 int 适配器序号, 只能是socket.ETH0, socket.STA, socket.AP,如果不填,会选择平台自带的方式,然后是最后一个注册的适配器 string ip_addr 地址 string port 端口,默认21 string…...

第二证券:市净率高好还是低好?

市净率是一个衡量公司股票投资价值的指标,通过比较公司股票价格和公司每股净资产的比值来评估公司股票的估值水平。市净率高好还是低好这个问题并没有一个简单的答案,取决于具体的市场环境和投资者的需求。本文将从多个角度分析市净率高好还是低好。 首…...

HTTP协议是什么

HTTP (全称为 “超文本传输协议”) 是一种应用非常广泛的 应用层协议,是一种网络通信协议。 超文本:所谓 “超文本” 的含义, 就是传输的内容不仅仅是文本(比如 html, css 这个就是文本), 还可以是一些其他的资源, 比如图片, 视频, 音频等二进制的数据。…...

微服务09-Sentinel的入门

文章目录 微服务中的雪崩现象解决办法:1. 超时处理2. 舱壁模式3. 熔断降级4.流量控制 Sentinel1.介绍2.使用操作3.限流规则4.实战:流量监控5.高级选项功能的使用1.关联模式2.链路模式3.总结 流控效果1.预热模式2.排队等待模式3.总结4.热点参数限流5.实战…...

2023-2024-1 高级语言程序设计实验一: 选择结构

7-1 古时年龄称谓知多少? 输入一个人的年龄(岁),判断出他属于哪个年龄段 ? 0-9 :垂髫之年; 10-19: 志学之年; 20-29 :弱冠之年; 30-39 &#…...

js事件循环详解

事件循环简介 JavaScript的事件循环是一种处理异步事件和回调函数的机制,它是在浏览器或Node.js环境中运行,用于管理任务队列和调用栈,以及在适当的时候执行回调函数。 事件循环的基本原理是,JavaScript引擎在空闲时等待事件的到…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一) 1. CSI-2层定义(CSI-2 Layer Definitions) 分层结构 :CSI-2协议分为6层: 物理层(PHY Layer) : 定义电气特性、时钟机制和传输介质(导线&#…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

【配置 YOLOX 用于按目录分类的图片数据集】

现在的图标点选越来越多,如何一步解决,采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集(每个目录代表一个类别,目录下是该类别的所有图片),你需要进行以下配置步骤&#x…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的,可以通过集中管理和高效资源的分配,来支持多个独立的网站同时运行,让每一个网站都可以分配到独立的IP地址,避免出现IP关联的风险,用户还可以通过控制面板进行管理功…...

Redis:现代应用开发的高效内存数据存储利器

一、Redis的起源与发展 Redis最初由意大利程序员Salvatore Sanfilippo在2009年开发,其初衷是为了满足他自己的一个项目需求,即需要一个高性能的键值存储系统来解决传统数据库在高并发场景下的性能瓶颈。随着项目的开源,Redis凭借其简单易用、…...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机,它可以执行Java字节码。Java虚拟机是Java平台的一部分,Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...