找回密码
 立即注册
查看: 247|回复: 2

遇强则强(七):遗传算法做得到吗?

[复制链接]
发表于 2023-1-20 10:14 | 显示全部楼层 |阅读模式
之前的若干文章中一直在尝试使用强化学习解决Gym环境中的若干控制问题。对于控制问题,还可以采用动态规划来求解。而在深度学习之前,所谓的智能算法大多数都指向了“进化算法”及其各种变种。
强化学习本质上也是优化算法,可用来求解控制问题,目的与“进化算法”这类元启发式优化算法是一致的,即找到最优解。那么今天来试试用遗传算法来控制Cart Pole吧。
遗传算法的步骤就不再介绍了。相信大家经过新冠,对优胜劣汰、适者生存这套丛林法则有了切身体验。遗传算法就是以下这些基本元素(当然,时髦的算法会再加上各种“动物”属性,灰狼、白鲸、黑熊?)。
个体

依然采用与前篇中相同的神经网络结构,含有一个隐藏层的MLP:
net = nn.Sequential(nn.Linear(in_dim, 64),
     nn.Tanh(),
     nn.Linear(64, 64),
     nn.Tanh(),
     nn.Linear(64, out_dim))
那么总参数量为
(4\times64+64)+(64\times 64+64)+(64\times 2+2)=4610 \\
接下来就是用遗传算法优化这4610个参数。每个个体是4610个浮点数组成的向量。从向量转换到网络:
def set_params(net, params):
    i = 0
    for layerid, layer in enumerate(net):
        if hasattr(layer, 'weight'):
            net[layerid].weight = params
            i += 1
        if hasattr(layer, 'bias'):
            net[layerid].bias = params
            i += 1
    return net
再从网络转换为向量:
def get_params(net):
    params = []
    for layer in net:
        if hasattr(layer, 'weight'):
            params.append(layer.weight)
        if hasattr(layer, 'bias'):
            params.append(layer.bias)
    return params
种群

种群即为个体的集合,其规模固定,随机权重初始化为:
def init_pop(net):
    base = get_params(net)
    shapes = [param.shape for param in base]
    print(shapes)
    pop = []
    for _ in range(POP_SIZE):
        entity = []
        for shape in shapes:
            try:
                rand_tensor = nn.init.kaiming_uniform_(torch.empty(shape)).to(device)
            except ValueError:
                rand_tensor = nn.init.uniform_(torch.empty(shape), -0.5, 0.5).to(device)            entity.append((torch.nn.parameter.Parameter(rand_tensor)))
        pop.append(entity)
    return pop
评估

评估是在Gym环境中进行的,与PPO的区别在于Agent直接选取输出值较大所代表的行动,而不是选择一个概率。遗传算法的探索行为在其随机的交叉变异上,或许在行动决策时就不随机了。另外这个评估值是带有随机性的,但我们仅采样一次作为其评估值
def fitness(solution, net):
    global global_step
    net = set_params(net, solution)
    ob = env.reset()
    done = False
    while not done:
        ob = torch.tensor(ob).float().unsqueeze(0).to(device)
        q_vals = net(ob)
        act = torch.argmax(q_vals.cpu()).item()
        ob_next, _, done, info = env.step(act)
        global_step+=1
        ob = ob_next
        if 'episode' in info.keys():
            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
            return info['episode']['r']
选择

“选择”就是从种群中根据其适应性函数的评估值,选择出优势个体并复制,比如这么选:
def select(pop, fitnesses):
    idx = np.random.choice(np.arange(POP_SIZE), size=POP_SIZE, replace=True, p=fitnesses/fitnesses.sum())
    return [pop for i in idx]
交叉

接下来就是让这些优势个体交换部分参数,以超参数CROSS_RATE的概率,在种群中随机找一个跟parent1交换参数产生新个体:
def crossover(parent1, pop):
    if np.random.rand() < CROSS_RATE:
        i = np.random.randint(0, POP_SIZE, size=1)[0]
        parent2 = pop
        child = []
        for p1l, p2l in zip(parent1, parent2):
            split = np.random.randint(0, len(p1l), size=1)[0]
            new_param = nn.parameter.Parameter(torch.cat([p1l[:split], p2l[split:]]))
            child.append(new_param)
        return child
    else:
        return parent1
变异

让个体的参数,也就是神经网络的参数随机漂移:
def mutate(child):
    for i in range(len(child)):
        for j in range(len(child)):
            child[j] += torch.randn(child[j].shape).to(device)*MUTATION_FACTOR
    return child
数值实验

继续使用Cart Pole环境,并跟PPO进行对比,其episode的得分跟全局使用步数如下图所示,PPO大约在10万步前就能稳定在最大值500了。


注意遗传算法中,进行评估的一整个种群中总有一些不太正常的个体,因此遗传算法的return非常震荡,所以我们选取每代最好的个体出来表演。
第五代的最优个体表现如图:


第十代最优个体的表现:


第二十代也就是最后一代最好的个体的表现:


勉强能顶得住一会儿,不太稳的样子。
总结

遗传算法作为元启发优化算法,也能实现倒立摆的控制,不过需要的step数量相对PPO高了不少,效果貌似也一般。或许进化优化的大佬来调优能改善,但也只能怪进化优化领域开源气氛不太好吧。 代码在github上
可能有用的资料
Salimans, T.; Ho, J.; Chen, X.; Sidor, S.; Sutskever, I. Evolution Strategies as a Scalable Alternative to Reinforcement Learning. arXiv September 7, 2017. http://arxiv.org/abs/1703.03864
<hr/>本文投稿自专栏投稿-运筹的黎明

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
发表于 2023-1-20 10:15 | 显示全部楼层
不是吧,我遗传算法很快就找到最优解了,网络不用那么大
发表于 2023-1-20 10:20 | 显示全部楼层
在github上,您改改参数pr就行,或者也可以直接分享下代码
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2024-11-16 04:36 , Processed in 0.138201 second(s), 26 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表