import randomimport torch# 1. 生成数据集def syntheticData(w:torch.Tensor, b:float, numExaples:int)->tuple[torch.Tensor]: ''' 生成y=Xw+b+噪声 ''' X:torch.Tensor = torch.normal(0, 1, (numExaples, len(w))) y:torch.Tensor = torch.matmul(X, w) + b y += torch.normal(0, 0.01, y.shape) # 添加噪声数据 return X, y.reshape((-1, 1))trueW:torch.Tensor = torch.tensor([2.0, -3.4])trueB:float = 4.2features, labels = syntheticData(trueW, trueB, 1000)# 2. 读取数据集def dataIter(batchSize:int, features:torch.Tensor, labels:torch.Tensor): numExamples:int = len(features) indices:list[int] = list(range(numExamples)) # 样本随机读取无特定的顺序 random.shuffle(indices) for idx in range(0, numExamples, batchSize): batchIndices:torch.Tensor = torch.tensor( indices[idx:min(idx+batchSize, numExamples)]) yield features[batchIndices], labels[batchIndices]# 3. 初始化模型参数initW:torch.Tensor = torch.normal(0.0, 0.01, size=(2, 1), requires_grad=True)initB:torch.Tensor = torch.zeros(1, requires_grad=True)# 4. 定义模型def linearRegModel(X:torch.Tensor, w:torch.Tensor, b:torch.Tensor)->torch.Tensor: ''' 线性回归模型 ''' return torch.matmul(X, w) + b# 5. 定义损失函数def squaredLoss(PreY:torch.Tensor, TrueY:torch.Tensor)->torch.Tensor: ''' 均方损失函数 ''' return (PreY-TrueY.reshape(PreY.shape))**2/2.0# 6. 定义优化算法def SGD(params:list[torch.Tensor], lr: float, batchSize:int): ''' 小批量随机梯度下降法 params: 模型参数集 合 lr: 学习率,确定每一步更新的大小 batchSize: 批量样本大小 ''' with torch.no_grad(): for param in params: param -= lr*param.grad/batchSize param.grad.zero_()# 7. 训练lr:float = 0.03 # 学习率numEporchs:int = 3 # 循环次数batchSize: int = 10 # 小批量样本数net = linearRegModelloss = squaredLossfor epoch in range(numEporchs): for X, y in dataIter(batchSize, features, labels): # a. 计算损失 l = loss(net(X, initW, initB), y) # X和y的小批量损失 # b. 反向传播求导 # 因为l形状是(batchSize, 1),而不是一个标量。l中的所有元素被加到一起, # 并以此计算关于[w,b]的梯度 l.sum().backward() # c. 更新参数 SGD([initW,initB], lr, batchSize) # 使用参数的梯度更新参数 # 每一次循环后计算全局损失的平均值 with torch.no_grad(): train_l = loss(net(features, initW, initB), labels) print(f'epoch {epoch+1}, loss {float(train_l.mean()):f}')# 完成模型训练后显示训练参数和真实参数间的误差print(f'w的估计误差: {trueW - initW.reshape(trueW.shape)}')print(f'b的估计误差: {trueB - initB}')来源:檐苔