pytorch classification的.py.zip

  • Betty_2001
    了解作者
  • Python
    开发工具
  • 1KB
    文件大小
  • zip
    文件格式
  • 0
    收藏次数
  • 1 积分
    下载积分
  • 1
    下载次数
  • 2020-04-19 10:29
    上传日期
pytorch 实现mnist手写体分类
pytorch classification的.py.zip
  • pytorch classification的副本.py
    3.5KB
内容介绍
import torch import torchvision import torchvision.transforms as transforms import numpy as np from torch.utils.data import Dataset, DataLoader import torch.nn as nn import torch.nn.functional as F dtype = torch.float32 device = torch.device("cpu") class MnistDataset(torch.utils.data.Dataset): def __init__(self, transform, data, label): super(MnistDataset, self).__init__() self.transform = transform self.images = data self.labels = label def __getitem__(self, idx): img = self.images[idx] img = self.transform(img) label = self.labels[idx] return img, label def __len__(self): return len(self.images) transforms = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) train_data = np.load("train_img.npy") train_label = np.load("train_label.npy") test_data = np.load("test_img.npy") test_label = np.load("test_label.npy") trainset = MnistDataset(transform=transforms, data=train_data, label=train_label) #dataset,readin number each time, random, parallel process #return a iterable trainloader = torch.utils.data.DataLoader( trainset, batch_size=4, shuffle=True, num_workers=0 ) testset = MnistDataset(transform=transforms, data=test_data, label=test_label) testloader = torch.utils.data.DataLoader( testset, batch_size=4, shuffle=False, num_workers=0 ) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 4, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(4, 10, 5) #fully connected self.fc1 = nn.Linear(10 * 4 * 4, 100) self.fc2 = nn.Linear(100, 70) self.fc3 = nn.Linear(70, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 10 * 4 * 4) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=1e-4, momentum=0.9) running_loss = 0.0 for epoch in range(10): for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() inputs = torch.tensor(inputs, dtype=dtype) labels = torch.tensor(labels, dtype=torch.long) outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 torch.save(net.state_dict(), "./cifar_net.pth") net = Net() net.load_state_dict(torch.load("./cifar_net.pth")) class_correct = list(0. for i in range(10)) class_total = list(0. for i in range(10)) classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9') with torch.no_grad(): for data in testloader: images, labels = data images = torch.tensor(images, dtype=dtype) labels = torch.tensor(labels, dtype=torch.long) outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1 for i in range(10): print("Accuracy of %5s : %f %%" % ( classes[i], 100 * class_correct[i] / class_total[i] ))
评论
    相关推荐
    • 基于pytorch的猫狗分类
      基于pytorch实现简单的猫狗分类。采用了全连接网络;可以用来了解数据加载过程,网络搭建、训练过程
    • Pytorch 实现文本分类
      文本分类的标准代码,Pytorch实现 数据集Dataset - IMDB - SST - Trec ### 模型 - FastText - BasicCNN (KimCNN,MultiLayerCNN, Multi-perspective CNN) - InceptionCNN - LSTM (BILSTM, StackLSTM) - LSTM with...
    • Classifier:pytorch分类
      简要介绍了经典的深度学习分类算法, 并使用PyTorch实现了其中的部分算法(粗体): (1998) (2012) (2013.11) (2014.9) (2013.12) (2014.9) (2015.2) (2015.12) (2016.2) (2016.10) (2015.7) ResNet(2015.12) ...
    • pytorch
      pytorch
    • pytorch.zip
      pytorch入门的一些练习程序,包含Tensor的创建,线性分类器,主流CNN模型搭建,二分类网络模型搭建,Optimizer优化器的使用等
    • pytorch
      pytorch
    • ASL:官方Pytorch实施
      多标签分类的不对称损失 | 官方PyTorch实施 伊曼纽尔·本·巴鲁克(Emanuel Ben-Baruch),塔尔·里德尼克(Tal Ridnik),纳达夫·扎米尔(Nadav Zamir),阿萨夫·诺伊(Asaf Noy),伊塔玛·弗里德曼(Itamar ...
    • Pytorch 实现自己的残差网络图片分类
      如果对代码有疑问可以看一下我的博客《Pytorch 实现自己的残差网络图片分类器》和压缩包中的README.docx。也欢迎大家在博客下面提问或者指出文中的错误,谢谢大家。
    • 分类_pytorch
      分类_pytorch
    • pytorch中文文本分类训练数据.rar
      pytorch中文文本分类训练数据