比较与torch.optim.Optimizer.step的功能差异

torch.optim.Optimizer.step

torch.optim.Optimizer.step(closure)

更多内容详见torch.optim.Optimizer.step

mindspore.nn.TrainOneStepCell

class mindspore.nn.TrainOneStepCell(
    network,
    optimizer,
    sens=1.0
)((*inputs))

更多内容详见mindspore.nn.TrainOneStepCell

使用方式

PyTorch:是Optimizer这个抽象类的抽象方法,需要由Optimizer的子类继承后具体实现,返回损失值。

MindSpore:是1个类,需要把networkoptimizer作为参数传入,且需要调用construct方法返回损失值。