【Pytorch和Keras】使用transformer库进行图像分类
目录
- 一、环境准备
- 二、基于Pytorch的预训练模型
- 1、准备数据集
- 2、加载预训练模型
- 3、 使用pytorch进行模型构建
- 三、基于keras的预训练模型
- 四、模型测试
- 五、参考
现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型,并且提供简洁的transformer模型调用,这大大提高了开发人员的开发效率。本博客主要利用transformer库实现一个简单的模型微调,以进行图像分类的任务。
一、环境准备
使用终端命令行安装对应的第三方包,具体安装命令输入如下:
pip install transformers datasets evaluate
二、基于Pytorch的预训练模型
由于下面这些内容需要在huggface上申请账号权限,才能进行模型和数据集加载,如果之前有从huggface上拉取模型和数据集的经验,可以略过,如果没有配置过,可以参考笔者之前的文章https://blog.csdn.net/qq_40734883/article/details/143922095,然后直接申请Write
权限就可以。
后续所有涉及到的数据集food101和transformer模型都需要参考上述文章进行直接下载,才能运行整个程序,或者在google的colab直接运行。
如果在google的colab上运行,请提前设置好电脑的GPU资源,同时加入huggface登录代码,具体如下:
from huggingface_hub import notebook_login
notebook_login()
运行之后会提示进行token输入,按之前获取到的token输入即可。
1、准备数据集
这里以food101
数据集作为微调数据集,在imagenet-21k上训练完成的transformer模型(vit-base-patch16-224
)进行优化
from datasets import load_datasetfood = load_dataset("food101", split="train[:5000]")# 划分数据集,训练集:测试集=8:2,food有两个键:一个train,一个test
food = food.train_test_split(test_size=0.2) # 标签转换
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = label
id2label
为通过id访问标签的字典,后续会使用到。
2、加载预训练模型
from transformers import AutoImageProcessorcheckpoint = "google/vit-base-patch16-224-in21k" # ImageNet-21k上的预训练模型
image_processor = AutoImageProcessor.from_pretrained(checkpoint) # 从huggface拉取并加载模型
3、 使用pytorch进行模型构建
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor# 数据预处理操作定义
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["shortest_edge"]if "shortest_edge" in image_processor.sizeelse (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])# 对原始数据进行RGB及字典化
def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesfood = food.with_transform(transforms)
# 验证
import evaluate
# 指定验证过程中的评价指标-准确率
accuracy = evaluate.load("accuracy")import numpy as np
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels)
训练设置和运行,具体输入代码如下:
# 整合训练中的数据,以便在模型训练或评估过程中使用
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()from transformers import AutoModelForImageClassification, TrainingArguments, Trainer# 初始化模型
model = AutoModelForImageClassification.from_pretrained(checkpoint,num_labels=len(labels),id2label=id2label,label2id=label2id,
)# 设置模型优化参数
training_args = TrainingArguments(output_dir="my_awesome_food_model",remove_unused_columns=False,evaluation_strategy="epoch",save_strategy="epoch",learning_rate=5e-5,per_device_train_batch_size=16,gradient_accumulation_steps=4,per_device_eval_batch_size=16,num_train_epochs=3,warmup_ratio=0.1,logging_steps=10,load_best_model_at_end=True,metric_for_best_model="accuracy",push_to_hub=True,
)# 初始化训练实例
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=food["train"],eval_dataset=food["test"],tokenizer=image_processor,compute_metrics=compute_metrics,
)trainer.train() # 开始训练trainer.push_to_hub() # 推送到huggfacehub
经过上述设置训练完成之后,会将模型微调结果推送到huggface平台,如果不想推送,可以不运行相关的命令行,并且training_args
中的push_to_hub=False
。
训练结果如下图所示:
默认需要选择是否关联wandb,如果不想选择,直接根据设置提示跳过即可。
如果选择了推送到huggfacehub(trainer.push_to_hub()
)的话,在个人的huggface上会有一个名为my_awesome_food_model
的模型,里面包含了模型训练的各个参数设置和测试结果。
三、基于keras的预训练模型
使用transflow的keras API 进行模型的搭建,具体代码如下:
from transformers import create_optimizer# 超参数设置
batch_size = 16
num_epochs = 5
num_train_steps = len(food["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01# 定义优化方式和策略
optimizer, lr_schedule = create_optimizer(init_lr=learning_rate, num_train_steps=num_train_steps, weight_decay_rate=weight_decay_rate, num_warmup_steps=0)# 定义分类器
from transformers import TFAutoModelForImageClassification
model = TFAutoModelForImageClassification.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)# converting our train dataset to tf.data.Dataset
tf_train_dataset = food["train"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator)# converting our test dataset to tf.data.Dataset
tf_eval_dataset = food["test"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=False, batch_size=batch_size, collate_fn=data_collator)# 定义损失函数
from tensorflow.keras.losses import SparseCategoricalCrossentropy
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)model.compile(optimizer=optimizer, loss=loss)from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
# 定义验证指标
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
# 推送到huggface回调函数
push_to_hub_callback = PushToHubCallback(output_dir="food_classifier", tokenizer=image_processor, save_strategy="no")
callbacks = [metric_callback, push_to_hub_callback]# 开始训练
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)
四、模型测试
这里使用微调好的模型在food101上找一张验证图像进行简单的验证测试,具体代码如下:
# 验证food中验证集的某一张图像
ds = load_dataset("food101", split="validation[-5:-1]")
image = ds["image"][-1]# visualize image
import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')
plt.show()
测试图像如下所示:
from transformers import pipeline
# initialize classifier instance
classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)from transformers import AutoImageProcessor
import torch
# load pre-trained image processor
image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")from transformers import AutoModelForImageClassification
# laod pre-trained model
model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():logits = model(**inputs).logits# 输出测试结果
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
输出结果如下所示:
Device set to use cuda:0
[{'label': 'ramen', 'score': 0.9517934918403625},{'label': 'bruschetta', 'score': 0.7566707730293274},{'label': 'hamburger', 'score': 0.7004948854446411},{'label': 'chicken_wings', 'score': 0.6275856494903564},{'label': 'prime_rib', 'score': 0.5991673469543457}]
预测结果为:ramen
。
释义:“ramen”一词源于日语“ラーメン”,是“拉面”的意思。它进一步追溯至汉语“拉面”,是一种起源于中国、流行于日本及其他东亚地区的面条食品。在日本,拉面通常由小麦面粉制成的面条,搭配肉汤和各种配料,如叉烧、鸡蛋、蔬菜等。
五、参考
[1] https://huggingface.co/docs/transformers/main/tasks/image_classification
[2] https://github.com/huggingface/transformers/blob/main/docs/source/en/installation.md
相关文章:

【Pytorch和Keras】使用transformer库进行图像分类
目录 一、环境准备二、基于Pytorch的预训练模型1、准备数据集2、加载预训练模型3、 使用pytorch进行模型构建 三、基于keras的预训练模型四、模型测试五、参考 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型&am…...
快速了解 c++ 异常处理 基础知识
相关代码概览: #include<stdexcept>std::runtime_errorcatch (const std::runtime_error& e) e.what() 相信大家一定见过这些代码,那么这些代码具体什么意思呢?我们一起来看一下 知识精讲: 异常处理是C中非常重要…...

deepseek API 调用-python
【1】创建 API keys 【2】安装openai SDK pip3 install openai 【3】代码: https://download.csdn.net/download/notfindjob/90343352...
玩转Gin框架:Golang使用Gin完成登录流程
文章目录 背景基于Token认证机制简介常见的Token类型Token的生成和验证在项目工程里创建jwt.go文件根目录新建.env文件 创建登录接口 /loginToken认证机制的优点 背景 登录流程,相信大家都很熟悉的。传统网站采用session后端验证登录状态,大致流程如下&…...

Linux学习笔记16---高精度延时实验
延时函数是很常用的 API 函数,在前面的实验中我们使用循环来实现延时函数,但是使用循环来实现的延时函数不准确,误差会很大。虽然使用到延时函数的地方精度要求都不会很严格( 要求严格的话就使用硬件定时器了 ) ,但是延时函数肯定…...

vue2:如何动态控制el-form-item之间的行间距
需求 某页面有查看和编辑两种状态: 编辑: 查看: 可以看到,查看时,行间距太大导致页面不紧凑,所以希望缩小查看是的行间距。 行间距设置 行间距通常是通过 CSS 的 margin 或 padding 属性来控制的。在 Element UI 的样式表中,.el-form-item 的下边距(margin-bottom)…...

deepseek从网络拓扑图生成说明文字实例
deepseek对话页面中输入问题指令: 我是安全测评工程师,正在撰写系统测评报告,现在需要对系统网络架构进行详细说明,请根据附件网络拓扑图输出详细说明文字。用总分的段落结构,先介绍各网络区域,再介绍网络…...

两种文件类型(pdf/图片)打印A4半张纸方法
环境:windows10、Adobe Reader XI v11.0.23 Pdf: 1.把内容由横排变为纵排: 2.点击打印按钮: 3.选择打印页范围和多页: 4.内容打印在纸张上部 图片: 1.右键图片点击打印: 2.选择打印类型: 3.打印配置&am…...

HTB:UnderPass[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机UDP开放端口进行脚本、服务扫描 …...

【deepseek实战】绿色好用,不断网
前言 最佳deepseek火热网络,我也开发一款windows的电脑端,接入了deepseek,基本是复刻了网页端,还加入一些特色功能。 助力国内AI,发出自己的热量 说一下开发过程和内容的使用吧。 目录 一、介绍 二、具体工作 1.1、引…...
MySQL 进阶专题:索引(索引原理/操作/优缺点/B+树)
在数据库的秋招面试中,索引(Index)是一个经典且高频的题目。索引的作用类似于书中的目录📖,它能够显著加快数据库查询的速度。本文将深入探讨索引的概念、作用、优缺点以及背后的数据结构,帮助你从原理到应…...

用NeuralProphet预测股价:AI金融新利器(附源码)
作者:老余捞鱼 原创不易,转载请标明出处及原作者。 写在前面的话:我用NeuralProphet模型预测了股票价格,发现其通过结合时间序列分析和神经网络算法,确实能提供比传统Last Value方法更精准的预测。经过一系列超参数调优…...
【Elasticsearch】parent aggregation
在Elasticsearch中,Parent Aggregation是一种特殊的单桶聚合,用于选择具有指定类型的父文档,这些类型是通过一个join字段定义的。以下是关于Parent Aggregation的详细介绍: 1.基本概念 Parent Aggregation是一种聚合操作&#x…...
IDEA使用Auto-dev+DeepSeek 10分钟快速集成,让java开发起飞
在当今的软件开发领域,AI 工具的辅助作用愈发凸显,DeepSeek AI 便是其中的佼佼者。它凭借强大的自然语言处理能力和高效的代码生成能力,成为众多开发者的得力助手。而 IntelliJ IDEA 作为一款广受欢迎的集成开发环境(IDE),若能与 DeepSeek AI 无缝集成,无疑将为开发者带…...

ASP.NET Core中间件Markdown转换器
目录 需求 文本编码检测 Markdown→HTML 注意 实现 需求 Markdown是一种文本格式;不被浏览器支持;编写一个在服务器端把Markdown转换为HTML的中间件。我们开发的中间件是构建在ASP.NET Core内置的StaticFiles中间件之上,并且在它之前运…...

使用page assist浏览器插件结合deepseek-r1 7b本地模型
为本地部署的DeepSeek R1 7b模型安装Page Assist,可以按照以下步骤进行: 一、下载并安装Ollama 首先,你需要下载并安装Ollama,这是部署DeepSeek所必需的工具。你可以访问Ollama的官方网站(ollama.com)下…...
【华为OD-E卷 - 108 最大矩阵和 100分(python、java、c++、js、c)】
【华为OD-E卷 - 最大矩阵和 100分(python、java、c、js、c)】 题目 给定一个二维整数矩阵,要在这个矩阵中选出一个子矩阵,使得这个子矩阵内所有的数字和尽量大,我们把这个子矩阵称为和最大子矩阵,子矩阵的…...

【Reading Notes】Favorite Articles from 2025
文章目录 1、January2、February3、March4、April5、May6、June7、July8、August9、September10、October11、November12、December 1、January 极越之后,中国车市只会倒下更多人(2025年01月01日) 在这波枪林弹雨中,合资品牌中最…...
云计算行业分析
云计算作为数字经济的核心基础设施,未来十年将持续重塑全球科技格局,并渗透到几乎所有行业的数字化转型中。 一、云计算的发展潜力 1. 技术融合驱动爆发式创新 AI与云计算的深度耦合 - **智能云服务**:云厂商将提供预训练模型、自动化ML工…...

【Linux系统】线程:线程的优点 / 缺点 / 超线程技术 / 异常 / 用途
1、线程的优点 创建和删除线程代价较小 创建一个新线程的代价要比创建一个新进程小得多,删除代价也小。这种说法主要基于以下几个方面: (1)资源共享 内存空间:每个进程都有自己独立的内存空间,包括代码段…...

应用分享 | 精准生成和时序控制!AWG在确定性三量子比特纠缠光子源中的应用
在量子技术飞速发展的今天,实现高效稳定的量子态操控是推动量子计算、量子通信等领域迈向实用化的关键。任意波形发生器(AWG)作为精准信号控制的核心设备,在量子实验中发挥着不可或缺的作用。丹麦哥本哈根大学的研究团队基于单个量…...
用Ai学习wxWidgets笔记——在 VS Code 中使用 CMake 搭建 wxWidgets 开发工程
声明:本文整理筛选Ai工具生成的内容辅助写作,仅供参考 >> 在 VS Code 中使用 CMake 搭建 wxWidgets 开发工程 下面是一步步指导如何在 VS Code 中配置 wxWidgets 开发环境,包括跨平台设置(Windows 和 Linux)。…...

什么是预训练?深入解读大模型AI的“高考集训”
1. 预训练的通俗理解:AI的“高考集训” 我们可以将预训练(Pre-training) 形象地理解为大模型AI的“高考集训”。就像学霸在高考前需要刷五年高考三年模拟一样,大模型在正式诞生前,也要经历一场声势浩大的“题海战术”…...

ISO 17387——解读自动驾驶相关标准法规(LCDAS)
Intelligent transport systems — Lane change decision aid systems (LCDAS) — Performance requirements and test procedures(First edition: 2008-05-01) 原文链接:https://cdn.standards.iteh.ai/samples/43654/701fd49bde7b4d3db165444b7c6f0c53/ISO-17387…...
Nginx+Tomcat负载均衡集群
目录 一、Tomcat 基础与单节点部署 (一)Tomcat 概述 (二)单节点部署案例 1. 案例环境 2. 实施准备 3. 安装 JDK 4. 查看 JDK 安装情况 5. 安装配置 Tomcat 6. 启动 Tomcat 7. 访问测试 8. 关闭 Tomcat (三…...

Model Context Protocol (MCP) 是一个前沿框架
微软发布了 Model Context Protocol (MCP) 课程:mcp-for-beginners。 Model Context Protocol (MCP) 是一个前沿框架,涵盖 C#、Java、JavaScript、TypeScript 和 Python 等主流编程语言,规范 AI 模型与客户端应用之间的交互。 MCP 课程结构 …...
PTC过流保护器件工作原理及选型方法
PTC过流保护器件 (Positive Temperature Coefficient,正温度系数热敏电阻)是一种过流保护元件,其工作原理基于电阻值随温度变化的特性。当电路正常工作时,PTC的阻值很小,电流可以顺畅通过;但当…...
Windows 下搭建 Zephyr 开发环境
1. 系统要求 操作系统:Windows 10/11(64位)磁盘空间:至少 8GB 可用空间(Zephyr 及其工具链较大)权限:管理员权限(部分工具需要) 2. 安装必要工具 winget安装依赖工具&am…...
使用 C/C++ 和 OpenCV 提取图像的感兴趣区域 (ROI)
使用 C/C 和 OpenCV 提取图像的感兴趣区域 (ROI) 在计算机视觉中,感兴趣区域 (Region of Interest, ROI) 是指从图像中选择的一个特定区域,我们希望对其进行进一步的处理或分析。例如,在人脸识别中,ROI 就是包含人脸的矩形框。Op…...

selinux firewalld
一、selinux 1.说明 SELinux 是 Security-Enhanced Linux 的缩写,意思是安全强化的 linux; SELinux 主要由美国国家安全局(NSA)开发,当初开发的目的是为了避免资源的误用 DAC(Discretionary Access Cont…...