class mindspore_rl.utils.BatchWrite[源代码]

写一个list的参数覆盖到目标值。

警告

  • 这是一个实验特性,未来有可能被修改或删除。

支持平台:

GPU CPU

样例:

>>> import mindspore
>>> from mindspore import nn
>>> from mindspore.common.parameter import Parameter, ParameterTuple
>>> from mindspore_rl.utils import BatchWrite
>>> class SourceNet(nn.Cell):
...   def __init__(self):
...     super(SourceNet, self).__init__()
...     self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
...     self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0)
>>> class DstNet(nn.Cell):
...   def __init__(self):
...     super(DstNet, self).__init__()
...     self.a = Parameter(Tensor(0.1, mstype.float32), name="a")
...     self.dense = nn.Dense(in_channels=16, out_channels=1)
>>> class Write(nn.Cell):
...   def __init__(self, dst, src):
...     super(Write, self).__init__()
...     self.w = BatchWrite()
...     self.dst = ParameterTuple(dst.trainable_params())
...     self.src = ParameterTuple(src.trainable_params())
...   def construct(self):
...     success = self.w(self.dst, self.src)
...     return success
>>> dst_net = DstNet()
>>> source_net = SourceNet()
>>> nets = nn.CellList()
>>> nets.append(dst_net)
>>> nets.append(source_net)
>>> success = Write(nets[0], nets[1])()
construct(dst, src)[源代码]

src 中的参数覆盖到 dst

参数:
  • dst (tuple(Parameters)) - 目标位置的参数列表。

  • src (tuple(Parameters)) - 源位置的参数列表。

返回:

True。

class mindspore_rl.utils.BatchRead[源代码]

读一个list的参数覆盖到目标值。

警告

  • 这是一个实验特性,未来有可能被修改或删除。

支持平台:

GPU CPU

样例:

>>> import mindspore
>>> from mindspore import nn
>>> from mindspore.common.parameter import Parameter, ParameterTuple
>>> from mindspore_rl.utils import BatchRead
>>> class SNet(nn.Cell):
...   def __init__(self):
...     super(SNet, self).__init__()
...     self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
...     self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0)
>>> class DNet(nn.Cell):
...   def __init__(self):
...     super(DNet, self).__init__()
...     self.a = Parameter(Tensor(0.1, mstype.float32), name="a")
...     self.dense = nn.Dense(in_channels=16, out_channels=1)
>>> class Read(nn.Cell):
...   def __init__(self, dst, src):
...     super(Read, self).__init__()
...     self.read = BatchRead()
...     self.dst = ParameterTuple(dst.trainable_params())
...     self.src = ParameterTuple(src.trainable_params())
...   def construct(self):
...     success = self.read(self.dst, self.src)
...     return success
>>> dst_net = DNet()
>>> source_net = SNet()
>>> nets = nn.CellList()
>>> nets.append(dst_net)
>>> nets.append(source_net)
>>> success = Read(nets[0], nets[1])()
construct(dst, src)[源代码]

读取 src 中的参数覆盖到 dst

参数:
  • dst (tuple(Parameters)) - 目标位置的参数列表。

  • src (tuple(Parameters)) - 源位置的参数列表。

返回:

True。