- class mindspore_rl.utils.BatchWrite[源代码]
写一个list的参数覆盖到目标值。
警告
这是一个实验特性,未来有可能被修改或删除。
- 支持平台:
GPU
CPU
样例:
>>> import mindspore >>> from mindspore import nn >>> from mindspore.common.parameter import Parameter, ParameterTuple >>> from mindspore_rl.utils import BatchWrite >>> class SourceNet(nn.Cell): ... def __init__(self): ... super(SourceNet, self).__init__() ... self.a = Parameter(Tensor(0.5, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0) >>> class DstNet(nn.Cell): ... def __init__(self): ... super(DstNet, self).__init__() ... self.a = Parameter(Tensor(0.1, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1) >>> class Write(nn.Cell): ... def __init__(self, dst, src): ... super(Write, self).__init__() ... self.w = BatchWrite() ... self.dst = ParameterTuple(dst.trainable_params()) ... self.src = ParameterTuple(src.trainable_params()) ... def construct(self): ... success = self.w(self.dst, self.src) ... return success >>> dst_net = DstNet() >>> source_net = SourceNet() >>> nets = nn.CellList() >>> nets.append(dst_net) >>> nets.append(source_net) >>> success = Write(nets[0], nets[1])()
- class mindspore_rl.utils.BatchRead[源代码]
读一个list的参数覆盖到目标值。
警告
这是一个实验特性,未来有可能被修改或删除。
- 支持平台:
GPU
CPU
样例:
>>> import mindspore >>> from mindspore import nn >>> from mindspore.common.parameter import Parameter, ParameterTuple >>> from mindspore_rl.utils import BatchRead >>> class SNet(nn.Cell): ... def __init__(self): ... super(SNet, self).__init__() ... self.a = Parameter(Tensor(0.5, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0) >>> class DNet(nn.Cell): ... def __init__(self): ... super(DNet, self).__init__() ... self.a = Parameter(Tensor(0.1, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1) >>> class Read(nn.Cell): ... def __init__(self, dst, src): ... super(Read, self).__init__() ... self.read = BatchRead() ... self.dst = ParameterTuple(dst.trainable_params()) ... self.src = ParameterTuple(src.trainable_params()) ... def construct(self): ... success = self.read(self.dst, self.src) ... return success >>> dst_net = DNet() >>> source_net = SNet() >>> nets = nn.CellList() >>> nets.append(dst_net) >>> nets.append(source_net) >>> success = Read(nets[0], nets[1])()