首页
/ 告别股价预测难题:用LSTM构建多框架股票预测系统(附PyTorch/Keras/TensorFlow全实现)

告别股价预测难题:用LSTM构建多框架股票预测系统(附PyTorch/Keras/TensorFlow全实现)

2026-01-16 10:39:38作者:伍希望

你是否还在为股价预测模型搭建繁琐而苦恼?尝试过多种框架却难以统一代码风格?本文将带你从零开始,使用stock_predict_with_LSTM开源项目构建专业级股票预测系统,掌握三大主流深度学习框架的实现细节,解决数据预处理、模型调优、增量训练等核心痛点。

读完本文你将获得

  • 掌握LSTM在时间序列预测中的核心配置参数(含20+关键参数详解)
  • 学会在PyTorch/Keras/TensorFlow中实现相同预测逻辑的对比方法
  • 获得处理股票数据的标准化流程(含特征工程与归一化代码)
  • 理解增量训练与连续预测的技术细节(附实战代码)
  • 掌握模型评估与可视化的完整方案(含MSE计算与趋势图绘制)

项目架构概览

stock_predict_with_LSTM采用模块化设计,将数据处理、模型构建、训练预测解耦为独立组件。项目结构如下:

stock_predict_with_LSTM/
├── data/                # 数据目录
│   └── stock_data.csv   # 股票历史数据
├── model/               # 多框架模型实现
│   ├── model_keras.py   # Keras实现
│   ├── model_pytorch.py # PyTorch实现
│   └── model_tensorflow.py # TensorFlow实现
├── figure/              # 预测结果可视化
├── main.py              # 主程序入口
└── requirements.txt     # 依赖清单

核心特性对比表:

特性 PyTorch实现 Keras实现 TensorFlow实现
支持增量训练 ✅ 需手动管理hidden状态 ❌ 原生不支持 ❌ 需自定义图结构
连续预测能力 ✅ 支持状态传递 ❌ 需修改RNNCell ✅ 需复用计算图
GPU加速配置 use_cuda=True gpu_train_init() sess_config配置
模型保存格式 .pth .h5 .ckpt
可视化工具 Visdom TensorBoard TensorBoardX

环境准备与数据预处理

快速开始

# 克隆项目
git clone https://gitcode.com/gh_mirrors/st/stock_predict_with_LSTM
cd stock_predict_with_LSTM

# 安装依赖
pip install -r requirements.txt

数据格式要求

项目使用data/stock_data.csv作为输入,需包含以下字段(示例):

日期 开盘价 最高价 最低价 收盘价 成交量 ...
2020-01-01 123.45 125.67 122.10 124.32 15600 ...

核心配置参数

main.pyConfig类中定义了所有关键参数,以下是影响预测效果的核心配置:

class Config:
    # 数据参数
    feature_columns = list(range(2, 9))  # 使用的特征列索引
    label_columns = [4, 5]               # 预测目标列(最高价、最低价)
    predict_day = 1                      # 预测未来天数
    
    # LSTM网络参数
    input_size = len(feature_columns)    # 输入特征数
    hidden_size = 128                    # 隐藏层神经元数
    lstm_layers = 2                      # LSTM堆叠层数
    dropout_rate = 0.2                   # Dropout比例
    time_step = 20                       # 时间窗口大小(关键参数)
    
    # 训练参数
    batch_size = 64
    learning_rate = 0.001
    epoch = 20
    patience = 5                         # 早停机制阈值

⚠️ 关键参数说明time_step决定使用前多少天数据预测未来价格,建议设置为20-60天(需保证训练数据量>该值)

多框架LSTM模型实现对比

数据处理核心逻辑

Data类实现了从CSV读取到生成训练样本的完整流程,核心方法get_train_and_valid_data()的工作流程:

flowchart TD
    A[读取CSV数据] --> B[归一化处理]
    B --> C[划分训练/测试集]
    C --> D[生成时间序列样本]
    D --> E[划分训练/验证集]
    E --> F[返回train_X,train_Y,valid_X,valid_Y]

归一化代码实现:

self.mean = np.mean(self.data, axis=0)              # 计算均值
self.std = np.std(self.data, axis=0)                # 计算标准差
self.norm_data = (self.data - self.mean)/self.std   # Z-Score归一化

PyTorch实现详解

PyTorch版本采用类封装方式,核心是Net类的定义:

class Net(Module):
    def __init__(self, config):
        super(Net, self).__init__()
        self.lstm = LSTM(
            input_size=config.input_size,
            hidden_size=config.hidden_size,
            num_layers=config.lstm_layers,
            batch_first=True,
            dropout=config.dropout_rate
        )
        self.linear = Linear(config.hidden_size, config.output_size)

    def forward(self, x, hidden=None):
        lstm_out, hidden = self.lstm(x, hidden)
        linear_out = self.linear(lstm_out)
        return linear_out, hidden

连续训练关键代码

# 保留hidden状态但清除梯度信息
h_0, c_0 = hidden_train
h_0.detach_(), c_0.detach_()    # 关键操作:分离计算图
hidden_train = (h_0, c_0)

Keras实现对比

Keras版本采用函数式API,代码更简洁:

def get_keras_model(config):
    input1 = Input(shape=(config.time_step, config.input_size))
    lstm = input1
    # 堆叠LSTM层
    for i in range(config.lstm_layers):
        lstm = LSTM(units=config.hidden_size,
                   dropout=config.dropout_rate,
                   return_sequences=True)(lstm)
    output = Dense(config.output_size)(lstm)
    model = Model(input1, output)
    model.compile(loss='mse', optimizer='adam')
    return model

TensorFlow实现要点

TensorFlow版本需手动定义计算图:

def net(self):
    def dropout_cell():
        basicLstm = tf.nn.rnn_cell.LSTMCell(self.config.hidden_size)
        return tf.nn.rnn_cell.DropoutWrapper(
            basicLstm, output_keep_prob=1-self.config.dropout_rate)
    
    cell = tf.nn.rnn_cell.MultiRNNCell(
        [dropout_cell() for _ in range(self.config.lstm_layers)])
    output_rnn, _ = tf.nn.dynamic_rnn(cell=cell, inputs=self.X, dtype=tf.float32)
    self.pred = tf.layers.dense(inputs=output_rnn, units=self.config.output_size)

三种框架的核心差异对比表:

实现细节 PyTorch Keras TensorFlow
网络定义 类继承Module 函数式API 手动构建计算图
训练循环 手动实现 fit()接口 Session.run()
状态管理 显式传递hidden 自动管理 需手动保存
损失计算 内置MSELoss compile指定 手动定义loss

模型训练与预测完整流程

训练流程控制

主程序通过frame参数选择框架:

frame = "pytorch"  # 可选:"keras", "pytorch", "tensorflow"
if frame == "pytorch":
    from model.model_pytorch import train, predict
elif frame == "keras":
    from model.model_keras import train, predict
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
elif frame == "tensorflow":
    from model.model_tensorflow import train, predict

训练主逻辑:

if config.do_train:
    train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
    train(config, logger, [train_X, train_Y, valid_X, valid_Y])

if config.do_predict:
    test_X, test_Y = data_gainer.get_test_data(return_label_data=True)
    pred_result = predict(config, test_X)
    draw(config, data_gainer, logger, pred_result)

增量训练实现

设置add_train=True启用增量训练:

# 训练参数配置
add_train = True           # 启用增量训练
shuffle_train_data = False # 增量训练时禁止shuffle

PyTorch增量训练代码:

if config.add_train:
    model.load_state_dict(torch.load(config.model_save_path + config.model_name))

预测结果可视化

draw()函数实现预测结果可视化与评估:

def draw(config: Config, origin_data: Data, logger, predict_norm_data: np.ndarray):
    # 还原归一化数据
    predict_data = predict_norm_data * origin_data.std[config.label_in_feature_index] + \
                   origin_data.mean[config.label_in_feature_index]
    # 计算MSE
    loss = np.mean((label_data[config.predict_day:] - predict_data[:-config.predict_day]) ** 2, axis=0)
    # 绘制预测对比图
    plt.plot(label_X, label_data[:, i], label='label')
    plt.plot(predict_X, predict_data[:, i], label='predict')

典型预测结果展示(PyTorch框架):

The mean squared error of stock ['最高价', '最低价'] is [0.0234 0.0198]
The predicted stock 最高价 for the next 1 day(s) is: [156.78]
The predicted stock 最低价 for the next 1 day(s) is: [152.34]

高级特性与参数调优

连续预测实现

设置do_continue_train=True启用连续预测:

do_continue_train = True    # 启用连续训练
continue_flag = "continue_"
if do_continue_train:
    shuffle_train_data = False
    batch_size = 1           # 连续训练需batch_size=1

连续预测样本生成逻辑:

train_x = [feature_data[start_index + i*self.config.time_step : start_index + (i+1)*self.config.time_step]
           for start_index in range(self.config.time_step)
           for i in range((self.train_num - start_index) // self.config.time_step)]

参数调优指南

影响模型性能的关键参数调优建议:

参数 推荐范围 调优策略
hidden_size 64-256 从128开始,验证集损失不再下降时增大
lstm_layers 1-3层 层数增加需对应增加hidden_size
time_step 20-60 观察数据周期性,设置为周期的2-3倍
dropout_rate 0.1-0.3 过拟合时增大,欠拟合时减小
batch_size 32-128 GPU内存允许时增大,加速训练

常见问题解决方案

  1. 训练不稳定

    • 设置固定随机种子:np.random.seed(config.random_seed)
    • 减小学习率:learning_rate=0.0005
    • 增加batch_size
  2. 预测偏差大

    • 检查time_step是否合理
    • 增加训练数据量
    • 调整特征列选择
  3. GPU内存不足

    • 减小batch_size
    • 降低hidden_size
    • 使用梯度累积

项目扩展与实际应用

多指标同时预测

修改label_columns支持多指标预测:

label_columns = [3,4,5]  # 同时预测收盘价、最高价、最低价

长期趋势预测

调整predict_day参数预测多天价格:

predict_day = 5  # 预测未来5天价格

实时数据预测扩展

可通过添加以下代码实现实时数据预测:

def fetch_realtime_data(symbol):
    # 实现从API获取实时数据的逻辑
    # 返回格式需与stock_data.csv一致
    pass

# 预测前获取最新数据
realtime_data = fetch_realtime_data("AAPL")
update_training_data(realtime_data)  # 更新训练数据

总结与展望

本文详细介绍了stock_predict_with_LSTM项目的核心实现与使用方法,通过对比三种主流框架的实现细节,展示了LSTM在股票预测中的应用。项目的模块化设计使得添加新特征、尝试新模型结构变得简单。

未来改进方向

  1. 添加注意力机制(Attention)提升长序列预测能力
  2. 实现多模型集成预测,降低单一模型风险
  3. 增加技术指标自动生成模块(MACD, RSI等)
  4. 开发Web界面实现可视化操作

希望本教程能帮助你快速掌握股票预测系统的构建方法。若有任何问题,欢迎提交Issue或PR参与项目改进。

请点赞+收藏本文,以便后续查阅多框架实现对比细节。下期将带来《LSTM超参数调优实战》,敬请关注!

登录后查看全文
热门项目推荐
相关项目推荐