crypto.ai/cryptoai/models/data_processor.py
2025-06-15 09:34:22 +08:00

288 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
import numpy as np
from typing import Dict, Any, List, Optional, Tuple
import datetime
import json
import pickle
from sklearn.preprocessing import MinMaxScaler
class DataProcessor:
"""数据处理类,用于处理和存储加密货币数据"""
def __init__(self, storage_path: str = "./data"):
"""
初始化数据处理器
Args:
storage_path: 数据存储路径
"""
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
self.scalers = {} # 用于存储每个特征的缩放器
def preprocess_market_data(self, symbol: str, data: pd.DataFrame) -> pd.DataFrame:
"""
预处理市场数据
Args:
symbol: 交易对符号,例如 'BTCUSDT'
data: 原始市场数据
Returns:
预处理后的数据
"""
# 确保数据类型正确
data = data.copy()
# 将时间戳设置为索引
if 'timestamp' in data.columns:
data.set_index('timestamp', inplace=True)
# 移除不必要的列
columns_to_keep = ['open', 'high', 'low', 'close', 'volume']
data = data[columns_to_keep]
# 检查并处理缺失值
data.bfill(inplace=True) # 前向填充
data.bfill(inplace=True) # 后向填充
# 添加技术指标
data = self.add_technical_indicators(data)
return data
def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
"""
添加技术指标
Args:
data: 原始价格数据
Returns:
添加了技术指标的数据
"""
df = data.copy()
# 计算移动平均线
df['MA5'] = df['close'].rolling(window=5).mean()
df['MA20'] = df['close'].rolling(window=20).mean()
df['MA50'] = df['close'].rolling(window=50).mean()
df['MA100'] = df['close'].rolling(window=100).mean()
# 计算相对强弱指标(RSI)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=14).mean()
avg_loss = loss.rolling(window=14).mean()
rs = avg_gain / avg_loss
df['RSI'] = 100 - (100 / (1 + rs))
# 计算MACD
exp1 = df['close'].ewm(span=12, adjust=False).mean()
exp2 = df['close'].ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
df['MACD'] = macd
df['MACD_Signal'] = signal
df['MACD_Hist'] = macd - signal
# 计算布林带
df['MA20_std'] = df['close'].rolling(window=20).std()
df['Bollinger_Upper'] = df['MA20'] + (df['MA20_std'] * 2)
df['Bollinger_Lower'] = df['MA20'] - (df['MA20_std'] * 2)
# 计算ATR (Average True Range)
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift())
low_close = abs(df['low'] - df['close'].shift())
ranges = pd.concat([high_low, high_close, low_close], axis=1)
true_range = ranges.max(axis=1)
df['ATR'] = true_range.rolling(14).mean()
# 填充计算指标产生的NaN值
df.bfill(inplace=True)
return df
def normalize_data(self, data: pd.DataFrame, symbol: str) -> Tuple[pd.DataFrame, Dict[str, Any]]:
"""
归一化数据
Args:
data: 预处理后的数据
symbol: 交易对符号
Returns:
归一化后的数据和缩放器
"""
df = data.copy()
scalers = {}
# 对每列进行归一化
for column in df.columns:
scaler = MinMaxScaler(feature_range=(0, 1))
df[column] = scaler.fit_transform(df[[column]])
scalers[column] = scaler
# 保存缩放器以便后续使用
self.scalers[symbol] = scalers
return df, scalers
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 60) -> Tuple[np.ndarray, np.ndarray]:
"""
准备模型输入数据(时间序列窗口)
Args:
data: 归一化后的数据
window_size: 时间窗口大小
Returns:
X和y数据
"""
X, y = [], []
for i in range(len(data) - window_size):
X.append(data.iloc[i:(i + window_size)].values)
y.append(data.iloc[i + window_size]['close'])
return np.array(X), np.array(y)
def save_data(self, symbol: str, data: pd.DataFrame, data_type: str = 'raw') -> str:
"""
保存数据到文件
Args:
symbol: 交易对符号
data: 要保存的数据
data_type: 数据类型,例如 'raw', 'processed', 'normalized'
Returns:
保存文件的路径
"""
# 创建目录(如果不存在)
symbol_dir = os.path.join(self.storage_path, symbol)
os.makedirs(symbol_dir, exist_ok=True)
# 生成文件名
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{symbol}_{data_type}_{timestamp}.csv"
file_path = os.path.join(symbol_dir, filename)
# 保存数据
data.to_csv(file_path)
return file_path
def save_scalers(self, symbol: str) -> str:
"""
保存缩放器到文件
Args:
symbol: 交易对符号
Returns:
保存文件的路径
"""
if symbol not in self.scalers:
raise ValueError(f"没有找到{symbol}的缩放器")
# 创建目录(如果不存在)
symbol_dir = os.path.join(self.storage_path, symbol)
os.makedirs(symbol_dir, exist_ok=True)
# 生成文件名
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{symbol}_scalers_{timestamp}.pkl"
file_path = os.path.join(symbol_dir, filename)
# 保存缩放器
with open(file_path, 'wb') as f:
pickle.dump(self.scalers[symbol], f)
return file_path
def load_data(self, file_path: str) -> pd.DataFrame:
"""
从文件加载数据
Args:
file_path: 文件路径
Returns:
加载的数据
"""
return pd.read_csv(file_path, index_col=0, parse_dates=True)
def load_scalers(self, file_path: str) -> Dict[str, Any]:
"""
从文件加载缩放器
Args:
file_path: 文件路径
Returns:
加载的缩放器
"""
with open(file_path, 'rb') as f:
return pickle.load(f)
def get_latest_data_file(self, symbol: str, data_type: str = 'processed') -> Optional[str]:
"""
获取最新的数据文件路径
Args:
symbol: 交易对符号
data_type: 数据类型
Returns:
最新的数据文件路径如果不存在则返回None
"""
symbol_dir = os.path.join(self.storage_path, symbol)
if not os.path.exists(symbol_dir):
return None
# 查找匹配的文件
files = [f for f in os.listdir(symbol_dir) if f.startswith(f"{symbol}_{data_type}") and f.endswith('.csv')]
if not files:
return None
# 按文件名(包含时间戳)排序
files.sort(reverse=True)
return os.path.join(symbol_dir, files[0])
def get_latest_scaler_file(self, symbol: str) -> Optional[str]:
"""
获取最新的缩放器文件路径
Args:
symbol: 交易对符号
Returns:
最新的缩放器文件路径如果不存在则返回None
"""
symbol_dir = os.path.join(self.storage_path, symbol)
if not os.path.exists(symbol_dir):
return None
# 查找匹配的文件
files = [f for f in os.listdir(symbol_dir) if f.startswith(f"{symbol}_scalers") and f.endswith('.pkl')]
if not files:
return None
# 按文件名(包含时间戳)排序
files.sort(reverse=True)
return os.path.join(symbol_dir, files[0])