当前位置: 首页 > news >正文

采用自动微分进行模型的训练

 自动微分训练模型

 简单代码实现:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的线性回归模型
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)# 准备训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 实例化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器# 训练模型
epochs = 1000
for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 测试模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'预测值: {predicted.item():.4f}')

代码分解:

1.定义一个简单的线性回归模型:

  • LinearRegression 类继承自nn.Module,这是所有神经网络模型的基类
  • 在 __init__ 方法中,定义了一个线性层 self.linear,它的输入维度是1,输出维度也是1。
  • forward 方法定义了数据在模型中的传播路径,即输入 x 经过 self.linear 层后得到输出。
    class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)
    

2.准备训练数据:

  • x_train 和 y_train 分别是输入和目标输出的训练数据。每个张量表示一个样本,x_train 中的每个元素是一个维度为1的张量,因为模型的输入维度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.实例化模型,损失函数和优化器:

  • model 是我们定义的 LinearRegression 类的一个实例,即我们要训练的线性回归模型。
  • criterion 是损失函数,这里选择了均方误差损失(MSE Loss),用于衡量预测值与实际值之间的差异。
  • optimizer 是优化器,这里选择了随机梯度下降(SGD),用于更新模型参数以最小化损失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    

4.训练模型:

  • 这里进行了1000次迭代的训练过程。
  • 在每个迭代中,首先进行前向传播,计算模型对 x_train 的预测输出 outputs,然后计算损失 loss
  • 调用 optimizer.zero_grad() 来清空之前的梯度,然后调用 loss.backward() 自动计算梯度,最后调用 optimizer.step() 来更新模型参数
    epochs = 1000
    for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.测试模型:

  • x_test 是用来测试模型的输入数据,这里表示输入为4.0。
  • model(x_test) 对 x_test 进行前向传播,得到预测结果 predicted
  • predicted.item() 取出预测结果的标量值并打印出来。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'预测值: {predicted.item():.4f}')
    

运行结果:

运行结果如下:

 

相关文章:

采用自动微分进行模型的训练

自动微分训练模型 简单代码实现: import torch import torch.nn as nn import torch.optim as optim# 定义一个简单的线性回归模型 class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear nn.Linear(1, 1) …...

k8s怎么配置secret呢?

在Kubernetes中,配置Secret主要涉及到创建、查看和使用Secret的过程。以下是配置Secret的详细步骤和相关信息: ### 1. Secret的概念 * Secret是Kubernetes用来保存密码、token、密钥等敏感数据的资源对象。 * 这些敏感数据可以存放在Pod或镜像中&#x…...

算法篇 滑动窗口 leetcode 长度最小的子数组

长度最小的子数组 1. 题目描述2. 算法图分析2.1 暴力图解2.2 滑动窗口图解 3. 代码演示 1. 题目描述 2. 算法图分析 2.1 暴力图解 2.2 滑动窗口图解 3. 代码演示...

数据库作业d8

要求: 一备份 1 mysqldump -u root -p booksDB > booksDB_all_tables.sql 2 mysqldump -u root -p booksDB books > booksDB_books_table.sql 3 mysqldump -u root -p --databases booksDB test > booksDB_and_test_databases.sql 4 mysql -u roo…...

前后端数据交互设计到的跨域问题

前后端分离项目的跨域问题及解决办法 一、跨域简述 1、问题描述 这里前端vue项目的端口号为9000,后端springboot项目的端口号为8080 2、什么是跨域 当一个请求url的协议、域名、端口三者之间任意一个与当前页面url不同即为跨域 当前页面url被请求页面url是否…...

非洲猪瘟监测设备的作用是什么?

TH-H160非洲猪瘟监测设备的主要作用是迅速、准确地检测出非洲猪瘟病毒&#xff0c;从而帮助控制和预防疫情的扩散。这些设备利用先进的生物传感技术和PCR分子生物学方法&#xff0c;能够在极短的时间内提供精确的检测结果<sup>1</sup><sup>2</sup><…...

移动硬盘损坏无法读取?专业恢复策略全解析

在数字化信息爆炸的今天&#xff0c;移动硬盘作为我们存储和传输大量数据的重要工具&#xff0c;其安全性和稳定性直接关系到个人与企业的数据安全。然而&#xff0c;当移动硬盘突然遭遇损坏&#xff0c;无法正常读取时&#xff0c;我们该如何应对&#xff1f;本文将深入探讨移…...

神经网络以及简单的神经网络模型实现

神经网络基本概念&#xff1a; 神经元&#xff08;Neuron&#xff09;&#xff1a; 神经网络的基本单元&#xff0c;接收输入&#xff0c;应用权重并通过激活函数生成输出。 层&#xff08;Layer&#xff09;&#xff1a; 神经网络由多层神经元组成。常见的层包括输入层、隐藏层…...

java中压缩文件的解析方式(解析文件)

背景了解&#xff1a;java中存在IO流的方式&#xff0c;支持我们对文件进行读取&#xff08;Input&#xff0c;从磁盘到内存&#xff09;或写入&#xff08;output&#xff0c;从内存到磁盘&#xff09;&#xff0c;那么我们在面对 “zip”格式或者 “rar” 格式的压缩文件&…...

巧用 VScode 网页版 IDE 搭建个人笔记知识库!

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] 巧用 VScode 网页版 IDE 搭建个人笔记知识库! 描述&#xff1a;最近自己在腾讯云轻量云服务器中部署了一个使用在线 VScode 搭建部署的个人Markdown在线笔记&#xff0c;考虑到在线 VScode 支持终…...

Jupyter Lab 使用

Jupyter Lab 使用详解 Jupyter Lab 是一个基于 Web 的交互式开发环境&#xff0c;提供了比 Jupyter Notebook 更加灵活和强大的用户界面和功能。以下是使用 Jupyter Lab 的详细指南&#xff0c;包括安装、基本使用、设置根目录和扩展功能等内容。 一、Jupyter Lab 安装与启动…...

MyBatis where标签内嵌foreach标签查询报错‘缺失右括号‘或‘命令未正确结束‘

MyBatis <where>标签内嵌<foreach>标签查询报错’缺失右括号’或’命令未正确结束’ <where>标签内嵌<foreach>标签 截取一段脱敏xml&#xff0c;写明大概意思 <select id"queryLogByIds" resultMap"BaseResultMap">SELE…...

重生奇迹MU 群战王牌

圣导师是重生奇迹MU游戏中八大职业之一&#xff0c;拥有风度翩翩、潇洒自如的形象和神一样的实力。无论是刷怪、PK、打boss还是混战&#xff0c;圣导师都表现出压制其他职业的强大气势。因此&#xff0c;这个职业在游戏中备受欢迎&#xff0c;人气非常高。 实力强大的二代隐藏…...

SpinalHDL之VHDL 和 Verilog 生成

本文作为SpinalHDL学习笔记第十六篇&#xff0c;记录使用SpinalHDL代码生成Verilog/VHDL代码的方法。 SpinalHDL学习笔记总纲链接如下&#xff1a; SpinalHDL 学习笔记_spinalhdl blackbox-CSDN博客 目录&#xff1a; 1.从 SpinalHDL 组件生成 VHDL 和 Verilog 2.生成的 VHD…...

c语言中的字符串函数

strstr函数 函数介绍 strstr 用于在一个字符串中查找另一个字符串的首次出现。 我们来看这个函数的参数名字&#xff1a;haysytack&#xff08;干草堆&#xff09;needle&#xff08;针&#xff09;,这个其实就是外国的一句谚语&#xff1a;在干草堆中找一根针&#xff0c;就…...

[AI 大模型] 百度 文心一言

文章目录 [AI 大模型] 百度 文心一言简介模型架构发展新技术和优势API 代码示例 [AI 大模型] 百度 文心一言 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0DwAIh0T-1720667576892)(https://i-blog.csdnimg.cn/direct/283919e5d78b4951ba1ade5dcfc…...

机器学习开源分子生成系列(2)-基于三维形状和静电相似性的DeepFMPO v3D安装及使用

前言 本文是基于 3D 的分子生成方法DeepFMPO v3D的介绍及安装使用。 一、DeepFMPO v3D是什么&#xff1f; github代码介绍文章 在药物发现中&#xff0c;如何寻找具新颖性和结构多样性的候选分子是颇受药物设计科学家关注的问题。通过虚拟筛选的化学空间搜索往往会受限于筛选…...

机器学习-16-分布式梯度提升库XGBoost的应用

参考XGBoost库 1 XGBoost分布式梯度提升库 XGBoost,全称为eXtreme Gradient Boosting,是一个优化的分布式梯度提升库,旨在高效、灵活且便携。它在Gradient Boosting框架下实现了机器学习算法,并广泛用于分类、回归和排序任务。XGBoost之所以受到广泛欢迎,主要归功于它的…...

视觉/AIGC面经->多模态

1.ocr检测如何做?qwen的文本检测是否合理? paligemma: <loc0110><loc0124><loc0224><loc0389> plate ; <loc0244><loc0130><loc0281><loc0430> plate ; <loc0364><loc0820><loc0403><loc0951> pl…...

<数据集>钢板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;1986张 标注数量(xml文件个数)&#xff1a;1986 标注数量(txt文件个数)&#xff1a;1986 标注类别数&#xff1a;7 标注类别名称&#xff1a;[crescent gap, silk spot, water spot, weld line, oil spot, punchin…...

棒板电极流注放电与氩气等离子体仿真的COMSOL研究

棒板电极流注放电&#xff0c; COMSOL&#xff0c;氩气形成的贯穿流注 氩气放电等离子体仿真。在高压实验室里见过那种细金属棒和接地板之间突然爆发的紫色放电吗&#xff1f;那玩意儿专业名称叫棒板电极流注放电。今天咱们用COMSOL扒开这朵"电火花"的外衣&#xff0…...

新手入门指南:在快马平台上用openclaw重启版本实现首个爬虫项目

最近在学习网络爬虫&#xff0c;发现openclaw重启版本对新手特别友好&#xff0c;于是尝试在InsCode(快马)平台上做了一个简单的新闻头条抓取项目。整个过程比想象中顺利&#xff0c;分享下我的学习路径和踩坑经验。 环境准备与库安装 传统爬虫项目最头疼的就是环境配置&#x…...

在Vue3中推荐使用的函数定义方法

const funcName (argName) > {}; 和 function funcName(argName) {} 这两种方式&#xff0c;哪种定义函数比较好一点呢&#xff1f;两种方式各有适用场景&#xff0c;简单总结&#xff1a; 箭头函数 const fn () > {} 没有自己的 this&#xff0c;继承外层作用域的 thi…...

基于Django REST framework的共享充电桩后台管理系统架构设计与实现

1. 为什么选择Django REST framework构建充电桩后台 第一次接触共享充电桩项目时&#xff0c;我对比了Node.js、Spring Boot和Django三个技术栈。最终选择Django REST framework&#xff08;DRF&#xff09;的原因很实在——它用30%的代码量就能实现其他框架80%的功能。特别是在…...

DA14531 实战指南(一)从调试到量产:OTP与Flash的权衡艺术

1. 初识DA14531的存储双刃剑 第一次拿到DA14531开发板时&#xff0c;最让我纠结的就是这个32KB的OTP存储器。就像给你一支只能写一次的钢笔&#xff0c;虽然墨水充足&#xff08;32KB对BLE应用绰绰有余&#xff09;&#xff0c;但每次落笔都要反复斟酌。实际开发中我发现&#…...

Java Swing 实战:手把手教你写一个拼图小游戏(一)

1.前言本文基于 Java Swing 实现带登录注册的拼图小游戏&#xff08;跟随 B 站黑马程序员教程练习&#xff09;&#xff0c;适合 Java 初学者、课设练手使用。本文为系列第一篇&#xff0c;主要讲解项目整体结构、登录界面&#xff08;LoginJFrame&#xff09;和注册界面&#…...

如何让Flash内容重获新生?CefFlashBrowser全方位应用指南

如何让Flash内容重获新生&#xff1f;CefFlashBrowser全方位应用指南 【免费下载链接】CefFlashBrowser Flash浏览器 / Flash Browser 项目地址: https://gitcode.com/gh_mirrors/ce/CefFlashBrowser 随着Adobe Flash Player的正式退役&#xff0c;大量依赖Flash技术的网…...

B站成分检测器深度解析:5大革新特性重塑评论区交互体验

B站成分检测器深度解析&#xff1a;5大革新特性重塑评论区交互体验 【免费下载链接】bilibili-comment-checker B站评论区自动标注成分油猴脚本&#xff0c;主要为原神玩家识别 项目地址: https://gitcode.com/gh_mirrors/bi/bilibili-comment-checker 在B站的海量评论互…...

强制脑机接口:某公司用神经监测防员工摸鱼

在科技伦理与管理方式交织的灰色地带&#xff0c;一则关于某公司计划引入脑机接口技术用于监测员工注意力、防止“摸鱼”的传闻&#xff0c;正在引发轩然大波。这并非科幻电影中的场景&#xff0c;而是随着神经技术快速商业化&#xff0c;正悄然逼近的现实可能。对于身处科技行…...

别再乱改组策略了!深入理解CredSSP更新与远程桌面安全的正确配置姿势

深入解析CredSSP安全机制与远程桌面连接的最佳实践 1. CredSSP协议与加密Oracle漏洞的本质 CredSSP&#xff08;Credential Security Support Provider&#xff09;协议是微软开发的一种身份验证协议&#xff0c;主要用于远程桌面连接等场景下的凭据安全传输。2018年曝光的CVE-…...