import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import typing
from matplotlib_inline import backend_inline
from IPython import display
import plotShow
class Accumulator:
'''
在n个变量上累加
'''
def __init__(self, n:int):
self.data = [0.0]*n
def add(self, *args):
self.data = [a+float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0]*len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class Animator:
'''
在动画中绘制数据
'''
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear',
yscale='linear', fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1, figsize=(3.5,2.5)):
if legend is None:
legend = []
backend_inline.set_matplotlib_formats('svg')
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
if nrows*ncols==1:
self.axes = [self.axes, ]
self.configAxes = lambda: plotShow.setAxes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
self.X, self.Y, self.fmts = None, None, fmts
def add(self, x, y):
if not hasattr(y, '__len__'):
y = [y]
n = len(y)
if not hasattr(x, '__len__'):
x = [x]*n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
self.axes[0].plot(x, y, fmt)
self.configAxes()
display.display(self.fig)
plt.draw()
plt.pause(0.001)
display.clear_output(wait=True)
def show(self):
display.display(self.fig)
def getFashionMnistLabels(labels:typing.Sequence):
'''
返回FashionMnist数据集的文本标签
'''
textLabels = ['t-shirt', 'trouser', 'pullever', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [textLabels[int(i)] for i in labels]
def showImages(imgs:list, numRows:int, numCols:int, titles:list=None, scale:float=1.5):
figsize = (numCols*scale, numRows*scale)
_, axes = plt.subplots(numRows, numCols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
def getDataloaderWorkers()->int:
'''
使用4个进程来读取数据
'''
return 4
def loadDataFashionMnist(batchSize:int, resize=None)->tuple:
'''
下载Fashion-MNIST数据集,然后将其加载到内存中
'''
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnistTrain = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
mnistTest = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
return (data.DataLoader(mnistTrain, batchSize, shuffle=True, num_workers=getDataloaderWorkers()),
data.DataLoader(mnistTest, batchSize, shuffle=False, num_workers=getDataloaderWorkers()))
def softmax(X:torch.Tensor)->torch.Tensor:
'''
1. 对每个项求幂(使用exp);
2. 对每一行求和(小批量中每个样本是一行), 得到每个样本的规范化常数;
3. 将每一行除以其规范化常数, 确保结果的和为1。
'''
XExp:torch.Tensor = torch.exp(X)
partition:torch.Tensor = XExp.sum(1, keepdim=True)
return XExp/partition
def accuracy(yHat:torch.Tensor, y:torch.Tensor)->float:
'''
计算预测正确的数量
'''
if len(yHat.shape)>1 and yHat.shape[1]>1:
yHat = yHat.argmax(axis=1)
cmp = yHat.type(y.dtype)==y
return float(cmp.type(y.dtype).sum())
def evaluateAccuracy(net:torch.nn.Module, dataIter:typing.Tuple[torch.Tensor])->float:
'''
计算在指定数据集上模型的精度
'''
if isinstance(net, torch.nn.Module):
net.eval()
metric = Accumulator(2)
with torch.no_grad():
for X, y in dataIter:
metric.add(accuracy(net(X), y), y.numel())
return metric[0]/metric[1]
def trainEpochCh3(net:torch.nn.Module,
trainIter:typing.Tuple[torch.Tensor],
loss:typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
updater:torch.optim.Optimizer)->typing.Tuple[float]:
'''
训练模型一个周期
'''
if isinstance(net, torch.nn.Module):
net.train()
metric = Accumulator(3)
for X, y in trainIter:
yHat:torch.Tensor = net(X)
l:torch.Tensor = loss(yHat, y)
if isinstance(updater, torch.optim.Optimizer):
updater.zero_grad()
l.mean().backward()
updater.step()
else:
l.sum().backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(yHat, y), y.numel())
return metric[0]/metric[2], metric[1]/metric[2]
def trainCh3(net:torch.nn.Module,
trainIter:data.DataLoader,
testIter:data.DataLoader,
loss:typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
numEpochs: int,
updater: torch.optim.Optimizer)->None:
'''
训练模型
'''
animator = Animator(xlabel='epoch', xlim=[1, numEpochs], ylim=[0.3,0.9],
legend=['train loss', 'train acc', 'test acc'])
for epoch in range(numEpochs):
trainMetrics = trainEpochCh3(net, trainIter, loss, updater)
testAcc = evaluateAccuracy(net, testIter)
animator.add(epoch+1, trainMetrics+(testAcc, ))
trainLoss, trainAcc = trainMetrics
assert trainLoss<0.5, trainLoss
assert trainAcc<=1 and trainAcc>0.7, trainAcc
assert testAcc<=1 and testAcc>0.7, testAcc
def SGD(params:list[torch.Tensor], lr: float, batchSize:int):
'''
小批量随机梯度下降法
params: 模型参数集 合
lr: 学习率,确定每一步更新的大小
batchSize: 批量样本大小
'''
with torch.no_grad():
for param in params:
param -= lr*param.grad/batchSize
param.grad.zero_()
def predictCh3(net, testIter, n=6):
'''
预测标签
'''
for X, y in testIter:
break
trues = getFashionMnistLabels(y)
preds = getFashionMnistLabels(net(X).argmax(axis=1))
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
showImages(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
plt.show()