|
【更新日志】
论文信息:Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021; arXiv:2106.06860.
本文是Google Brain团队和McGill大学合作,由 TD3、BCQ的作者 Fujimoto 提出并发表在NeurIPS2021顶会上的文章,本文方法最大的优点是:方法简单、无任何复杂数学公式、可实现性强(开源)、对比实验非常充分(满分推荐),正如标题一样(A minimalist approach)。
摘要: 相比于几篇博客讲过的BCQ(通过扰动网络生成动作,不断将学习策略和行为策略拉进)、BEAR(通过支撑集匹配避免分布匹配的问题)、BRAC(通过VP和PR两个方法正则化)以及REM(通过随机集成混合方法对多个值函数求取凸优化最优的鲁棒性)方法。本文作者提出的TD3+BC方法,结构简单,仅在值函数上添加一个行为克隆(BC)的正则项,并对state进行normalizing,简单的对TD3修改了几行代码就可以与前几种方法相媲美,结果表明:TD3+BC效果好,训练时间也比其他少很多。 1. Offline RL的一些挑战。
- 实现和Tune的复杂性(Implementation and Tuning Complexities), 在强化学习中,算法的实现、论文的复现都是一个非常难的问题,很多算法并没法去复现,即使相同的seed有时候未必也能达到效果。同样在Offline中仍然存在,此外在Offline中还要解决分布偏移、OODd等之外的一些问题。
- 额外算力需求(Extra Computation Requirement),由于过于复杂的数学优化、过多的超参数等算法的执行带来了很长的训练时间,导致不得不增加计算资源来学习算法使得其收敛。
- 训练策略的不稳定性(Instability of Trained Policies),强化学习领域的不稳定性众所周知,所以Offline RL如何才能与Supervised leanring一样很稳定是一个重要的研究问题。
- Offline RL改进问题(algorithmic/Coding/Optimization),包括了代码层次的优化改进和理论结构方面的改进等。
其实本文并不是去解决传统的offline RL中的一些诸如分布偏移、OOD、过估计以及等等这些问题,而是去解决如何简单、快速、高效的实现算法的实现与高效运行问题,因此作者面对这些问题,发出疑问并给出方法:
2. TD3+BC原理
2.1 TD3+BC相比于其他的优势
下图是TD3+BC算法相对于CQL、Fish-BRC算法的复杂性对比,从表中我们可以看到CQL和Fish-BRC在算法(algorithmic)上有了很多的变种,使用生成网络,近似 等,而TD3+BC仅仅添加了一个BC term和Normalized state,足够的简单。
2.2 理论部分
对于经典的DDPG、TD3等算法来讲, 策略梯度的计算根据David sliver提出的如下定义,即求解状态-动作值函数的期望值。
本文中,作者为了尽可能的让两个动作接近添加了一个正则项 以及 ,
个人看法: 有点像BCQ中的让学习策略和行为策略之间的距离减少那种意思,只不过添加到正则项里面.
另外一个技术点就是从代码执行层面的优化,即Normalize State,具体的Normalize过程如公式所示:
其中的 表示一个normalization常量,作者在文中使用了 , 和 表示期望和标准差(standard deviation)。
实验效果(关于纵坐标Percent difference后文有说明,本部分只看效果)
最后一个技术点就是关于 的求解,作者给出了计算公式,并在后文中说取值为 的时候效果最好, 实验部分有作者做的ablation实验证明。
最后贴出作者在TD3代码上的改动部分==》TD3+BC算法实现
2.3 经典的Rebuttal场面
此外,我们看一下作者如何rebuttle这些OpenReview提出的审稿意见[1],[2]
其实这部分蛮有意思的,我们发现大多数普通人的工作还是集中在对算法的小部分优化(数学大佬和代码大神略过),这里作者教你手把手给审稿人回复(建议收藏,特别是第2条)
审稿人: (1)首先,该方法的新颖性似乎有点有限。作者似乎直接使 RL+BC 适应离线设置,只是他们添加了状态归一化,这也不是新的。作者也没有从理论上证明这种方法的合理性。例如,作者应该证明该方法可以保证安全的策略改进,并且享有可比或更好的策略改进保证 w.r.t.先前的方法。如果没有理论依据,并且考虑到该方法的当前形式,我认为该方法有点增量。 (2)此外,实证评估并不彻底。作者仅在 D4RL 中的简单 mujoco 环境中评估了该方法。目前尚不清楚该方法是否可以很好地执行更多无向多任务数据集,例如蚂蚁迷宫和厨房,以及更复杂的操作任务,例如 D4RL 中的 adroit。似乎该方法在随机数据集上表现不佳。这是一个主要限制吗?我还认为作者应该将状态归一化添加到所有基线以确保公平比较,因为状态归一化不是 RL 中的新技术。 (3)最后,我认为比较不完整。作者还应该将该方法与最近的无模型离线 RL 方法(如 [1])和基于模型的方法(如 [2,3])进行比较,后者在随机和中等重放数据集上获得了更好的性能。 总的来说,鉴于上述评论,我会投票支持弱拒绝。
下面我们看作者的神奇巧妙回复
作者回复: (1)关于新颖性:我们完全不同意我们的算法在新颖性方面是递增的(我们在相关工作中强调了许多类似的算法)。然而,我们的主要主张/贡献与其说这是最好的离线 RL 算法,或者说它特别新颖,不如说是令人惊讶的观察,即使用非常简单的技术可以匹配/优于当前算法。希望 TD3+BC 可以用作易于实现的基线或其他添加(例如 S4RL)的起点,同时消除更复杂方法所需的许多不必要的复杂性、超参数调整或计算成本. (2)关于经验评估:据我们所知,我们最强的基线 Fisher-BRC 被认为是无模型算法的 SOTA,最近在 ICML 上发表。 (3)由于 D4RL 结果的标准化,我们可以直接与建议的基线进行比较(我们会将这些结果包含在最终草案中)。我们在下面报告这些,但我们想说明两点: (4)MOReL 和 MOPO 来自不同的方法系列(基于模型),并且都使用特定于环境的超参数。 S4RL 与我们的方法相切,只需将 CQL 替换为 TD3+BC,就可以很容易地将其添加到我们的方法中。我们的方法可以说更适合基础算法的这些类型的添加,因为超参数更少,这意味着我们不必担心变化之间的交互作用。 最终,我们没有发现添加状态归一化可以为基线提供相同水平的好处,这可能是因为这些方法需要超参数调整来补偿额外的修改。
挺有意思的,学习收藏吧!
3. 实验及过程分析
3.1 实验超参数
这部分是作者实验的一些基础,挺良心的,具体到了每一个实验环境的版本号
这部分特意说明一下作者的良心部分:代码版本都放出来了
3.2 衡量指标:百分比差异(Percent Difference)
这部分公式是作者实验的参考基准计算方式,其中在博客也提出了关于差距百分比的疑问,特意查了了一下计算过程[3](备注,有的地方可能用了绝对值):
3.3 实验验证与结果简要分析
说明:关于D4RL数据集的组成、安装和解释请参考博文 离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决
本实验参数 HC = HalfCheetah, Hop = Hopper, W = Walker, r = random, m = medium, mr = medium-replay, me = medium-expert, e = expert. While online algorithms (TD3) typically have small episode variances per trained policy (as they should at convergence),
3.3.1 D4RL验证讨论
3.3.2 运行训练时间讨论
可以从实验结果中很直白的看到,CQL、FishBRC与TD3+BC( )的运行时间, 其实这与算法的复杂性紧密相关,对于TD3来说只需要去根据超参数学习网络即可,但对于CQL等算法,需要学习一堆的参数。
3.3.3 消融(ablation)实验(如何确定 ?)
这部分其实对比了vanillaBC方法和区别,同时就参数 做了对比得出了最好的 。
4. 代码实例分析
def train(self, replay_buffer, batch_size=256):
self.total_it += 1
# Sample replay buffer
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
with torch.no_grad():
# Select action according to policy and add clipped noise
noise = (
torch.randn_like(action) * self.policy_noise
).clamp(-self.noise_clip, self.noise_clip)
next_action = (
self.actor_target(next_state) + noise
).clamp(-self.max_action, self.max_action)
# Compute the target Q value
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + not_done * self.discount * target_Q
# Get current Q estimates
current_Q1, current_Q2 = self.critic(state, action)
# Compute critic loss
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy updates
if self.total_it % self.policy_freq == 0:
# Compute actor loss
pi = self.actor(state)
Q = self.critic.Q1(state, pi)
lmbda = self.alpha/Q.abs().mean().detach()
actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action)
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def eval_policy(policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10):
eval_env = gym.make(env_name)
eval_env.seed(seed + seed_offset)
avg_reward = 0.
for _ in range(eval_episodes):
state, done = eval_env.reset(), False
while not done:
state = (np.array(state).reshape(1,-1) - mean)/std
action = policy.select_action(state)
state, reward, done, _ = eval_env.step(action)
avg_reward += reward
avg_reward /= eval_episodes
d4rl_score = eval_env.get_normalized_score(avg_reward) * 100
print("---------------------------------------")
print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}")
print("---------------------------------------")
return d4rl_score参考文献
[1]. Scott Fujimoto, Shixiang Shane Gu: “A Minimalist Approach to Offline Reinforcement Learning”, 2021; arXiv:2106.06860. [2]. A Minimalist Approach to Offline Reinforcement Learning, OpenReview [3]. percent-difference, percent-difference
<hr/>OfflineRL推荐阅读
离线强化学习(Offline RL)系列3: (算法篇) REM(Random Ensemble Mixture)算法详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BRAC算法原理详解与实现(经验篇)
离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BEAR算法原理详解与实现
离线强化学习(Offline RL)系列3: (算法篇)策略约束-BCQ算法详解与实现
离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决
离线强化学习(Offline RL)系列1:离线强化学习原理入门 |
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有账号?立即注册
×
|