mindspore_rl.policy.epsilon_greedy_policy 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GreedyEpsilonGreedyPolicyPolicy.
"""

import mindspore
import numpy as np
from mindspore import Tensor

from mindspore_rl.policy import GreedyPolicy, RandomPolicy, policy


[文档]class EpsilonGreedyPolicy(policy.Policy): r""" Produces a sample action base on the given epsilon-greedy policy. Args: input_network (Cell): A network returns policy action. size (int): Shape of epsilon. epsi_high (float): A high epsilon for exploration betweens [0, 1]. epsi_low (float): A low epsilon for exploration betweens [0, epsi_high]. decay (float): A decay factor applied to epsilon. action_space_dim (int): Dimensions of the action space. shape (tuple, optional): Shape of output action in random policy, it should be the same as action get in greedy policy. Default: (1,). Examples: >>> state_dim, hidden_dim, action_dim = (4, 10, 2) >>> input_net = FullyConnectedNet(state_dim, hidden_dim, action_dim) >>> policy = EpsilonGreedyPolicy(input_net, 1, 0.1, 0.1, 100, action_dim) >>> state = Tensor(np.ones([1, state_dim]).astype(np.float32)) >>> step = Tensor(np.array([10,]).astype(np.float32)) >>> output = policy(state, step) >>> print(output.shape) (1,) """ def __init__( self, input_network, size, epsi_high, epsi_low, decay, action_space_dim, shape=(1,), ): super().__init__() self._input_network = input_network self.sub = mindspore.ops.Sub() self.add = mindspore.ops.Add() self.div = mindspore.ops.Div() self.mul = mindspore.ops.Mul() self.exp = mindspore.ops.Exp() self.slice = mindspore.ops.Slice() self.squeeze = mindspore.ops.Squeeze(1) self.less = mindspore.ops.Less() self.select = mindspore.ops.Select() self.randreal = mindspore.ops.UniformReal() self.decay_epsilon = epsi_high != epsi_low self.epsi_low = epsi_low self._size = size self._shape = (1,) self._elow_arr = np.ones(self._size) * epsi_low self._ehigh_arr = np.ones(self._size) * epsi_high self._steps_arr = np.ones(self._size) self._decay_arr = np.ones(self._size) * decay self._mone_arr = np.ones(self._size) * -1 self._epsi_high = Tensor(self._ehigh_arr, mindspore.float32) self._epsi_low = Tensor(self._elow_arr, mindspore.float32) self._decay = Tensor(self._decay_arr, mindspore.float32) self._mins_one = Tensor(self._mone_arr, mindspore.float32) self._action_space_dim = action_space_dim self.greedy_policy = GreedyPolicy(self._input_network) self.random_policy = RandomPolicy(self._action_space_dim, shape=shape) # pylint:disable=W0221
[文档] def construct(self, state, step): """ The interface of the construct function. Args: state (Tensor): The input tensor for network. step (Tensor): The current step, effects the epsilon decay. Returns: The output action. """ greedy_action = self.greedy_policy(state) random_action = self.random_policy() if self.decay_epsilon: epsi_sub = self.sub(self._epsi_high, self._epsi_low) epsi_exp = self.exp(self.mul(self._mins_one, self.div(step, self._decay))) epsi_mul = self.mul(epsi_sub, epsi_exp) epsi = self.add(self._epsi_low, epsi_mul) epsi = self.slice(epsi, (0, 0), (1, 1)) epsi = self.squeeze(epsi) else: epsi = self.epsi_low cond = self.less(self.randreal(random_action.shape), epsi) output_action = self.select(cond, random_action, greedy_action) return output_action