简介

PyTorch是一个开源的深度学习框架,它由Facebook的人工智能研究实验室(Facebook AI Research, FAIR)开发和维护。该框架提供了灵活的张量计算和动态计算图的功能,使得在深度学习任务中定义和训练神经网络变得更加直观和灵活。

下图是 scikit-learn 官方给的帮助选择算法的图

两个函数

dir():帮助查看某个包里面的内容

help():查看某个函数的用法

数据加载

Dataset 和 Dataloder

Dataset 是一个抽象类,子类需要实现 __getitem__ 方法和 __len__ 方法

path = os.path.join(root_dir, label)

os.path.join 用来拼接两个目录,接收两个字符串,字符串开头和末尾不需要 “\“

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):

def __init__(self,root_dir,label):
self.root_dir = root_dir
self.label = label
self.path =os.path.join(self.root_dir, self.label)
self.img_path = os.listdir(self.path)


def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label,img_name)
img = Image.open(img_item_path)
label = self.label
print(img_name)
return img, label

def __len__(self):
return len(self.img_path)


root_dir = "/home/Disk-2T/suchanghui/dataset/hymenoptera_data/train"
label = "ants"
ants_dataset = MyData(root_dir,label)

img, label = ants_dataset[1]
# img.show()
img.save("/home/Disk-2T/suchanghui/pytorch/1.png")

用自己定义的 MyData 类和不同的数据文件夹创建了多个 ***_dataset 实例,这些实例可以直接相加组成一个大的数据集:

train_dataset = ants_dataset + bees_dataset

Tensorboard 的使用