首页
/ 【亲测免费】 Tianshou 强化学习平台教程

【亲测免费】 Tianshou 强化学习平台教程

2026-01-19 10:51:09作者:冯爽妲Honey

项目介绍

Tianshou(天授)是一个基于纯 PyTorch 的强化学习平台。与现有的主要基于 TensorFlow 的强化学习库不同,Tianshou 提供了快速的框架和友好的 Python API,用于构建深度强化学习代理。Tianshou 支持多种强化学习算法,并且具有多 GPU 训练的能力。

项目快速启动

安装

你可以通过以下命令从 PyPI 安装 Tianshou:

pip install tianshou

如果你使用 Anaconda 或 Miniconda,可以通过以下命令从 conda-forge 安装 Tianshou:

conda install tianshou -c conda-forge

快速启动示例

以下是一个简单的 DQN 示例代码:

import gym
import tianshou as ts
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

# 创建环境
env = gym.make('CartPole-v0')
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])

# 定义策略
net = ts.net.DQN((env.observation_space.shape[0],), env.action_space.n)
optim = ts.optim.Adam(net.parameters(), lr=1e-3)
policy = DQNPolicy(net, optim, discount_factor=0.99)

# 数据收集器和回放缓冲区
buffer = ReplayBuffer(20000)
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)

# 训练
result = offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=10, step_per_epoch=10000, collect_per_step=10,
    batch_size=64, test_in_train=False
)

print(result)

应用案例和最佳实践

案例一:多智能体强化学习

Tianshou 支持多智能体强化学习(MARL),可以用于解决多个智能体在同一环境中的协同或竞争问题。以下是一个简单的多智能体示例:

import tianshou as ts
from tianshou.policy import MultiAgentPolicyManager, DQNPolicy
from tianshou.data import Collector

# 创建环境
env = ts.env.MultiAgentEnv('simple_spread')

# 定义策略
policies = [DQNPolicy for _ in range(env.n_agents)]
manager = MultiAgentPolicyManager(policies, env)

# 数据收集器
collector = Collector(manager, env)

# 训练
result = ts.trainer.onpolicy_trainer(
    manager, collector,
    max_epoch=10, step_per_epoch=10000, collect_per_step=10,
    batch_size=64, test_in_train=False
)

print(result)

案例二:自定义强化学习算法

Tianshou 允许用户自定义强化学习算法。以下是一个自定义算法的示例:

import tianshou as ts
from tianshou.policy import BasePolicy

class CustomPolicy(BasePolicy):
    def __init__(self, net, optim):
        super().__init__()
        self.net = net
        self.optim = optim

    def forward(self, batch, state=None):
        # 自定义前向传播逻辑
        pass

    def learn(self, batch):
        # 自定义学习逻辑
        pass

# 创建环境
env = gym.make('CartPole-v0')

# 定义网络和优化器
net = ts.net.MLP((env.observation_space.shape[0],), env.action_space.n)
optim = ts.optim.Adam(net.parameters(), lr=1e-3)
登录后查看全文
热门项目推荐
相关项目推荐