当前位置: 首页 > news >正文

PyTorch实现逻辑回归

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:  # 停止条件break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止条件break

相关文章:

PyTorch实现逻辑回归

最终效果 先看下最终效果&#xff1a; 这里用一条直线把二维平面上不同的点分开。 生成随机数据 #创建训练数据 x torch.rand(10,1)*10 #shape(10,1) y 2*x (5 torch.randn(10,1))#构建线性回归参数 w torch.randn((1))#随机初始化w&#xff0c;要用到自动梯度求导 b …...

什么是FPGA原型验证?

EDA工具的使用主要分为设计、验证和制造三大类。验证工作贯穿整个芯片设计流程&#xff0c;可以说芯片的验证阶段占据了整个芯片开发的大部分时间。从芯片需求定义、功能设计开发到物理实现制造&#xff0c;每个环节都需要进行大量的验证。 现如今验证方法也越来越多&#xff…...

基于VUE3+Layui从头搭建通用后台管理系统(前端篇)十四:系统设置模块相关功能实现

一、本章内容 本章使用已实现的公共组件实现系统管理中的系统设置模块相关功能,包括菜单管理、角色管理、日志管理、用户管理、系统配置、数据字典等。 1. 详细课程地址: 待发布 2. 源码下载地址: 待发布 二、界面预览 三、开发视频 3.1 B站视频地址:...

使用Visual Studio(VS)创建空项目的Win32桌面应用程序【main函数入口变WinMain】

前言 在Visual Studio中直接新建Windows桌面应用程序会有很多多余的代码生成&#xff0c;本文将提供从空项目创建Win32项目的方法&#xff0c;解决新建空项目直接使用WinMain代码编译报错的问题 例如&#xff1a;LNK2019 &#xff1a;无法解析的外部符号 参考博客&#xff1…...

基于自动化脚本批量上传依赖到nexus内网私服

前言 因为某些原因某些企业希望私服是不能连接外网的&#xff0c;所以需要某些开源依赖需要我们手动导入到nexus中&#xff0c;尽管nexus为我们提供了web页面。但是一个个手动导入显然是一个庞大的工程。 对此我们就不妨基于脚本的方式实现这一过程。 预期效果 笔者本地仓库…...

Linux中ps命令使用指南

目录 1 前言2 ps命令的含义和作用3 ps命令的基本使用4 常用选项参数5 一些常用情景5.1 查看系统中的所有进程&#xff08;标准语法&#xff09;5.2 使用 BSD 语法查看系统中的所有进程5.3 打印进程树5.4 获取线程信息5.5 获取安全信息5.6 查看以 root 用户身份&#xff08;实际…...

PHP开发语言中,网页端常用的标签

在PHP开发语言中&#xff0c;网页端常用的标签包括以下几种&#xff1a; <html>&#xff1a;用于定义整个HTML文档。<head>&#xff1a;用于定义文档的头部&#xff0c;包含元数据、样式表和脚本等。<title>&#xff1a;用于定义文档的标题&#xff0c;显示…...

Java 入门第四篇 集合

Java 入门第四篇 集合 一&#xff0c;什么是集合 在Java中&#xff0c;集合&#xff08;Collection&#xff09;是一种用于存储和操作一组对象的容器类。它提供了一系列的方法和功能&#xff0c;用于方便地管理和操作对象的集合。集合框架是Java中非常重要和常用的一部分&…...

VBA技术资料MF93:将多个Excel表插入PowerPoint不同位置

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。我的教程一共九套&#xff0c;分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的入门&#xff0c;到…...

STM32 MCU的易坑点收集

IIC配置中的Clock No Stretch Mode Clock Stretch Mode时钟延长模式&#xff1a; 时钟延长是一个术语&#xff0c;某些从设备可以把时钟线拉低&#xff0c;主设备发现自己释放时钟线之后时钟线还没有变成高电平&#xff0c;就会停止发送数据&#xff0c;然后等待从设备释放时钟…...

Vue3项目filter.js组件封装

1、element-plus(el-table)修改table的行样式 export function elTableRowClassName({ row, rowIndex }) {if (rowIndex % 2 ! 0) {return default-row} }2、时间戳转换格式 export function parseTimeFilter(dateTime, dateType) {if (dateTime || dateTime undefined ||…...

Linux: pwd命令查看当前工作目录

pwd 是 Linux 和其他类 Unix 操作系统中的一个命令&#xff0c;用于显示当前工作目录的绝对路径。 语法 pwd 描述 pwd 是 "print working directory" 的缩写&#xff0c;它用于打印当前工作目录的完整路径。这对于确定当前目录位置非常有用&#xff0c;特别是在嵌…...

【深度学习】PHP操作mysql数据库总结

一.PHP数据库的扩展分类 1.MySQL 扩展是针对 MySQL 4.1.3 或更早版本设计的&#xff0c;是 PHP 与 MySQL数据库交互的早期扩展。由于其不支持 MySQL 数据库服务器的新特性&#xff0c;且安全性差&#xff0c;在项目开发中不建议使用&#xff0c;可用 MySQLi 扩展代替。 2.MySQ…...

【送书活动】探究AIGC、AGI、GPT和人工智能大模型

文章目录 前言01 《ChatGPT 驱动软件开发》推荐语 02 《ChatGPT原理与实战》推荐语 03 《神经网络与深度学习》推荐语 04 《AIGC重塑教育》推荐语 05 《通用人工智能》推荐语 后记赠书活动 前言 人工智能技术在过去几年中发展迅猛&#xff0c;得益于大数据、云计算、深度学习等…...

Apple Find My「查找」认证芯片找哪家,认准伦茨科技ST17H6x芯片

深圳市伦茨科技有限公司&#xff08;以下简称“伦茨科技”&#xff09;发布ST17H6x Soc平台。成为继Nordic之后全球第二家取得Apple Find My「查找」认证的芯片厂家&#xff0c;该平台提供可通过Apple Find My认证的Apple查找&#xff08;Find My&#xff09;功能集成解决方案。…...

java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value

问题描述 使用Springcloudalibaba的nacos作为配置中心&#xff0c;服务启动时报错&#xff1a; java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value java.lang.IllegalArgumentException: Param ‘serviceName’ is illegal, serviceName is …...

自动机器学习是什么?概念及应用

自动机器学习 (Auto Machine Learning) 的应用和方法 随着众多企业在大量场景中开始采用机器学习&#xff0c;前后期处理和优化的数据量及规模指数级增长。企业很难雇用充足的人手来完成与高级机器学习模型相关的所有工作&#xff0c;因此机器学习自动化工具是未来人工智能 (A…...

el-date-picker限制选择7天内禁止内框选择

需求&#xff1a;elementPlus时间段选择框需要满足&#xff1a;①最多选7天时间。②不能手动输入。 <el-date-picker v-model"timeArrange" focus"timeEditable" :editable"false" type"datetimerange" range-separator"至&qu…...

Navicat 技术指引 | 适用于 GaussDB 分布式的调试器

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结…...

人工智能导论习题集(3)

第五章&#xff1a;不确定性推理 题1题2题3题4题5题6题7题8 题1 题2 题3 题4 题5 题6 题7 题8...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明&#xff1a;server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)

宇树机器人多姿态起立控制强化学习框架论文解析 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架&#xff08;一&#xff09; 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...

LabVIEW双光子成像系统技术

双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制&#xff0c;展现出显著的技术优势&#xff1a; 深层组织穿透能力&#xff1a;适用于活体组织深度成像 高分辨率观测性能&#xff1a;满足微观结构的精细研究需求 低光毒性特点&#xff1a;减少对样本的损伤…...