459 lines
15 KiB
Python
459 lines
15 KiB
Python
"""Chat assistant persistence helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
|
|
from app.db.llm_insights import repair_mojibake_json, repair_mojibake_text
|
|
from app.db.postgres_connection import ensure_migrations_once
|
|
from app.db.schema import get_conn
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now().isoformat(timespec="seconds")
|
|
|
|
|
|
def _loads(value, fallback=None):
|
|
try:
|
|
if isinstance(value, str) and value.strip():
|
|
return repair_mojibake_json(json.loads(value))
|
|
if value is not None:
|
|
return repair_mojibake_json(value)
|
|
except Exception:
|
|
pass
|
|
return fallback if fallback is not None else {}
|
|
|
|
|
|
def _dumps(value) -> str:
|
|
return json.dumps(repair_mojibake_json(value if value is not None else {}), ensure_ascii=False, sort_keys=True, default=str)
|
|
|
|
|
|
def init_chat_tables():
|
|
ensure_migrations_once()
|
|
|
|
|
|
def _normalize_title(title: str) -> str:
|
|
title = str(title or "").strip()
|
|
return title[:32] or "新对话"
|
|
|
|
|
|
def _load_session(row):
|
|
if not row:
|
|
return None
|
|
item = dict(row)
|
|
item["memory"] = _loads(item.pop("memory_json", "{}"), {})
|
|
return item
|
|
|
|
|
|
def _load_message(row):
|
|
if not row:
|
|
return None
|
|
item = dict(row)
|
|
item["content_text"] = repair_mojibake_text(item.get("content_text", ""))
|
|
item["content"] = _loads(item.pop("content_json", "{}"), {})
|
|
item["context"] = _loads(item.pop("context_json", "{}"), {})
|
|
return item
|
|
|
|
|
|
def get_user_preferences(user_id: int) -> dict:
|
|
init_chat_tables()
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute("SELECT * FROM chat_user_preferences WHERE user_id=%s", (int(user_id),)).fetchone()
|
|
finally:
|
|
conn.close()
|
|
if not row:
|
|
return {
|
|
"preferred_symbols": [],
|
|
"preferred_timeframes": ["15m", "1h", "4h", "1d"],
|
|
"answer_style": "two_stage",
|
|
"risk_profile": "balanced",
|
|
"last_intent": "",
|
|
"last_symbol": "",
|
|
"recent_topics": [],
|
|
}
|
|
prefs = _loads(row.get("preferences_json"), {})
|
|
prefs.setdefault("preferred_symbols", [])
|
|
prefs.setdefault("preferred_timeframes", ["15m", "1h", "4h", "1d"])
|
|
prefs.setdefault("answer_style", "two_stage")
|
|
prefs.setdefault("risk_profile", "balanced")
|
|
prefs.setdefault("last_intent", "")
|
|
prefs.setdefault("last_symbol", "")
|
|
prefs.setdefault("recent_topics", [])
|
|
return prefs
|
|
|
|
|
|
def update_user_preferences(user_id: int, patch: dict) -> dict:
|
|
init_chat_tables()
|
|
current = get_user_preferences(user_id)
|
|
patch = patch or {}
|
|
for key, value in patch.items():
|
|
if key in ("preferred_symbols", "preferred_timeframes", "recent_topics") and isinstance(value, list):
|
|
merged = list(dict.fromkeys([str(x) for x in current.get(key, []) if str(x).strip()] + [str(x) for x in value if str(x).strip()]))
|
|
current[key] = merged[-12:]
|
|
elif value is not None:
|
|
current[key] = value
|
|
now = _now()
|
|
conn = get_conn()
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO chat_user_preferences (user_id, preferences_json, updated_at)
|
|
VALUES (%s, %s, %s)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
preferences_json=excluded.preferences_json,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(int(user_id), _dumps(current), now),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
return current
|
|
|
|
|
|
def create_chat_session(user_id: int, title: str = "", summary: str = "", last_symbol: str = "", last_intent: str = "") -> dict:
|
|
init_chat_tables()
|
|
now = _now()
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"""
|
|
INSERT INTO chat_sessions (user_id, title, summary, memory_json, last_symbol, last_intent, created_at, updated_at)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING *
|
|
""",
|
|
(int(user_id), _normalize_title(title), summary or "", _dumps({}), last_symbol or "", last_intent or "", now, now),
|
|
).fetchone()
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
return _load_session(row)
|
|
|
|
|
|
def list_chat_sessions(user_id: int, limit: int = 20, offset: int = 0) -> dict:
|
|
init_chat_tables()
|
|
limit = max(1, min(int(limit or 20), 100))
|
|
offset = max(0, int(offset or 0))
|
|
conn = get_conn()
|
|
try:
|
|
total = conn.execute("SELECT COUNT(*) FROM chat_sessions WHERE user_id=%s AND COALESCE(archived_at, '')=''", (int(user_id),)).fetchone()[0]
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT s.*,
|
|
(SELECT m.content_text FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_text,
|
|
(SELECT m.role FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_role,
|
|
(SELECT m.created_at FROM chat_messages m WHERE m.session_id=s.id ORDER BY m.id DESC LIMIT 1) AS last_message_at,
|
|
(SELECT COUNT(*) FROM chat_messages m WHERE m.session_id=s.id) AS message_count
|
|
FROM chat_sessions s
|
|
WHERE s.user_id=%s AND COALESCE(s.archived_at, '')=''
|
|
ORDER BY s.updated_at DESC, s.id DESC
|
|
LIMIT %s OFFSET %s
|
|
""",
|
|
(int(user_id), limit, offset),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
items = []
|
|
for row in rows:
|
|
item = _load_session(row)
|
|
item["last_message_text"] = item.pop("last_message_text", "")
|
|
item["last_message_role"] = item.pop("last_message_role", "")
|
|
item["last_message_at"] = item.pop("last_message_at", "")
|
|
item["message_count"] = int(item.pop("message_count", 0) or 0)
|
|
items.append(item)
|
|
return {
|
|
"items": items,
|
|
"total": int(total or 0),
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": offset + len(items) < int(total or 0),
|
|
}
|
|
|
|
|
|
def get_chat_session(session_id: int, user_id: int) -> dict | None:
|
|
init_chat_tables()
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT * FROM chat_sessions WHERE id=%s AND user_id=%s AND COALESCE(archived_at, '')=''",
|
|
(int(session_id), int(user_id)),
|
|
).fetchone()
|
|
finally:
|
|
conn.close()
|
|
return _load_session(row)
|
|
|
|
|
|
def update_chat_session(session_id: int, user_id: int, **fields) -> dict | None:
|
|
init_chat_tables()
|
|
session = get_chat_session(session_id, user_id)
|
|
if not session:
|
|
return None
|
|
allowed = {"title", "summary", "memory_json", "last_symbol", "last_intent", "archived_at"}
|
|
updates = {}
|
|
for key, value in fields.items():
|
|
if key not in allowed or value is None:
|
|
continue
|
|
updates[key] = value
|
|
if not updates:
|
|
return session
|
|
updates["updated_at"] = _now()
|
|
if "title" in updates:
|
|
updates["title"] = _normalize_title(updates["title"])
|
|
if "memory_json" in updates and not isinstance(updates["memory_json"], str):
|
|
updates["memory_json"] = _dumps(updates["memory_json"])
|
|
if updates.get("archived_at") == "now":
|
|
updates["archived_at"] = _now()
|
|
sets = ", ".join(f"{key}=%s" for key in updates)
|
|
params = list(updates.values()) + [int(session_id), int(user_id)]
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
f"UPDATE chat_sessions SET {sets} WHERE id=%s AND user_id=%s RETURNING *",
|
|
tuple(params),
|
|
).fetchone()
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
return _load_session(row)
|
|
|
|
|
|
def list_chat_messages(session_id: int, user_id: int, limit: int = 50, offset: int = 0) -> dict:
|
|
init_chat_tables()
|
|
limit = max(1, min(int(limit or 50), 200))
|
|
offset = max(0, int(offset or 0))
|
|
conn = get_conn()
|
|
try:
|
|
total = conn.execute(
|
|
"SELECT COUNT(*) FROM chat_messages WHERE session_id=%s AND user_id=%s",
|
|
(int(session_id), int(user_id)),
|
|
).fetchone()[0]
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM chat_messages
|
|
WHERE session_id=%s AND user_id=%s
|
|
ORDER BY id ASC
|
|
LIMIT %s OFFSET %s
|
|
""",
|
|
(int(session_id), int(user_id), limit, offset),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
return {
|
|
"items": [_load_message(row) for row in rows],
|
|
"total": int(total or 0),
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": offset + len(rows) < int(total or 0),
|
|
}
|
|
|
|
|
|
def append_chat_message(
|
|
session_id: int,
|
|
user_id: int,
|
|
role: str,
|
|
content_text: str = "",
|
|
content_json=None,
|
|
context_json=None,
|
|
intent: str = "",
|
|
symbol: str = "",
|
|
timeframe: str = "",
|
|
model: str = "",
|
|
) -> dict:
|
|
init_chat_tables()
|
|
now = _now()
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"""
|
|
INSERT INTO chat_messages (
|
|
session_id, user_id, role, content_text, content_json, context_json,
|
|
intent, symbol, timeframe, model, created_at
|
|
) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
|
|
RETURNING *
|
|
""",
|
|
(
|
|
int(session_id),
|
|
int(user_id),
|
|
str(role or "user"),
|
|
repair_mojibake_text(str(content_text or "")),
|
|
_dumps(content_json or {}),
|
|
_dumps(context_json or {}),
|
|
str(intent or ""),
|
|
str(symbol or ""),
|
|
str(timeframe or ""),
|
|
str(model or ""),
|
|
now,
|
|
),
|
|
).fetchone()
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
return _load_message(row)
|
|
|
|
|
|
def bootstrap_chat(user_id: int) -> dict:
|
|
prefs = get_user_preferences(user_id)
|
|
sessions = list_chat_sessions(user_id=user_id, limit=20, offset=0)
|
|
prompts = [
|
|
"分析 BTC/USDT 现在的技术面",
|
|
"解释当前看板里这条推荐为什么是等回踩",
|
|
"看一下市场总览,今天是偏强还是偏弱",
|
|
"这个币的舆情和技术面有没有共振",
|
|
"帮我复盘最近一次纸面交易",
|
|
]
|
|
return {
|
|
"preferences": prefs,
|
|
"sessions": sessions,
|
|
"suggested_prompts": prompts,
|
|
}
|
|
|
|
|
|
def get_chat_admin_overview(hours: int = 24) -> dict:
|
|
init_chat_tables()
|
|
hours = max(0, int(hours or 24))
|
|
conn = get_conn()
|
|
try:
|
|
params = []
|
|
where = "1=1"
|
|
if hours > 0:
|
|
where += " AND m.created_at >= %s"
|
|
params.append((datetime.now().replace(microsecond=0) - timedelta(hours=hours)).isoformat())
|
|
totals = conn.execute(
|
|
f"""
|
|
SELECT
|
|
COUNT(*) AS total_messages,
|
|
COUNT(*) FILTER (WHERE role='user') AS total_questions,
|
|
COUNT(DISTINCT session_id) AS total_sessions,
|
|
COUNT(DISTINCT user_id) AS total_users
|
|
FROM chat_messages m
|
|
WHERE {where}
|
|
""",
|
|
tuple(params),
|
|
).fetchone()
|
|
top_intents = conn.execute(
|
|
f"""
|
|
SELECT COALESCE(NULLIF(intent, ''), 'unknown') AS intent, COUNT(*) AS n
|
|
FROM chat_messages m
|
|
WHERE role='user' AND {where}
|
|
GROUP BY 1
|
|
ORDER BY n DESC, intent ASC
|
|
LIMIT 8
|
|
""",
|
|
tuple(params),
|
|
).fetchall()
|
|
top_symbols = conn.execute(
|
|
f"""
|
|
SELECT COALESCE(NULLIF(symbol, ''), 'unknown') AS symbol, COUNT(*) AS n
|
|
FROM chat_messages m
|
|
WHERE role='user' AND {where}
|
|
GROUP BY 1
|
|
ORDER BY n DESC, symbol ASC
|
|
LIMIT 10
|
|
""",
|
|
tuple(params),
|
|
).fetchall()
|
|
recent = conn.execute(
|
|
f"""
|
|
SELECT m.id, m.session_id, m.user_id, m.role, m.content_text, m.intent, m.symbol, m.model, m.created_at,
|
|
s.title AS session_title, u.email AS user_email
|
|
FROM chat_messages m
|
|
LEFT JOIN chat_sessions s ON s.id=m.session_id
|
|
LEFT JOIN app_user u ON u.id=m.user_id
|
|
WHERE m.role='user' AND {where}
|
|
ORDER BY m.created_at DESC, m.id DESC
|
|
LIMIT 50
|
|
""",
|
|
tuple(params),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
return {
|
|
"hours": hours,
|
|
"total_messages": int((totals[0] if totals else 0) or 0),
|
|
"total_questions": int((totals[1] if totals else 0) or 0),
|
|
"total_sessions": int((totals[2] if totals else 0) or 0),
|
|
"total_users": int((totals[3] if totals else 0) or 0),
|
|
"top_intents": [dict(row) for row in top_intents],
|
|
"top_symbols": [dict(row) for row in top_symbols],
|
|
"recent_questions": [dict(row) for row in recent],
|
|
}
|
|
|
|
|
|
def list_chat_admin_questions(hours: int = 24, intent: str = "", search: str = "", offset: int = 0, limit: int = 50) -> dict:
|
|
init_chat_tables()
|
|
hours = max(0, int(hours or 24))
|
|
offset = max(0, int(offset or 0))
|
|
limit = max(1, min(int(limit or 50), 200))
|
|
intent = str(intent or "").strip()
|
|
search = str(search or "").strip()
|
|
conn = get_conn()
|
|
try:
|
|
params = []
|
|
where = "m.role='user'"
|
|
if hours > 0:
|
|
where += " AND m.created_at >= %s"
|
|
params.append((datetime.now().replace(microsecond=0) - timedelta(hours=hours)).isoformat())
|
|
if intent and intent != "all":
|
|
where += " AND COALESCE(NULLIF(m.intent, ''), 'unknown')=%s"
|
|
params.append(intent)
|
|
if search:
|
|
where += " AND (m.content_text ILIKE %s OR COALESCE(s.title,'') ILIKE %s OR COALESCE(u.email,'') ILIKE %s)"
|
|
like = f"%{search}%"
|
|
params.extend([like, like, like])
|
|
total = conn.execute(
|
|
f"""
|
|
SELECT COUNT(*)
|
|
FROM chat_messages m
|
|
LEFT JOIN chat_sessions s ON s.id=m.session_id
|
|
LEFT JOIN app_user u ON u.id=m.user_id
|
|
WHERE {where}
|
|
""",
|
|
tuple(params),
|
|
).fetchone()[0]
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT m.id, m.session_id, m.user_id, m.role, m.content_text, m.content_json, m.context_json,
|
|
m.intent, m.symbol, m.timeframe, m.model, m.created_at,
|
|
s.title AS session_title, u.email AS user_email
|
|
FROM chat_messages m
|
|
LEFT JOIN chat_sessions s ON s.id=m.session_id
|
|
LEFT JOIN app_user u ON u.id=m.user_id
|
|
WHERE {where}
|
|
ORDER BY m.created_at DESC, m.id DESC
|
|
LIMIT %s OFFSET %s
|
|
""",
|
|
tuple(params + [limit, offset]),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
items = []
|
|
for row in rows:
|
|
item = dict(row)
|
|
item["content"] = _loads(item.pop("content_json", "{}"), {})
|
|
item["context"] = _loads(item.pop("context_json", "{}"), {})
|
|
item["content_text"] = repair_mojibake_text(item.get("content_text", ""))
|
|
items.append(item)
|
|
return {
|
|
"items": items,
|
|
"total": int(total or 0),
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": offset + len(items) < int(total or 0),
|
|
}
|
|
|
|
|
|
__all__ = [
|
|
"append_chat_message",
|
|
"bootstrap_chat",
|
|
"create_chat_session",
|
|
"get_chat_session",
|
|
"get_user_preferences",
|
|
"init_chat_tables",
|
|
"list_chat_messages",
|
|
"list_chat_sessions",
|
|
"update_chat_session",
|
|
"update_user_preferences",
|
|
]
|