251 lines
9.0 KiB
Python
251 lines
9.0 KiB
Python
"""用户自选股 API"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import text
|
|
|
|
from app.core.deps import get_current_user
|
|
from app.db.database import get_db
|
|
from app.engine.watchlist import analyze_watchlist_for_all_users, analyze_watchlist_item
|
|
|
|
router = APIRouter(prefix="/api/watchlists", tags=["watchlists"])
|
|
|
|
|
|
class WatchlistCreateRequest(BaseModel):
|
|
ts_code: str
|
|
name: str
|
|
note: str = ""
|
|
watch_group: str = "observe"
|
|
cost_price: float | None = None
|
|
|
|
|
|
class WatchlistUpdateRequest(BaseModel):
|
|
note: str | None = None
|
|
watch_group: str | None = None
|
|
cost_price: float | None = None
|
|
|
|
|
|
WATCH_GROUPS = {"observe", "focus", "candidate", "holding"}
|
|
|
|
|
|
@router.get("")
|
|
async def list_watchlists(current_user: dict = Depends(get_current_user)):
|
|
async with get_db() as db:
|
|
rows = (await db.execute(
|
|
text(
|
|
"SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.created_at, "
|
|
"a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, a.created_at AS analysis_created_at "
|
|
"FROM user_watchlists w "
|
|
"LEFT JOIN watchlist_analyses a ON a.id = ("
|
|
" SELECT id FROM watchlist_analyses "
|
|
" WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1"
|
|
") "
|
|
"WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 "
|
|
"ORDER BY w.created_at DESC"
|
|
),
|
|
{"uid": current_user["id"]},
|
|
)).fetchall()
|
|
|
|
return [dict(row._mapping) for row in rows]
|
|
|
|
|
|
@router.post("")
|
|
async def create_watchlist(req: WatchlistCreateRequest, current_user: dict = Depends(get_current_user)):
|
|
normalized_code = req.ts_code.strip().upper()
|
|
normalized_name = req.name.strip()
|
|
normalized_note = req.note.strip()
|
|
normalized_group = (req.watch_group or "observe").strip().lower()
|
|
normalized_cost = req.cost_price if req.cost_price and req.cost_price > 0 else None
|
|
|
|
if not normalized_code or not normalized_name:
|
|
raise HTTPException(status_code=400, detail="股票代码和名称不能为空")
|
|
if normalized_group not in WATCH_GROUPS:
|
|
raise HTTPException(status_code=400, detail="无效的自选分组")
|
|
|
|
async with get_db() as db:
|
|
exists = (await db.execute(
|
|
text(
|
|
"SELECT id FROM user_watchlists "
|
|
"WHERE user_id = :uid AND ts_code = :code AND COALESCE(is_active, 1) = 1"
|
|
),
|
|
{"uid": current_user["id"], "code": normalized_code},
|
|
)).fetchone()
|
|
if exists:
|
|
raise HTTPException(status_code=400, detail="该股票已在自选列表中")
|
|
|
|
result = await db.execute(
|
|
text(
|
|
"INSERT INTO user_watchlists (user_id, ts_code, name, note, watch_group, cost_price, is_active) "
|
|
"VALUES (:uid, :code, :name, :note, :watch_group, :cost_price, 1)"
|
|
),
|
|
{
|
|
"uid": current_user["id"],
|
|
"code": normalized_code,
|
|
"name": normalized_name,
|
|
"note": normalized_note,
|
|
"watch_group": normalized_group,
|
|
"cost_price": normalized_cost,
|
|
},
|
|
)
|
|
await db.commit()
|
|
watchlist_id = getattr(result, "lastrowid", None)
|
|
|
|
if not watchlist_id:
|
|
inserted = (await db.execute(
|
|
text(
|
|
"SELECT id FROM user_watchlists "
|
|
"WHERE user_id = :uid AND ts_code = :code "
|
|
"ORDER BY id DESC LIMIT 1"
|
|
),
|
|
{"uid": current_user["id"], "code": normalized_code},
|
|
)).fetchone()
|
|
if not inserted:
|
|
raise HTTPException(status_code=500, detail="自选股创建失败")
|
|
watchlist_id = inserted._mapping["id"]
|
|
|
|
await analyze_watchlist_item(
|
|
watchlist_id=watchlist_id,
|
|
user_id=current_user["id"],
|
|
ts_code=normalized_code,
|
|
name=normalized_name,
|
|
note=normalized_note,
|
|
watch_group=normalized_group,
|
|
cost_price=normalized_cost,
|
|
mode="manual",
|
|
)
|
|
|
|
return {"status": "ok", "message": "已加入自选并完成首次分析", "watchlist_id": watchlist_id}
|
|
|
|
|
|
@router.patch("/{watchlist_id}")
|
|
async def update_watchlist(watchlist_id: int, req: WatchlistUpdateRequest, current_user: dict = Depends(get_current_user)):
|
|
updates: list[str] = []
|
|
params: dict = {"id": watchlist_id, "uid": current_user["id"]}
|
|
|
|
if req.note is not None:
|
|
updates.append("note = :note")
|
|
params["note"] = req.note.strip()
|
|
if req.watch_group is not None:
|
|
normalized_group = req.watch_group.strip().lower()
|
|
if normalized_group not in WATCH_GROUPS:
|
|
raise HTTPException(status_code=400, detail="无效的自选分组")
|
|
updates.append("watch_group = :watch_group")
|
|
params["watch_group"] = normalized_group
|
|
if req.cost_price is not None:
|
|
updates.append("cost_price = :cost_price")
|
|
params["cost_price"] = req.cost_price if req.cost_price > 0 else None
|
|
|
|
if not updates:
|
|
raise HTTPException(status_code=400, detail="没有可更新的字段")
|
|
|
|
updates.append("updated_at = CURRENT_TIMESTAMP")
|
|
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
f"UPDATE user_watchlists SET {', '.join(updates)} "
|
|
"WHERE id = :id AND user_id = :uid AND COALESCE(is_active, 1) = 1"
|
|
),
|
|
params,
|
|
)
|
|
await db.commit()
|
|
|
|
if result.rowcount == 0:
|
|
raise HTTPException(status_code=404, detail="自选股不存在")
|
|
|
|
row = (await db.execute(
|
|
text(
|
|
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price "
|
|
"FROM user_watchlists "
|
|
"WHERE id = :id AND user_id = :uid"
|
|
),
|
|
{"id": watchlist_id, "uid": current_user["id"]},
|
|
)).fetchone()
|
|
|
|
item = row._mapping
|
|
return {"status": "ok", "item": dict(item)}
|
|
|
|
|
|
@router.delete("/{watchlist_id}")
|
|
async def delete_watchlist(watchlist_id: int, current_user: dict = Depends(get_current_user)):
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
text(
|
|
"UPDATE user_watchlists SET is_active = 0 "
|
|
"WHERE id = :id AND user_id = :uid"
|
|
),
|
|
{"id": watchlist_id, "uid": current_user["id"]},
|
|
)
|
|
await db.commit()
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/{watchlist_id}/analyze")
|
|
async def analyze_single_watchlist(watchlist_id: int, current_user: dict = Depends(get_current_user)):
|
|
async with get_db() as db:
|
|
row = (await db.execute(
|
|
text(
|
|
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price FROM user_watchlists "
|
|
"WHERE id = :id AND user_id = :uid AND COALESCE(is_active, 1) = 1"
|
|
),
|
|
{"id": watchlist_id, "uid": current_user["id"]},
|
|
)).fetchone()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="自选股不存在")
|
|
|
|
item = row._mapping
|
|
result = await analyze_watchlist_item(
|
|
watchlist_id=item["id"],
|
|
user_id=item["user_id"],
|
|
ts_code=item["ts_code"],
|
|
name=item["name"],
|
|
note=item["note"] or "",
|
|
watch_group=item["watch_group"] or "observe",
|
|
cost_price=item["cost_price"],
|
|
mode="manual",
|
|
)
|
|
return {"status": "ok", "result": result}
|
|
|
|
|
|
@router.post("/analyze-all")
|
|
async def analyze_all_watchlists(current_user: dict = Depends(get_current_user)):
|
|
async with get_db() as db:
|
|
rows = (await db.execute(
|
|
text(
|
|
"SELECT id, user_id, ts_code, name, note, watch_group, cost_price FROM user_watchlists "
|
|
"WHERE user_id = :uid AND COALESCE(is_active, 1) = 1"
|
|
),
|
|
{"uid": current_user["id"]},
|
|
)).fetchall()
|
|
|
|
count = 0
|
|
for row in rows:
|
|
item = row._mapping
|
|
await analyze_watchlist_item(
|
|
watchlist_id=item["id"],
|
|
user_id=item["user_id"],
|
|
ts_code=item["ts_code"],
|
|
name=item["name"],
|
|
note=item["note"] or "",
|
|
watch_group=item["watch_group"] or "observe",
|
|
cost_price=item["cost_price"],
|
|
mode="manual",
|
|
)
|
|
count += 1
|
|
return {"status": "ok", "count": count, "message": f"已完成 {count} 条自选股分析"}
|
|
|
|
|
|
@router.get("/{watchlist_id}/history")
|
|
async def watchlist_history(watchlist_id: int, current_user: dict = Depends(get_current_user)):
|
|
async with get_db() as db:
|
|
rows = (await db.execute(
|
|
text(
|
|
"SELECT a.* FROM watchlist_analyses a "
|
|
"INNER JOIN user_watchlists w ON w.id = a.watchlist_id "
|
|
"WHERE a.watchlist_id = :wid AND w.user_id = :uid "
|
|
"ORDER BY a.created_at DESC LIMIT 20"
|
|
),
|
|
{"wid": watchlist_id, "uid": current_user["id"]},
|
|
)).fetchall()
|
|
return [dict(row._mapping) for row in rows]
|