PyTorch 基础篇(1):Pytorch 基础
Pytorch 学习开始
入门的材料来自两个地方:
第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。
第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-tutorial,代码写得干净整洁。
目的:我是直接把 Yunjey 的教程的 python 代码挪到 Jupyter Notebook 上来,一方面可以看到运行结果,另一方面可以添加注释和相关资料链接。方便后面查阅。
顺便一题,我的 Pytorch 的版本是 0.4.1
- import torch
- print(torch.version)
- 0.4.1
- # 包
- import torch
- import torchvision
- import torch.nn as nn
- import numpy as np
- import torchvision.transforms as transforms
autograd(自动求导 / 求梯度) 基础案例 1
- # 创建张量(tensors)
- x = torch.tensor(1., requires_grad=True)
- w = torch.tensor(2., requires_grad=True)
- b = torch.tensor(3., requires_grad=True)
- # 构建计算图( computational graph):前向计算
- y = w * x + b # y = 2 * x + 3
- # 反向传播,计算梯度(gradients)
- y.backward()
- # 输出梯度
- print(x.grad) # x.grad = 2
- print(w.grad) # w.grad = 1
- print(b.grad) # b.grad = 1
- tensor(2.)
- tensor(1.)
- tensor(1.)
autograd(自动求导 / 求梯度) 基础案例 2
- # 创建大小为 (10, 3) 和 (10, 2)的张量.
- x = torch.randn(10, 3)
- y = torch.randn(10, 2)
- # 构建全连接层(fully connected layer)
- linear = nn.Linear(3, 2)
- print ('w: ', linear.weight)
- print ('b: ', linear.bias)
- # 构建损失函数和优化器(loss function and optimizer)
- # 损失函数使用均方差
- # 优化器使用随机梯度下降,lr是learning rate
- criterion = nn.MSELoss()
- optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
- # 前向传播
- pred = linear(x)
- # 计算损失
- loss = criterion(pred, y)
- print('loss: ', loss.item())
- # 反向传播
- loss.backward()
- # 输出梯度
- print ('dL/dw: ', linear.weight.grad)
- print ('dL/db: ', linear.bias.grad)
- # 执行一步-梯度下降(1-step gradient descent)
- optimizer.step()
- # 更底层的实现方式是这样子的
- # linear.weight.data.sub_(0.01 * linear.weight.grad.data)
- # linear.bias.data.sub_(0.01 * linear.bias.grad.data)
- # 进行一次梯度下降之后,输出新的预测损失
- # loss的确变少了
- pred = linear(x)
- loss = criterion(pred, y)
- print(‘loss after 1 step optimization: ‘, loss.item())
- w: Parameter containing:
- tensor([[ 0.5180, 0.2238, -0.5470],
- [ 0.1531, 0.2152, -0.4022]], requires_grad=True)
- b: Parameter containing:
- tensor([-0.2110, -0.2629], requires_grad=True)
- loss: 0.8057981729507446
- dL/dw: tensor([[-0.0315, 0.1169, -0.8623],
- [ 0.4858, 0.5005, -0.0223]])
- dL/db: tensor([0.1065, 0.0955])
- loss after 1 step optimization: 0.7932316660881042
从 Numpy 装载数据
- # 创建Numpy数组
- x = np.array([[1, 2], [3, 4]])
- print(x)
- # 将numpy数组转换为torch的张量
- y = torch.from_numpy(x)
- print(y)
- # 将torch的张量转换为numpy数组
- z = y.numpy()
- print(z)
- [[1 2]
- [3 4]]
- tensor([[1, 2],
- [3, 4]])
- [[1 2]
- [3 4]]
输入工作流(Input pipeline)
- # 下载和构造CIFAR-10 数据集
- # Cifar-10数据集介绍:https://www.cs.toronto.edu/~kriz/cifar.html
- train_dataset = torchvision.datasets.CIFAR10(root=’…/…/…/data/’,
- train=True,
- transform=transforms.ToTensor(),
- download=True)
- # 获取一组数据对(从磁盘中读取)
- image, label = train_dataset[0]
- print (image.size())
- print (label)
- # 数据加载器(提供了队列和线程的简单实现)
- train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
- batch_size=64,
- shuffle=True)
- # 迭代的使用
- # 当迭代开始时,队列和线程开始从文件中加载数据
- data_iter = iter(train_loader)
- # 获取一组mini-batch
- images, labels = data_iter.next()
- # 正常的使用方式如下:
- for images, labels in train_loader:
- # 在此处添加训练用的代码
- pass
- Files already downloaded and verified
- torch.Size([3, 32, 32])
- 6
自定义数据集的 Input pipeline
- # 构建自定义数据集的方式如下:
- class CustomDataset(torch.utils.data.Dataset):
- def init(self):
- # TODO
- # 1. 初始化文件路径或者文件名
- pass
- def getitem(self, index):
- # TODO
- # 1. 从文件中读取一份数据(比如使用nump.fromfile,PIL.Image.open)
- # 2. 预处理数据(比如使用 torchvision.Transform)
- # 3. 返回数据对(比如 image和label)
- pass
- def len(self):
- # 将0替换成数据集的总长度
- return 0
- # 然后就可以使用预置的数据加载器(data loader)了
- custom_dataset = CustomDataset()
- train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
- batch_size=64,
- shuffle=True)
- 预训练模型
- # 下载并加载预训练好的模型 ResNet-18
- resnet = torchvision.models.resnet18(pretrained=True)
- # 如果想要在模型仅对Top Layer进行微调的话,可以设置如下:
- # requieres_grad设置为False的话,就不会进行梯度更新,就能保持原有的参数
- for param in resnet.parameters():
- param.requires_grad = False
- # 替换TopLayer,只对这一层做微调
- resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
- # 前向传播
- images = torch.randn(64, 3, 224, 224)
- outputs = resnet(images)
- print (outputs.size()) # (64, 100)
- torch.Size([64, 100])
保存和加载模型
- # 保存和加载整个模型
- torch.save(resnet, ‘model.ckpt’)
- model = torch.load(‘model.ckpt’)
- # 仅保存和加载模型的参数(推荐这个方式)
- torch.save(resnet.state_dict(), ‘params.ckpt’)
- resnet.load_state_dict(torch.load(‘params.ckpt’))
相关文章:
PyTorch 基础篇(1):Pytorch 基础
Pytorch 学习开始 入门的材料来自两个地方: 第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。 第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-t…...
掌握Selenium4:详解各种定位方式
Selenium4中有多种元素定位方式,主要包括以下几种: 通过ID属性定位:根据元素的id属性进行定位。通过name属性定位:当元素没有id属性而有name属性时,可以使用name属性进行元素定位。通过class name定位:可以…...
go-fastfds部署心得
我是windows系统安装 Docker Desktop部署 docker run --name go-fastdfs(任意的一个名称) --privilegedtrue -t -p 3666:8080 -v /data/fasttdfs_data:/data -e GO_FASTDFS_DIR/data sjqzhang/go-fastdfs:lastest docker run:该命令用于运…...
Python第三次练习
Python 一、如何判断一个字符串是否是另一个字符串的子串二、如何验证一个字符串中的每一个字符均在另一个字符串中出现三、如何判定一个字符串中既有数字又有字母四、做一个注册登录系统 一、如何判断一个字符串是否是另一个字符串的子串 实现代码: string1 inp…...
从Java8升级到Java17,特色优化点
从Java8升级到Java17,特色优化点 一、局部变量类型推断二、switch表达式三、文本块四、Records五、模式匹配instanceof六、密封类七、NullPointerException 从Java 8 到 Java 20,Java 已经走过了漫长的道路,自 Java 8 以来,Java 生…...
js实现富文本
当涉及到使用 JavaScript 实现富文本时,一种常见的方法是使用一些现成的富文本编辑器库,比如: Quill:一个功能强大、易于集成的富文本编辑器,支持自定义样式和格式,提供丰富的插件和API。 TinyMCE…...
每日OJ题_算法_双指针②_力扣1089. 复写零
目录 力扣1089. 复写零 解析代码 力扣1089. 复写零 1089. 复写零 - 力扣(LeetCode) 难度 简单 给你一个长度固定的整数数组 arr ,请你将该数组中出现的每个零都复写一遍,并将其余的元素向右平移。 注意:请不要在…...
C++——红黑树
作者:几冬雪来 时间:2023年12月7日 内容:C——红黑树讲解 目录 前言: 红黑树的概念: 红黑树的性质: 红黑树的路径计算: 最长路径和最短路径: AVL树与红黑树的区别ÿ…...
【神化世界】asp网页500内部服务器错误的解决方法
问题解决方案记录 一、问题 在asp网页调试的时候,不小心改错了,好好的页面突然出现如下错误信息了: 二、解决方法 终于找到了问题所在,是sql语句出错造成的,特别记录一下。 正确的写法 sql"select * from mem…...
java面试题6
1.什么是Java中的泛型(Generic)? 答案:泛型是一种参数化类型的机制,在编译时提供类型安全性检查和重用代码的能力。使用泛型可以在编译时检测类型错误,并减少类型转换的需要。 2.Java中的反射(…...
(03)vite 处理 css
文章目录 系列全集vite 处理css流程vite如何解决协同开发,样式重复覆盖的问题?使用less通过配置,更改vite的css默认行为vite 利用postcss样式兼容低版本浏览器 系列全集 (01)vite 从启动服务器开始 (02&am…...
阿里云上传文件出现的问题解决(跨域设置)
跨域设置引起的问题 起因:开通对象存储服务后,上传文件限制在5M 大小,无法上传大文件。 1.查看报错信息 2.分析阿里云服务端响应内容 <?xml version"1.0" encoding"UTF-8"?> <Error><Code>Invali…...
利用JavaFX生成验证码图片
以下是一个基于 JavaFX 的验证码图片生成小程序的示例代码: import javafx.application.Application; import javafx.embed.swing.SwingFXUtils; import javafx.scene.Scene; import javafx.scene.canvas.Canvas; import javafx.scene.canvas.GraphicsContext; import javafx…...
6-55.汽车类的继承
根据给定的汽车类vehicle(包含的数据成员有车轮个数wheels和车重weight)声明,完成其中成员函数的定义,之后再定义其派生类并完成测试。 小车类car是它的派生类,其中包含载人数passenger_load。每个类都有相关数据的输出…...
SCI论文——respectively用法
respectively用于配对两组(三组)事物,表明后一组与前一组按照相同的顺序排列,从而使句意明确。一般是在句子的最后,而且在respectively的前面需要一个逗号“,” 一 、两组事物: 原则是尽可能靠近第二组的…...
解决方案 | 法大大电子签约加速农牧业数字化进程
近年来,我国农业技术得到快速发展,并开发出一批实用的数字农业技术产品,建立了专用网络数字农业技术平台。数字农业是农业现代化的高级阶段,是创新推动农业农村信息化发展的有效手段,也是我国由农业大国迈向农业强国的…...
设计模式之GoF23介绍
深入探讨设计模式:构建可维护、可扩展的软件架构 一、设计模式的背景1.1 什么是设计模式1.2 设计模式的历史 二、设计模式的分类2.1 创建型模式2.2 结构型模式2.3 行为型模式 三、七大设计原则四、设计模式关系结论 :rocket: :rocket: :rocket: 在软件开发领域&…...
UDP协议实现群聊
服务端 import java.io.*; import java.net.*; import java.util.ArrayList; public class Server{public static ServerSocket server_socket;public static ArrayList<Socket> socketListnew ArrayList<Socket>(); public static void main(String []args){try{…...
lombok原理 @Slf4j 怎么生成get set log
Lombok是一种Java库,通过注解的方式提供了许多有用的功能,包括生成Getter、Setter、日志等。Slf4j注解是Lombok中的一种,它用于自动生成日志记录器(Logger)。 下面简要介绍一下Lombok的原理,以及Slf4j注解…...
【目标检测】进行实时检测计数时,在摄像头窗口显示实时计数个数
这里我是用我本地训练的基于yolov8环境的竹签计数模型,在打开摄像头窗口增加了实时计数显示的代码,可以直接运行,大家可以根据此代码进行修改,其底层原理时将检测出来的目标的个数显示了出来。 该项目链接:【目标检测…...
Python:操作 Excel 折叠
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...
令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...
使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台
🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...
Netty从入门到进阶(二)
二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架,用于…...
深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用
文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
通过 Ansible 在 Windows 2022 上安装 IIS Web 服务器
拓扑结构 这是一个用于通过 Ansible 部署 IIS Web 服务器的实验室拓扑。 前提条件: 在被管理的节点上安装WinRm 准备一张自签名的证书 开放防火墙入站tcp 5985 5986端口 准备自签名证书 PS C:\Users\azureuser> $cert New-SelfSignedCertificate -DnsName &…...
es6+和css3新增的特性有哪些
一:ECMAScript 新特性(ES6) ES6 (2015) - 革命性更新 1,记住的方法,从一个方法里面用到了哪些技术 1,let /const块级作用域声明2,**默认参数**:函数参数可以设置默认值。3&#x…...
