首页/文章/ 详情

线性神经网络3-线性回归的简洁实现

1天前浏览3















































































import numpy as npimport torchfrom torch.utils import datafrom torch import nn
# 1. 生成数据集def syntheticData(w:torch.Tensor, b:float, numExaples:int)->tuple[torch.Tensor]:  '''  生成y=Xw+b+噪声  '''  X:torch.Tensor = torch.normal(01, (numExaples, len(w)))  y:torch.Tensor = torch.matmul(X, w) + b  y += torch.normal(00.01, y.shape) # 添加噪声数据  return X, y.reshape((-11))trueW:torch.Tensor = torch.tensor([2.0, -3.4])trueB:float = 4.2features, labels = syntheticData(trueW, trueB, 1000)
# 2. 读取数据集def loadArray(dataArrays:torch.Tensor, batchSize:int, isTrain:bool=True):  '''  构造一个Pytorch数据迭代器  dataArrays: 数据集,包括特征数据和标签数据  batchSize:  小批量样本数  isTrain:    表示迭代器对象是否在每个迭代周期内打乱数据  '''  dataSet:data.TensorDataset = data.TensorDataset(*dataArrays)  return data.DataLoader(dataSet, batchSize, shuffle=isTrain)batchSize:int = 10dataIter = loadArray((features, labels), batchSize)
# 3. 定义模型# a. Sequential类将多个层串联在一起。当给定输入数据时,Sequential实例将数据传入到第一层,# 然后将第一层的输出作为第二层的输入# b. 将两个参数传递到nn.Linear中。第一个指定输入特征形状,即2,# 第二个指定输出特征形状,输出特征形状为单个标量,因此为1。net:nn.Sequential = nn.Sequential(nn.Linear(21))
# 4. 初始化模型参数# 通过net[0]选择网络中的第一个图层net[0].weight.data.normal_(0.00.01)net[0].bias.data.fill_(0.0)
# 5. 定义损失函数# 计算均方误差使用的是MSELoss类,也称为平方L2范数。loss = nn.MSELoss()
# 6. 定义优化算法# 指定优化的参数(可通过net.parameters()从我们的模型中获得)# 以及优化算法所需的超参数字典, 小批量随机梯度下降只需要设置lr值。trainer = torch.optim.SGD(net.parameters(), lr=0.03)
# 7. 训练# 模型训练步骤# a. 通过调用net(X)生成预测并计算损失l(前向传播)。# b. 通过进行反向传播来计算梯度。# c. 通过调用优化器来更新模型参数。numEpochs:int = 3for epoch in range(numEpochs):  for X, y in dataIter:    # a. 通过调用net(X)生成预测并计算损失l(前向传播)。    l = loss(net(X), y)    trainer.zero_grad()        # b. 通过进行反向传播来计算梯度。    l.backward()        # c. 通过调用优化器来更新模型参数。    trainer.step()    # 每一次循环后计算全局损失的平均值  l = loss(net(features), labels)  print(f'epoch {epoch+1}, loss {l:f}')
# 完成模型训练后显示训练参数和真实参数间的误差w = net[0].weight.datab = net[0].bias.dataprint(f'w的估计误差: {trueW - w.reshape(trueW.shape)}')print(f'b的估计误差: {trueB - b}')

来源:檐苔
UM
著作权归作者所有,欢迎分享,未经许可,不得转载
首次发布时间:2025-08-26
最近编辑:1天前
青瓦松
硕士 签名征集中
获赞 17粉丝 0文章 45课程 0
点赞
收藏
作者推荐

ANSA二次开发:模型结构树重构(2)

在基于ANSA的二次开发中,许多功能开发都需要利用到模型的结构树,如模型任务分发功能定制,连接(焊点、粘胶、焊缝等)的自动化生成都需要用到结构树。下面介绍三种结构树构造方法并对每种方案的性能进行说明。方法二主要思路:获取原始结构树中的所有节点,并通过关键字Hierarchy读取其原始节点链字符串构建一个一一对应的字典,然后从根节点开始从上向下一个分支一个分支的逐层构建。前期将原始节点链字符串进行保存,获取若由于某种原因(如网格划分)导致结构树破坏,利用该方法再将原始结构树进行复原。针对一个完整的车身(BIW)CAD数据,进行构造,64核电脑用时0.12秒。from ansa import guitkfrom ansa import basefrom typing import List, Tuple, Dict, Uniondeck = base.CurrentDeck()PARTTYPETUPLE:Tuple[str] = ('ANSAPART', 'ANSAGROUP')VIEWCOLNAMETUPLE:Tuple[str] = ('Module Id', 'Name', 'Status') # ListView列名称def setItemText(item:object, name:str, id:str='')->None: ''' 对输入的item设置名称和ID ''' guitk.BCListViewItemSetText(item, VIEWCOLNAMETUPLE.index('Name'), name) if id: guitk.BCListViewItemSetText(item, VIEWCOLNAMETUPLE.index('Module Id'), id) return None def createItem(parentItem:object, part:base.Entity, isTop:bool=False)->object: ''' 功能: 为输入的part在parentItem下创建一个新的item 参数: parentItem object 准备新创建item的父节点,为item或listView类型 part: Entity 一个part或group对象 isTop bool 判断需要创建节点是否为根节点的判断器 返回值: 返回生成的item对象 ''' if not isTop: item:object = guitk.BCListViewItemAddChild(parentItem) else: item:object = guitk.BCListViewAddTopLevelItem(parentItem) if part.ansa_type(deck)==PARTTYPETUPLE[0]: partId: str = part.get_entity_values(deck, ['Module Id', ])['Module Id'] else: partId: str = '' setItemText(item, part._name.strip(), id=partId) guitk.BCListViewItemSetUserData(item, part) return itemdef loadModelTree02()->None: partIter:List[base.Entity] = base.CollectEntitiesI(deck, None, PARTTYPETUPLE) hierToPartDict:Dict[str,List[base.Entity]] = dict() for part in partIter: tree:str = part.get_entity_values(deck, ['Hierarchy', ])['Hierarchy'] tree = tree.strip() try: hierToPartDict[tree].append(part) except KeyError: hierToPartDict[tree] = [part] crtTreeBranch(leftModelView, '', hierToPartDict, isTop=True)def crtTreeBranch(parent:object, name:str, hierToPartDict:Dict[str,List[base.Entity]], isTop:bool=False)->None: children:List[base.Entity] = hierToPartDict.get(name, None) if not children: return None for child in children: item = leftPartToItemDict.get(child, None) if not item: item = createItem(parent, child, isTop=isTop) leftPartToItemDict[child] = item if child.ansa_type(deck)==PARTTYPETUPLE[0]:continue crtTreeBranch(item, f'{name}{child._name.strip()}/', hierToPartDict) return Noneif __name__ == '__main__': # 用于存储结构树数据,实现part和新创建的item一一对应 leftPartToItemDict:Dict[str, List[object]] = dict() mainWindow:object = guitk.BCWindowCreate('结构树构造', guitk.constants.BCOnExitDestroy) mainBox:object = guitk.BCVBoxCreate(mainWindow) leftModelView:object = guitk.BCListViewCreate(mainBox, 3, VIEWCOLNAMETUPLE, True) guitk.BCListViewSetIsRootDecorated(leftModelView, True) guitk.BCListViewSetFilterEnabled(leftModelView, True) guitk.BCListViewSetSelectionMode(leftModelView, guitk.constants.BCMulti) loadModelTree02() guitk.BCShow(mainWindow)未经作者同意,不得转载该文!!!来源:檐苔

未登录
还没有评论
课程
培训
服务
行家
VIP会员 学习计划 福利任务
下载APP
联系我们
帮助与反馈