当前位置: 首页 > news >正文

【深度学习】2.单层感知机

目标:

实现一个简单的二分类模型的训练过程,通过模拟数据集进行训练和优化,训练目标是使模型能够根据输入特征正确分类数据。

演示:

1.通过PyTorch生成了一个模拟的二分类数据集,包括特征矩阵data_x和对应的标签数据data_y。标签数据通过基于特征的线性组合生成,并转换成独热编码的形式。

import torch
# 从torch库中导入神经网络模块nn,用于构建神经网络模型
from torch import nn
# 导入torch.nn模块中的functional子模块,可用于访问各种函数,例如激活函数
import torch.nn.functional as Fn_item = 1000
n_feature = 2
learning_rate = 0.01
epochs = 100# 生成一个模拟的数据集,其中包括一个随机生成的特征矩阵data_x和相应生成的标签数据data_y。标签数据通过基于特征的线性组合生成,并且转换成独热编码的形式。# 设置随机数生成器的种子为123,通过设置随机种子,我们可以确保在每次运行代码时生成的随机数相同,这对于结果的可重现性非常重要。
torch.manual_seed(123)
# 生成一个随机数矩阵data_x,其中包含n_item行和n_feature列。矩阵中的元素是从标准正态分布(均值为0,标准差为1)中随机采样的。
data_x = torch.randn(size=(n_item, n_feature)).float()
# torch.where(...): 根据条件返回两个张量中相应位置的值。如果条件成立,将为0,否则为1。  long(): 用于将张量转换为Long型数据类型。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5)+0.02 > 0, 0, 1).long()
# 将标签数据data_y转换为独热编码形式,即将每个标签转换为一个相应长度的独热向量
data_y = F.one_hot(data_y)# print(data_x)
# print(data_y)

2.定义了一个简单的二分类模型BinaryClassificationModel,包含一个单层感知器(Single Perceptron)结构,其中使用了一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。

# 定义了一个简单的二分类模型,采用单层感知器的结构,包含一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。这样的模型可以用来对数据集进行二分类任务的预测。# 定义了一个名为BinaryClassificationModel的类,其继承自nn.Module类,这意味着这个类是一个PyTorch模型。
class BinaryClassificationModel(nn.Module):def __init__(self, in_feature):# 调用了父类nn.Module的构造函数,确保正确初始化模型。super(BinaryClassificationModel, self).__init__()"""single perception"""# 这行代码定义了模型的第一层,是一个线性层(Fully Connected Layer)。in_features参数指定输入特征的数量,out_features指定输出特征的数量,这里设置为2表示二分类问题。bias=True表示该层包含偏置项。self.layer_1 = nn.Linear(in_features=in_feature, out_features=2, bias=True)# 定义模型前向传播的方法,即输入数据x通过模型前向计算得到输出。def forward(self, x):# 输入数据x首先通过定义的线性层self.layer_1进行线性变换,然后通过F.sigmoid()函数进行激活函数处理。return F.sigmoid(self.layer_1(x))

3.创建了该二分类模型的实例model、使用随机梯度下降(SGD)优化器opt、以及二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss)。

4.在训练过程中,通过多个epoch和每个样本的批处理(在这里是一次处理一个样本),计算模型预测输出和真实标签之间的损失值,进行反向传播计算梯度,并更新模型参数以最小化损失函数。

# 完成对模型的训练过程,每个epoch中通过优化器进行参数更新,计算损失,反向传播更新梯度。最终我们会得到训练过程中每个epoch的损失值,并可以观察损失的变化情况。# 创建了一个二分类模型实例model,参数n_feature表示输入特征的数量。
model = BinaryClassificationModel(n_feature)
# 创建了一个随机梯度下降(SGD)优化器opt,用于根据计算出的梯度更新模型参数。
opt = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 创建了一个二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss),用于衡量模型输出与真实标签之间的差异。
criteria = nn.BCELoss()for epoch in range(epochs):# 对每个样本进行训练。for step in range(n_item):x = data_x[step]y = data_y[step]# 梯度清零,避免梯度累加影响优化结果。opt.zero_grad()# 将输入特征x通过模型前向传播得到预测输出y_hat。unsqueeze(0)是因为我们的模型期望输入是(batch_size, n_feature)的形式。y_hat = model(x.unsqueeze(0))# 计算预测输出y_hat和真实标签y之间的损失值。loss = criteria(y_hat, y.unsqueeze(0).float())# 反向传播计算梯度。loss.backward()# 根据计算出的梯度更新模型参数。opt.step()print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

5.打印出每个epoch的序号和损失值,用于监控训练过程中损失值的变化情况。

相关文章:

【深度学习】2.单层感知机

目标: 实现一个简单的二分类模型的训练过程,通过模拟数据集进行训练和优化,训练目标是使模型能够根据输入特征正确分类数据。 演示: 1.通过PyTorch生成了一个模拟的二分类数据集,包括特征矩阵data_x和对应的标签数据data_y。标签…...

JS经常碰见的报错问题

语法错误:由于 JavaScript 是一种动态语言,因此编写代码期间可能会出现语法错误。这可能包括拼写错误、漏掉分号或括号等问题。 作用域问题:JavaScript 中存在全局作用域和局部作用域的概念,有时候可能会出现变量作用域混乱导致的…...

纯前端实现截图功能

纯前端实现截图功能 一、插件二、主要代码 一、插件 一、安装html2canvas、vue-cropper npm i html2canvas --save //用于将指定区域转为图片 npm i vue-cropper -S //将图片进行裁剪二、在main.js注册vue-cropper组件 import VueCropper from vue-cropper Vue.use(VueCropper…...

【网络协议】应用层协议--HTTP

文章目录 一、HTTP是什么?二、HTTP协议工作过程三、HTTP协议1. fiddler2. Fiddler抓包的原理3. 代理服务器是什么?4. HTTP协议格式1.1 请求1.2 响应 四、认识HTTP的请求1.认识HTTP请求的方法2.认识请求头(header)3.认识URL3.1 URL是什么&…...

【图书推荐】《Vue.js 3.x+Element Plus从入门到精通(视频教学版)》

配套示例源码与PPT课件下载 百度网盘链接: https://pan.baidu.com/s/1nBQLd9UugetofFKE57BE5g?pwdqm9f 自学能力强的,估计不要书就能看代码学会吧。 内容简介 本书通过对Vue.js(简称Vue)的示例和综合案例的介绍与演练,使读者…...

抖店如何打造出爆品?学好这几招,轻松打爆新品流量

大家好,我是电商花花。 近年来,抖店商家越来越多,而选品,爆品就是我们商家竞争的核心了,谁能选出好的新品,打造出爆品,谁的会赚的多,销量多。 做抖音小店想出单,想赚钱…...

软件需求规范说明模板

每个软件开发组织都会为自己的项目选用一个或多个标准的软件需求规范说明模板。有许多软件需求规范说明模板可以使用(例如ISO/IEC/IEEE2011;Robertson and Robertson2013)。如果你的组织要处理各种类型或规模的项目,例如新的大型系统开发或是对现有系统进行微调&…...

vs2013使用qt Linguist以及tr不生效问题

一、qt Linguist(语言家)步骤流程 1、创建翻译文件,在qt选项中 2.选择对应所需的语言,得到.ts后缀的翻译文件 3.创建.pro文件,并将.ts配置在.pro文件中 3.使用qt Linguist 打开创建好的以.ts为后缀的翻译文件,按图所示…...

Leetcode 3163. String Compression III

Leetcode 3163. String Compression III 1. 解题思路2. 代码实现 题目链接:3163. String Compression III 1. 解题思路 这一题的话就是一个简单的贪婪算法,把相同的字符进行归并,然后按照题目中的表示方法进行表示一下即可。 2. 代码实现…...

Java匿名内部类的使用

演示匿名内部类的使用,很重要 package com.shedu.Inner;/*** 演示匿名内部类的使用*/ public class AnonymousInnerClass {//外部其他类public static void main(String[] args) {Outer04 outer04 new Outer04();outer04.method();} }class Outer04{//外部类priva…...

把自己的垃圾代码发布到官方中央仓库

参考博客:将组件发布到maven中央仓库-CSDN博客 感谢这位博主。但是他的步骤有漏缺,相对进行补充 访问管理页面 网址:Maven Central 新注册账号,或者使用github快捷登录,建议使用github快捷登录 添加命名空间 注意&…...

单机一天轻松300+ 最新微信小程序拼多多+京东全自动掘金项目、

现代互联网经济的发展带来了新型的盈利方式,这种方法通过微信小程序的拼多多和京东进行商品自动巡视,以此给商家带来增加的流量,同时为使用者带来利润。实践这一手段无需复杂操作,用户仅需启动相应程序,商品信息便会被…...

线性回归模型之套索回归

概述 本案例是基于之前的岭回归的案例的。之前案例的完整代码如下: import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import Ridge, LinearRegression from sklearn.datasets import make_regression from sklearn.model_selectio…...

解决文件夹打开出错问题:原因、数据恢复与预防措施

在我们日常使用电脑或移动设备时,有时会遇到一个非常棘手的问题——文件夹打开出错。这种错误可能会让您无法访问重要的文件和数据,给工作和生活带来极大的不便。本文将带您深入了解文件夹打开出错的原因,并提供有效的数据恢复方案&#xff0…...

Spring:面向切面(AOP)

1. 代理模式 二十三种设计模式中的一种,属于结构型模式。它的作用就是通过提供一个代理类,让我们在调用目标方法的时候,不再是直接对目标方法进行调用,而是通过代理类**间接**调用。让不属于目标方法核心逻辑的代码从目标方法中剥…...

本地镜像文件怎么导入docker desktop

docker tag d1134b7b2d5a new_repo:new_tag...

【机器学习-23】关联规则(Apriori)算法:介绍、应用与实现

在现代数据分析中,经常需要从大规模数据集中挖掘有用的信息。关联规则挖掘是一种强大的技术,可以揭示数据中的隐藏关系和规律。本文将介绍如何使用Python进行关联规则挖掘,以帮助您发现数据中的有趣模式。 一、引言 1. 简要介绍关联规则学习…...

Gradle筑基——Gradle Maven仓库管理

基础概念: 1.POM pom:全名Project Object Model 项目对象模型,用来描述当前maven项目发布模块的基础信息 pom主要节点信息如下: 配置描述举例(com.android.tools.build:gradle:4.1.1)groupId组织 / 公司的名称com.…...

c++11:智能指针的种类以及使用场景

指针管理困境 内存释放,指针没有置空;内存泄漏;资源重复释放 怎样解决? RAII 智能指针种类 shared_ptr 实现原理:多个指针指向同一资源,引用计数清零,再调用析构函数释放内存。 使用场景…...

RabbitMQ-默认读、写方式介绍

1、RabbitMQ简介 rabbitmq是一个开源的消息中间件,主要有以下用途,分别是: 应用解耦:通过使用RabbitMQ,不同的应用程序之间可以通过消息进行通信,从而降低应用程序之间的直接依赖性,提高系统的…...

低空经济新蓝海:海事监测无人机技术全解析与应用展望

低空经济新蓝海:海事监测无人机技术全解析与应用展望 引言 大家好!随着“低空经济”被正式列为国家战略性新兴产业,无人机技术的应用版图正以前所未有的速度从我们熟悉的陆地,向更为广阔的海洋延伸。在这片“新蓝海”中&#xff0…...

渗透测试之信息收集指南

目录 信息收集基础 一、域名信息收集 1. WHOIS查询 2. 备案查询 3. 子域名查询 3.1 搜索引擎查询语法 3.2 CT证书查询 3.3 JS文件查询 3.4 网络空间安全搜索引擎 3.5 Python脚本工具 4. 网站信息收集 4.1 网站目录扫描工具 4.4 网站系统等信息收集 二、IP信息收集 1. 域名查询I…...

k3s-ansible高级定制:私有镜像仓库和自定义CNI配置

k3s-ansible高级定制:私有镜像仓库和自定义CNI配置 【免费下载链接】k3s-ansible 项目地址: https://gitcode.com/gh_mirrors/k3s/k3s-ansible K3s-ansible是一个使用Ansible自动化部署轻量级Kubernetes集群k3s的强大工具。本指南将详细介绍如何通过k3s-ans…...

终极Node.js最佳实践指南:2024年102个开发技巧大揭秘

终极Node.js最佳实践指南:2024年102个开发技巧大揭秘 【免费下载链接】nodebestpractices :white_check_mark: The Node.js best practices list (July 2024) 项目地址: https://gitcode.com/GitHub_Trending/no/nodebestpractices Node.js开发者在构建企业级…...

辩题直击:AI是正向生产力?OpenClaw裁员给出答案

近期科技圈的辩论愈演愈烈:AI到底是推动时代的正向生产力,还是引发失业危机的“负作用制造者”?一边是甲骨文凌晨裁员3万人,直言“AI可替代人力”,郑州某软件公司部署OpenClaw后裁撤一半员工,HR哭诉“被一行…...

Python 环境构建艺术:虚拟环境、包管理与开发工具链

# 002、环境构建艺术:虚拟环境、包管理与开发工具链上周帮同事调试一个老项目,问题出得挺典型:本地跑得好好的脚本,放到服务器上就报依赖冲突。日志里赫然一行“numpy版本不匹配导致内存布局错误”,两个人对着屏幕查了…...

芒果文件编码转换工具 非常好用的代码转ANSI转UTF8格式小工具

群里的大佬 写的小工具 试了下很好用 下载链接...

OpenClaw技能市场指南:为千问3.5-9B寻找合适的功能扩展

OpenClaw技能市场指南:为千问3.5-9B寻找合适的功能扩展 1. 为什么需要技能市场 当我第一次在本地部署完OpenClaw并成功接入千问3.5-9B模型时,发现这个组合虽然能完成基础的对话和简单任务,但面对实际工作场景中的复杂需求时总显得力不从心。…...

OpenClaw替代方案:当Kimi-VL-A3B-Thinking不可用时的应急处理

OpenClaw替代方案:当Kimi-VL-A3B-Thinking不可用时的应急处理 1. 为什么需要制定模型故障应对策略 上周五凌晨3点,我被一阵急促的报警声惊醒。手机屏幕上闪烁着OpenClaw的异常通知——我部署的Kimi-VL-A3B-Thinking模型服务突然不可用。这个模型负责处…...

Android Jetpack Compose - 修饰符顺序的影响、Divider(分隔线)、DropdownMenu(下拉菜单)、NavigationBar(导航栏)

一、修饰符顺序的影响 红色背景区域:200 - 50 * 2 100 * 100 dp,点击区域:200 - 50 * 2 100 * 100 dp val context LocalContext.currentBox(Modifier.size(200.dp).padding(50.dp).background(Color.Red).clickable {Toast.makeText(cont…...