首页
/ Learnware-LAMDA项目:Learnware模型准备与上传全流程指南

Learnware-LAMDA项目:Learnware模型准备与上传全流程指南

2025-06-19 11:09:10作者:卓艾滢Kingsley

概述

Learnware-LAMDA项目中的Learnware系统是一个创新性的机器学习模型共享平台,允许开发者上传和共享训练好的机器学习模型(称为"learnware")。本文将详细介绍如何准备并上传一个符合规范的learnware到Learnware市场。

Learnware基本组成

一个有效的learnware必须包含以下四个核心文件,打包为一个zip压缩包:

  1. learnware.yaml - 模型配置文件
  2. __init__.py - 模型调用接口文件
  3. stat.json - 模型统计规范文件
  4. 环境依赖文件(environment.yamlrequirements.txt

模型调用文件(init.py)详解

__init__.py是learnware的核心文件,定义了模型的使用接口。以下是关键要点:

基本接口要求

必须实现以下接口(至少需要predict接口):

  • fit(X, y) - 模型训练接口
  • predict(X) - 模型预测接口(必需)
  • finetune(X, y) - 模型微调接口

代码模板示例

import os
import pickle
import numpy as np
from learnware.model import BaseModel

class MyModel(BaseModel):
    def __init__(self):
        super(MyModel, self).__init__(input_shape=(37,), output_shape=(1,))
        dir_path = os.path.dirname(os.path.abspath(__file__))
        model_path = os.path.join(dir_path, "model.pkl")
        with open(model_path, "rb") as f:
            self.model = pickle.load(f)

    def fit(self, X: np.ndarray, y: np.ndarray):
        self.model = self.model.fit(X)

    def predict(self, X: np.ndarray) -> np.ndarray:
        return self.model.predict(X)

    def finetune(self, X: np.ndarray, y: np.ndarray):
        pass

输入输出维度说明

  • input_shape: 单个输入样本的维度
  • output_shape: 单个样本的输出维度

特殊场景处理:

  • 文本数据:input_shape可设为None
  • 变长输出任务(如目标检测):output_shape可设为None
  • 分类任务:
    • 直接输出标签:output_shape=(1,),标签应从0开始
    • 输出logits:output_shape=(class_num,)

统计规范文件(stat.json)生成

统计规范文件描述了模型的数据特征,使用训练数据生成:

from learnware.specification import generate_stat_spec

data_type = "table"  # 支持类型: ["table", "image", "text"]
spec = generate_stat_spec(type=data_type, X=train_x)
spec.save("stat.json")

注意事项:

  • 此过程完全在本地运行,不会上传任何数据
  • 如果训练数据过大,建议先采样再生成规范

配置文件(learnware.yaml)详解

配置文件连接各个组件,示例:

model:
  class_name: MyModel
  kwargs: {}
stat_specifications:
  - module_path: learnware.specification
    class_name: RKMETableSpecification
    file_name: stat.json
    kwargs: {}

注意不同数据类型的规范类名:

  • 表格数据: RKMETableSpecification
  • 图像数据: RKMEImageSpecification
  • 文本数据: RKMETextSpecification

环境依赖文件准备

支持两种方式指定运行环境:

1. conda环境(environment.yaml)

生成命令:

conda env export | grep -v "^prefix: " > environment.yaml

验证环境无冲突:

conda env create --name test_env --file environment.yaml

2. pip依赖(requirements.txt)

示例内容:

numpy==1.23.5
scikit-learn==1.2.2

可使用pipreqs自动生成:

pip install pipreqs
pipreqs ./

语义规范准备

语义规范描述任务和模型特征,示例:

from learnware.specification import generate_semantic_spec

input_description = {
    "Dimension": 5,
    "Description": {
        "0": "age",
        "1": "weight",
        "2": "body length",
        "3": "animal type",
        "4": "claw length"
    },
}

output_description = {
    "Dimension": 3,
    "Description": {
        "0": "cat",
        "1": "dog",
        "2": "bird",
    },
}

semantic_spec = generate_semantic_spec(
    name="learnware_example",
    description="示例模型",
    data_type="Table",
    task_type="Classification",
    library_type="Scikit-learn",
    scenarios=["Business", "Financial"],
    license="MIT",
    input_description=input_description,
    output_description=output_description,
)

上传Learnware

准备完成后,使用以下代码上传:

from learnware.market import BaseChecker
from learnware.market import instantiate_learnware_market

demo_market = instantiate_learnware_market(market_id="demo", name="hetero", rebuild=True)
learnware_id, learnware_status = demo_market.add_learnware(zip_path, semantic_spec)

assert learnware_status != BaseChecker.INVALID_LEARNWARE, "上传失败!"

删除Learnware

具有权限的管理员可删除learnware:

demo_market.delete_learnware(learnware_id)

最佳实践建议

  1. 模块导入使用相对路径(如from .package import module
  2. 环境依赖尽量精简,只包含必要包
  3. 对于敏感包(如torch),务必指定版本号
  4. 上传前本地测试所有接口功能
  5. 确保文件编码为UTF-8(特别是environment.yaml)

通过遵循以上指南,您可以顺利地将自己的机器学习模型作为learnware共享到Learnware市场,同时确保其他用户能够正确使用您的模型。

登录后查看全文
热门项目推荐

项目优选

收起