首页
/ torch-molecule 分子机器学习库使用指南

torch-molecule 分子机器学习库使用指南

2025-06-11 23:12:01作者:卓艾滢Kingsley

概述

torch-molecule 是一个基于 PyTorch 的分子机器学习库,专注于分子性质预测和分子生成任务。它为研究人员和开发者提供了一套完整的工具链,可以方便地进行分子相关的机器学习实验和应用开发。

分子性质预测

分子性质预测是计算化学和药物发现中的重要任务。torch-molecule 提供了 GREAMolecularPredictor 等预测器,支持多种图神经网络架构和自动超参数优化。

基本使用流程

  1. 定义搜索参数空间:首先需要定义模型架构和训练相关的参数搜索空间
search_GNN = {
    "gnn_type": ParameterSpec(ParameterType.CATEGORICAL, ["gin-virtual", "gcn-virtual", "gin", "gcn"]),
    "norm_layer": ParameterSpec(ParameterType.CATEGORICAL, ["batch_norm", "layer_norm"]),
    "graph_pooling": ParameterSpec(ParameterType.CATEGORICAL, ["mean", "sum", "max"]),
    "augmented_feature": ParameterSpec(ParameterType.CATEGORICAL, ["maccs,morgan", "maccs", "morgan", None]),
    "num_layer": ParameterSpec(ParameterType.INTEGER, (2, 5)),
    "hidden_size": ParameterSpec(ParameterType.INTEGER, (64, 512)),
    "drop_ratio": ParameterSpec(ParameterType.FLOAT, (0.0, 0.5)),
    "learning_rate": ParameterSpec(ParameterType.LOG_FLOAT, (1e-5, 1e-2)),
    "weight_decay": ParameterSpec(ParameterType.LOG_FLOAT, (1e-10, 1e-3)),
}
  1. 初始化预测器:根据任务类型选择合适的预测器
grea_model = GREAMolecularPredictor(
    num_task=num_task,
    task_type="regression",
    model_name="GREA_multitask",
    batch_size=BATCH_SIZE,
    epochs=N_epoch,
    evaluate_criterion='r2',
    evaluate_higher_better=True,
    verbose=True
)
  1. 自动拟合模型:使用 autofit 方法进行自动超参数搜索和模型训练
grea_model.autofit(
    X_train=X_train.tolist(),
    y_train=y_train,
    X_val=X_val.tolist(),
    y_val=y_val,
    n_experiments=N_trial,
    search_parameters=search_GREA
)

技术要点

  • 支持多种 GNN 架构:GIN、GCN 等
  • 提供多种图池化方法:mean、sum、max
  • 可添加分子指纹作为增强特征:MACCS、Morgan 指纹等
  • 内置自动超参数优化功能

分子生成

分子生成是药物发现中的关键环节,torch-molecule 提供了基于扩散模型的分子生成器 GraphDITMolecularGenerator。

基本使用流程

  1. 初始化生成器:指定任务类型和训练参数
model_cond = GraphDITMolecularGenerator(
    task_type=['regression'] * len(property_names),
    batch_size=1024,
    drop_condition=0.1,
    verbose=True,
    epochs=10000,
)
  1. 训练模型:使用已知分子和性质数据进行训练
model_cond.fit(train_smiles_list, train_property_array)
  1. 生成分子:根据目标性质生成新分子
generated_smiles_list = model_cond.generate(test_property_array)
  1. 有效性检查:验证生成的分子结构是否有效
def is_valid_smiles(smiles):
    if smiles is None:
        return False
    mol = Chem.MolFromSmiles(smiles)
    return mol is not None

技术要点

  • 基于扩散模型的分子生成方法
  • 支持条件生成(根据目标性质生成分子)
  • 内置重试机制处理无效分子
  • 与 RDKit 兼容,便于后续分析

预训练模型使用

torch-molecule 支持模型的保存和加载,便于模型共享和部署。

保存和加载模型

  1. 保存模型到本地或模型库
model.push_to_huggingface(
    repo_id=repo_id,
    task_id=f"{task_name}",
    metrics=metrics,
    commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
    private=False
)
  1. 加载预训练模型
model = GREAMolecularPredictor()
model.load_model(f"{model_dir}/GREA_{task_name}.pt", repo_id=repo_id)
model.set_params(verbose=True)
  1. 使用模型进行预测
predictions = model.predict(smiles_list)

技术要点

  • 支持模型版本管理和共享
  • 保存完整的模型配置和训练指标
  • 便于模型复现和部署

最佳实践

  1. 数据准备:确保输入数据格式正确,SMILES 字符串需要转换为列表形式
  2. 参数调优:合理设置搜索空间,避免过大导致搜索效率低下
  3. 验证策略:使用独立的验证集评估模型性能
  4. 错误处理:对于分子生成任务,实现适当的重试机制
  5. 性能监控:关注训练过程中的关键指标变化

torch-molecule 为分子机器学习提供了全面的解决方案,无论是性质预测还是分子生成任务,都能通过简洁的 API 实现高效开发。开发者可以根据具体需求选择合适的组件,快速构建分子相关的机器学习应用。

登录后查看全文
热门项目推荐

项目优选

收起
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
176
260
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
858
507
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
129
182
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
255
299
ShopXO开源商城ShopXO开源商城
🔥🔥🔥ShopXO企业级免费开源商城系统,可视化DIY拖拽装修、包含PC、H5、多端小程序(微信+支付宝+百度+头条&抖音+QQ+快手)、APP、多仓库、多商户、多门店、IM客服、进销存,遵循MIT开源协议发布、基于ThinkPHP8框架研发
JavaScript
93
15
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
331
1.08 K
HarmonyOS-ExamplesHarmonyOS-Examples
本仓将收集和展示仓颉鸿蒙应用示例代码,欢迎大家投稿,在仓颉鸿蒙社区展现你的妙趣设计!
Cangjie
397
370
note-gennote-gen
一款跨平台的 Markdown AI 笔记软件,致力于使用 AI 建立记录和写作的桥梁。
TSX
83
4
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.07 K
0
kernelkernel
deepin linux kernel
C
21
5