PyTorch复现GoogleNet学习笔记
一篇简单的学习笔记,实现五类花分类,这里只介绍复现的一些细节
如果想了解更多有关网络的细节,请去看论文《Going Deeper with Convolutions》
简单说明下数据集,下载链接,这里用的数据与AlexNet的那篇是一样的所以不在说明
一、环境准备
可以去看之前的一篇博客,里面写的很详细了,并且推荐了一篇炮哥的环境搭建环境
- Anaconda3(建议使用)
- python=3.6/3.7/3.8
- pycharm (IDE)
- pytorch=1.11.0 (pip package)
- torchvision=0.12.0 (pip package)
- cudatoolkit=11.3
二、模型搭建、训练
1.整体框图
GoogleNet整体框图
两个红色框是表示辅助输出器的位置
inception结构
辅助分类器结构
说明:
GoogleNet作为2014年的ILSVRC比赛的冠军相比VGG,网络参数只有vgg的1/10不到
其创新点:
- 引入了inception结构
- 使用1*1的卷积进行降维以及映射处理
- 添加两个辅助分类器帮助训练
- 丢弃全连接层,使用平均池化层(大大降低网络的参数)
2.net.py
网络整体结构代码

1 import torch.nn as nn 2 import torch 3 import torch.nn.functional as F 4 5 6 class GoogLeNet(nn.Module): 7 def __init__(self, num_classes=1000, aux_logits=True, init_weights=False): 8 super(GoogLeNet, self).__init__() 9 self.aux_logits = aux_logits 10 11 self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 12 self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 13 14 self.conv2 = BasicConv2d(64, 64, kernel_size=1) 15 self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 16 self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 17 18 self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 19 self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 20 self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 21 22 self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 23 self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 24 self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 25 self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 26 self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 27 self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 28 29 self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 30 self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 31 32 if self.aux_logits: 33 self.aux1 = InceptionAux(512, num_classes) 34 self.aux2 = InceptionAux(528, num_classes) 35 36 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 37 self.dropout = nn.Dropout(0.4) 38 self.fc = nn.Linear(1024, num_classes) 39 if init_weights: 40 self._initialize_weights() 41 42 def forward(self, x): 43 # N x 3 x 224 x 224 44 x = self.conv1(x) 45 # N x 64 x 112 x 112 46 x = self.maxpool1(x) 47 # N x 64 x 56 x 56 48 x = self.conv2(x) 49 # N x 64 x 56 x 56 50 x = self.conv3(x) 51 # N x 192 x 56 x 56 52 x = self.maxpool2(x) 53 54 # N x 192 x 28 x 28 55 x = self.inception3a(x) 56 # N x 256 x 28 x 28 57 x = self.inception3b(x) 58 # N x 480 x 28 x 28 59 x = self.maxpool3(x) 60 # N x 480 x 14 x 14 61 x = self.inception4a(x) 62 # N x 512 x 14 x 14 63 if self.training and self.aux_logits: # eval model lose this layer 64 aux1 = self.aux1(x) 65 66 x = self.inception4b(x) 67 # N x 512 x 14 x 14 68 x = self.inception4c(x) 69 # N x 512 x 14 x 14 70 x = self.inception4d(x) 71 # N x 528 x 14 x 14 72 if self.training and self.aux_logits: # eval model lose this layer 73 aux2 = self.aux2(x) 74 75 x = self.inception4e(x) 76 # N x 832 x 14 x 14 77 x = self.maxpool4(x) 78 # N x 832 x 7 x 7 79 x = self.inception5a(x) 80 # N x 832 x 7 x 7 81 x = self.inception5b(x) 82 # N x 1024 x 7 x 7 83 84 x = self.avgpool(x) 85 # N x 1024 x 1 x 1 86 x = torch.flatten(x, 1) 87 # N x 1024 88 x = self.dropout(x) 89 x = self.fc(x) 90 # N x 1000 (num_classes) 91 if self.training and self.aux_logits: # eval model lose this layer 92 return x, aux2, aux1 93 return x 94 95 def _initialize_weights(self): 96 for m in self.modules(): 97 if isinstance(m, nn.Conv2d): 98 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 99 if m.bias is not None: 100 nn.init.constant_(m.bias, 0) 101 elif isinstance(m, nn.Linear): 102 nn.init.normal_(m.weight, 0, 0.01) 103 nn.init.constant_(m.bias, 0) 104 105 106 class Inception(nn.Module): 107 def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 108 super(Inception, self).__init__() 109 110 self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 111 112 self.branch2 = nn.Sequential( 113 BasicConv2d(in_channels, ch3x3red, kernel_size=1), 114 BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小 115 ) 116 117 self.branch3 = nn.Sequential( 118 BasicConv2d(in_channels, ch5x5red, kernel_size=1), 119 # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue 120 # Please see https://github.com/pytorch/vision/issues/906 for details. 121 BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小 122 ) 123 124 self.branch4 = nn.Sequential( 125 nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 126 BasicConv2d(in_channels, pool_proj, kernel_size=1) 127 ) 128 129 def forward(self, x): 130 branch1 = self.branch1(x) 131 branch2 = self.branch2(x) 132 branch3 = self.branch3(x) 133 branch4 = self.branch4(x) 134 135 outputs = [branch1, branch2, branch3, branch4] 136 return torch.cat(outputs, 1) 137 138 139 class InceptionAux(nn.Module): 140 def __init__(self, in_channels, num_classes): 141 super(InceptionAux, self).__init__() 142 self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3) 143 self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4] 144 145 self.fc1 = nn.Linear(2048, 1024) 146 self.fc2 = nn.Linear(1024, num_classes) 147 148 def forward(self, x): 149 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 150 x = self.averagePool(x) 151 # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 152 x = self.conv(x) 153 # N x 128 x 4 x 4 154 x = torch.flatten(x, 1) 155 x = F.dropout(x, 0.5, training=self.training) 156 # N x 2048 157 x = F.relu(self.fc1(x), inplace=True) 158 x = F.dropout(x, 0.5, training=self.training) 159 # N x 1024 160 x = self.fc2(x) 161 # N x num_classes 162 return x 163 164 165 class BasicConv2d(nn.Module): 166 def __init__(self, in_channels, out_channels, **kwargs): 167 super(BasicConv2d, self).__init__() 168 self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 169 self.relu = nn.ReLU(inplace=True) 170 171 def forward(self, x): 172 x = self.conv(x) 173 x = self.relu(x) 174 return x 175 176 if __name__=="__main__": 177 x = torch.rand([1, 3, 224, 224]) 178 model = GoogLeNet(num_classes=5) 179 y = model(x) 180 print(y) 181 182 # 统计模型参数 183 # sum = 0 184 # for name, param in model.named_parameters(): 185 # num = 1 186 # for size in param.shape: 187 # num *= size 188 # sum += num 189 # #print("{:30s} : {}".format(name, param.shape)) 190 # print("total param num {}".format(sum))#total param num 10,318,655
写完后保存,运行可以检查是否报错
如果需要打印模型参数,将代码注释去掉即可,得到googlenet的参数为10,318,655
3.数据划分
这里与AlexNet用的一样
分好后的数据集
运行下面代码将数据按一定比例,划分为训练集和验证集

4.train.py
训练的代码,训练结束后画出训练集和验证集的loss,准确度

1 import json 2 import torch 3 from torch import nn 4 from torchvision import transforms,datasets,utils 5 from torch import optim 6 from torch.optim import lr_scheduler 7 from tqdm import tqdm#用于画进度条 8 from model import GoogLeNet 9 import matplotlib.pyplot as plt 10 import os 11 import sys 12 13 14 15 def main(): 16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 print("using {} device".format(device)) 18 19 data_transform = { 20 "train": transforms.Compose([transforms.RandomResizedCrop(224), 21 transforms.RandomHorizontalFlip(), 22 transforms.ToTensor(), 23 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 24 "val": transforms.Compose([transforms.Resize((224, 224)), 25 transforms.ToTensor(), 26 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 27 28 #训练集以及测试集路径 29 ROOT_TRAIN = 'data/train' 30 ROOT_TEST = 'data/val' 31 32 batch_size = 16 33 34 train_dataset = datasets.ImageFolder(root=ROOT_TRAIN,transform=data_transform["train"]) 35 train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True) 36 37 train_num = len(train_dataset) 38 39 flow_list = train_dataset.class_to_idx#转换维字典,train_dataset里有这个对象 40 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 41 cla_dict = dict((val,key) for key,val in flow_list.items())#键值对转换 42 #{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'} 43 # write dict into json file 44 json_str = json.dumps(cla_dict, indent=4) 45 with open('class_indices.json', 'w') as json_file: 46 json_file.write(json_str) # 保存json文件(好处,方便转换为其它类型数据)用于预测用 47 48 val_dataset = datasets.ImageFolder(root=ROOT_TEST,transform=data_transform["val"]) 49 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 50 51 val_num = len(val_dataset) 52 53 print("using {} images for training, {} images for validation.".format(train_num, 54 val_num)) 55 net = GoogLeNet(num_classes=5,aux_logits=True,init_weights=True) 56 57 #加载预训练模型 58 # weights_path = "save_model/best_model.pth" 59 # assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) 60 # missing_keys, unexpected_keys = net.load_state_dict(torch.load(weights_path,),strict=False) 61 62 net.to(device) 63 loss_fc = nn.CrossEntropyLoss() 64 optimizer = optim.Adam(net.parameters(), lr=0.0003) 65 # 学习率每隔10epoch变为原来的0.1 66 lr_s = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 67 68 #定义训练函数 69 def train(dataloader, net, loss_fn, optimizer,i,epoch): 70 net.train() 71 loss, current,n = 0.0, 0.0,0 72 train_bar = tqdm(train_loader, file=sys.stdout) 73 for batch, (x, y) in enumerate(train_bar): 74 # 前向传播 75 image, y = x.to(device), y.to(device) 76 logits,aux_logits1,aux_logits2 = net(image) 77 loss0 = loss_fn(logits,y) 78 loss1 = loss_fn(aux_logits1,y) 79 loss2 = loss_fn(aux_logits2,y) 80 cur_loss = loss0 + loss1*0.3 + loss2*0.3#在论文中辅助分类器权重为0.3 81 _, pred = torch.max(logits, axis=-1) 82 cur_acc = torch.sum(y == pred) / logits.shape[0] 83 # 反向传播 84 optimizer.zero_grad() # 梯度清零 85 cur_loss.backward() 86 optimizer.step() 87 loss += cur_loss 88 current += cur_acc 89 n +=1 90 train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(i + 1, epoch, cur_loss) 91 train_loss = loss / n 92 train_acc = current / n 93 94 print(f"tran_loss:{train_loss}") 95 print(f"tran_acc:{train_acc}") 96 return train_loss, train_acc 97 98 def val(dataloader,net,loss_fn): 99 #验证模式 100 net.eval() 101 loss, current ,n = 0.0, 0.0,0 102 with torch.no_grad(): 103 val_bar = tqdm(val_loader, file=sys.stdout) 104 for batch,(x,y) in enumerate(val_bar): 105 #前向传播 106 image,y = x.to(device),y.to(device) 107 output = net(image) 108 cur_loss = loss_fn(output,y) 109 _,pred = torch.max(output,axis=-1) 110 cur_acc = torch.sum(y==pred)/output.shape[0] 111 loss += cur_loss 112 current += cur_acc 113 val_bar.desc = "val epoch[{}/{}] loss:{:.3f}".format(i + 1, epoch, cur_loss) 114 n +=1 115 val_loss = loss / n 116 val_acc = current / n 117 print(f"val_loss:{val_loss}") 118 print(f"val_acc:{val_acc}") 119 return val_loss,val_acc 120 121 # 解决中文显示问题 122 plt.rcParams['font.sans-serif'] = ['SimHei'] 123 plt.rcParams['axes.unicode_minus'] = False 124 125 # 画图函数 126 def matplot_loss(train_loss, val_loss): 127 plt.figure() 128 plt.plot(train_loss, label='train_loss') # 画图 129 plt.plot(val_loss, label='val_loss') 130 plt.legend(loc='best') # 图例 131 plt.ylabel('loss', fontsize=12) 132 plt.xlabel('epoch', fontsize=12) 133 plt.title("训练集和验证集loss对比图") 134 plt.savefig('./loss.jpg') 135 136 def matplot_acc(train_acc, val_acc): 137 plt.figure()#声明一个新画布,这样两张图像的结果就不会出现重叠 138 plt.plot(train_acc, label='train_acc') # 画图 139 plt.plot(val_acc, label='val_acc') 140 plt.legend(loc='best') # 图例 141 plt.ylabel('acc', fontsize=12) 142 plt.xlabel('epoch', fontsize=12) 143 plt.title("训练集和验证集acc对比图") 144 plt.savefig('./acc.jpg') 145 146 # 开始训练 147 train_loss_list = [] 148 val_loss_list = [] 149 train_acc_list = [] 150 val_acc_list = [] 151 152 epoch = 60 153 max_acc = 0 154 155 for i in range(epoch): 156 lr_s.step()#学习率优化,10epoch变为原来的0.5 157 158 train_loss,train_acc = train(train_loader,net,loss_fc,optimizer,i,epoch) 159 160 val_loss,val_acc = val(val_loader,net,loss_fc) 161 162 train_loss_list.append(train_loss) 163 train_acc_list.append(train_acc) 164 val_acc_list.append(val_acc) 165 val_loss_list.append(val_loss) 166 167 # 保存最好的模型权重 168 if val_acc > max_acc: 169 folder = 'save_model' 170 if not os.path.exists(folder): 171 os.mkdir('save_model') 172 max_acc = val_acc 173 print(f'save best model,第{i + 1}轮') 174 torch.save(net.state_dict(), 'save_model/best_model.pth') # 保存网络权重 175 # 保存最后一轮 176 if i == epoch - 1: 177 torch.save(net.state_dict(), 'save_model/last_model.pth') # 保存 178 179 print("done") 180 181 #画图 182 matplot_loss(train_loss_list,val_loss_list) 183 matplot_acc(train_acc_list,val_acc_list) 184 if __name__=="__main__": 185 os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 186 main()
最后一轮的结果,总共60轮,3360张训练图片,364张验证图片,用时2h,batchsize=16
训练结束后可以得到训练集和验证集的loss,acc对比图
简单的评估下:可以看到与之前的AlexNet相比,验证集的准确率好多了。
三、模型推理
测试代码,这里用的测试集其实是之前训练时的验证集,本来是要另外创建一个的
这里路径需要自己改到需要推理的图片

1 import os 2 import json 3 import torch 4 from PIL import Image 5 from torchvision import transforms 6 import matplotlib.pyplot as plt 7 from model import GoogLeNet 8 9 def main(): 10 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 12 data_transform = transforms.Compose([ 13 transforms.Resize((224,224)), 14 transforms.ToTensor(), 15 transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) 16 ]) 17 #load image 18 img_path = "data/val/daisy/476856232_7c35952f40_n.jpg" 19 assert os.path.exists(img_path),"file:'{}' dose not exist. ".format(img_path) 20 img = Image.open(img_path) 21 plt.imshow(img) 22 23 #[N, C, H, W]归一化 24 img = data_transform(img) 25 # expand batch dimension 26 img = torch.unsqueeze(img,dim=0) 27 28 # read class_indict 29 json_path = './class_indices.json' 30 assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 31 32 with open(json_path,"r") as f: 33 class_indict = json.load(f) 34 35 #实例化模型 36 model = GoogLeNet(num_classes=5,aux_logits=False).to(device) 37 38 #加载权重 39 weights_path = "save_model/best_model.pth" 40 assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) 41 missing_keys,unexpected_keys = model.load_state_dict(torch.load(weights_path,map_location=device), 42 strict=False) 43 model.eval() 44 with torch.no_grad(): 45 #预测 46 output = torch.squeeze(model(img.to(device))).cpu() 47 predict = torch.softmax(output, dim=0) 48 predict_cla = torch.argmax(predict).numpy() 49 #最大概率结果 50 print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 51 predict[predict_cla].numpy()) 52 #前10个结果 53 plt.title(print_res) 54 for i in range(len(predict)): 55 print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 56 predict[i].numpy())) 57 plt.show() 58 if __name__=="__main__": 59 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 60 main()
运行代码后,对模型进行推理,去网上找几张图片,我这里随便从验证集拿了一张
下面是一张雏菊的照片,以及5类花预测的概率显示(右边)
总结
googlenet还是挺可以的
自己敲一下代码,会学到很多不懂的东西
最后,多看,多学,多试,总有一天你会称为大佬!