《Python实战进阶》No34:卷积神经网络(CNN)图像分类实战
第34集:卷积神经网络(CNN)图像分类实战
2025年3月28日更新 增加了 CNN和AI大模型关系的说明。
2025年3月29日更新了代码,优化损失系数曲线可视化。
详细环境配置依赖和可一次性复制的完整代码见文末。
摘要
最近大模型推陈出新迭代简直眼花缭乱,其中多模态成为主流和趋势,以通义千问为例,Qwen2.5-VL-32B-Instruct 3月底刚刚出来没有两天,支持全模态的Qwen2.5-VL-Omni 模型魔搭链接这两天就出来了,看来各大AI厂商都是憋足了劲要镀金一季度产品报告。在大模型向全模态发展的趋势背景下,卷积神经网络(CNN)作为AI多模态大模型的基础技术之一,重要性不言而喻… …
卷积神经网络(CNN)是计算机视觉领域的核心技术,特别擅长处理图像分类任务。本集将简要介绍CNN和AI大模型的关系,然后结合实战深入讲解 CNN 的核心组件(卷积层、池化层、全连接层),并演示如何使用 PyTorch 构建一个完整的 CNN 模型,在 CIFAR-10 经典图像分类数据集上实现图像分类,所有代码在Python3.11.5版本环境验证跑通,提供程序输出截图。我们还将探讨数据增强和正则化技术(如 Dropout 和 BatchNorm)对模型性能的影响。

卷积神经网络与AI大模型的关系
在众多深度学习架构中,卷积神经网络(Convolutional Neural Networks, CNNs)和大规模预训练模型(通常称为“AI大模型”)是两个关键的技术支柱。尽管它们的应用场景和技术特点有所不同,但二者之间存在着密切的联系。
一、技术原理:CNN 是 AI 大模型的重要组成部分
-
卷积神经网络的核心特性
卷积神经网络是一种专门设计用于处理具有网格状拓扑结构数据的深度学习模型,例如图像、视频等。其核心特性包括:- 局部感知野:通过卷积核提取局部特征。
- 权值共享:同一卷积核在整个输入数据上共享参数,减少计算量。
- 池化操作:通过降采样保留重要信息,降低计算复杂度。
这些特性使得 CNN 在计算机视觉任务中表现出色,尤其是在图像分类、目标检测和图像分割等领域。
-
AI 大模型的多模态特性
AI 大模型通常指参数量巨大的深度学习模型,例如 GPT 系列、BERT 和 Vision Transformer(ViT)。这些模型通过大规模数据训练,具备强大的泛化能力和跨领域迁移能力。AI 大模型的特点包括:- 超大规模参数量:支持更复杂的特征表示。
- 多模态融合:能够同时处理文本、图像、音频等多种类型的数据。
- 自监督学习:利用无标签数据进行预训练,提升模型性能。
-
CNN 与 AI 大模型的结合
在 AI 大模型的发展过程中,CNN 的思想和技术被广泛借鉴和融合。例如:- 图像处理模块:许多多模态大模型(如 CLIP、DALL·E)在处理图像数据时仍然依赖于 CNN 的架构或其变体。
- 特征提取器:在一些混合模型中,CNN 被用作底层特征提取器,为后续的 Transformer 模块提供高质量的输入特征。
- 轻量化设计:为了提高效率,一些大模型在特定任务中采用轻量化的 CNN 结构,以平衡计算资源和性能。
二、应用场景:从单一任务到多模态融合
-
CNN 的传统应用场景
CNN 最初主要用于计算机视觉任务,例如:- 图像分类(ImageNet)
- 目标检测(YOLO、Faster R-CNN)
- 图像分割(U-Net、Mask R-CNN)
这些任务通常针对单一模态(图像)进行优化,专注于局部特征的提取和空间结构的理解。
-
AI 大模型的多模态扩展
AI 大模型则更多地关注跨模态任务,例如:- 图文生成:DALL·E 和 Stable Diffusion 利用文本描述生成高质量图像。
- 图文匹配:CLIP 模型通过联合训练文本和图像数据,实现跨模态检索。
- 语音与文本转换:Whisper 模型可以同时处理语音识别和翻译任务。
-
CNN 在 AI 大模型中的角色
尽管 AI 大模型在多模态任务中表现出色,但 CNN 仍然是不可或缺的一部分。例如:- 在图文生成任务中,CNN 负责图像的空间特征提取。
- 在跨模态检索任务中,CNN 提供了对图像内容的高效编码。
- 在实时应用中,CNN 的轻量化版本(如 MobileNet)被用于加速推理过程。
三、发展趋势:从独立架构到深度融合
-
CNN 的局限性与改进方向
尽管 CNN 在图像处理领域取得了巨大成功,但它也存在一定的局限性:- 全局建模能力不足:CNN 主要关注局部特征,难以捕捉长距离依赖关系。
- 对小样本学习的支持有限:需要大量标注数据才能达到最佳性能。
针对这些问题,研究人员提出了多种改进方案,例如引入注意力机制(Attention)和图神经网络(Graph Neural Networks, GNNs)。
-
AI 大模型的挑战与机遇
AI 大模型虽然功能强大,但也面临以下挑战:- 计算资源需求高:训练和部署大模型需要昂贵的硬件支持。
- 可解释性差:模型内部的决策过程难以理解。
- 数据隐私问题:大规模数据收集可能引发隐私争议。
-
CNN 与 AI 大模型的融合趋势
未来,CNN 和 AI 大模型的融合将更加紧密,主要体现在以下几个方面:- 混合架构设计:结合 CNN 的局部特征提取能力和 Transformer 的全局建模能力,构建更高效的多模态模型。
- 轻量化与边缘计算:通过优化 CNN 和 Transformer 的结构,在资源受限的设备上实现高性能推理。
- 自监督学习与迁移学习:利用 CNN 和 Transformer 的互补优势,开发适用于小样本场景的通用模型。
卷积神经网络和 AI 大模型在技术原理、应用场景和发展趋势上既有区别又有联系。CNN 以其卓越的局部特征提取能力奠定了计算机视觉领域的基础,而 AI 大模型则通过多模态融合和自监督学习实现了更广泛的智能化应用。在未来,随着深度学习技术的不断进步,CNN 和 AI 大模型将进一步深度融合,共同推动人工智能技术的发展,为各行各业带来更大的价值。
总结关键词:卷积神经网络(CNN)、AI 大模型、多模态融合、自监督学习、混合架构

核心概念和知识点
1. CNN 的核心组件
- 卷积层:通过滤波器(Filter)提取局部特征(如边缘、纹理)。
- 池化层:通过下采样(如最大池化)减少参数数量,增强特征鲁棒性。
- 全连接层:将提取的特征映射到分类标签。
2. 数据增强技术
- 常用方法:随机水平翻转、随机裁剪、色彩抖动(调整亮度、对比度)。
- 作用:增加训练数据的多样性,防止过拟合。
3. 过拟合与正则化
- 过拟合:模型在训练集表现优异,但在测试集性能下降。
- 正则化方法:
- Dropout:随机关闭部分神经元,减少对特定特征的依赖。
- BatchNorm:标准化每层的输入,加速训练并提升泛化能力。
4. 与 AI 大模型的关联
- 基础架构角色:CNN 是许多大模型(如 ResNet、EfficientNet)的核心组件。
- 迁移学习:通过预训练的 CNN 模型(如 ImageNet 权重)快速适应新任务。
- 自监督学习:利用 CNN 提取特征,用于无标签数据的预训练。
实战案例:使用 CNN 分类 CIFAR-10 数据集
背景
CIFAR-10 包含 60,000 张 32x32 彩色图像,分为 10 个类别(飞机、汽车、鸟类等)。我们将构建一个轻量级 CNN 模型,结合数据增强和正则化技术提升分类性能。

代码实现
1. 环境准备
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F #3月29日新增
2. 数据加载和预处理
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")def load_data():# 数据增强transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 加载CIFAR-10数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)return trainset, testset
3. 构建CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一个卷积块self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(32)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout1 = nn.Dropout(0.25)# 第二个卷积块self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(64)self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.bn4 = nn.BatchNorm2d(64)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout2 = nn.Dropout(0.25)# 第三个卷积块self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn5 = nn.BatchNorm2d(128)self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.bn6 = nn.BatchNorm2d(128)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout3 = nn.Dropout(0.25)# 全连接层self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout4 = nn.Dropout(0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# 第一个卷积块x = self.pool1(F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))))x = self.dropout1(x)# 第二个卷积块x = self.pool2(F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(x)))))))x = self.dropout2(x)# 第三个卷积块x = self.pool3(F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(x)))))))x = self.dropout3(x)# 全连接层x = x.view(-1, 128 * 4 * 4)x = self.dropout4(F.relu(self.fc1(x)))x = self.fc2(x)return x
4. 训练和评估
def train_model(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0batch_losses = [] # 用于记录每批数据的损失for i, data in enumerate(trainloader):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()batch_losses.append(loss.item()) # 记录每批的损失_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if (i + 1) % 100 == 0:print(f'Batch [{i + 1}], Loss: {running_loss/100:.4f}, 'f'Acc: {100.*correct/total:.2f}%')running_loss = 0.0return batch_losses # 返回所有批次的损失值def evaluate_model(model, testloader, device):model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()accuracy = 100. * correct / totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracy
5. 可视化训练过程
def plot_training_history(train_losses, test_accuracies):plt.figure(figsize=(12, 4))# 绘制训练损失 改为英文不使用中文plt.subplot(1, 2, 1)plt.plot(train_losses)plt.title('Training Loss Curve')plt.xlabel('Batch')plt.ylabel('Loss Value')# 绘制测试准确率plt.subplot(1, 2, 2)plt.plot(test_accuracies)plt.title('Test Accuracy Curve')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.tight_layout()plt.show()
6. 定义主函数和程序入口
def main():# 设置超参数batch_size = 128epochs = 50learning_rate = 0.001# 加载数据print("正在加载数据...")trainset, testset = load_data()trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)# 创建模型print("正在创建模型...")model = CNN().to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型print("开始训练模型...")train_losses = []test_accuracies = []for epoch in range(epochs):print(f'\nEpoch {epoch + 1}/{epochs}')batch_losses = train_model(model, trainloader, criterion, optimizer, device)train_losses.extend(batch_losses) # 将本epoch的所有批次损失添加到训练损失列表accuracy = evaluate_model(model, testloader, device)test_accuracies.append(accuracy)# 保存模型torch.save(model.state_dict(), 'cifar10_cnn.pth')# 可视化训练过程print("\n绘制训练历史...")plot_training_history(train_losses, test_accuracies)if __name__ == "__main__":main()
程序输出结果:
使用设备: cuda
正在加载数据...
正在创建模型...
开始训练模型...Epoch 1/50
Batch [100], Loss: 1.9564, Acc: 27.57%
Batch [200], Loss: 1.6755, Acc: 32.42%
Batch [300], Loss: 1.5559, Acc: 35.44%
测试集准确率: 48.49%
。。。
。。。
Epoch 49/50
Batch [100], Loss: 0.4992, Acc: 82.85%
Batch [200], Loss: 0.4891, Acc: 83.20%
Batch [300], Loss: 0.4730, Acc: 83.39%
测试集准确率: 85.98%Epoch 50/50
Batch [100], Loss: 0.4758, Acc: 83.62%
Batch [200], Loss: 0.4895, Acc: 83.46%
Batch [300], Loss: 0.4888, Acc: 83.39%
测试集准确率: 86.50%绘制训练历史...
程序输出图像:3月29日更新

总结
通过本集的学习,我们掌握了 CNN 的核心组件和正则化技术,并通过 CIFAR-10 图像分类任务验证了模型的有效性。CNN 的卷积层和池化层能够有效提取图像特征,而数据增强与 Dropout/BatchNorm 的结合显著提升了模型的泛化能力。
扩展思考
1. 迁移学习提升模型性能
- 使用预训练模型(如 ResNet-18)作为特征提取器,仅微调最后几层。
- 代码示例:
import torchvision.models as models resnet = models.resnet18(pretrained=True) # 冻结卷积层 for param in resnet.parameters():param.requires_grad = False # 替换最后的全连接层 resnet.fc = nn.Linear(resnet.fc.in_features, 10)
2. 自监督学习的潜力
- 自监督学习通过无标签数据预训练模型(如通过图像旋转预测任务),可在小数据集上取得更好的效果。
- 例如,使用 MoCo 框架预训练 CNN 编码器。
专栏链接:(Python实战进阶)
下期预告:No35:循环神经网络(RNN)时间序列预测
环境依赖文件(里面部分是多余的,因为几个程序的依赖混在一个环境中,只要安装代码前面import中提到的模块):
accelerate==1.5.2
addict==2.4.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.14
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.5
attrs==25.3.0
av==14.2.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
cloudpickle==3.1.1
colorama==0.4.6
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
datasets==3.4.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
dill==0.3.8
distro==1.9.0
einops==0.8.1
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.13.1
fonttools==4.56.0
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2024.6.1
gym==0.26.2
gym-notices==0.0.8
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.29.3
idna==3.10
ipykernel==6.29.5
ipython==9.0.2
ipython_pygments_lexers==1.1.1
isoduration==20.11.0
jedi==0.19.2
Jinja2==3.1.6
jiter==0.9.0
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
mistune==3.1.3
modelscope==1.24.0
mpmath==1.3.0
multidict==6.2.0
multiprocess==0.70.16
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
notebook==7.3.3
notebook_shim==0.2.4
numpy==2.1.2
openai==1.68.2
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
pillow==11.0.0
platformdirs==4.3.7
prometheus_client==0.21.1
prompt_toolkit==3.0.50
propcache==0.3.1
protobuf==6.30.2
psutil==7.0.0
pure_eval==0.2.3
pyarrow==19.0.1
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
pygame==2.6.1
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
python-json-logger==3.3.0
pytz==2025.2
pywin32==310
pywinpty==2.0.15
PyYAML==6.0.2
pyzmq==26.3.0
qwen-vl-utils==0.0.10
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
safetensors==0.5.3
Send2Trash==1.8.3
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
stack-data==0.6.3
sympy==1.13.1
terminado==0.18.1
tiktoken==0.9.0
tinycss2==1.4.0
tokenizers==0.21.1
torch==2.6.0+cu126
torchaudio==2.6.0+cu126
torchvision==0.21.0+cu126
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
transformers==4.50.1
transformers-stream-generator==0.0.5
types-python-dateutil==2.9.0.20241206
typing_extensions==4.13.0
tzdata==2025.2
uri-template==1.3.0
urllib3==2.3.0
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
xxhash==3.5.0
yarl==1.18.3
完整代码:(可一次性复制)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")def load_data():# 数据增强transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 加载CIFAR-10数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)return trainset, testsetclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一个卷积块self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(32)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout1 = nn.Dropout(0.25)# 第二个卷积块self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(64)self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.bn4 = nn.BatchNorm2d(64)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout2 = nn.Dropout(0.25)# 第三个卷积块self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn5 = nn.BatchNorm2d(128)self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.bn6 = nn.BatchNorm2d(128)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.dropout3 = nn.Dropout(0.25)# 全连接层self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout4 = nn.Dropout(0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# 第一个卷积块x = self.pool1(F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))))x = self.dropout1(x)# 第二个卷积块x = self.pool2(F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(x)))))))x = self.dropout2(x)# 第三个卷积块x = self.pool3(F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(x)))))))x = self.dropout3(x)# 全连接层x = x.view(-1, 128 * 4 * 4)x = self.dropout4(F.relu(self.fc1(x)))x = self.fc2(x)return xdef train_model(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0batch_losses = [] # 用于记录每批数据的损失for i, data in enumerate(trainloader):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()batch_losses.append(loss.item()) # 记录每批的损失_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if (i + 1) % 100 == 0:print(f'Batch [{i + 1}], Loss: {running_loss/100:.4f}, 'f'Acc: {100.*correct/total:.2f}%')running_loss = 0.0return batch_losses # 返回所有批次的损失值def evaluate_model(model, testloader, device):model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()accuracy = 100. * correct / totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracydef plot_training_history(train_losses, test_accuracies):plt.figure(figsize=(12, 4))# 绘制训练损失 改为英文不使用中文plt.subplot(1, 2, 1)plt.plot(train_losses)plt.title('Training Loss Curve')plt.xlabel('Batch')plt.ylabel('Loss Value')# 绘制测试准确率plt.subplot(1, 2, 2)plt.plot(test_accuracies)plt.title('Test Accuracy Curve')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.tight_layout()plt.show()def main():# 设置超参数batch_size = 128epochs = 50learning_rate = 0.001# 加载数据print("正在加载数据...")trainset, testset = load_data()trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)# 创建模型print("正在创建模型...")model = CNN().to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型print("开始训练模型...")train_losses = []test_accuracies = []for epoch in range(epochs):print(f'\nEpoch {epoch + 1}/{epochs}')batch_losses = train_model(model, trainloader, criterion, optimizer, device)train_losses.extend(batch_losses) # 将本epoch的所有批次损失添加到训练损失列表accuracy = evaluate_model(model, testloader, device)test_accuracies.append(accuracy)# 保存模型torch.save(model.state_dict(), 'cifar10_cnn.pth')# 可视化训练过程print("\n绘制训练历史...")plot_training_history(train_losses, test_accuracies)if __name__ == "__main__":main()
相关文章:
《Python实战进阶》No34:卷积神经网络(CNN)图像分类实战
第34集:卷积神经网络(CNN)图像分类实战 2025年3月28日更新 增加了 CNN和AI大模型关系的说明。 2025年3月29日更新了代码,优化损失系数曲线可视化。 详细环境配置依赖和可一次性复制的完整代码见文末。 摘要 最近大模型推陈出新迭…...
嵌入式Linux网络编程:UNIX Domain Socket进程间通信(IPC)
嵌入式Linux网络编程:UNIX Domain Socket进程间通信(IPC) 【本文代码已在Linux平台验证通过】 一、UNIX Domain Socket核心优势 1.1 本地IPC方案对比 特性UNIX Domain Socket管道(Pipe)消息队列(Message Queue)共享内存(Shared Memory)跨进…...
【qt】 布局器
参考博客:https://blog.csdn.net/Fdog_/article/details/107522283 目录 布局管理器概念常见的布局管理器及特点🔵QHBoxLayout水平布局🔵QVBoxLayout垂直布局 🔵QGridLayout网格布局 🔵QFormLayout表单布局 QT 高级布…...
Hosts文件与DNS的关系:原理、应用场景与安全风险
目录 引言 Hosts文件与DNS的基本概念 2.1 什么是Hosts文件? 2.2 什么是DNS? Hosts文件与DNS的关系 Hosts文件的应用场景 4.1 本地开发与测试 4.2 屏蔽广告与恶意网站 4.3 绕过DNS污染或劫持 Hosts文件的优势 5.1 解析速度快 5.2 不受DNS缓存影…...
VMware Windows Tools 存在认证绕过漏洞(CVE-2025-22230)
漏洞概述 博通公司(Broadcom)近日修复了 VMware Windows Tools 中存在的一个高危认证绕过漏洞,该漏洞编号为 CVE-2025-22230(CVSS 评分为 9.8)。VMware Windows Tools 是一套实用程序套件,可提升运行在 VM…...
pnpm 依赖升级终极指南:从语义化版本控制到 Monorepo 全局更新的企业级实践
要使用 pnpm 更新所有依赖包,可以通过以下命令实现: 1. 更新所有依赖到符合语义化版本的范围 pnpm update该命令会根据 package.json 中定义的版本范围(如 ^1.0.0 或 ~2.3.4)更新依赖包到最新兼容版本,但不会突破版本…...
Sentinel[超详细讲解]-2
异常处理 默认情况下,Sentinel 会抛出 BlockException 异常,如果希望自定义异常,则可以使用 SentinelResource 注解的 blockHandler 属性。 1、自定义异常处理 BlockExceptionHandler 自定义异常处理类实现 BlockExceptionHandler 接口&#…...
【问题解决】Linux安装conda修改~/.bashrc配置文件后,root 用户下显示 -bash-4.2#
问题描述 在Linux安装conda下的python环境时候,修改了~/.bashrc文件,修改完成后,再次进入服务器后,登录时候显示的不是正常的[rootlocalhost ~]#,而是-bash-4.2# 原因分析: 网上原因有:/root下…...
优化webpack打包体积思路
Webpack 打包过大的问题通常会导致页面加载变慢,影响用户体验。可以从代码优化、依赖优化、构建优化等多个角度入手来减少打包体积: 代码优化 (1)按需加载(代码拆分) ① 路由懒加载 如果你的项目使用 Vu…...
RabbitMQ 技术详解:异步消息通信的核心原理与实践
这里写目录标题 RabbitMQ 技术详解:异步消息通信的核心原理与实践一、RabbitMQ 本质剖析核心架构组件 二、核心功能与应用场景主要作用典型应用场景 三、工作流程深度解析消息传递流程关键协议机制 四、Java 实现示例1. 依赖配置(Maven)2. 消…...
CF每日5题Day4(1400)
好困,感觉很累,今天想赶紧写完题早睡。睡眠不足感觉做题都慢了。 1- 1761C 构造 void solve(){int n;cin>>n;vector<vector<int>>a(n1);forr(i,1,n){//保证每个集合不同a[i].push_back(i);}forr(i,1,n){string s;cin>>s;forr(…...
LLM架构解析:NLP基础(第一部分)—— 模型、核心技术与发展历程全解析
本专栏深入探究从循环神经网络(RNN)到Transformer等自然语言处理(NLP)模型的架构,以及基于这些模型构建的应用程序。 本系列文章内容: NLP自然语言处理基础(本文)词嵌入࿰…...
k近邻算法K-Nearest Neighbors(KNN)
算法核心 KNN算法的核心思想是“近朱者赤,近墨者黑”。对于一个待分类或预测的样本点,它会查找训练集中与其距离最近的K个样本点(即“最近邻”)。然后根据这K个最近邻的标签信息来对当前样本进行分类或回归。 在分类任务中&#…...
Dubbo(21)如何配置Dubbo的注册中心?
在分布式系统中,注册中心是一个关键组件,用于服务的注册和发现。Dubbo 支持多种注册中心,包括 ZooKeeper、Nacos、Consul、Etcd 等。下面详细介绍如何配置 Dubbo 的注册中心,以 ZooKeeper 为例。 配置步骤 引入依赖:…...
【Android15 ShellTransitions】(九)结束动画+Android原生ANR问题分析
finishTransition这部分的内容不多,并且我个人的实际工作中很少接触这块,因此我之前都觉得没有必要专门开一篇去分析最后留下的这一丁点儿的动画流程。但是最近碰到了一个google原生ANR问题,正好是和这块相关的,也让我意识到了fin…...
如何让DeepSeek-R1在内网稳定运行并实现随时随地远程在线调用
前言:最近,国产AI圈里的新星——Deepseek,简直是火到不行。但是,你是不是已经对那些千篇一律的手机APP和网页版体验感到腻味了?别急,今天就带你解锁一个超炫的操作:在你的Windows电脑上本地部署…...
STM32通用定时器结构框图
STM32单片机快速入门 通用定时器框图 TIM9和TIM12 通用定时器框图 TIM9和TIM12 (二) 通用定时器框图...
How to install vmware workstation pro on Linux mint 22
概述 VMware 是一家专注于虚拟化技术和云计算解决方案的全球领先软件公司,成立于1998年,总部位于美国加州。它的核心技术是通过“虚拟化”将一台物理计算机的硬件资源(如CPU、内存、存储等)分割成多个独立的虚拟环境(…...
深度学习 Deep Learning 第11章 实用方法论
深度学习 Deep Learning 第11章 实用方法论 章节概述 本章深入探讨了机器学习在实际应用中的方法论,强调了从确定目标到逐步优化的系统性过程。在机器学习项目中,明确的目标和性能指标是指导整个开发过程的关键。通过建立初始的端到端系统,…...
【常用的中间件】
中间件(Middleware)是位于客户端和服务器之间的软件层,用于处理客户端请求和服务器响应之间的各种任务。中间件可以提供多种功能,如负载均衡、消息队列、缓存、身份验证等。以下是常用的中间件及其作用: 1. 消息队列中…...
如何排查C++程序的CPU占用过高的问题
文章目录 可能的原因程序设计的BUG系统资源问题恶意软件硬件问题 通常步骤一个简单的问题代码在windows平台上如何排查Windows Process ExplorerWinDBG 在Linux平台如何排查使用TOP GDBPerf 可能的原因 程序设计的BUG 有死循环低效算法与数据结构滥用自旋锁频繁的系统调用&a…...
个人学习编程(3-29) leetcode刷题
最后一个单词的长度: 思路:跳过末尾的空格,可以从后向前遍历 然后再利用 while(i>0 && s[i] ! ) 可以得到字符串的长度, int lengthOfLastWord(char* s) {int length 0;int i strlen(s) - 1; //从字符串末尾开始//…...
Linux云计算SRE-第二十一周
构建单节点prometheus,部署node exporter和mongo exporter。构建kibana大盘。包含主机PU使用率,主机MEM使用率,主机网络包速度。mongo db大盘,包含节点在线状态,读操作延迟等 一、实验环境准备 - 节点信息࿱…...
无人机,云台参数设置,PWM输出控制云台俯仰
目录 1、云台与飞控的连接 2、PX4飞控控制云台,QGC地面站的设置 3、遥控器映射通道设置 4、其他设置 4.1、COM_PREARM_MODE,预解锁模式 4.2、RC9_DZ ,遥控器通道死区设置 1、云台与飞控的连接 首先确定一下,设置飞控第几路…...
EtherCAT转ProfiNet协议转换网关构建西门子PLC与海克斯康机器人的冗余通信链路
一、案例背景 某电子制造企业的5G通信模块组装线,采用西门子S7-1200PLC(ProfiNet主站)进行产线调度,而精密组装工序由3台海克斯康工业机器人(EtherCAT从站)完成。由于协议差异,机器人动作与PLC…...
Android R adb remount 调用流程
目的:调查adb remount 与adb shell进去后执行remount的差异 调试方法:添加log编译adbd,替换system\apex\com.android.adbd\bin\adbd 一、调查adb remount实现 关键代码:system\core\adb\daemon\services.cpp unique_fd daemon_service_to…...
网络中常用协议
一, TCP协议 TCP(Transmission Control Protocol,传输控制协议)是互联网核心协议之一,位于传输层,为应用层提供可靠的、面向连接的数据传输服务。 1. TCP的核心特点 特性说明面向连接通信前需通过三次握手建立连接&a…...
自动驾驶04:点云预处理03
点云组帧 感知算法人员在完成点云的运动畸变补偿后,会发现一个问题:激光雷达发送的点云数据包中的点云数量其实非常少,完全无法用来进行后续感知和定位层面的处理工作。 此时,感知算法人员就需要对这些数据包进行点云组帧的处理…...
Linux内核软中断分析
一、软中断类型 在Linux内核中,中断处理分为上半部(硬中断)和下半部。上半部负责快速响应硬件事件,而下半部用于处理耗时任务,避免阻塞系统。下半部有三种机制:软中断(Softirq)、小任…...
Linux修改默认shell为zsh
一、修改模型shell为zsh 1、检查当前使用的shell echo $SHELL 2、检查当前系统支持的shell cat /etc/shells# 输出结果显示如下: """ /bin/sh /bin/bash /usr/bin/sh /usr/bin/bash /bin/csh /bin/tcsh /usr/bin/csh /usr/bin/tcsh /usr/bin/zsh…...
