当前位置: 首页 > news >正文

百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Imageimport matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224],                     # 输入图片的shape"class_dim": 12,                                 # 分类数"src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径"target_path":"/home/aistudio/data/dataset",     # 要解压的路径 "train_list_path": "./train.txt",                # train_data.txt路径"eval_list_path": "./eval.txt",                  # eval_data.txt路径"label_dict":{},                                 # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6,                                 # 训练轮数"train_batch_size": 16,                          # 批次的大小"learning_strategy": {                           # 优化函数相关的配置"lr": 0.0005                                  # 超参数学习率} 
}scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)

 解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

 这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练 

if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')

 使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))

设置eval模式,使用批处理测试准确率 

相关文章:

百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门 要求很简单,给train和test集,训练模型实现图像分类。 这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。 训练十几个迭代,每个批次60左右,准确率达到90%以上…...

多服务器云探针源码(服务器云监控)/多服务器多节点_云监控程序python源码

源码简介: 多服务器云探针源码(服务器云监控),支持python多服务器多节点,云监控程序源码。它是一款很实用的云探针和服务器云监控程序源码。使用它可以帮助管理员能够快速监控和管理各种服务器和节点,实用性强。 源码链接: 网盘…...

ESP8266 WiFi物联网智能插座—下位机软件实现

目录 1、软件架构 2、开发环境 3、软件功能 4、程序设计 4.1、初始化 4.2、主循环状态机 4.3、初始化模式 4.4、配置模式 4.5、运行模式 4.6、重启模式 4.7、升级模式 5、程序功能特点 5.1、日志管理 5.2、数据缓存队列 本篇博文开始讲解下位机插座节点的MCU软件…...

微信小程序--下拉选择框组件封装,可CV直接使用

一、起因 接到的项目需求,查看ui设计图后,由于微信小程序官方设计的下拉选择框不符合需求,而且常用的第三方库也没有封装类似的,所以选择自己自定义组件。在此记录一下,方便日后复用。 ui设计图如下: 微信官方提供的选择框 对比发现并不能实现我们想要的功能。 二、自定义组件…...

代码随想录算法训练营第五十九天 |647. 回文子串、516.最长回文子序列、动态规划总结篇

一、647. 回文子串 题目链接/文章讲解:代码随想录 思考: 1.确定dp数组(dp table)以及下标的含义 如果本题定义dp[i] 为 下标i结尾的字符串有 dp[i]个回文串的话: 会发现很难找到递归关系,dp[i] 和 dp[i-1]…...

互联网性能和可用性优化CDN和DNS

当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…...

使用 ErrorStack 在出现报错 ORA-14402 时产生的日志量

0、测试结论: 测试结果:设置 ErrorStack 级别为 1 时产生 Trace 的日志量最小,大小为 308K,同时在 alert 日志中也存在记录。 1、准备测试数据: sqlplus / as sysdba show pdbs alter session set containerpdb; …...

详解Spring-ApplicationContext

加载器目前有两种选择:ContextLoaderListener和ContextLoaderServlet。 这两者在功能上完全等同,只是一个是基于Servlet2.3版本中新引入的Listener接口实现,而另一个基于Servlet接口实现。开发中可根据目标Web容器的实际情况进行选择。 配…...

关键字extern、static与const

关键字extern、static与const extern关键字与include的区别 extern:于声明某个函数或变量是外部的(其他源文件中)include:用于批量引入 项目中可以根据需要引入的函数或变量数量决定使用extern还是include static关键字 static关键字用于限制函数和全局变量的作用域仅在当…...

虹科方案|国庆出游季,古建筑振动监测让历史古迹不再受损

全文导读: 国庆长假即将到来,各位小伙伴是不是都做好了出游计划呢?今年中秋、国庆“双节”连休八天,多地预计游客接待量将创下新高,而各地的名胜古迹更是人流爆满。迎接游客的同时,如何保障历史古迹不因巨大…...

Python学习笔记-使用哈希算法Hash,Hashlib进行数据加密

文章目录 一、概述1.1 哈希算法1.2 常见算法分类1.2.1 SHA算法1.2.2 MD4算法1.2.3 MD5算法 1.3 Hash算法的特性1.4 Hash算法的应用场景1.4.1 数据校验1.4.2 安全加密1.4.3 数字签名 二、Hash算法使用2.1 使用hash函数直接获取hash值2.2 使用hashlib库进行hash计算2.2.1 基本使用…...

跨境电商能否成为黄河流域产业带的新引擎?

近年来,随着全球贸易格局的不断演变和中国经济的快速崛起,跨境电商已经成为中国外贸的一大亮点。而在中国国内,黄河流域产业带一直以其丰富的资源和悠久的历史而闻名,但也面临着转型升级的挑战。那么,跨境电商是否有潜…...

从数据到决策:企业投资信息查询API的关键作用

前言 在现代商业环境中,数据是一项无价的资产。企业不仅需要访问大量数据,还需要将这些数据转化为有用的见解,以支持战略决策。对于企业投资而言,准确的信息和实时的市场数据至关重要。在这个信息时代,企业投资信息查…...

NSIC2050JBT3G 车规级120V 50mA ±15% 用于LED照明的线性恒流调节器(CCR) 增强汽车安全

随着汽车行业的巨大变革,高品质的汽车氛围灯效、仪表盘等LED指示灯效已成为汽车内饰设计中不可或缺的元素。深力科安森美LED驱动芯片系列赋能智能座舱灯效充满艺术感和科技感——NSIC2050JBT3G LED驱动芯片,实现对每路LED亮度和颜色进行细腻控制&#xf…...

LuatOS-SOC接口文档(air780E)-- ftp - ftp 客户端

ftp.login(adapter,ip_addr,port,username,password)# FTP客户端 参数 传入值类型 解释 int 适配器序号, 只能是socket.ETH0, socket.STA, socket.AP,如果不填,会选择平台自带的方式,然后是最后一个注册的适配器 string ip_addr 地址 string port 端口,默认21 string…...

第二证券:市净率高好还是低好?

市净率是一个衡量公司股票投资价值的指标,通过比较公司股票价格和公司每股净资产的比值来评估公司股票的估值水平。市净率高好还是低好这个问题并没有一个简单的答案,取决于具体的市场环境和投资者的需求。本文将从多个角度分析市净率高好还是低好。 首…...

HTTP协议是什么

HTTP (全称为 “超文本传输协议”) 是一种应用非常广泛的 应用层协议,是一种网络通信协议。 超文本:所谓 “超文本” 的含义, 就是传输的内容不仅仅是文本(比如 html, css 这个就是文本), 还可以是一些其他的资源, 比如图片, 视频, 音频等二进制的数据。…...

微服务09-Sentinel的入门

文章目录 微服务中的雪崩现象解决办法:1. 超时处理2. 舱壁模式3. 熔断降级4.流量控制 Sentinel1.介绍2.使用操作3.限流规则4.实战:流量监控5.高级选项功能的使用1.关联模式2.链路模式3.总结 流控效果1.预热模式2.排队等待模式3.总结4.热点参数限流5.实战…...

2023-2024-1 高级语言程序设计实验一: 选择结构

7-1 古时年龄称谓知多少? 输入一个人的年龄(岁),判断出他属于哪个年龄段 ? 0-9 :垂髫之年; 10-19: 志学之年; 20-29 :弱冠之年; 30-39 &#…...

js事件循环详解

事件循环简介 JavaScript的事件循环是一种处理异步事件和回调函数的机制,它是在浏览器或Node.js环境中运行,用于管理任务队列和调用栈,以及在适当的时候执行回调函数。 事件循环的基本原理是,JavaScript引擎在空闲时等待事件的到…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 (1)连接查询(JOIN) 内连接(INNER JOIN):返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

【JavaEE】-- HTTP

1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

java 实现excel文件转pdf | 无水印 | 无限制

文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...