RhinoFreak 发表于 2022-4-15 13:40

离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解 ...

【更新日志】


论文信息:Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021; arXiv:2106.06860.
本文是Google Brain团队和McGill大学合作,由 TD3、BCQ的作者 Fujimoto 提出并发表在NeurIPS2021顶会上的文章,本文方法最大的优点是:方法简单、无任何复杂数学公式、可实现性强(开源)、对比实验非常充分(满分推荐),正如标题一样(A minimalist approach)。
摘要: 相比于几篇博客讲过的BCQ(通过扰动网络生成动作,不断将学习策略和行为策略拉进)、BEAR(通过支撑集匹配避免分布匹配的问题)、BRAC(通过VP和PR两个方法正则化)以及REM(通过随机集成混合方法对多个值函数求取凸优化最优的鲁棒性)方法。本文作者提出的TD3+BC方法,结构简单,仅在值函数上添加一个行为克隆(BC)的正则项,并对state进行normalizing,简单的对TD3修改了几行代码就可以与前几种方法相媲美,结果表明:TD3+BC效果好,训练时间也比其他少很多。1. Offline RL的一些挑战。


[*]实现和Tune的复杂性(Implementation and Tuning Complexities), 在强化学习中,算法的实现、论文的复现都是一个非常难的问题,很多算法并没法去复现,即使相同的seed有时候未必也能达到效果。同样在Offline中仍然存在,此外在Offline中还要解决分布偏移、OODd等之外的一些问题。
[*]额外算力需求(Extra Computation Requirement),由于过于复杂的数学优化、过多的超参数等算法的执行带来了很长的训练时间,导致不得不增加计算资源来学习算法使得其收敛。
[*]训练策略的不稳定性(Instability of Trained Policies),强化学习领域的不稳定性众所周知,所以Offline RL如何才能与Supervised leanring一样很稳定是一个重要的研究问题。
[*]Offline RL改进问题(algorithmic/Coding/Optimization),包括了代码层次的优化改进和理论结构方面的改进等。
其实本文并不是去解决传统的offline RL中的一些诸如分布偏移、OOD、过估计以及等等这些问题,而是去解决如何简单、快速、高效的实现算法的实现与高效运行问题,因此作者面对这些问题,发出疑问并给出方法:


2. TD3+BC原理

2.1 TD3+BC相比于其他的优势

下图是TD3+BC算法相对于CQL、Fish-BRC算法的复杂性对比,从表中我们可以看到CQL和Fish-BRC在算法(algorithmic)上有了很多的变种,使用生成网络,近似 https://www.zhihu.com/equation?tex=logsumexp 等,而TD3+BC仅仅添加了一个BC term和Normalized state,足够的简单。


2.2 理论部分

对于经典的DDPG、TD3等算法来讲, 策略梯度的计算根据David sliver提出的如下定义,即求解状态-动作值函数的期望值。

https://www.zhihu.com/equation?tex=%5Cpi%3D%5Coperatorname%7Bargmax%7D+%5Cmathbb%7BE%7D_%7B%28s%2C+a%29+%5Csim+%5Cmathcal%7BD%7D%7D%5BQ%28s%2C+%5Cpi%28s%29%29%5D+%5C%5C
本文中,作者为了尽可能的让两个动作接近添加了一个正则项 https://www.zhihu.com/equation?tex=%28%5Cpi%28s%29-a%29 以及   ,

https://www.zhihu.com/equation?tex=%5Cpi%3D%5Cunderset%7B%5Cpi%7D%7B%5Coperatorname%7Bargmax%7D%7D+%5Cmathbb%7BE%7D_%7Bs+%5Csim+%5Cmathcal%7BD%7D%7D%5BQ%28s%2C+%5Cpi%28s%29%29%5D+%5Crightarrow+%5Cpi%3D%5Cunderset%7B%5Cpi%7D%7B%5Coperatorname%7Bargmax%7D%7D+%5Cmathbb%7BE%7D_%7B%28s%2C+a%29+%5Csim+%5Cmathcal%7BD%7D%7D%5Cleft%5B%5Clambda+Q%28s%2C+%5Cpi%28s%29%29-%28%5Cpi%28s%29-a%29%5E%7B2%7D%5Cright%5D+%5C%5C
个人看法: 有点像BCQ中的让学习策略和行为策略之间的距离减少那种意思,只不过添加到正则项里面.

https://www.zhihu.com/equation?tex=%5Comega+%5Cleftarrow+%5Coperatorname%7Bargmin%7D_%7B%5Comega%7D+%5Csum%28a-%5Ctilde%7Ba%7D%29%5E%7B2%7D%2BD_%7B%5Cmathrm%7BKL%7D%7D%28%5Cmathcal%7BN%7D%28%5Cmu%2C+%5Csigma%29+%5C%7C+%5Cmathcal%7BN%7D%280%2C1%29%29+%5C%5C+%5Cphi+%5Cleftarrow+%5Coperatorname%7Bargmax%7D_%7B%5Cphi%7D+%5Csum+Q_%7B%5Ctheta_%7B1%7D%7D%5Cleft%28s%2C+a%2B%5Cxi_%7B%5Cphi%7D%28s%2C+a%2C+%5CPhi%29%5Cright%29%2C+a+%5Csim+G_%7B%5Comega%7D%28s%29%5C%5C
另外一个技术点就是从代码执行层面的优化,即Normalize State,具体的Normalize过程如公式所示:

https://www.zhihu.com/equation?tex=s_%7Bi%7D%3D%5Cfrac%7Bs_%7Bi%7D-%5Cmu_%7Bi%7D%7D%7B%5Csigma_%7Bi%7D%2B%5Cepsilon%7D+%5C%5C
其中的 https://www.zhihu.com/equation?tex=%5Cepsilon 表示一个normalization常量,作者在文中使用了 https://www.zhihu.com/equation?tex=%5Cepsilon%3D10%5E%7B-3%7D, https://www.zhihu.com/equation?tex=%5Cmu_%7Bi%7D 和 https://www.zhihu.com/equation?tex=%5Csigma_%7Bi%7D 表示期望和标准差(standard deviation)。
实验效果(关于纵坐标Percent difference后文有说明,本部分只看效果)

最后一个技术点就是关于   的求解,作者给出了计算公式,并在后文中说取值为 https://www.zhihu.com/equation?tex=%5Clambda%3D2.5 的时候效果最好, 实验部分有作者做的ablation实验证明。

https://www.zhihu.com/equation?tex=%5Clambda%3D%5Cfrac%7B%5Calpha%7D%7B%5Cfrac%7B1%7D%7BN%7D+%5Csum_%7B%5Cleft%28s_%7Bi%7D%2C+a_%7Bi%7D%5Cright%29%7D%5Cleft%7CQ%5Cleft%28s_%7Bi%7D%2C+a_%7Bi%7D%5Cright%29%5Cright%7C%7D+%5C%5C
最后贴出作者在TD3代码上的改动部分==》TD3+BC算法实现


2.3 经典的Rebuttal场面

此外,我们看一下作者如何rebuttle这些OpenReview提出的审稿意见,


其实这部分蛮有意思的,我们发现大多数普通人的工作还是集中在对算法的小部分优化(数学大佬和代码大神略过),这里作者教你手把手给审稿人回复(建议收藏,特别是第2条)
审稿人: (1)首先,该方法的新颖性似乎有点有限。作者似乎直接使 RL+BC 适应离线设置,只是他们添加了状态归一化,这也不是新的。作者也没有从理论上证明这种方法的合理性。例如,作者应该证明该方法可以保证安全的策略改进,并且享有可比或更好的策略改进保证 w.r.t.先前的方法。如果没有理论依据,并且考虑到该方法的当前形式,我认为该方法有点增量。 (2)此外,实证评估并不彻底。作者仅在 D4RL 中的简单 mujoco 环境中评估了该方法。目前尚不清楚该方法是否可以很好地执行更多无向多任务数据集,例如蚂蚁迷宫和厨房,以及更复杂的操作任务,例如 D4RL 中的 adroit。似乎该方法在随机数据集上表现不佳。这是一个主要限制吗?我还认为作者应该将状态归一化添加到所有基线以确保公平比较,因为状态归一化不是 RL 中的新技术。 (3)最后,我认为比较不完整。作者还应该将该方法与最近的无模型离线 RL 方法(如 )和基于模型的方法(如 )进行比较,后者在随机和中等重放数据集上获得了更好的性能。 总的来说,鉴于上述评论,我会投票支持弱拒绝。
下面我们看作者的神奇巧妙回复
作者回复: (1)关于新颖性:我们完全不同意我们的算法在新颖性方面是递增的(我们在相关工作中强调了许多类似的算法)。然而,我们的主要主张/贡献与其说这是最好的离线 RL 算法,或者说它特别新颖,不如说是令人惊讶的观察,即使用非常简单的技术可以匹配/优于当前算法。希望 TD3+BC 可以用作易于实现的基线或其他添加(例如 S4RL)的起点,同时消除更复杂方法所需的许多不必要的复杂性、超参数调整或计算成本. (2)关于经验评估:据我们所知,我们最强的基线 Fisher-BRC 被认为是无模型算法的 SOTA,最近在 ICML 上发表。 (3)由于 D4RL 结果的标准化,我们可以直接与建议的基线进行比较(我们会将这些结果包含在最终草案中)。我们在下面报告这些,但我们想说明两点: (4)MOReL 和 MOPO 来自不同的方法系列(基于模型),并且都使用特定于环境的超参数。 S4RL 与我们的方法相切,只需将 CQL 替换为 TD3+BC,就可以很容易地将其添加到我们的方法中。我们的方法可以说更适合基础算法的这些类型的添加,因为超参数更少,这意味着我们不必担心变化之间的交互作用。 最终,我们没有发现添加状态归一化可以为基线提供相同水平的好处,这可能是因为这些方法需要超参数调整来补偿额外的修改。
挺有意思的,学习收藏吧!
3. 实验及过程分析

3.1 实验超参数

这部分是作者实验的一些基础,挺良心的,具体到了每一个实验环境的版本号




这部分特意说明一下作者的良心部分:代码版本都放出来了


3.2 衡量指标:百分比差异(Percent Difference)

这部分公式是作者实验的参考基准计算方式,其中在博客也提出了关于差距百分比的疑问,特意查了了一下计算过程(备注,有的地方可能用了绝对值):


3.3 实验验证与结果简要分析

说明:关于D4RL数据集的组成、安装和解释请参考博文 离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决
本实验参数 HC = HalfCheetah, Hop = Hopper, W = Walker, r = random, m = medium, mr = medium-replay, me = medium-expert, e = expert. While online algorithms (TD3) typically have small episode variances per trained policy (as they should at convergence),
3.3.1 D4RL验证讨论





3.3.2 运行训练时间讨论

可以从实验结果中很直白的看到,CQL、FishBRC与TD3+BC( https://www.zhihu.com/equation?tex=%5Cleq+39m )的运行时间, 其实这与算法的复杂性紧密相关,对于TD3来说只需要去根据超参数学习网络即可,但对于CQL等算法,需要学习一堆的参数。




3.3.3 消融(ablation)实验(如何确定 ?)

这部分其实对比了vanillaBC方法和区别,同时就参数做了对比得出了最好的。








4. 代码实例分析



def train(self, replay_buffer, batch_size=256):
                self.total_it += 1

                # Sample replay buffer
                state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

                with torch.no_grad():
                        # Select action according to policy and add clipped noise
                        noise = (
                                torch.randn_like(action) * self.policy_noise
                        ).clamp(-self.noise_clip, self.noise_clip)
                       
                        next_action = (
                                self.actor_target(next_state) + noise
                        ).clamp(-self.max_action, self.max_action)

                        # Compute the target Q value
                        target_Q1, target_Q2 = self.critic_target(next_state, next_action)
                        target_Q = torch.min(target_Q1, target_Q2)
                        target_Q = reward + not_done * self.discount * target_Q

                # Get current Q estimates
                current_Q1, current_Q2 = self.critic(state, action)

                # Compute critic loss
                critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

                # Optimize the critic
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()

                # Delayed policy updates
                if self.total_it % self.policy_freq == 0:

                        # Compute actor loss
                        pi = self.actor(state)
                        Q = self.critic.Q1(state, pi)
                        lmbda = self.alpha/Q.abs().mean().detach()

                        actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action)
                       
                        # Optimize the actor
                        self.actor_optimizer.zero_grad()
                        actor_loss.backward()
                        self.actor_optimizer.step()

                        # Update the frozen target models
                        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

                        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def eval_policy(policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10):
        eval_env = gym.make(env_name)
        eval_env.seed(seed + seed_offset)

        avg_reward = 0.
        for _ in range(eval_episodes):
                state, done = eval_env.reset(), False
                while not done:
                        state = (np.array(state).reshape(1,-1) - mean)/std
                        action = policy.select_action(state)
                        state, reward, done, _ = eval_env.step(action)
                        avg_reward += reward

        avg_reward /= eval_episodes
        d4rl_score = eval_env.get_normalized_score(avg_reward) * 100

        print("---------------------------------------")
        print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}")
        print("---------------------------------------")
        return d4rl_score参考文献

. Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021; arXiv:2106.06860. . A Minimalist Approach to Offline Reinforcement Learning, OpenReview . percent-difference, percent-difference
<hr/>OfflineRL推荐阅读

离线强化学习(Offline RL)系列3: (算法篇) REM(Random Ensemble Mixture)算法详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BRAC算法原理详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BEAR算法原理详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束-BCQ算法详解与实现
离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决
离线强化学习(Offline RL)系列1:离线强化学习原理入门

jquave 发表于 2022-4-15 13:42

更正一下,是NeurIPS2021不是2020

七彩极 发表于 2022-4-15 13:50

感谢反馈,已更正[握手]
页: [1]
查看完整版本: 离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解 ...