比较与torch.optim.Optimizer.step的功能差异
torch.optim.Optimizer.step
torch.optim.Optimizer.step(closure)
更多内容详见torch.optim.Optimizer.step。
mindspore.nn.TrainOneStepCell
class mindspore.nn.TrainOneStepCell(
network,
optimizer,
sens=1.0
)((*inputs))
使用方式
PyTorch:是Optimizer这个抽象类的抽象方法,需要由Optimizer的子类继承后具体实现,返回损失值。
MindSpore:是1个类,需要把network和optimizer作为参数传入,且需要调用construct方法返回损失值。
