astock-agent/backend/app/api/watchlists.py
2026-04-22 11:02:19 +08:00

251 lines
9.0 KiB
Python

"""用户自选股 API"""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import text
from app.core.deps import get_current_user
from app.db.database import get_db
from app.engine.watchlist import analyze_watchlist_for_all_users, analyze_watchlist_item
router = APIRouter(prefix="/api/watchlists", tags=["watchlists"])
class WatchlistCreateRequest(BaseModel):
ts_code: str
name: str
note: str = ""
watch_group: str = "observe"
cost_price: float | None = None
class WatchlistUpdateRequest(BaseModel):
note: str | None = None
watch_group: str | None = None
cost_price: float | None = None
WATCH_GROUPS = {"observe", "focus", "candidate", "holding"}
@router.get("")
async def list_watchlists(current_user: dict = Depends(get_current_user)):
async with get_db() as db:
rows = (await db.execute(
text(
"SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.created_at, "
"a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, a.created_at AS analysis_created_at "
"FROM user_watchlists w "
"LEFT JOIN watchlist_analyses a ON a.id = ("
" SELECT id FROM watchlist_analyses "
" WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1"
") "
"WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 "
"ORDER BY w.created_at DESC"
),
{"uid": current_user["id"]},
)).fetchall()
return [dict(row._mapping) for row in rows]
@router.post("")
async def create_watchlist(req: WatchlistCreateRequest, current_user: dict = Depends(get_current_user)):
normalized_code = req.ts_code.strip().upper()
normalized_name = req.name.strip()
normalized_note = req.note.strip()
normalized_group = (req.watch_group or "observe").strip().lower()
normalized_cost = req.cost_price if req.cost_price and req.cost_price > 0 else None
if not normalized_code or not normalized_name:
raise HTTPException(status_code=400, detail="股票代码和名称不能为空")
if normalized_group not in WATCH_GROUPS:
raise HTTPException(status_code=400, detail="无效的自选分组")
async with get_db() as db:
exists = (await db.execute(
text(
"SELECT id FROM user_watchlists "
"WHERE user_id = :uid AND ts_code = :code AND COALESCE(is_active, 1) = 1"
),
{"uid": current_user["id"], "code": normalized_code},
)).fetchone()
if exists:
raise HTTPException(status_code=400, detail="该股票已在自选列表中")
result = await db.execute(
text(
"INSERT INTO user_watchlists (user_id, ts_code, name, note, watch_group, cost_price, is_active) "
"VALUES (:uid, :code, :name, :note, :watch_group, :cost_price, 1)"
),
{
"uid": current_user["id"],
"code": normalized_code,
"name": normalized_name,
"note": normalized_note,
"watch_group": normalized_group,
"cost_price": normalized_cost,
},
)
await db.commit()
watchlist_id = getattr(result, "lastrowid", None)
if not watchlist_id:
inserted = (await db.execute(
text(
"SELECT id FROM user_watchlists "
"WHERE user_id = :uid AND ts_code = :code "
"ORDER BY id DESC LIMIT 1"
),
{"uid": current_user["id"], "code": normalized_code},
)).fetchone()
if not inserted:
raise HTTPException(status_code=500, detail="自选股创建失败")
watchlist_id = inserted._mapping["id"]
await analyze_watchlist_item(
watchlist_id=watchlist_id,
user_id=current_user["id"],
ts_code=normalized_code,
name=normalized_name,
note=normalized_note,
watch_group=normalized_group,
cost_price=normalized_cost,
mode="manual",
)
return {"status": "ok", "message": "已加入自选并完成首次分析", "watchlist_id": watchlist_id}
@router.patch("/{watchlist_id}")
async def update_watchlist(watchlist_id: int, req: WatchlistUpdateRequest, current_user: dict = Depends(get_current_user)):
updates: list[str] = []
params: dict = {"id": watchlist_id, "uid": current_user["id"]}
if req.note is not None:
updates.append("note = :note")
params["note"] = req.note.strip()
if req.watch_group is not None:
normalized_group = req.watch_group.strip().lower()
if normalized_group not in WATCH_GROUPS:
raise HTTPException(status_code=400, detail="无效的自选分组")
updates.append("watch_group = :watch_group")
params["watch_group"] = normalized_group
if req.cost_price is not None:
updates.append("cost_price = :cost_price")
params["cost_price"] = req.cost_price if req.cost_price > 0 else None
if not updates:
raise HTTPException(status_code=400, detail="没有可更新的字段")
updates.append("updated_at = CURRENT_TIMESTAMP")
async with get_db() as db:
result = await db.execute(
text(
f"UPDATE user_watchlists SET {', '.join(updates)} "
"WHERE id = :id AND user_id = :uid AND COALESCE(is_active, 1) = 1"
),
params,
)
await db.commit()
if result.rowcount == 0:
raise HTTPException(status_code=404, detail="自选股不存在")
row = (await db.execute(
text(
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price "
"FROM user_watchlists "
"WHERE id = :id AND user_id = :uid"
),
{"id": watchlist_id, "uid": current_user["id"]},
)).fetchone()
item = row._mapping
return {"status": "ok", "item": dict(item)}
@router.delete("/{watchlist_id}")
async def delete_watchlist(watchlist_id: int, current_user: dict = Depends(get_current_user)):
async with get_db() as db:
await db.execute(
text(
"UPDATE user_watchlists SET is_active = 0 "
"WHERE id = :id AND user_id = :uid"
),
{"id": watchlist_id, "uid": current_user["id"]},
)
await db.commit()
return {"status": "ok"}
@router.post("/{watchlist_id}/analyze")
async def analyze_single_watchlist(watchlist_id: int, current_user: dict = Depends(get_current_user)):
async with get_db() as db:
row = (await db.execute(
text(
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price FROM user_watchlists "
"WHERE id = :id AND user_id = :uid AND COALESCE(is_active, 1) = 1"
),
{"id": watchlist_id, "uid": current_user["id"]},
)).fetchone()
if not row:
raise HTTPException(status_code=404, detail="自选股不存在")
item = row._mapping
result = await analyze_watchlist_item(
watchlist_id=item["id"],
user_id=item["user_id"],
ts_code=item["ts_code"],
name=item["name"],
note=item["note"] or "",
watch_group=item["watch_group"] or "observe",
cost_price=item["cost_price"],
mode="manual",
)
return {"status": "ok", "result": result}
@router.post("/analyze-all")
async def analyze_all_watchlists(current_user: dict = Depends(get_current_user)):
async with get_db() as db:
rows = (await db.execute(
text(
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price FROM user_watchlists "
"WHERE user_id = :uid AND COALESCE(is_active, 1) = 1"
),
{"uid": current_user["id"]},
)).fetchall()
count = 0
for row in rows:
item = row._mapping
await analyze_watchlist_item(
watchlist_id=item["id"],
user_id=item["user_id"],
ts_code=item["ts_code"],
name=item["name"],
note=item["note"] or "",
watch_group=item["watch_group"] or "observe",
cost_price=item["cost_price"],
mode="manual",
)
count += 1
return {"status": "ok", "count": count, "message": f"已完成 {count} 条自选股分析"}
@router.get("/{watchlist_id}/history")
async def watchlist_history(watchlist_id: int, current_user: dict = Depends(get_current_user)):
async with get_db() as db:
rows = (await db.execute(
text(
"SELECT a.* FROM watchlist_analyses a "
"INNER JOIN user_watchlists w ON w.id = a.watchlist_id "
"WHERE a.watchlist_id = :wid AND w.user_id = :uid "
"ORDER BY a.created_at DESC LIMIT 20"
),
{"wid": watchlist_id, "uid": current_user["id"]},
)).fetchall()
return [dict(row._mapping) for row in rows]