463 lines
16 KiB
Python
463 lines
16 KiB
Python
import os
|
||
import json
|
||
import requests
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import time
|
||
import logging
|
||
import datetime
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler("deepseek_token_usage.log"),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
class DeepSeekAPI:
|
||
"""DeepSeek API交互类,用于进行市场分析和预测"""
|
||
|
||
def __init__(self, api_key: str, model: str = "deepseek-moe-16b-chat"):
|
||
"""
|
||
初始化DeepSeek API
|
||
|
||
Args:
|
||
api_key: DeepSeek API密钥
|
||
model: 使用的模型名称
|
||
"""
|
||
self.api_key = api_key
|
||
self.model = model
|
||
self.base_url = "https://api.deepseek.com/v1"
|
||
self.headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
# Token 使用统计
|
||
self.token_usage = {
|
||
"total_prompt_tokens": 0,
|
||
"total_completion_tokens": 0,
|
||
"total_tokens": 0,
|
||
"calls": []
|
||
}
|
||
|
||
# 创建日志记录器
|
||
self.logger = logging.getLogger("DeepSeekAPI")
|
||
|
||
def analyze_market_data(self, market_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
分析市场数据
|
||
|
||
Args:
|
||
market_data: 包含市场数据的字典,例如价格、交易量等
|
||
|
||
Returns:
|
||
分析结果
|
||
"""
|
||
# 将市场数据格式化为适合大模型的格式
|
||
formatted_data = self._format_market_data(market_data)
|
||
|
||
# 构建提示词
|
||
prompt = self._build_market_analysis_prompt(formatted_data)
|
||
|
||
# 调用API获取分析
|
||
response, usage = self._call_api(prompt, task_type="市场分析", symbol=market_data.get("symbol", "未知"))
|
||
|
||
# 解析响应
|
||
result = self._parse_analysis_response(response)
|
||
|
||
# 添加token使用信息
|
||
if usage:
|
||
result["_token_usage"] = usage
|
||
|
||
return result
|
||
|
||
def predict_price_trend(self, symbol: str, historical_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
预测价格趋势
|
||
|
||
Args:
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
historical_data: 历史数据
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
# 格式化历史数据
|
||
formatted_data = self._format_historical_data(symbol, historical_data)
|
||
|
||
# 构建提示词
|
||
prompt = self._build_price_prediction_prompt(symbol, formatted_data)
|
||
|
||
# 调用API获取预测
|
||
response, usage = self._call_api(prompt, task_type="价格预测", symbol=symbol)
|
||
|
||
# 解析响应
|
||
result = self._parse_prediction_response(response)
|
||
|
||
# 添加token使用信息
|
||
if usage:
|
||
result["_token_usage"] = usage
|
||
|
||
return result
|
||
|
||
def generate_trading_strategy(self, symbol: str, analysis_result: Dict[str, Any], risk_level: str) -> Dict[str, Any]:
|
||
"""
|
||
生成交易策略
|
||
|
||
Args:
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
analysis_result: 分析结果
|
||
risk_level: 风险等级,'low', 'medium', 'high'
|
||
|
||
Returns:
|
||
交易策略
|
||
"""
|
||
# 构建提示词
|
||
prompt = self._build_trading_strategy_prompt(symbol, analysis_result, risk_level)
|
||
|
||
# 调用API获取策略
|
||
response, usage = self._call_api(prompt, task_type="交易策略", symbol=symbol)
|
||
|
||
# 解析响应
|
||
result = self._parse_strategy_response(response)
|
||
|
||
# 添加token使用信息
|
||
if usage:
|
||
result["_token_usage"] = usage
|
||
|
||
return result
|
||
|
||
def get_token_usage_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取Token使用统计信息
|
||
|
||
Returns:
|
||
包含使用统计的字典
|
||
"""
|
||
return {
|
||
"total_prompt_tokens": self.token_usage["total_prompt_tokens"],
|
||
"total_completion_tokens": self.token_usage["total_completion_tokens"],
|
||
"total_tokens": self.token_usage["total_tokens"],
|
||
"total_calls": len(self.token_usage["calls"]),
|
||
"average_tokens_per_call": self.token_usage["total_tokens"] / len(self.token_usage["calls"]) if self.token_usage["calls"] else 0,
|
||
"detailed_calls": self.token_usage["calls"][-10:] # 仅返回最近10次调用详情
|
||
}
|
||
|
||
def export_token_usage(self, file_path: str = None, format: str = "json") -> str:
|
||
"""
|
||
导出Token使用数据到文件
|
||
|
||
Args:
|
||
file_path: 文件路径,如果为None则自动生成
|
||
format: 导出格式,支持'json'或'csv'
|
||
|
||
Returns:
|
||
导出文件的路径
|
||
"""
|
||
if file_path is None:
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
file_path = f"deepseek_token_usage_{timestamp}.{format}"
|
||
|
||
try:
|
||
if format.lower() == "json":
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.token_usage, f, indent=2, ensure_ascii=False)
|
||
elif format.lower() == "csv":
|
||
import csv
|
||
|
||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
# 写入表头
|
||
writer.writerow([
|
||
"timestamp", "task_type", "symbol", "model",
|
||
"prompt_tokens", "completion_tokens", "total_tokens",
|
||
"duration_seconds"
|
||
])
|
||
|
||
# 写入数据
|
||
for call in self.token_usage["calls"]:
|
||
writer.writerow([
|
||
call.get("timestamp", ""),
|
||
call.get("task_type", ""),
|
||
call.get("symbol", ""),
|
||
call.get("model", ""),
|
||
call.get("prompt_tokens", 0),
|
||
call.get("completion_tokens", 0),
|
||
call.get("total_tokens", 0),
|
||
call.get("duration_seconds", 0)
|
||
])
|
||
|
||
# 写入总计
|
||
writer.writerow([])
|
||
writer.writerow([
|
||
f"总计 (调用次数: {len(self.token_usage['calls'])})",
|
||
"", "", "",
|
||
self.token_usage["total_prompt_tokens"],
|
||
self.token_usage["total_completion_tokens"],
|
||
self.token_usage["total_tokens"],
|
||
""
|
||
])
|
||
else:
|
||
raise ValueError(f"不支持的格式: {format},仅支持 'json' 或 'csv'")
|
||
|
||
self.logger.info(f"Token使用数据已导出到: {file_path}")
|
||
return file_path
|
||
|
||
except Exception as e:
|
||
error_msg = f"导出Token使用数据时出错: {e}"
|
||
self.logger.error(error_msg)
|
||
return ""
|
||
|
||
def _call_api(self, prompt: str, task_type: str = "未知任务", symbol: str = "未知") -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
"""
|
||
调用DeepSeek API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
task_type: 任务类型
|
||
symbol: 交易对符号
|
||
|
||
Returns:
|
||
(API响应, token使用信息)
|
||
"""
|
||
usage_info = {}
|
||
|
||
try:
|
||
endpoint = f"{self.base_url}/chat/completions"
|
||
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个专业的加密货币分析助手,擅长分析市场趋势、预测价格走向和提供交易建议。请始终使用中文回复,并确保输出格式规范的JSON。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"temperature": 0.2, # 低温度使输出更加确定性
|
||
"max_tokens": 2000
|
||
}
|
||
|
||
start_time = time.time()
|
||
response = requests.post(endpoint, headers=self.headers, json=payload)
|
||
response.raise_for_status()
|
||
response_data = response.json()
|
||
end_time = time.time()
|
||
|
||
# 记录token使用情况
|
||
if 'usage' in response_data:
|
||
prompt_tokens = response_data['usage'].get('prompt_tokens', 0)
|
||
completion_tokens = response_data['usage'].get('completion_tokens', 0)
|
||
total_tokens = response_data['usage'].get('total_tokens', 0)
|
||
|
||
usage_info = {
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": total_tokens,
|
||
"task_type": task_type,
|
||
"symbol": symbol,
|
||
"model": self.model,
|
||
"timestamp": datetime.datetime.now().isoformat(),
|
||
"duration_seconds": round(end_time - start_time, 2)
|
||
}
|
||
|
||
# 更新总计
|
||
self.token_usage["total_prompt_tokens"] += prompt_tokens
|
||
self.token_usage["total_completion_tokens"] += completion_tokens
|
||
self.token_usage["total_tokens"] += total_tokens
|
||
self.token_usage["calls"].append(usage_info)
|
||
|
||
# 记录到日志
|
||
self.logger.info(
|
||
f"DeepSeek API调用 - 任务: {task_type}, 符号: {symbol}, "
|
||
f"输入tokens: {prompt_tokens}, 输出tokens: {completion_tokens}, "
|
||
f"总tokens: {total_tokens}, 耗时: {round(end_time - start_time, 2)}秒"
|
||
)
|
||
|
||
return response_data, usage_info
|
||
|
||
except Exception as e:
|
||
error_msg = f"调用DeepSeek API时出错: {e}"
|
||
self.logger.error(error_msg)
|
||
return {}, usage_info
|
||
|
||
def _format_market_data(self, market_data: Dict[str, Any]) -> str:
|
||
"""
|
||
格式化市场数据为适合大模型的格式
|
||
|
||
Args:
|
||
market_data: 市场数据
|
||
|
||
Returns:
|
||
格式化的数据字符串
|
||
"""
|
||
# 这里可以根据实际情况调整格式化方式
|
||
return json.dumps(market_data, indent=2)
|
||
|
||
def _format_historical_data(self, symbol: str, historical_data: Dict[str, Any]) -> str:
|
||
"""
|
||
格式化历史数据为适合大模型的格式
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
historical_data: 历史数据
|
||
|
||
Returns:
|
||
格式化的数据字符串
|
||
"""
|
||
# 可以根据实际情况调整格式化方式
|
||
return json.dumps(historical_data, indent=2)
|
||
|
||
def _build_market_analysis_prompt(self, formatted_data: str) -> str:
|
||
"""
|
||
构建市场分析提示词
|
||
|
||
Args:
|
||
formatted_data: 格式化的市场数据
|
||
|
||
Returns:
|
||
提示词
|
||
"""
|
||
return f"""请分析以下加密货币市场数据,并提供详细的市场分析。请使用中文回复。
|
||
|
||
数据:
|
||
{formatted_data}
|
||
|
||
请包括以下内容:
|
||
1. 市场总体趋势
|
||
2. 主要支撑位和阻力位
|
||
3. 交易量分析
|
||
4. 市场情绪评估
|
||
5. 关键技术指标解读(如RSI、MACD等)
|
||
|
||
请以JSON格式回复,包含以下字段:
|
||
- market_trend: 市场趋势 (牛市, 熊市, 震荡)
|
||
- support_levels: 支撑位列表
|
||
- resistance_levels: 阻力位列表
|
||
- volume_analysis: 交易量分析
|
||
- market_sentiment: 市场情绪
|
||
- technical_indicators: 技术指标分析
|
||
- summary: 总结
|
||
|
||
请确保回复为有效的JSON格式,并使用中文进行分析。"""
|
||
|
||
def _build_price_prediction_prompt(self, symbol: str, formatted_data: str) -> str:
|
||
"""
|
||
构建价格预测提示词
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
formatted_data: 格式化的历史数据
|
||
|
||
Returns:
|
||
提示词
|
||
"""
|
||
return f"""请基于以下{symbol}的历史数据,预测未来24小时、7天和30天的价格走势。请使用中文回复。
|
||
|
||
历史数据:
|
||
{formatted_data}
|
||
|
||
请考虑市场趋势、技术指标、历史模式和当前市场情况,提供详细的预测分析。
|
||
|
||
请以JSON格式回复,包含以下字段:
|
||
- symbol: 交易对符号
|
||
- current_price: 当前价格
|
||
- prediction_24h: 24小时预测 (包含 price_range价格区间, trend趋势, confidence置信度)
|
||
- prediction_7d: 7天预测 (包含 price_range价格区间, trend趋势, confidence置信度)
|
||
- prediction_30d: 30天预测 (包含 price_range价格区间, trend趋势, confidence置信度)
|
||
- key_factors: 影响预测的关键因素
|
||
- risk_assessment: 风险评估
|
||
|
||
请确保回复为有效的JSON格式,并使用中文进行分析。"""
|
||
|
||
def _build_trading_strategy_prompt(self, symbol: str, analysis_result: Dict[str, Any], risk_level: str) -> str:
|
||
"""
|
||
构建交易策略提示词
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
analysis_result: 分析结果
|
||
risk_level: 风险等级
|
||
|
||
Returns:
|
||
提示词
|
||
"""
|
||
analysis_json = json.dumps(analysis_result, indent=2)
|
||
|
||
return f"""请基于以下{symbol}的市场分析结果,生成一个风险等级为{risk_level}的交易策略。请使用中文回复。
|
||
|
||
分析结果:
|
||
{analysis_json}
|
||
|
||
请考虑市场趋势、技术指标、风险等级和当前市场情况,提供详细的交易策略。
|
||
|
||
请以JSON格式回复,包含以下字段:
|
||
- symbol: 交易对符号
|
||
- risk_level: 风险等级 (low低风险, medium中风险, high高风险)
|
||
- position: 建议仓位 (买入、卖出、持有)
|
||
- entry_points: 入场点列表
|
||
- exit_points: 出场点列表
|
||
- stop_loss: 止损位
|
||
- take_profit: 止盈位
|
||
- time_frame: 建议的交易时间框架
|
||
- strategy_type: 策略类型 (例如:趋势跟踪、反转、突破等)
|
||
- reasoning: 策略推理过程
|
||
|
||
请确保回复为有效的JSON格式,并使用中文进行分析。"""
|
||
|
||
def _parse_analysis_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
解析分析响应
|
||
|
||
Args:
|
||
response: API响应
|
||
|
||
Returns:
|
||
解析后的分析结果
|
||
"""
|
||
try:
|
||
if 'choices' in response and len(response['choices']) > 0:
|
||
content = response['choices'][0]['message']['content']
|
||
|
||
# 尝试从响应中提取JSON
|
||
start_idx = content.find('{')
|
||
end_idx = content.rfind('}') + 1
|
||
|
||
if start_idx != -1 and end_idx != -1:
|
||
json_str = content[start_idx:end_idx]
|
||
return json.loads(json_str)
|
||
|
||
return {"error": "无法从响应中提取JSON", "raw_content": content}
|
||
|
||
return {"error": "API响应格式不正确", "raw_response": response}
|
||
|
||
except Exception as e:
|
||
error_msg = f"解析分析响应时出错: {e}"
|
||
self.logger.error(error_msg)
|
||
return {"error": str(e), "raw_response": response}
|
||
|
||
def _parse_prediction_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
解析预测响应
|
||
|
||
Args:
|
||
response: API响应
|
||
|
||
Returns:
|
||
解析后的预测结果
|
||
"""
|
||
# 与_parse_analysis_response相同的实现
|
||
return self._parse_analysis_response(response)
|
||
|
||
def _parse_strategy_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
解析策略响应
|
||
|
||
Args:
|
||
response: API响应
|
||
|
||
Returns:
|
||
解析后的策略结果
|
||
"""
|
||
# 与_parse_analysis_response相同的实现
|
||
return self._parse_analysis_response(response) |