[{"data":1,"prerenderedAt":508},["ShallowReactive",2],{"content-query-48KVZBoQL8":3},{"_path":4,"_dir":5,"_draft":6,"_partial":6,"_locale":7,"title":8,"description":9,"date":10,"cover":11,"type":12,"category":13,"body":14,"_type":502,"_id":503,"_source":504,"_file":505,"_stem":506,"_extension":507},"/technology-blogs/zh/3213","zh",false,"","一文教你在昇思MindSpore中实现A2C算法训练","作者：irrational                                                                                 来源：华为云社区","2024-07-04","https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2024/07/05/44ed8ee5c50743848bac13c8698b8e4c.png","technology-blogs","实践",{"type":15,"children":16,"toc":499},"root",[17,25,44,49,57,65,85,93,101,106,116,124,132,137,142,147,152,157,162,167,172,177,182,187,192,200,205,213,218,226,231,244,249,257,265,273,281,289,297,315,320,325,330,335,340,348,356,364,372,385,390,395,400,405,410,415,423,431,439,462,470,478,486,494],{"type":18,"tag":19,"props":20,"children":22},"element","h1",{"id":21},"一文教你在昇思mindspore中实现a2c算法训练",[23],{"type":24,"value":8},"text",{"type":18,"tag":26,"props":27,"children":28},"p",{},[29,31,37,39],{"type":24,"value":30},"**作者：**",{"type":18,"tag":32,"props":33,"children":34},"strong",{},[35],{"type":24,"value":36},"irrational",{"type":24,"value":38}," ",{"type":18,"tag":32,"props":40,"children":41},{},[42],{"type":24,"value":43},"来源：华为云社区",{"type":18,"tag":26,"props":45,"children":46},{},[47],{"type":24,"value":48},"Advantage Actor-Critic (A2C)算法是一个强化学习算法，它结合了策略梯度（Actor）和价值函数（Critic）的方法。A2C算法在许多强化学习任务中表现优越，因为它能够利用价值函数来减少策略梯度的方差，同时直接优化策略。",{"type":18,"tag":26,"props":50,"children":51},{},[52],{"type":18,"tag":32,"props":53,"children":54},{},[55],{"type":24,"value":56},"01",{"type":18,"tag":26,"props":58,"children":59},{},[60],{"type":18,"tag":32,"props":61,"children":62},{},[63],{"type":24,"value":64},"A2C算法的核心思想",{"type":18,"tag":66,"props":67,"children":68},"ul",{},[69,75,80],{"type":18,"tag":70,"props":71,"children":72},"li",{},[73],{"type":24,"value":74},"**Actor：**根据当前策略选择动作。",{"type":18,"tag":70,"props":76,"children":77},{},[78],{"type":24,"value":79},"**Critic：**评估一个状态-动作对的值（通常是使用状态值函数或动作值函数）。",{"type":18,"tag":70,"props":81,"children":82},{},[83],{"type":24,"value":84},"**优势函数（Advantage Function）：**用来衡量某个动作相对于平均水平的好坏，通常定义为A(s,a)=Q(s,a)−V(s)。",{"type":18,"tag":26,"props":86,"children":87},{},[88],{"type":18,"tag":32,"props":89,"children":90},{},[91],{"type":24,"value":92},"02",{"type":18,"tag":26,"props":94,"children":95},{},[96],{"type":18,"tag":32,"props":97,"children":98},{},[99],{"type":24,"value":100},"A2C算法的伪代码",{"type":18,"tag":26,"props":102,"children":103},{},[104],{"type":24,"value":105},"以下是A2C算法的伪代码：",{"type":18,"tag":107,"props":108,"children":110},"pre",{"code":109},"Initialize policy network (actor) π with parameters θ\nInitialize value network (critic) V with parameters w\nInitialize learning rates α_θ for policy network and α_w for value network\n\nfor each episode do\n    Initialize state s\n    while state s is not terminal do\n        # Actor: select action a according to the current policy π(a|s; θ)\n        a = select_action(s, θ)\n\n        # Execute action a in the environment, observe reward r and next state s'\n        r, s' = environment.step(a)\n\n        # Critic: compute the value of the current state V(s; w)\n        V_s = V(s, w)\n\n        # Critic: compute the value of the next state V(s'; w)\n        V_s_prime = V(s', w)\n\n        # Compute the TD error (δ)\n        δ = r + γ * V_s_prime - V_s\n\n        # Critic: update the value network parameters w\n        w = w + α_w * δ * ∇_w V(s; w)\n\n        # Compute the advantage function A(s, a)\n        A = δ\n\n        # Actor: update the policy network parameters θ\n        θ = θ + α_θ * A * ∇_θ log π(a|s; θ)\n\n        # Move to the next state\n        s = s'\n    end while\nend for\n",[111],{"type":18,"tag":112,"props":113,"children":114},"code",{"__ignoreMap":7},[115],{"type":24,"value":109},{"type":18,"tag":26,"props":117,"children":118},{},[119],{"type":18,"tag":32,"props":120,"children":121},{},[122],{"type":24,"value":123},"03",{"type":18,"tag":26,"props":125,"children":126},{},[127],{"type":18,"tag":32,"props":128,"children":129},{},[130],{"type":24,"value":131},"解释",{"type":18,"tag":26,"props":133,"children":134},{},[135],{"type":24,"value":136},"**1、初始化：**初始化策略网络（Actor）和价值网络（Critic）的参数，以及它们的学习率。",{"type":18,"tag":26,"props":138,"children":139},{},[140],{"type":24,"value":141},"**2、循环每个Episode：**在每个Episode开始时，初始化状态。",{"type":18,"tag":26,"props":143,"children":144},{},[145],{"type":24,"value":146},"**3、选择动作：**根据当前策略从Actor中选择动作。",{"type":18,"tag":26,"props":148,"children":149},{},[150],{"type":24,"value":151},"**4、执行动作：**在环境中执行动作，并观察奖励和下一个状态。",{"type":18,"tag":26,"props":153,"children":154},{},[155],{"type":24,"value":156},"**5、计算状态值：**用Critic评估当前状态和下一个状态的值。",{"type":18,"tag":26,"props":158,"children":159},{},[160],{"type":24,"value":161},"**6、计算TD误差：**计算时序差分误差（Temporal Difference Error），它是当前奖励加上下一个状态的折扣值与当前状态值的差。",{"type":18,"tag":26,"props":163,"children":164},{},[165],{"type":24,"value":166},"**7、更新Critic：**根据TD误差更新价值网络的参数。",{"type":18,"tag":26,"props":168,"children":169},{},[170],{"type":24,"value":171},"**8、计算优势函数：**使用TD误差计算优势函数。",{"type":18,"tag":26,"props":173,"children":174},{},[175],{"type":24,"value":176},"**9、更新Actor：**根据优势函数更新策略网络的参数。",{"type":18,"tag":26,"props":178,"children":179},{},[180],{"type":24,"value":181},"**10、更新状态：**移动到下一个状态，重复上述步骤，直到Episode结束。",{"type":18,"tag":26,"props":183,"children":184},{},[185],{"type":24,"value":186},"这个伪代码展示了A2C算法的核心步骤，实际实现中可能会有更多细节，如使用折扣因子γ、多个并行环境等。",{"type":18,"tag":26,"props":188,"children":189},{},[190],{"type":24,"value":191},"代码如下：",{"type":18,"tag":107,"props":193,"children":195},{"code":194},"import argparse\n\nfrom mindspore import context\nfrom mindspore import dtype as mstype\nfrom mindspore.communication import init\n\nfrom mindspore_rl.algorithm.a2c import config\nfrom mindspore_rl.algorithm.a2c.a2c_session import A2CSession\nfrom mindspore_rl.algorithm.a2c.a2c_trainer import A2CTrainer\n\nparser = argparse.ArgumentParser(description=\"MindSpore Reinforcement A2C\")\nparser.add_argument(\"--episode\", type=int, default=10000, help=\"total episode numbers.\")\nparser.add_argument(\n    \"--device_target\",\n    type=str,\n    default=\"CPU\",\n    choices=[\"CPU\", \"GPU\", \"Ascend\", \"Auto\"],\n    help=\"Choose a devioptions.device_targece to run the ac example(Default: Auto).\",\n)\nparser.add_argument(\n    \"--precision_mode\",\n    type=str,\n    default=\"fp32\",\n    choices=[\"fp32\", \"fp16\"],\n    help=\"Precision mode\",\n)\nparser.add_argument(\n    \"--env_yaml\",\n    type=str,\n    default=\"../env_yaml/CartPole-v0.yaml\",\n    help=\"Choose an environment yaml to update the a2c example(Default: CartPole-v0.yaml).\",\n)\nparser.add_argument(\n    \"--algo_yaml\",\n    type=str,\n    default=None,\n    help=\"Choose an algo yaml to update the a2c example(Default: None).\",\n)\nparser.add_argument(\n    \"--enable_distribute\",\n    type=bool,\n    default=False,\n    help=\"Train in distribute mode (Default: False).\",\n)\nparser.add_argument(\n    \"--worker_num\",\n    type=int,\n    default=2,\n    help=\"Worker num (Default: 2).\",\n)\noptions, _ = parser.parse_known_args()\n",[196],{"type":18,"tag":112,"props":197,"children":198},{"__ignoreMap":7},[199],{"type":24,"value":194},{"type":18,"tag":26,"props":201,"children":202},{},[203],{"type":24,"value":204},"首先初始化参数，然后运行。",{"type":18,"tag":107,"props":206,"children":208},{"code":207},"episode=options.episode\n\"\"\"Train a2c\"\"\"\nif options.device_target != \"Auto\":\n    context.set_context(device_target=options.device_target)\nif context.get_context(\"device_target\") in [\"CPU\", \"GPU\"]:\n    context.set_context(enable_graph_kernel=True)\ncontext.set_context(mode=context.GRAPH_MODE)\ncompute_type = (\n    mstype.float32 if options.precision_mode == \"fp32\" else mstype.float16\n)\nconfig.algorithm_config[\"policy_and_network\"][\"params\"][\n    \"compute_type\"\n] = compute_type\nif compute_type == mstype.float16 and options.device_target != \"Ascend\":\n    raise ValueError(\"Fp16 mode is supported by Ascend backend.\")\nis_distribte = options.enable_distribute\nif is_distribte:\n    init()\n    context.set_context(enable_graph_kernel=False)\n    config.deploy_config[\"worker_num\"] = options.worker_num\na2c_session = A2CSession(options.env_yaml, options.algo_yaml, is_distribte)\n",[209],{"type":18,"tag":112,"props":210,"children":211},{"__ignoreMap":7},[212],{"type":24,"value":207},{"type":18,"tag":26,"props":214,"children":215},{},[216],{"type":24,"value":217},"设置上下文管理器。",{"type":18,"tag":107,"props":219,"children":221},{"code":220},"import sys\nimport time\nfrom io import StringIO\n\nclass RealTimeCaptureAndDisplayOutput(object):\n    def __init__(self):\n        self._original_stdout = sys.stdout\n        self._original_stderr = sys.stderr\n        self.captured_output = StringIO()\n\n    def write(self, text):\n        self._original_stdout.write(text)  # 实时打印\n        self.captured_output.write(text)   # 保存到缓冲区\n\n    def flush(self):\n        self._original_stdout.flush()\n        self.captured_output.flush()\n\n    def __enter__(self):\n        sys.stdout = self\n        sys.stderr = self\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        sys.stdout = self._original_stdout\n        sys.stderr = self._original_stderr\n\nepisode=10\n# dqn_session.run(class_type=DQNTrainer, episode=episode)\nwith RealTimeCaptureAndDisplayOutput() as captured_new:\n    a2c_session.run(class_type=A2CTrainer, episode=episode)\n\nimport re\nimport matplotlib.pyplot as plt\n\n# 原始输出\nraw_output = captured_new.captured_output.getvalue()\n\n# 使用正则表达式从输出中提取loss和rewards\nloss_pattern = r\"loss=(\\d+\\.\\d+)\"\nreward_pattern = r\"running_reward=(\\d+\\.\\d+)\"\nloss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]\nreward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)]\n\n# 绘制loss曲线\nplt.plot(loss_values, label='Loss')\nplt.xlabel('Episode')\nplt.ylabel('Loss')\nplt.title('Loss Curve')\nplt.legend()\nplt.show()\n\n# 绘制reward曲线\nplt.plot(reward_values, label='Rewards')\nplt.xlabel('Episode')\nplt.ylabel('Rewards')\nplt.title('Rewards Curve')\nplt.legend()\nplt.show()\n",[222],{"type":18,"tag":112,"props":223,"children":224},{"__ignoreMap":7},[225],{"type":24,"value":220},{"type":18,"tag":26,"props":227,"children":228},{},[229],{"type":24,"value":230},"展示结果：",{"type":18,"tag":26,"props":232,"children":233},{},[234,239,240],{"type":18,"tag":235,"props":236,"children":238},"img",{"alt":7,"src":237},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2024/07/05/f7ce31aaa9584f74b4727ba697635e7f.png",[],{"type":24,"value":38},{"type":18,"tag":235,"props":241,"children":243},{"alt":7,"src":242},"https://obs-mindspore-file.obs.cn-north-4.myhuaweicloud.com/file/2024/07/05/61f79071e6324520a112d68af55b52ad.png",[],{"type":18,"tag":26,"props":245,"children":246},{},[247],{"type":24,"value":248},"下面将详细解释你提供的昇思MindSpore A2C 算法训练配置参数的含义：",{"type":18,"tag":26,"props":250,"children":251},{},[252],{"type":18,"tag":32,"props":253,"children":254},{},[255],{"type":24,"value":256},"04",{"type":18,"tag":26,"props":258,"children":259},{},[260],{"type":18,"tag":32,"props":261,"children":262},{},[263],{"type":24,"value":264},"Actor 配置",{"type":18,"tag":107,"props":266,"children":268},{"code":267},"'actor': {\n  'number': 1,\n  'type': mindspore_rl.algorithm.a2c.a2c.A2CActor,\n  'params': {\n    'collect_environment': PyFuncWrapper\u003C\n       (_envs): GymEnvironment\u003C>\n     >,\n   'eval_environment': PyFuncWrapper\u003C\n     (_envs): GymEnvironment\u003C>\n     >,\n   'replay_buffer': None,\n   'a2c_net': ActorCriticNet\u003C\n     (common): Dense\n     (actor): Dense\n     (critic): Dense\n     (relu): LeakyReLU\u003C>\n     >},\n  'policies': [],\n  'networks': ['a2c_net']\n}\nnumber: Actor 的实例数量，这里设置为1，表示使用一个 Actor 实例。\ntype: Actor 的类型，这里使用 mindspore_rl.algorithm.a2c.a2c.A2CActor。\nparams: Actor 的参数配置。\ncollect_environment 和 eval_environment: 使用 PyFuncWrapper 包装的 GymEnvironment，用于数据收集和评估环境。\nreplay_buffer: 设置为 None，表示不使用经验回放缓冲区。\na2c_net: Actor-Critic 网络，包含一个公共层、一个 Actor 层和一个 Critic 层，以及一个 Leaky ReLU 激活函数。\npolicies 和 networks: Actor 关联的策略和网络，这里主要是 a2c_net。\n",[269],{"type":18,"tag":112,"props":270,"children":271},{"__ignoreMap":7},[272],{"type":24,"value":267},{"type":18,"tag":26,"props":274,"children":275},{},[276],{"type":18,"tag":32,"props":277,"children":278},{},[279],{"type":24,"value":280},"05",{"type":18,"tag":26,"props":282,"children":283},{},[284],{"type":18,"tag":32,"props":285,"children":286},{},[287],{"type":24,"value":288},"Learner配置",{"type":18,"tag":107,"props":290,"children":292},{"code":291},"'learner': {\n  'number': 1,\n  'type': mindspore_rl.algorithm.a2c.a2c.A2CLearner,\n  'params': {\n    'gamma': 0.99,\n    'state_space_dim': 4,\n    'action_space_dim': 2,\n    'a2c_net': ActorCriticNet\u003C\n      (common): Dense\n      (actor): Dense\n      (critic): Dense\n      (relu): LeakyReLU\u003C>\n    >,\n    'a2c_net_train': TrainOneStepCell\u003C\n      (network): Loss\u003C\n        (a2c_net): ActorCriticNet\u003C\n          (common): Dense\n          (actor): Dense\n          (critic): Dense\n          (relu): LeakyReLU\u003C>\n        >\n        (smoothl1_loss): SmoothL1Loss\u003C>\n      >\n      (optimizer): Adam\u003C>\n      (grad_reducer): Identity\u003C>\n    >\n  },\n  'networks': ['a2c_net_train', 'a2c_net']\n}\n",[293],{"type":18,"tag":112,"props":294,"children":295},{"__ignoreMap":7},[296],{"type":24,"value":291},{"type":18,"tag":66,"props":298,"children":299},{},[300,305,310],{"type":18,"tag":70,"props":301,"children":302},{},[303],{"type":24,"value":304},"**number：**Learner的实例数量，这里设置为1，表示使用一个 Learner实例。",{"type":18,"tag":70,"props":306,"children":307},{},[308],{"type":24,"value":309},"**type：**Learner的类型，这里使用mindspore_rl.algorithm.a2c.a2c.A2CLearner。",{"type":18,"tag":70,"props":311,"children":312},{},[313],{"type":24,"value":314},"**params：**Learner的参数配置。",{"type":18,"tag":26,"props":316,"children":317},{},[318],{"type":24,"value":319},"**gamma：**折扣因子，用于未来奖励的折扣计算。",{"type":18,"tag":26,"props":321,"children":322},{},[323],{"type":24,"value":324},"**state_space_dim：**状态空间的维度，这里为4。",{"type":18,"tag":26,"props":326,"children":327},{},[328],{"type":24,"value":329},"**action_space_dim：**动作空间的维度，这里为2。",{"type":18,"tag":26,"props":331,"children":332},{},[333],{"type":24,"value":334},"**a2c_net：**Actor-Critic网络定义，与Actor中相同。",{"type":18,"tag":26,"props":336,"children":337},{},[338],{"type":24,"value":339},"**a2c_net_train：**用于训练的网络，包含损失函数（SmoothL1Loss）、优化器（Adam）和梯度缩减器（Identity）。",{"type":18,"tag":66,"props":341,"children":342},{},[343],{"type":18,"tag":70,"props":344,"children":345},{},[346],{"type":24,"value":347},"**networks：**Learner关联的网络，包括a2c_net_train和a2c_net。",{"type":18,"tag":26,"props":349,"children":350},{},[351],{"type":18,"tag":32,"props":352,"children":353},{},[354],{"type":24,"value":355},"06",{"type":18,"tag":26,"props":357,"children":358},{},[359],{"type":18,"tag":32,"props":360,"children":361},{},[362],{"type":24,"value":363},"Policy and Network****配置",{"type":18,"tag":107,"props":365,"children":367},{"code":366},"'policy_and_network': {\n  'type': mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork,\n  'params': {\n    'lr': 0.01,\n    'state_space_dim': 4,\n    'action_space_dim': 2,\n    'hidden_size': 128,\n    'gamma': 0.99,\n    'compute_type': mindspore.float32,\n    'environment_config': {\n      'id': 'CartPole-v0',\n      'entry_point': 'gym.envs.classic_control:CartPoleEnv',\n      'reward_threshold': 195.0,\n      'nondeterministic': False,\n      'max_episode_steps': 200,\n      '_kwargs': {},\n      '_env_name': 'CartPole'\n    }\n  }\n}\n",[368],{"type":18,"tag":112,"props":369,"children":370},{"__ignoreMap":7},[371],{"type":24,"value":366},{"type":18,"tag":66,"props":373,"children":374},{},[375,380],{"type":18,"tag":70,"props":376,"children":377},{},[378],{"type":24,"value":379},"**type：**策略和网络的类型，这里使用mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork。",{"type":18,"tag":70,"props":381,"children":382},{},[383],{"type":24,"value":384},"**params：**策略和网络的参数配置。",{"type":18,"tag":26,"props":386,"children":387},{},[388],{"type":24,"value":389},"**lr：**学习率，这里为0.01。",{"type":18,"tag":26,"props":391,"children":392},{},[393],{"type":24,"value":394},"**state_space_dim 和 action_space_dim：**状态和动作空间的维度。",{"type":18,"tag":26,"props":396,"children":397},{},[398],{"type":24,"value":399},"**hidden_size：**隐藏层的大小，这里为128。",{"type":18,"tag":26,"props":401,"children":402},{},[403],{"type":24,"value":404},"**gamma：**折扣因子。",{"type":18,"tag":26,"props":406,"children":407},{},[408],{"type":24,"value":409},"**compute_type：**计算类型，这里为mindspore.float32。",{"type":18,"tag":26,"props":411,"children":412},{},[413],{"type":24,"value":414},"**environment_config：**环境配置，包括环境ID、入口、奖励阈值、最大步数等。",{"type":18,"tag":26,"props":416,"children":417},{},[418],{"type":18,"tag":32,"props":419,"children":420},{},[421],{"type":24,"value":422},"07",{"type":18,"tag":26,"props":424,"children":425},{},[426],{"type":18,"tag":32,"props":427,"children":428},{},[429],{"type":24,"value":430},"Collect Environment****配置",{"type":18,"tag":107,"props":432,"children":434},{"code":433},"'collect_environment': {\n  'number': 1,\n  'type': mindspore_rl.environment.gym_environment.GymEnvironment,\n  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],\n  'params': {\n    'GymEnvironment': {\n      'name': 'CartPole-v0',\n      'seed': 42\n    },\n    'name': 'CartPole-v0'\n  }\n}\n",[435],{"type":18,"tag":112,"props":436,"children":437},{"__ignoreMap":7},[438],{"type":24,"value":433},{"type":18,"tag":66,"props":440,"children":441},{},[442,447,452,457],{"type":18,"tag":70,"props":443,"children":444},{},[445],{"type":24,"value":446},"**number：**环境实例数量，这里为1。",{"type":18,"tag":70,"props":448,"children":449},{},[450],{"type":24,"value":451},"**type：**环境的类型，这里使用mindspore_rl.environment.gym_environment.GymEnvironment。",{"type":18,"tag":70,"props":453,"children":454},{},[455],{"type":24,"value":456},"**wrappers：**环境使用的包装器，这里是PyFuncWrapper。",{"type":18,"tag":70,"props":458,"children":459},{},[460],{"type":24,"value":461},"**params：**环境的参数配置，包括环境名称CartPole-v0和随机种子 42。",{"type":18,"tag":26,"props":463,"children":464},{},[465],{"type":18,"tag":32,"props":466,"children":467},{},[468],{"type":24,"value":469},"08",{"type":18,"tag":26,"props":471,"children":472},{},[473],{"type":18,"tag":32,"props":474,"children":475},{},[476],{"type":24,"value":477},"Eval Environment 配置",{"type":18,"tag":107,"props":479,"children":481},{"code":480},"'eval_environment': {\n  'number': 1,\n  'type': mindspore_rl.environment.gym_environment.GymEnvironment,\n  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],\n  'params': {\n    'GymEnvironment': {\n      'name': 'CartPole-v0',\n      'seed': 42\n    },\n    'name': 'CartPole-v0'\n  }\n}\n",[482],{"type":18,"tag":112,"props":483,"children":484},{"__ignoreMap":7},[485],{"type":24,"value":480},{"type":18,"tag":66,"props":487,"children":488},{},[489],{"type":18,"tag":70,"props":490,"children":491},{},[492],{"type":24,"value":493},"配置与Collect_Environment类似，用于评估模型性能。",{"type":18,"tag":26,"props":495,"children":496},{},[497],{"type":24,"value":498},"总结一下，这些配置定义了Actor-Critic算法在昇思MindSpore框架中的具体实现，包括Actor和Learner的设置、策略和网络的参数，以及训练和评估环境的配置。这个还是比较基础的。",{"title":7,"searchDepth":500,"depth":500,"links":501},4,[],"markdown","content:technology-blogs:zh:3213.md","content","technology-blogs/zh/3213.md","technology-blogs/zh/3213","md",1776506127298]