当前位置: 首页 > article >正文

基于 Flask的深度学习模型部署服务端详解

基于 Flask 的深度学习模型部署服务端详解

在深度学习领域,训练出一个高精度的模型只是第一步,将其部署到生产环境中,为实际业务提供服务才是最终目标。本文将详细解析一个基于 Flask 和 PyTorch 的深度学习模型部署服务端代码,帮助你理解如何将训练好的模型以 API 形式提供给客户端使用。

一、整体概述

这段代码的主要功能是搭建一个基于 Flask 的 Web 服务,用于接收客户端发送的图像数据,使用预训练的 PyTorch 模型对图像进行分类预测,并将预测结果以 JSON 格式返回给客户端。

二、代码详细解析

1. 导入必要的库

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
  • io:用于处理二进制数据,这里主要用于将客户端发送的图像二进制数据转换为图像对象。
  • flask:一个轻量级的 Web 框架,用于搭建 Web 服务。
  • torchtorch.nn.functional:PyTorch 的核心库,用于深度学习模型的构建和计算。
  • PIL.Image:Python Imaging Library(PIL)的一部分,用于处理图像文件。
  • torch.nn:用于定义神经网络的层和模块。
  • torchvision.transformstorchvision.modelstransforms 用于图像预处理,models 提供了预训练的深度学习模型。

2. 初始化 Flask 应用和模型相关变量

app = flask.Flask(__name__)
model = None
use_gpu = False
  • app = flask.Flask(__name__):创建一个新的 Flask 应用实例,__name__ 参数用于确定应用的根路径。
  • model:用于存储加载的深度学习模型,初始化为 None
  • use_gpu:一个布尔变量,用于控制是否使用 GPU 进行模型推理,初始化为 False

3. 加载模型

def load_model():global modelmodel = models.resnet18()num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))checkpoint = torch.load('best.pth')model.load_state_dict(checkpoint['state_dict'])model.eval()if use_gpu:model.cuda()
  • global model:声明 model 为全局变量,以便在函数内部修改它。
  • model = models.resnet18():加载预训练的 ResNet-18 模型。
  • num_ftrs = model.fc.in_features:获取 ResNet-18 模型最后一层全连接层的输入特征数。
  • model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后一层全连接层,将输出维度改为 102,这里的 102 可以根据实际任务的类别数进行调整。
  • checkpoint = torch.load('best.pth'):从文件 best.pth 中加载训练好的模型参数。
  • model.load_state_dict(checkpoint['state_dict']):将加载的参数应用到模型中。
  • model.eval():将模型设置为评估模式,关闭一些在训练时使用的特殊层(如 Dropout)。
  • if use_gpu: model.cuda():如果 use_gpuTrue,将模型移动到 GPU 上。

4. 图像预处理

def prepare_image(image, target_size):if image.mode != 'RGB':image = image.convert('RGB')image = transforms.Resize(target_size)(image)image = transforms.ToTensor()(image)image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)image = image[None]if use_gpu:image = image.cuda()return torch.tensor(image)
  • if image.mode != 'RGB': image = image.convert('RGB'):确保输入图像为 RGB 格式。
  • image = transforms.Resize(target_size)(image):将图像调整为指定的大小。
  • image = transforms.ToTensor()(image):将图像转换为 PyTorch 张量。
  • image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image):对图像进行归一化处理,使用的均值和标准差是在 ImageNet 数据集上计算得到的。
  • image = image[None]:增加一个维度,将图像转换为批量输入的格式。
  • if use_gpu: image = image.cuda():如果 use_gpuTrue,将图像移动到 GPU 上。

5. 定义预测接口

@app.route('/predict', methods=['POST'])
def predict():data = {'success': False}if flask.request.method == 'POST':if flask.request.files.get('image'):image = flask.request.files['image'].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, target_size=(224, 224))preds = F.softmax(model(image), dim=1)results = torch.topk(preds.cpu().data, k=3, dim=1)results = (results[0].cpu().numpy(), results[1].cpu().numpy())data['prediction'] = list()for prob, label in zip(results[0][0], results[1][0]):r = {'label': str(label), 'probability': float(prob)}data['prediction'].append(r)data['success'] = Truereturn flask.jsonify(data)
  • @app.route('/predict', methods=['POST']):使用 Flask 的装饰器定义一个路由,当客户端向 /predict 路径发送 POST 请求时,会调用 predict 函数。
  • data = {'success': False}:初始化一个字典,用于存储预测结果和状态信息,初始状态为 success = False
  • if flask.request.method == 'POST':检查请求方法是否为 POST。
  • if flask.request.files.get('image'):检查请求中是否包含名为 image 的文件。
  • image = flask.request.files['image'].read():读取客户端发送的图像文件内容。
  • image = Image.open(io.BytesIO(image)):将二进制数据转换为图像对象。
  • image = prepare_image(image, target_size=(224, 224)):对图像进行预处理。
  • preds = F.softmax(model(image), dim=1):使用模型进行预测,并通过 softmax 函数将输出转换为概率分布。
  • results = torch.topk(preds.cpu().data, k=3, dim=1):获取概率最大的前 3 个结果。
  • results = (results[0].cpu().numpy(), results[1].cpu().numpy()):将结果转换为 NumPy 数组。
  • data['prediction'] = list():初始化一个列表,用于存储预测结果。
  • for prob, label in zip(results[0][0], results[1][0]):遍历前 3 个结果,将标签和概率封装成字典,并添加到 data['prediction'] 列表中。
  • data['success'] = True:将状态信息设置为 success = True,表示预测成功。
  • return flask.jsonify(data):将结果以 JSON 格式返回给客户端。

6. 启动服务

if __name__ == '__main__':print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started')load_model()app.run(host='192.168.1.20', port=5012)
  • if __name__ == '__main__':确保代码作为主程序运行时才执行以下操作。
  • print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started'):打印启动信息。
  • load_model():调用 load_model 函数加载模型。
  • app.run(host='192.168.1.20', port=5012):启动 Flask 服务,监听 192.168.1.20 地址的 5012 端口。运行结果如下
  • 在这里插入图片描述

三、总结

通过上述代码,我们成功搭建了一个基于 Flask 和 PyTorch 的深度学习模型部署服务端。客户端可以通过向 /predict 路径发送包含图像文件的 POST 请求,获取图像分类的预测结果。在实际应用中,可以根据需要对代码进行扩展,如增加更多的模型、优化图像预处理流程、添加错误处理机制等。希望本文能帮助你更好地理解深度学习模型的部署过程。

相关文章:

基于 Flask的深度学习模型部署服务端详解

基于 Flask 的深度学习模型部署服务端详解 在深度学习领域,训练出一个高精度的模型只是第一步,将其部署到生产环境中,为实际业务提供服务才是最终目标。本文将详细解析一个基于 Flask 和 PyTorch 的深度学习模型部署服务端代码,帮…...

洛谷 P1850 [NOIP 2016 提高组] 换教室

题目传送门 前言 终于自己想出概率期望 d p dp dp 的状态了,但是依旧没能相对转移方程。(招笑) 暴力 这题部分分和特殊情况分给的挺多的,所以先拿部分分。 一、思路 先跑一边 F l o y d Floyd Floyd 最短路求出两点间最短距…...

C#生成二维码和条形码

C# 实现二维码和条形码生成:从入门到实战 文章目录 C# 实现二维码和条形码生成:从入门到实战一、引言二、准备工作2.1 开发环境搭建2.2 引入相关库 三、生成条形码3.1 条形码基本概念3.2 使用[ZXing.Net](https://ZXing.Net)生成条形码3.2.1 核心代码实现…...

【金仓数据库征文】金仓数据库 KES:MySQL 迁移实用指南

我们都知道,现在企业数字化转型那可是势在必行,数据库迁移这事儿就变得特别关键。金仓数据库的 KingbaseES(简称 KES),就给咱从 MySQL 往 KES 迁移数据库提供了一套超好用的方案。下面咱就讲下 咋用金仓数据库来完成这…...

多态(c++详细版)

一.多态 1.1 多态的概念 多态(polymorphism)的概念:通俗来说,就是多种形态。多态分为编译时多态(静态多态)和运⾏时多态(动态多态),这⾥我们重点讲运⾏时多态,编译时多态(静态多态)和运⾏时多态(动态多态)。编译时多态(静态多态)主…...

内存泄漏系列专题分析之八:高通相机CamX内存泄漏内存占用分析--通用ION(dmabuf)内存拆解

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:内存泄漏系列专题分析之七:高通相机CamX--Android通用ION(dmabuf)内存分配和释放原理 这一篇我们开始讲: 内存泄漏系列专题分析之八:高通相机CamX内存泄漏&内存占用分析--通用ION(dmabuf)内…...

后端项目进度汇报

项目概述 本项目致力于构建一个先进的智能任务自动化平台。其核心技术是一套由大型语言模型(LLM)驱动的后端系统。该系统能够模拟一个多角色协作的团队,通过一系列精心设计或动态生成的处理阶段,来高效完成各种复杂任务&#xff…...

数据结构——二叉树和堆(万字,最详细)

目录 1.树 1.1 树的概念与结构 1.2 树相关的术语 1.3 树的表示法 2.二叉树 2.1 概念与结构 2.2 特殊的二叉树 2.2.1 满二叉树 2.2.2 完全二叉树 2.3 二叉树存储结构 2.3.1 顺序结构 2.3.2 实现顺序结构二叉树 2.3.2.1 堆的概念与结构 2.3.2. 2 堆的插入与删除数据…...

MATLAB基于格拉姆角场与2DCNN-BiGRU的轴承故障诊断模型

本博客来源于CSDN机器鱼,未同意任何人转载。 更多内容,欢迎点击本专栏目录,查看更多内容。 目录 0 引言 1 格拉姆角场原理 2 2DCNN-BiGRU网络结构 3 应用实例 3.1 数据准备 3.2 格拉姆角场数据提取 3.3 网络模型搭建-重中之重 3.4 …...

正点原子IMX6U开发板移植Qt时出现乱码

移植Qt时出现乱码 1、前言2、问题3、总结 1、前言 记录一下正点原子IMX6U开发板移植Qt时出现乱码的解决方法,方便自己日后回顾,也可以给有需要的人提供帮助。 2、问题 用正点原子IMX6U开发板移植Qt时移植Qt后,sd卡里已经存储了Qt的各种库&…...

JVM局部变量表和操作数栈的内存布局

局部变量表和操作数栈 首先看一段Java源码 public class Add_Sample{public int add(int i, int j){int k 100;int result i j k;return result;}public static void main(String[] args){int result new Add_Sample().add(10,20);System.out.println(result);} }使用ja…...

Mockoon 使用教程

文章目录 一、简介二、模拟接口1、Get2、Post 一、简介 1、Mockoon 可以快速模拟API,无需远程部署,无需帐户,免费,跨平台且开源,适合离线环境。 2、支持get、post、put、delete等所有格式。 二、模拟接口 1、Get 左…...

使用 IDEA + Maven 搭建传统 Spring MVC 项目的详细步骤(非Spring Boot)

搭建Spring MVC项目 第一步:创建Maven项目第二步:配置pom.xml第三步:配置web.xml第四步:创建Spring配置文件第五步:创建控制器第六步:创建JSP视图第七步:配置Tomcat并运行目录结构常见问题解决与…...

以下是在 Ubuntu 上的几款PDF 阅读器,涵盖轻量级、功能丰富和特色工具:

默认工具:Evince(GNOME 文档查看器) 特点:Ubuntu 预装,轻量快速,支持基本标注和书签。 安装:已预装,或手动安装: sudo apt install evince功能全面:Okular&…...

3.2.3 掌握RDD转换算子 - 4. 按键归约算子 - reduceByKey()

在本节课中,我们深入学习了Spark RDD的reduceByKey()算子。reduceByKey()主要用于处理元素为(key, value)形式的RDD,能够将相同key的元素聚集并合并,最终返回一个新RDD,其元素类型与原RDD保持一致。通过案例演示,我们首…...

AI领域的MCP(Model-Centric Paradigm)

1. 什么是MCP(Model-Centric Paradigm)? MCP(Model-Centric Paradigm)是人工智能开发中的一种核心理念,强调以模型的优化与改进作为主要驱动因素来提升AI系统的表现。在MCP模式下,开发者专注于…...

Pandas比MySQL快?

知乎上有人问,处理百万级数据,Python列表、Pandas、Mysql哪个更快? Pands是Python中非常流行的数据处理库,拥有大量用户,所以拿它和Mysql对比也是情理之中。 实测来看,MySQL > Pandas > Python列表…...

模拟内存管理

文章目录 1. 实验六:内存管理2. 记录内存空间使用情况2.1 全局参数2.2 内存空间相关参数2.3 关键结构体定义2.4 内存系统初始化 3. 记录空闲分区3.1 采用位图的方式记录物理内存中的空闲帧3.1.1 记录方式3.1.2 举例分析 3.2 主要操作3.2.1 初始化空闲帧:…...

大模型调优方法与注意事项

大模型调优(Fine-tuning)是指对预训练的大型语言模型(如GPT、BERT、LLaMA等)进行二次训练,使其适应特定任务或领域的过程。以下是调优的关键步骤、方法和注意事项: 一、调优的核心步骤 任务定义与数据准备 …...

简易的考试系统设计(Web实验)

简易的考试系统设计(Web实验) 1.实验内容与设计思想(一)实验需求(二)设计思路 2.代码展示3.实验小结 1.实验内容与设计思想 (一)实验需求 1.编写两个页面程序,一个HTML…...

【嵌入式开发-SDIO】

嵌入式开发--SDIO ■ SDIO-简介■■■■■ ■ SDIO-简介 SDIO(Secure Digital Input and Output),即安全数字输入输出接口。它是在SD卡接口的基础上发展而来,它可以兼容之前的SD卡,并可以连接SDIO接口设备,比如:蓝牙、…...

基于Kubernetes的Apache Pulsar云原生架构解析与集群部署指南(上)

#作者:闫乾苓 文章目录 概念和架构概述主要特点消息传递核心概念Pulsar 的消息模型Pulsar 的消息存储与分发Pulsar 的高级特性架构BrokerBookKeeperZooKeeper 概念和架构 概述 Pulsar 是一个多租户、高性能的服务器到服务器消息传递解决方案。Pulsar 最初由雅虎开…...

车载网络TOP20核心概念科普

一、基础协议与总线技术 CAN总线 定义:控制器局域网,采用差分信号传输,速率最高1Mbps,适用于实时控制(如动力系统)。形象比喻:如同“神经系统”,负责传递关键控制信号。 LIN总线 定…...

使用JAVA对接Deepseek API实现首次访问和提问

一、标题 参考:https://www.cnblogs.com/saoge/p/18866776 使用JAVA对接Deepseek API实现首次访问和 提问:我有50万能做什么小本生意,举例3个! 二、代码 import java.io.BufferedReader; import java.io.InputStreamReader; import java.…...

【C语言】文件操作(续)

目录 复习: 一⽂件的顺序读写 例子: 前言: 在上篇文章中介绍了文件的类型,文件指针,流,操作的函数。 在本篇文章继续为大家带来文件细节分享,如 顺序读写等等。 复习: fopen是…...

基于CBOW模型的词向量训练实战:从原理到PyTorch实现

基于CBOW模型的词向量训练实战:从原理到PyTorch实现 在自然语言处理(NLP)领域,词向量是将单词映射为计算机可处理的数值向量的重要方式。通过词向量,单词之间的语义关系能够以数学形式表达,为后续的文本分…...

mac连接lniux服务器教学笔记

从你的检查结果看,容器内已经安装了 XFCE 桌面环境(xfce.desktop 和 xubuntu.desktop 的存在说明桌面环境已存在)。以下是针对 Docker 容器环境的远程桌面配置方案: 一、容器内快速配置远程桌面(XFCE VNC)…...

vue3 - keepAlive缓存组件

在Vue 3中&#xff0c;<keep-alive>组件用于缓存动态组件或路由组件的状态&#xff0c;避免重复渲染&#xff0c;提升性能。 我们新建两个组件&#xff0c;在每一个组件里面写一个input&#xff0c;在默认情况下当组件切换的时候&#xff0c;数据会被清空&#xff0c;但…...

阀门产业发展方向报告(石油化工阀门应用技术交流大会)

本文大部分内容来自中国通用机械工业协会副会长张宗列在“2024全国石油化工阀门应用技术交流大会”上发表的报告。 一、国外阀门产业发展 从全球阀门市场分布看&#xff0c;亚洲是最大的工业阀门市场&#xff0c;美洲是全球第二大工业阀门市场&#xff0c;欧洲位列第三。 从国…...

Windows Server 2025 安装AMD显卡驱动

运行显卡驱动安装程序&#xff0c;会提示出问题。但是此时资源已经解压 来到驱动路径 C:\AMD\AMD-Software-Installer\Packages\Drivers\Display\WT6A_INF 打开配置文件&#xff0c;把这两行替换掉 %ATI% ATI.Mfg, NTamd64.10.0...16299, NTamd64.10.0, NTamd64.6.0, NTamd64.…...