288 lines
8.7 KiB
Python
288 lines
8.7 KiB
Python
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import datetime
|
||
import json
|
||
import pickle
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
|
||
|
||
class DataProcessor:
|
||
"""数据处理类,用于处理和存储加密货币数据"""
|
||
|
||
def __init__(self, storage_path: str = "./data"):
|
||
"""
|
||
初始化数据处理器
|
||
|
||
Args:
|
||
storage_path: 数据存储路径
|
||
"""
|
||
self.storage_path = storage_path
|
||
os.makedirs(storage_path, exist_ok=True)
|
||
self.scalers = {} # 用于存储每个特征的缩放器
|
||
|
||
def preprocess_market_data(self, symbol: str, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
预处理市场数据
|
||
|
||
Args:
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
data: 原始市场数据
|
||
|
||
Returns:
|
||
预处理后的数据
|
||
"""
|
||
# 确保数据类型正确
|
||
data = data.copy()
|
||
|
||
# 将时间戳设置为索引
|
||
if 'timestamp' in data.columns:
|
||
data.set_index('timestamp', inplace=True)
|
||
|
||
# 移除不必要的列
|
||
columns_to_keep = ['open', 'high', 'low', 'close', 'volume']
|
||
data = data[columns_to_keep]
|
||
|
||
# 检查并处理缺失值
|
||
data.bfill(inplace=True) # 前向填充
|
||
data.bfill(inplace=True) # 后向填充
|
||
|
||
# 添加技术指标
|
||
data = self.add_technical_indicators(data)
|
||
|
||
return data
|
||
|
||
def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
添加技术指标
|
||
|
||
Args:
|
||
data: 原始价格数据
|
||
|
||
Returns:
|
||
添加了技术指标的数据
|
||
"""
|
||
df = data.copy()
|
||
|
||
# 计算移动平均线
|
||
df['MA5'] = df['close'].rolling(window=5).mean()
|
||
df['MA20'] = df['close'].rolling(window=20).mean()
|
||
df['MA50'] = df['close'].rolling(window=50).mean()
|
||
df['MA100'] = df['close'].rolling(window=100).mean()
|
||
|
||
# 计算相对强弱指标(RSI)
|
||
delta = df['close'].diff()
|
||
gain = delta.where(delta > 0, 0)
|
||
loss = -delta.where(delta < 0, 0)
|
||
|
||
avg_gain = gain.rolling(window=14).mean()
|
||
avg_loss = loss.rolling(window=14).mean()
|
||
|
||
rs = avg_gain / avg_loss
|
||
df['RSI'] = 100 - (100 / (1 + rs))
|
||
|
||
# 计算MACD
|
||
exp1 = df['close'].ewm(span=12, adjust=False).mean()
|
||
exp2 = df['close'].ewm(span=26, adjust=False).mean()
|
||
macd = exp1 - exp2
|
||
signal = macd.ewm(span=9, adjust=False).mean()
|
||
|
||
df['MACD'] = macd
|
||
df['MACD_Signal'] = signal
|
||
df['MACD_Hist'] = macd - signal
|
||
|
||
# 计算布林带
|
||
df['MA20_std'] = df['close'].rolling(window=20).std()
|
||
df['Bollinger_Upper'] = df['MA20'] + (df['MA20_std'] * 2)
|
||
df['Bollinger_Lower'] = df['MA20'] - (df['MA20_std'] * 2)
|
||
|
||
# 计算ATR (Average True Range)
|
||
high_low = df['high'] - df['low']
|
||
high_close = abs(df['high'] - df['close'].shift())
|
||
low_close = abs(df['low'] - df['close'].shift())
|
||
|
||
ranges = pd.concat([high_low, high_close, low_close], axis=1)
|
||
true_range = ranges.max(axis=1)
|
||
df['ATR'] = true_range.rolling(14).mean()
|
||
|
||
# 填充计算指标产生的NaN值
|
||
df.bfill(inplace=True)
|
||
|
||
return df
|
||
|
||
def normalize_data(self, data: pd.DataFrame, symbol: str) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
||
"""
|
||
归一化数据
|
||
|
||
Args:
|
||
data: 预处理后的数据
|
||
symbol: 交易对符号
|
||
|
||
Returns:
|
||
归一化后的数据和缩放器
|
||
"""
|
||
df = data.copy()
|
||
scalers = {}
|
||
|
||
# 对每列进行归一化
|
||
for column in df.columns:
|
||
scaler = MinMaxScaler(feature_range=(0, 1))
|
||
df[column] = scaler.fit_transform(df[[column]])
|
||
scalers[column] = scaler
|
||
|
||
# 保存缩放器以便后续使用
|
||
self.scalers[symbol] = scalers
|
||
|
||
return df, scalers
|
||
|
||
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 60) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
准备模型输入数据(时间序列窗口)
|
||
|
||
Args:
|
||
data: 归一化后的数据
|
||
window_size: 时间窗口大小
|
||
|
||
Returns:
|
||
X和y数据
|
||
"""
|
||
X, y = [], []
|
||
|
||
for i in range(len(data) - window_size):
|
||
X.append(data.iloc[i:(i + window_size)].values)
|
||
y.append(data.iloc[i + window_size]['close'])
|
||
|
||
return np.array(X), np.array(y)
|
||
|
||
def save_data(self, symbol: str, data: pd.DataFrame, data_type: str = 'raw') -> str:
|
||
"""
|
||
保存数据到文件
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
data: 要保存的数据
|
||
data_type: 数据类型,例如 'raw', 'processed', 'normalized'
|
||
|
||
Returns:
|
||
保存文件的路径
|
||
"""
|
||
# 创建目录(如果不存在)
|
||
symbol_dir = os.path.join(self.storage_path, symbol)
|
||
os.makedirs(symbol_dir, exist_ok=True)
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||
filename = f"{symbol}_{data_type}_{timestamp}.csv"
|
||
file_path = os.path.join(symbol_dir, filename)
|
||
|
||
# 保存数据
|
||
data.to_csv(file_path)
|
||
|
||
return file_path
|
||
|
||
def save_scalers(self, symbol: str) -> str:
|
||
"""
|
||
保存缩放器到文件
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
|
||
Returns:
|
||
保存文件的路径
|
||
"""
|
||
if symbol not in self.scalers:
|
||
raise ValueError(f"没有找到{symbol}的缩放器")
|
||
|
||
# 创建目录(如果不存在)
|
||
symbol_dir = os.path.join(self.storage_path, symbol)
|
||
os.makedirs(symbol_dir, exist_ok=True)
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||
filename = f"{symbol}_scalers_{timestamp}.pkl"
|
||
file_path = os.path.join(symbol_dir, filename)
|
||
|
||
# 保存缩放器
|
||
with open(file_path, 'wb') as f:
|
||
pickle.dump(self.scalers[symbol], f)
|
||
|
||
return file_path
|
||
|
||
def load_data(self, file_path: str) -> pd.DataFrame:
|
||
"""
|
||
从文件加载数据
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
加载的数据
|
||
"""
|
||
return pd.read_csv(file_path, index_col=0, parse_dates=True)
|
||
|
||
def load_scalers(self, file_path: str) -> Dict[str, Any]:
|
||
"""
|
||
从文件加载缩放器
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
加载的缩放器
|
||
"""
|
||
with open(file_path, 'rb') as f:
|
||
return pickle.load(f)
|
||
|
||
def get_latest_data_file(self, symbol: str, data_type: str = 'processed') -> Optional[str]:
|
||
"""
|
||
获取最新的数据文件路径
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
data_type: 数据类型
|
||
|
||
Returns:
|
||
最新的数据文件路径,如果不存在则返回None
|
||
"""
|
||
symbol_dir = os.path.join(self.storage_path, symbol)
|
||
|
||
if not os.path.exists(symbol_dir):
|
||
return None
|
||
|
||
# 查找匹配的文件
|
||
files = [f for f in os.listdir(symbol_dir) if f.startswith(f"{symbol}_{data_type}") and f.endswith('.csv')]
|
||
|
||
if not files:
|
||
return None
|
||
|
||
# 按文件名(包含时间戳)排序
|
||
files.sort(reverse=True)
|
||
|
||
return os.path.join(symbol_dir, files[0])
|
||
|
||
def get_latest_scaler_file(self, symbol: str) -> Optional[str]:
|
||
"""
|
||
获取最新的缩放器文件路径
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
|
||
Returns:
|
||
最新的缩放器文件路径,如果不存在则返回None
|
||
"""
|
||
symbol_dir = os.path.join(self.storage_path, symbol)
|
||
|
||
if not os.path.exists(symbol_dir):
|
||
return None
|
||
|
||
# 查找匹配的文件
|
||
files = [f for f in os.listdir(symbol_dir) if f.startswith(f"{symbol}_scalers") and f.endswith('.pkl')]
|
||
|
||
if not files:
|
||
return None
|
||
|
||
# 按文件名(包含时间戳)排序
|
||
files.sort(reverse=True)
|
||
|
||
return os.path.join(symbol_dir, files[0]) |