Skip to content

jay-zc/Uncertainty_TSRAG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TS-RAG: 检索增强时间序列预测框架

TS-RAG 是一个模块化、配置驱动的时间序列预测框架,基于 Chronos-Bolt 模型并通过检索增强生成 (RAG) 提升预测精度。本项目采用 Hydra 进行分层配置管理。


🧠 模型架构与原理解析 (Architecture Insight)

为什么随机初始化的 RAG 模块没有破坏模型效果?

本项目采用 Residual Connection (残差连接) 机制,确保了 Base Model 的强大能力即使在 RAG 模块未训练时也能被完整保留。

                            Input Context
                                  │
                                  ▼
                   ┌─────────────────────────────┐
                   │  T5 Encoder (Pre-trained)   │
                   └──────────────┬──────────────┘
                                  │
                                  ▼
                   ┌─────────────────────────────┐
                   │  T5 Decoder (Pre-trained)   │
                   └──────────────┬──────────────┘
                                  │
      ┌───────────────────────────┴───────────────────────────┐
      │                                                       │
      ▼                                                       ▼
[ Sequence Output ]                             [ MoE / Gate Layers ] (Random Init)
 (Strong Baseline)                                        ▲
      │                                                   │
      │                                           [ Retrieved Context ]
      │                                                   │
      │                                                   ▼
      │                                            [ Fused Output ]
      │                                          (Small Random Noise)
      │                                                   │
      └───────────────────────────► (+) ◄─────────────────┘
                                   │
                                   ▼
                          [ Output Projection ]
                                   │
                                   ▼
                          [ Quantile Forecast ]

核心机制

  • Base Output Keep: 主干网络的输出 (SeqOut) 直接通过残差连接传递给最终输出。
  • RAG as Perturbation: 随机初始化的 RAG 模块(MoE/Gate)产生的输出(Fusion)由于数值较小,实际上只充当了“微小扰动”的角色,而没有破坏底座模型的预测。
  • Training Goal: 训练过程就是让这个 Fusion 从“随机噪声”变成“有意义的修正信号”。

📂 项目目录结构

TS-RAG/
├── configs/                # [Hydra 配置中心]
│   ├── config.yaml         # 全局共享配置 (Trainer, Logger, Seed)
│   ├── pretrain/           # 预训练配置
│   │   └── default.yaml    # 默认预训练参数 (继承 config.yaml)
│   └── benchmark/          # 评测配置
│       └── default.yaml    # 默认评测参数 (继承 config.yaml)
├── pretrain/               # [预训练工作流]
│   ├── download_chronos.py # 1. 数据下载与索引构建
│   ├── process_chronos.py  # 2. 数据与检索库构建
│   └── run.py              # 3. 预训练主程序
├── benchmark/              # [评测工作流]
│   ├── prepare_benchmarks.py # 4. 评测数据准备
│   └── evaluate.py         # 5. 评测主程序
├── src/                    # [核心代码库]
│   ├── data/               # Dataset, Retriever
│   ├── models/             # ChronosBolt Wrapper
│   └── trainers/           # Lightning Trainer
├── datasets/               # [数据存储]
│   ├── chronos_raw/        # 原始预训练数据
│   ├── pretrain_stage/     # 预处理后的训练数据
│   └── benchmark_stage/    # 评测用数据
└── outputs/                # [实验日志与权重]
    ├── pretrain/           # 预训练日志
    └── benchmark/          # 评测日志

🚀 快速上手指南 (Step-by-Step Workflow)

阶段一:预训练 (Pre-training)

1. 下载原始数据与构建索引

下载 Hugging Face 上的 Chronos 数据集,并生成 dataset_index.json

python pretrain/download_chronos.py

2. 生成检索数据库与训练集

读取索引文件,使用 Chronos 模型生成 Embedding 并在 datasets/pretrain_stage/ 构建检索库。

  • 配置来源: configs/pretrain/default.yaml (process 部分)
python pretrain/process_chronos.py

若需修改处理参数(如文件数量):

python pretrain/process_chronos.py process.num_files=100

3. 启动预训练

加载处理后的数据进行训练。Checkpoint 将保存在 outputs/pretrain/YYYY-MM-DD/HH-MM-SS/checkpoints/

  • 配置来源: configs/pretrain/default.yaml
python pretrain/run.py

自定义实验(例如修改学习率):

# 方式 A: 命令行覆盖
python pretrain/run.py trainer.learning_rate=0.001

# 方式 B: 创建新排配置 configs/pretrain/exp1.yaml (继承 default)
python pretrain/run.py config_name=pretrain/exp1

阶段二:基准评测 (Benchmarking)

4. 准备评测数据

处理标准数据集 (ETT, Weather 等) 为滑动窗口格式,生成到 datasets/benchmark_stage/

  • 配置来源: configs/benchmark/default.yaml
python benchmark/prepare_benchmarks.py

5. 执行评测

加载预训练权重,对所有注册的数据集进行 Zero-shot 评测。

  • 配置来源: configs/benchmark/default.yaml
  • 注意: 修改 checkpoint_path 指向你训练好的模型。
python benchmark/evaluate.py data.checkpoint_path="outputs/pretrain/.../checkpoints/checkpoint_epoch_9.pth"

查看结果:

  • 终端输出摘要表格。
  • 详细 CSV 报告保存在 batch_evaluation_report.csv

⚙️ 配置系统详解 (Configuration inheritance)

本项目使用 Hydra 的组合与继承机制。

  1. Shared Config (configs/config.yaml):

    • 定义了所有阶段共用的 trainer (训练器参数) 和 logger (日志参数)。
  2. Default Configs (configs/*/default.yaml):

    • pretrain/default.yaml: 继承 Shared Config,并定义 modelprocess 参数。
    • benchmark/default.yaml: 继承 Shared Config,并定义 datasets 列表。
  3. 如何新建实验: 如果你想固定一组参数进行实验,不要直接修改 default.yaml。 请在 configs/pretrain/ 下新建 my_exp.yaml:

    defaults:
      - default    # 继承默认
      - _self_
    
    # 覆盖特定参数
    trainer:
      max_epochs: 50
    model:
      dropout: 0.2

    然后运行:

    python pretrain/run.py config_name=pretrain/my_exp

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages