alphax/app/db/chat_assistant_db.py
2026-06-07 20:58:35 +08:00

459 lines
15 KiB
Python

"""Chat assistant persistence helpers."""
from __future__ import annotations
import json
from datetime import datetime, timedelta
from app.db.llm_insights import repair_mojibake_json, repair_mojibake_text
from app.db.postgres_connection import ensure_migrations_once
from app.db.schema import get_conn
def _now() -> str:
return datetime.now().isoformat(timespec="seconds")
def _loads(value, fallback=None):
try:
if isinstance(value, str) and value.strip():
return repair_mojibake_json(json.loads(value))
if value is not None:
return repair_mojibake_json(value)
except Exception:
pass
return fallback if fallback is not None else {}
def _dumps(value) -> str:
return json.dumps(repair_mojibake_json(value if value is not None else {}), ensure_ascii=False, sort_keys=True, default=str)
def init_chat_tables():
ensure_migrations_once()
def _normalize_title(title: str) -> str:
title = str(title or "").strip()
return title[:32] or "新对话"
def _load_session(row):
if not row:
return None
item = dict(row)
item["memory"] = _loads(item.pop("memory_json", "{}"), {})
return item
def _load_message(row):
if not row:
return None
item = dict(row)
item["content_text"] = repair_mojibake_text(item.get("content_text", ""))
item["content"] = _loads(item.pop("content_json", "{}"), {})
item["context"] = _loads(item.pop("context_json", "{}"), {})
return item
def get_user_preferences(user_id: int) -> dict:
init_chat_tables()
conn = get_conn()
try:
row = conn.execute("SELECT * FROM chat_user_preferences WHERE user_id=%s", (int(user_id),)).fetchone()
finally:
conn.close()
if not row:
return {
"preferred_symbols": [],
"preferred_timeframes": ["15m", "1h", "4h", "1d"],
"answer_style": "two_stage",
"risk_profile": "balanced",
"last_intent": "",
"last_symbol": "",
"recent_topics": [],
}
prefs = _loads(row.get("preferences_json"), {})
prefs.setdefault("preferred_symbols", [])
prefs.setdefault("preferred_timeframes", ["15m", "1h", "4h", "1d"])
prefs.setdefault("answer_style", "two_stage")
prefs.setdefault("risk_profile", "balanced")
prefs.setdefault("last_intent", "")
prefs.setdefault("last_symbol", "")
prefs.setdefault("recent_topics", [])
return prefs
def update_user_preferences(user_id: int, patch: dict) -> dict:
init_chat_tables()
current = get_user_preferences(user_id)
patch = patch or {}
for key, value in patch.items():
if key in ("preferred_symbols", "preferred_timeframes", "recent_topics") and isinstance(value, list):
merged = list(dict.fromkeys([str(x) for x in current.get(key, []) if str(x).strip()] + [str(x) for x in value if str(x).strip()]))
current[key] = merged[-12:]
elif value is not None:
current[key] = value
now = _now()
conn = get_conn()
try:
conn.execute(
"""
INSERT INTO chat_user_preferences (user_id, preferences_json, updated_at)
VALUES (%s, %s, %s)
ON CONFLICT(user_id) DO UPDATE SET
preferences_json=excluded.preferences_json,
updated_at=excluded.updated_at
""",
(int(user_id), _dumps(current), now),
)
conn.commit()
finally:
conn.close()
return current
def create_chat_session(user_id: int, title: str = "", summary: str = "", last_symbol: str = "", last_intent: str = "") -> dict:
init_chat_tables()
now = _now()
conn = get_conn()
try:
row = conn.execute(
"""
INSERT INTO chat_sessions (user_id, title, summary, memory_json, last_symbol, last_intent, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING *
""",
(int(user_id), _normalize_title(title), summary or "", _dumps({}), last_symbol or "", last_intent or "", now, now),
).fetchone()
conn.commit()
finally:
conn.close()
return _load_session(row)
def list_chat_sessions(user_id: int, limit: int = 20, offset: int = 0) -> dict:
init_chat_tables()
limit = max(1, min(int(limit or 20), 100))
offset = max(0, int(offset or 0))
conn = get_conn()
try:
total = conn.execute("SELECT COUNT(*) FROM chat_sessions WHERE user_id=%s AND COALESCE(archived_at, '')=''", (int(user_id),)).fetchone()[0]
rows = conn.execute(
"""
SELECT s.*,
(SELECT m.content_text FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_text,
(SELECT m.role FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_role,
(SELECT m.created_at FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_at,
(SELECT COUNT(*) FROM chat_messages m WHERE m.session_id=s.id) AS message_count
FROM chat_sessions s
WHERE s.user_id=%s AND COALESCE(s.archived_at, '')=''
ORDER BY s.updated_at DESC, s.id DESC
LIMIT %s OFFSET %s
""",
(int(user_id), limit, offset),
).fetchall()
finally:
conn.close()
items = []
for row in rows:
item = _load_session(row)
item["last_message_text"] = item.pop("last_message_text", "")
item["last_message_role"] = item.pop("last_message_role", "")
item["last_message_at"] = item.pop("last_message_at", "")
item["message_count"] = int(item.pop("message_count", 0) or 0)
items.append(item)
return {
"items": items,
"total": int(total or 0),
"limit": limit,
"offset": offset,
"has_more": offset + len(items) < int(total or 0),
}
def get_chat_session(session_id: int, user_id: int) -> dict | None:
init_chat_tables()
conn = get_conn()
try:
row = conn.execute(
"SELECT * FROM chat_sessions WHERE id=%s AND user_id=%s AND COALESCE(archived_at, '')=''",
(int(session_id), int(user_id)),
).fetchone()
finally:
conn.close()
return _load_session(row)
def update_chat_session(session_id: int, user_id: int, **fields) -> dict | None:
init_chat_tables()
session = get_chat_session(session_id, user_id)
if not session:
return None
allowed = {"title", "summary", "memory_json", "last_symbol", "last_intent", "archived_at"}
updates = {}
for key, value in fields.items():
if key not in allowed or value is None:
continue
updates[key] = value
if not updates:
return session
updates["updated_at"] = _now()
if "title" in updates:
updates["title"] = _normalize_title(updates["title"])
if "memory_json" in updates and not isinstance(updates["memory_json"], str):
updates["memory_json"] = _dumps(updates["memory_json"])
if updates.get("archived_at") == "now":
updates["archived_at"] = _now()
sets = ", ".join(f"{key}=%s" for key in updates)
params = list(updates.values()) + [int(session_id), int(user_id)]
conn = get_conn()
try:
row = conn.execute(
f"UPDATE chat_sessions SET {sets} WHERE id=%s AND user_id=%s RETURNING *",
tuple(params),
).fetchone()
conn.commit()
finally:
conn.close()
return _load_session(row)
def list_chat_messages(session_id: int, user_id: int, limit: int = 50, offset: int = 0) -> dict:
init_chat_tables()
limit = max(1, min(int(limit or 50), 200))
offset = max(0, int(offset or 0))
conn = get_conn()
try:
total = conn.execute(
"SELECT COUNT(*) FROM chat_messages WHERE session_id=%s AND user_id=%s",
(int(session_id), int(user_id)),
).fetchone()[0]
rows = conn.execute(
"""
SELECT * FROM chat_messages
WHERE session_id=%s AND user_id=%s
ORDER BY id ASC
LIMIT %s OFFSET %s
""",
(int(session_id), int(user_id), limit, offset),
).fetchall()
finally:
conn.close()
return {
"items": [_load_message(row) for row in rows],
"total": int(total or 0),
"limit": limit,
"offset": offset,
"has_more": offset + len(rows) < int(total or 0),
}
def append_chat_message(
session_id: int,
user_id: int,
role: str,
content_text: str = "",
content_json=None,
context_json=None,
intent: str = "",
symbol: str = "",
timeframe: str = "",
model: str = "",
) -> dict:
init_chat_tables()
now = _now()
conn = get_conn()
try:
row = conn.execute(
"""
INSERT INTO chat_messages (
session_id, user_id, role, content_text, content_json, context_json,
intent, symbol, timeframe, model, created_at
) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
RETURNING *
""",
(
int(session_id),
int(user_id),
str(role or "user"),
repair_mojibake_text(str(content_text or "")),
_dumps(content_json or {}),
_dumps(context_json or {}),
str(intent or ""),
str(symbol or ""),
str(timeframe or ""),
str(model or ""),
now,
),
).fetchone()
conn.commit()
finally:
conn.close()
return _load_message(row)
def bootstrap_chat(user_id: int) -> dict:
prefs = get_user_preferences(user_id)
sessions = list_chat_sessions(user_id=user_id, limit=20, offset=0)
prompts = [
"分析 BTC/USDT 现在的技术面",
"解释当前看板里这条推荐为什么是等回踩",
"看一下市场总览,今天是偏强还是偏弱",
"这个币的舆情和技术面有没有共振",
"帮我复盘最近一次纸面交易",
]
return {
"preferences": prefs,
"sessions": sessions,
"suggested_prompts": prompts,
}
def get_chat_admin_overview(hours: int = 24) -> dict:
init_chat_tables()
hours = max(0, int(hours or 24))
conn = get_conn()
try:
params = []
where = "1=1"
if hours > 0:
where += " AND m.created_at >= %s"
params.append((datetime.now().replace(microsecond=0) - timedelta(hours=hours)).isoformat())
totals = conn.execute(
f"""
SELECT
COUNT(*) AS total_messages,
COUNT(*) FILTER (WHERE role='user') AS total_questions,
COUNT(DISTINCT session_id) AS total_sessions,
COUNT(DISTINCT user_id) AS total_users
FROM chat_messages m
WHERE {where}
""",
tuple(params),
).fetchone()
top_intents = conn.execute(
f"""
SELECT COALESCE(NULLIF(intent, ''), 'unknown') AS intent, COUNT(*) AS n
FROM chat_messages m
WHERE role='user' AND {where}
GROUP BY 1
ORDER BY n DESC, intent ASC
LIMIT 8
""",
tuple(params),
).fetchall()
top_symbols = conn.execute(
f"""
SELECT COALESCE(NULLIF(symbol, ''), 'unknown') AS symbol, COUNT(*) AS n
FROM chat_messages m
WHERE role='user' AND {where}
GROUP BY 1
ORDER BY n DESC, symbol ASC
LIMIT 10
""",
tuple(params),
).fetchall()
recent = conn.execute(
f"""
SELECT m.id, m.session_id, m.user_id, m.role, m.content_text, m.intent, m.symbol, m.model, m.created_at,
s.title AS session_title, u.email AS user_email
FROM chat_messages m
LEFT JOIN chat_sessions s ON s.id=m.session_id
LEFT JOIN app_user u ON u.id=m.user_id
WHERE m.role='user' AND {where}
ORDER BY m.created_at DESC, m.id DESC
LIMIT 50
""",
tuple(params),
).fetchall()
finally:
conn.close()
return {
"hours": hours,
"total_messages": int((totals[0] if totals else 0) or 0),
"total_questions": int((totals[1] if totals else 0) or 0),
"total_sessions": int((totals[2] if totals else 0) or 0),
"total_users": int((totals[3] if totals else 0) or 0),
"top_intents": [dict(row) for row in top_intents],
"top_symbols": [dict(row) for row in top_symbols],
"recent_questions": [dict(row) for row in recent],
}
def list_chat_admin_questions(hours: int = 24, intent: str = "", search: str = "", offset: int = 0, limit: int = 50) -> dict:
init_chat_tables()
hours = max(0, int(hours or 24))
offset = max(0, int(offset or 0))
limit = max(1, min(int(limit or 50), 200))
intent = str(intent or "").strip()
search = str(search or "").strip()
conn = get_conn()
try:
params = []
where = "m.role='user'"
if hours > 0:
where += " AND m.created_at >= %s"
params.append((datetime.now().replace(microsecond=0) - timedelta(hours=hours)).isoformat())
if intent and intent != "all":
where += " AND COALESCE(NULLIF(m.intent, ''), 'unknown')=%s"
params.append(intent)
if search:
where += " AND (m.content_text ILIKE %s OR COALESCE(s.title,'') ILIKE %s OR COALESCE(u.email,'') ILIKE %s)"
like = f"%{search}%"
params.extend([like, like, like])
total = conn.execute(
f"""
SELECT COUNT(*)
FROM chat_messages m
LEFT JOIN chat_sessions s ON s.id=m.session_id
LEFT JOIN app_user u ON u.id=m.user_id
WHERE {where}
""",
tuple(params),
).fetchone()[0]
rows = conn.execute(
f"""
SELECT m.id, m.session_id, m.user_id, m.role, m.content_text, m.content_json, m.context_json,
m.intent, m.symbol, m.timeframe, m.model, m.created_at,
s.title AS session_title, u.email AS user_email
FROM chat_messages m
LEFT JOIN chat_sessions s ON s.id=m.session_id
LEFT JOIN app_user u ON u.id=m.user_id
WHERE {where}
ORDER BY m.created_at DESC, m.id DESC
LIMIT %s OFFSET %s
""",
tuple(params + [limit, offset]),
).fetchall()
finally:
conn.close()
items = []
for row in rows:
item = dict(row)
item["content"] = _loads(item.pop("content_json", "{}"), {})
item["context"] = _loads(item.pop("context_json", "{}"), {})
item["content_text"] = repair_mojibake_text(item.get("content_text", ""))
items.append(item)
return {
"items": items,
"total": int(total or 0),
"limit": limit,
"offset": offset,
"has_more": offset + len(items) < int(total or 0),
}
__all__ = [
"append_chat_message",
"bootstrap_chat",
"create_chat_session",
"get_chat_session",
"get_user_preferences",
"init_chat_tables",
"list_chat_messages",
"list_chat_sessions",
"update_chat_session",
"update_user_preferences",
]