当前位置: 首页 > news >正文

pytorch神经网络入门代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 定义神经网络结构
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 设置超参数
input_size = 784  # MNIST数据集的输入大小是28x28=784
hidden_size = 784
num_classes = 10learning_rate = 0.01
num_epochs = 10# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)# 实例化模型
model = SimpleNN(input_size, hidden_size, num_classes)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# 将输入数据转换为一维向量images = images.reshape(-1, 28*28)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))# 测试模型
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, 28*28)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))# 获取模型参数
params = model.parameters()# 打印每个参数的名称和值
for name, param in model.named_parameters():print(f'Parameter name: {name}')print(f'Parameter value: {param}')

以下代码测试正确率为:99.37%

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义适合MNIST数据集的CNN模型
class MNISTCNN(nn.Module):def __init__(self):super(MNISTCNN, self).__init__()# 卷积块 1self.conv_block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 卷积块 2self.conv_block2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 全连接层self.fc_layer = nn.Sequential(nn.Linear(64 * 7 * 7, 512),  # 假设经过前面的卷积和池化后特征图大小为7x7nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(512, 10)  # MNIST有10个类别)def forward(self, x):x = self.conv_block1(x)x = self.conv_block2(x)# 将卷积层输出展平为一维向量x = x.view(x.size(0), -1)# 通过全连接层x = self.fc_layer(x)return x# 创建模型实例
model = MNISTCNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 加载MNIST数据集并预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 使用DataLoader加载批量数据
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 开始训练
num_epochs = 10
for epoch in range(num_epochs):for inputs, labels in train_loader:# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空梯度缓存loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 每个epoch结束时打印损失print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()  # 将模型切换到评估模式(禁用Dropout和BatchNorm等)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total}%')

相关文章:

pytorch神经网络入门代码

import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms# 定义神经网络结构 class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(SimpleNN, self).__init_…...

代码随想录算法训练营第三十四天|860.柠檬水找零 406.根据身高重建队列 452. 用最少数量的箭引爆气球

860.柠檬水找零 链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 细节: 1. 首先根据题意就是只有5.的成本,然后就开始找钱,找钱也是10.和5. 2. 直接根据10 和 5 进行变量定义,然后去循环…...

Ditto:提升剪贴板体验的宝藏软件(复制粘贴效率翻倍、文本处理好助手)

名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录 一、什么是Ditto?二、下载安装三、如…...

【自然语言处理-工具篇】spaCy<2>--模型的使用

前言 之前已经介绍了spaCy的安装,接下来我们要通过下载和加载模型去开始使用spaCy。 下载模型 经过训练的 spaCy 管道可以作为 Python 包安装。这意味着它们是应用程序的一个组件,就像任何其他模块一样。可以使用 spaCy download的命令安装模型,也可以通过将 pip 指向路径或…...

Java之通过Jsch库连接Linux实现文件传输

Java之通过JSch库连接Linux实现文件传输 文章目录 Java之通过JSch库连接Linux实现文件传输1. JSch2. Java通过Jsch连接Linux1. poxm.xml2. 工具类3. 调用案例 1. JSch 官网:JSch - Java Secure Channel (jcraft.com) JSch是SSH2的纯Java实现。 JSch 允许您连接到 ss…...

Nginx七层负载均衡之动静分离

思路: servera:负载均衡服务器 serverb:静态服务器 serverc:动态服务器 serverd:默认服务器 servera(192.168.233.132): # 安装 Nginx 服务器 yum install nginx -y#关闭防火墙和selinux systemctl stop firewalld setenforce 0# 切换到 Nginx 配置文…...

305_C++_定义了一个定时器池 TimerPool 类和相关的枚举类型和结构体

头文件:定义了一个定时器池 TimerPool 类和相关的枚举类型和结构体 #ifndef TIMERPOOL_H #define TIMERPOOL_H #include "rsglobal.h" #include "taskqueue.h" #incl...

大整数因数分解工具——yafu

一、安装 yafu--下载链接 二、配置环境变量,直接从cmd打开 1.找到yafu-x64.exe 所在的文件路径 2.点击设置——系统——系统信息——高级系统设置——环境变量——点击PATH(上下都可以)——新建 添加yafu-x64.exe 所在路径——点击确定 3…...

非关系型数据库(NOSQL)和关系型数据库(SQL)区别详解

前言: 在我们的日常开发中,关系型数据库和非关系型数据库的使用已经是一个成熟的软件产品开发过程中必不可却的存储数据的工具了。那么用了这么久的关系数据库和非关系型数据库你们都知道他们之间的区别了吗?下面我们来详细的介绍一下。 关系…...

7.Cloud-GateWay

0.概述 https://cloud.spring.io/spring-cloud-static/spring-cloud-gateway/2.2.1.RELEASE/reference/html/ 1.入门配置 1.1 POM <!--新增gateway--> <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-sta…...

【Linux】Framebuffer 应用

# 前置知识 LCD 操作原理 在 Linux 系统中通过 Framebuffer 驱动程序来控制 LCD。 Frame 是帧的意思&#xff0c; buffer 是缓冲的意思&#xff0c;这意味着 Framebuffer 就是一块内存&#xff0c;里面保存着一帧图像。 Framebuffer 中保存着一帧图像的每一个像素颜色值&…...

markdown绘制流程图相关代码片段记录

有时候会使用typora来绘制一些流程图&#xff0c;进行编码之类的工作&#xff0c;在网络搜集了一些笔记&#xff0c;做个记录&#xff0c;方便日后进行复习&#xff0c;相关的记录如下&#xff1a; 每次作图时&#xff0c;代码以「graph <布局方向>」开头&#xff0c;如…...

云计算基础-计算虚拟化-CPU虚拟化

CPU指令系统 在CPU的工作原理中&#xff0c;CPU有不同的指令集&#xff0c;如下图&#xff0c;CPU有4各指令集&#xff1a;Ring0-3&#xff0c;指令集是在服务器上运行的所有命令&#xff0c;最终都会在CPU上执行&#xff0c;但是CPU并不是说所有的命令都是一视同仁的&#xf…...

MySQL数据库⑪_C/C++连接MySQL_发送请求

目录 1. 下载库文件 2. 使用库 3. 链接MySQL函数 4. C/C链接示例 5. 发送SQL请求 6. 获取查询结果 本篇完。 1. 下载库文件 要使用C/C连接MySQL&#xff0c;需要使用MySQL官网提供的库。 进入MySQL官网选择适合自己平台的mysql connect库&#xff0c;然后点击下载就行…...

选择排序和快速排序(1)

目录 选择排序 基本思想 选择排序的实现 图片实现 代码实现 快速排序 基本思想 快速排序的实现 图片实现 代码实现 选择排序 基本思想 每一次从待排序的数据元素中选出最小&#xff08;最大&#xff09;的元素&#xff0c;存放在序列的起始位置&#xff0c;直到全部…...

得物面试:Redis用哈希槽,而不是一致性哈希,为什么?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; Redis为何用哈希槽而不用一致性哈希&#xff1f; 最近…...

matlab发送串口数据,并进行串口数据头的添加,我们来看下pwm解析后并通过串口输出的效果

uintt16位的话会在上面前面加上00&#xff0c;16位的话一定是两个字节&#xff0c;一共16位的数据 如果是unint8的话就不会&#xff0c; 注意这里给的是13&#xff0c;但是现实的00 0D&#xff0c;这是大小端的问题&#xff0c;在matlanb里设置&#xff0c;我们就默认用这个模式…...

二分、快排、堆排与双指针

二分 int Binary_Search(vector<int> A,int key){int nA.size();int low0,highn-1,mid;while(low<high){mid(lowhigh)/2;if(A[mid]key)return mid;else if(A[mid]>key)highmid-1;elselowmid1; }return -1; }折半插入排序 ——找到第一个 ≥ \ge ≥tem的元素 voi…...

微信小程序步数返还的时间戳为什么返回的全是1970?

微信小程序步数返还的时间戳为什么返回的全是1970&#xff1f; 将返回的时间 乘以 1000 再 new Date() 转化就对了 微信返回的是秒S单位的&#xff0c;我们要转化为毫秒ms单位&#xff0c;才能进行格式化日期。 微信给我们下了个坑&#xff0c; 参考&#xff1a; https://d…...

Python函数——函数介绍

一、引言 在Python编程中&#xff0c;函数是构建高效代码的关键。通过创建可重用的代码块&#xff0c;我们可以使程序更加清晰、易读且易于维护。在本文中&#xff0c;我们将深入了解Python函数的基本概念及其特性。 二、Python函数的基本概念 函数是一段具有特定功能的代码块…...

FreeRTOS定时器防抖实战:用STM32 HAL库+按键中断,告别按键连击烦恼

FreeRTOS定时器防抖实战&#xff1a;用STM32 HAL库按键中断&#xff0c;告别按键连击烦恼 在嵌入式开发中&#xff0c;按键处理看似简单却暗藏玄机。我曾在一个智能家居项目中遇到这样的尴尬场景&#xff1a;用户按下墙壁开关时&#xff0c;本该只触发一次的动作&#xff0c;由…...

IndexTTS-2-LLM新手教程:从部署到生成,完整流程详解

IndexTTS-2-LLM新手教程&#xff1a;从部署到生成&#xff0c;完整流程详解 1. 快速了解IndexTTS-2-LLM IndexTTS-2-LLM是一款基于大语言模型的智能语音合成系统&#xff0c;能够将文字转换为自然流畅的语音。相比传统语音合成技术&#xff0c;它具有以下特点&#xff1a; 声…...

跨平台文件同步:OpenClaw调用GLM-4.7-Flash智能归类方案

跨平台文件同步&#xff1a;OpenClaw调用GLM-4.7-Flash智能归类方案 1. 为什么需要智能文件同步 作为一个长期在多台设备间切换工作的开发者&#xff0c;我深受文件管理混乱的困扰。Mac上的设计稿、Windows里的会议记录、手机拍摄的参考图&#xff0c;最终都会堆积在某个临时…...

颠覆认知的5个Stagehand实战技巧:突破AI网页自动化瓶颈的进阶策略

颠覆认知的5个Stagehand实战技巧&#xff1a;突破AI网页自动化瓶颈的进阶策略 【免费下载链接】stagehand An AI web browsing framework focused on simplicity and extensibility. 项目地址: https://gitcode.com/GitHub_Trending/stag/stagehand 引言&#xff1a;从工…...

数据治理进阶——解读埃森哲大型央企数字化转型数据治理企业架构建设案例【附全文阅读】

该方案聚焦大型央国企数字化转型&#xff0c;适用于企业高层决策者、IT 部门负责人、业务部门管理者以及对数字化转型感兴趣的专业人士。方案主要内容围绕数字化转型展开&#xff0c;涵盖数据治理、企业架构建设等关键领域。在数字化转型部分&#xff0c;明确其目的是释放禁锢价…...

通过爱毕业AI的智能改写功能,五个方法助你快速降低论文重复率

嘿&#xff0c;大家好&#xff01;我是AI菌。今天咱们来聊聊一个让无数学生头疼的问题&#xff1a;论文重复率飙到30%以上怎么办&#xff1f;别慌&#xff0c;我这就分享5个实用降重技巧&#xff0c;帮你一次搞定&#xff0c;轻松压到合格线以下。这些方法都是我亲身试验过的&a…...

华为交换机-跨Vlan通信的实战配置指南

1. 华为交换机跨VLAN通信的核心原理 第一次接触跨VLAN通信时&#xff0c;我也被那些专业术语搞得一头雾水。直到把整个流程拆解成生活场景&#xff0c;才真正理解其中的奥妙。想象一下&#xff0c;VLAN就像公司里的不同部门&#xff0c;财务部、技术部、市场部各自在独立的办公…...

Kubernetes集群的搭建与DevOps实践(下)- 部署实践篇

需求清单&#xff1a; 100张数据表要迁移&#xff08;还要支持后续动态新增&#xff09; 双链路同步&#xff1a;MySQL到MySQL、MongoDB到PostgreSQL 不能写死配置&#xff0c;要能灵活扩展 工期不到1个月 技术约束&#xff1a; 源环境&#xff08;塔外&#xff09;和目标环境&…...

Java轻量级边缘运行时深度解析(OpenJDK GraalVM Substrate VM在ARM64 IoT设备上的实测压测报告)

第一章&#xff1a;Java轻量级边缘运行时概览与技术定位Java轻量级边缘运行时是面向资源受限边缘设备&#xff08;如工业网关、智能传感器、车载终端&#xff09;设计的精简型JVM执行环境&#xff0c;它在保持Java语言语义兼容性的同时&#xff0c;显著降低内存占用、启动延迟与…...

智能配置黑苹果终极指南:OpCore Simplify一键生成OpenCore EFI完整教程

智能配置黑苹果终极指南&#xff1a;OpCore Simplify一键生成OpenCore EFI完整教程 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify 还在为繁琐的黑苹果…...