PyTorch 基础篇(1):Pytorch 基础
Pytorch 学习开始
入门的材料来自两个地方:
第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。
第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-tutorial,代码写得干净整洁。
目的:我是直接把 Yunjey 的教程的 python 代码挪到 Jupyter Notebook 上来,一方面可以看到运行结果,另一方面可以添加注释和相关资料链接。方便后面查阅。
顺便一题,我的 Pytorch 的版本是 0.4.1
- import torch
- print(torch.version)
- 0.4.1
- # 包
- import torch
- import torchvision
- import torch.nn as nn
- import numpy as np
- import torchvision.transforms as transforms
autograd(自动求导 / 求梯度) 基础案例 1
- # 创建张量(tensors)
- x = torch.tensor(1., requires_grad=True)
- w = torch.tensor(2., requires_grad=True)
- b = torch.tensor(3., requires_grad=True)
- # 构建计算图( computational graph):前向计算
- y = w * x + b # y = 2 * x + 3
- # 反向传播,计算梯度(gradients)
- y.backward()
- # 输出梯度
- print(x.grad) # x.grad = 2
- print(w.grad) # w.grad = 1
- print(b.grad) # b.grad = 1
- tensor(2.)
- tensor(1.)
- tensor(1.)
autograd(自动求导 / 求梯度) 基础案例 2
- # 创建大小为 (10, 3) 和 (10, 2)的张量.
- x = torch.randn(10, 3)
- y = torch.randn(10, 2)
- # 构建全连接层(fully connected layer)
- linear = nn.Linear(3, 2)
- print ('w: ', linear.weight)
- print ('b: ', linear.bias)
- # 构建损失函数和优化器(loss function and optimizer)
- # 损失函数使用均方差
- # 优化器使用随机梯度下降,lr是learning rate
- criterion = nn.MSELoss()
- optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
- # 前向传播
- pred = linear(x)
- # 计算损失
- loss = criterion(pred, y)
- print('loss: ', loss.item())
- # 反向传播
- loss.backward()
- # 输出梯度
- print ('dL/dw: ', linear.weight.grad)
- print ('dL/db: ', linear.bias.grad)
- # 执行一步-梯度下降(1-step gradient descent)
- optimizer.step()
- # 更底层的实现方式是这样子的
- # linear.weight.data.sub_(0.01 * linear.weight.grad.data)
- # linear.bias.data.sub_(0.01 * linear.bias.grad.data)
- # 进行一次梯度下降之后,输出新的预测损失
- # loss的确变少了
- pred = linear(x)
- loss = criterion(pred, y)
- print(‘loss after 1 step optimization: ‘, loss.item())
- w: Parameter containing:
- tensor([[ 0.5180, 0.2238, -0.5470],
- [ 0.1531, 0.2152, -0.4022]], requires_grad=True)
- b: Parameter containing:
- tensor([-0.2110, -0.2629], requires_grad=True)
- loss: 0.8057981729507446
- dL/dw: tensor([[-0.0315, 0.1169, -0.8623],
- [ 0.4858, 0.5005, -0.0223]])
- dL/db: tensor([0.1065, 0.0955])
- loss after 1 step optimization: 0.7932316660881042
从 Numpy 装载数据
- # 创建Numpy数组
- x = np.array([[1, 2], [3, 4]])
- print(x)
- # 将numpy数组转换为torch的张量
- y = torch.from_numpy(x)
- print(y)
- # 将torch的张量转换为numpy数组
- z = y.numpy()
- print(z)
- [[1 2]
- [3 4]]
- tensor([[1, 2],
- [3, 4]])
- [[1 2]
- [3 4]]
输入工作流(Input pipeline)
- # 下载和构造CIFAR-10 数据集
- # Cifar-10数据集介绍:https://www.cs.toronto.edu/~kriz/cifar.html
- train_dataset = torchvision.datasets.CIFAR10(root=’…/…/…/data/’,
- train=True,
- transform=transforms.ToTensor(),
- download=True)
- # 获取一组数据对(从磁盘中读取)
- image, label = train_dataset[0]
- print (image.size())
- print (label)
- # 数据加载器(提供了队列和线程的简单实现)
- train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
- batch_size=64,
- shuffle=True)
- # 迭代的使用
- # 当迭代开始时,队列和线程开始从文件中加载数据
- data_iter = iter(train_loader)
- # 获取一组mini-batch
- images, labels = data_iter.next()
- # 正常的使用方式如下:
- for images, labels in train_loader:
- # 在此处添加训练用的代码
- pass
- Files already downloaded and verified
- torch.Size([3, 32, 32])
- 6
自定义数据集的 Input pipeline
- # 构建自定义数据集的方式如下:
- class CustomDataset(torch.utils.data.Dataset):
- def init(self):
- # TODO
- # 1. 初始化文件路径或者文件名
- pass
- def getitem(self, index):
- # TODO
- # 1. 从文件中读取一份数据(比如使用nump.fromfile,PIL.Image.open)
- # 2. 预处理数据(比如使用 torchvision.Transform)
- # 3. 返回数据对(比如 image和label)
- pass
- def len(self):
- # 将0替换成数据集的总长度
- return 0
- # 然后就可以使用预置的数据加载器(data loader)了
- custom_dataset = CustomDataset()
- train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
- batch_size=64,
- shuffle=True)
- 预训练模型
- # 下载并加载预训练好的模型 ResNet-18
- resnet = torchvision.models.resnet18(pretrained=True)
- # 如果想要在模型仅对Top Layer进行微调的话,可以设置如下:
- # requieres_grad设置为False的话,就不会进行梯度更新,就能保持原有的参数
- for param in resnet.parameters():
- param.requires_grad = False
- # 替换TopLayer,只对这一层做微调
- resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
- # 前向传播
- images = torch.randn(64, 3, 224, 224)
- outputs = resnet(images)
- print (outputs.size()) # (64, 100)
- torch.Size([64, 100])
保存和加载模型
- # 保存和加载整个模型
- torch.save(resnet, ‘model.ckpt’)
- model = torch.load(‘model.ckpt’)
- # 仅保存和加载模型的参数(推荐这个方式)
- torch.save(resnet.state_dict(), ‘params.ckpt’)
- resnet.load_state_dict(torch.load(‘params.ckpt’))
相关文章:
PyTorch 基础篇(1):Pytorch 基础
Pytorch 学习开始 入门的材料来自两个地方: 第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。 第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-t…...
掌握Selenium4:详解各种定位方式
Selenium4中有多种元素定位方式,主要包括以下几种: 通过ID属性定位:根据元素的id属性进行定位。通过name属性定位:当元素没有id属性而有name属性时,可以使用name属性进行元素定位。通过class name定位:可以…...
go-fastfds部署心得
我是windows系统安装 Docker Desktop部署 docker run --name go-fastdfs(任意的一个名称) --privilegedtrue -t -p 3666:8080 -v /data/fasttdfs_data:/data -e GO_FASTDFS_DIR/data sjqzhang/go-fastdfs:lastest docker run:该命令用于运…...
Python第三次练习
Python 一、如何判断一个字符串是否是另一个字符串的子串二、如何验证一个字符串中的每一个字符均在另一个字符串中出现三、如何判定一个字符串中既有数字又有字母四、做一个注册登录系统 一、如何判断一个字符串是否是另一个字符串的子串 实现代码: string1 inp…...
从Java8升级到Java17,特色优化点
从Java8升级到Java17,特色优化点 一、局部变量类型推断二、switch表达式三、文本块四、Records五、模式匹配instanceof六、密封类七、NullPointerException 从Java 8 到 Java 20,Java 已经走过了漫长的道路,自 Java 8 以来,Java 生…...
js实现富文本
当涉及到使用 JavaScript 实现富文本时,一种常见的方法是使用一些现成的富文本编辑器库,比如: Quill:一个功能强大、易于集成的富文本编辑器,支持自定义样式和格式,提供丰富的插件和API。 TinyMCE…...
每日OJ题_算法_双指针②_力扣1089. 复写零
目录 力扣1089. 复写零 解析代码 力扣1089. 复写零 1089. 复写零 - 力扣(LeetCode) 难度 简单 给你一个长度固定的整数数组 arr ,请你将该数组中出现的每个零都复写一遍,并将其余的元素向右平移。 注意:请不要在…...
C++——红黑树
作者:几冬雪来 时间:2023年12月7日 内容:C——红黑树讲解 目录 前言: 红黑树的概念: 红黑树的性质: 红黑树的路径计算: 最长路径和最短路径: AVL树与红黑树的区别ÿ…...
【神化世界】asp网页500内部服务器错误的解决方法
问题解决方案记录 一、问题 在asp网页调试的时候,不小心改错了,好好的页面突然出现如下错误信息了: 二、解决方法 终于找到了问题所在,是sql语句出错造成的,特别记录一下。 正确的写法 sql"select * from mem…...
java面试题6
1.什么是Java中的泛型(Generic)? 答案:泛型是一种参数化类型的机制,在编译时提供类型安全性检查和重用代码的能力。使用泛型可以在编译时检测类型错误,并减少类型转换的需要。 2.Java中的反射(…...
(03)vite 处理 css
文章目录 系列全集vite 处理css流程vite如何解决协同开发,样式重复覆盖的问题?使用less通过配置,更改vite的css默认行为vite 利用postcss样式兼容低版本浏览器 系列全集 (01)vite 从启动服务器开始 (02&am…...
阿里云上传文件出现的问题解决(跨域设置)
跨域设置引起的问题 起因:开通对象存储服务后,上传文件限制在5M 大小,无法上传大文件。 1.查看报错信息 2.分析阿里云服务端响应内容 <?xml version"1.0" encoding"UTF-8"?> <Error><Code>Invali…...
利用JavaFX生成验证码图片
以下是一个基于 JavaFX 的验证码图片生成小程序的示例代码: import javafx.application.Application; import javafx.embed.swing.SwingFXUtils; import javafx.scene.Scene; import javafx.scene.canvas.Canvas; import javafx.scene.canvas.GraphicsContext; import javafx…...
6-55.汽车类的继承
根据给定的汽车类vehicle(包含的数据成员有车轮个数wheels和车重weight)声明,完成其中成员函数的定义,之后再定义其派生类并完成测试。 小车类car是它的派生类,其中包含载人数passenger_load。每个类都有相关数据的输出…...
SCI论文——respectively用法
respectively用于配对两组(三组)事物,表明后一组与前一组按照相同的顺序排列,从而使句意明确。一般是在句子的最后,而且在respectively的前面需要一个逗号“,” 一 、两组事物: 原则是尽可能靠近第二组的…...
解决方案 | 法大大电子签约加速农牧业数字化进程
近年来,我国农业技术得到快速发展,并开发出一批实用的数字农业技术产品,建立了专用网络数字农业技术平台。数字农业是农业现代化的高级阶段,是创新推动农业农村信息化发展的有效手段,也是我国由农业大国迈向农业强国的…...
设计模式之GoF23介绍
深入探讨设计模式:构建可维护、可扩展的软件架构 一、设计模式的背景1.1 什么是设计模式1.2 设计模式的历史 二、设计模式的分类2.1 创建型模式2.2 结构型模式2.3 行为型模式 三、七大设计原则四、设计模式关系结论 :rocket: :rocket: :rocket: 在软件开发领域&…...
UDP协议实现群聊
服务端 import java.io.*; import java.net.*; import java.util.ArrayList; public class Server{public static ServerSocket server_socket;public static ArrayList<Socket> socketListnew ArrayList<Socket>(); public static void main(String []args){try{…...
lombok原理 @Slf4j 怎么生成get set log
Lombok是一种Java库,通过注解的方式提供了许多有用的功能,包括生成Getter、Setter、日志等。Slf4j注解是Lombok中的一种,它用于自动生成日志记录器(Logger)。 下面简要介绍一下Lombok的原理,以及Slf4j注解…...
【目标检测】进行实时检测计数时,在摄像头窗口显示实时计数个数
这里我是用我本地训练的基于yolov8环境的竹签计数模型,在打开摄像头窗口增加了实时计数显示的代码,可以直接运行,大家可以根据此代码进行修改,其底层原理时将检测出来的目标的个数显示了出来。 该项目链接:【目标检测…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
Webpack性能优化:构建速度与体积优化策略
一、构建速度优化 1、升级Webpack和Node.js 优化效果:Webpack 4比Webpack 3构建时间降低60%-98%。原因: V8引擎优化(for of替代forEach、Map/Set替代Object)。默认使用更快的md4哈希算法。AST直接从Loa…...
STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...
HubSpot推出与ChatGPT的深度集成引发兴奋与担忧
上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...
Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement
Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...
