Files
wecom-ai-assistant/backend/app/routers/wecom.py
2026-02-05 16:36:32 +08:00

164 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""企业微信回调GET 校验 + POST 消息回调验签、解密、echo 回复、会话入库)。"""
import time
import logging
import random
import string
from fastapi import APIRouter, Request, Query, Depends
from fastapi.responses import PlainTextResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.services.session_service import get_or_create_session, add_message
from app.services.wecom_crypto import (
verify_and_decrypt_echostr,
verify_signature,
parse_encrypted_body,
decrypt,
parse_decrypted_xml,
build_reply_xml,
encrypt,
make_reply_signature,
build_encrypted_response,
)
from app.logging_config import get_trace_id
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/callback")
def wecom_verify(
request: Request,
signature: str = Query(None, alias="signature"),
msg_signature: str = Query(None, alias="msg_signature"),
timestamp: str = Query(..., alias="timestamp"),
nonce: str = Query(..., alias="nonce"),
echostr: str = Query(..., alias="echostr"),
):
"""企业微信 GET 验签:校验签名并解密 echostr原样返回明文。
兼容 signature 和 msg_signature 两种参数名。
"""
trace_id = get_trace_id()
# 兼容 signature 和 msg_signature 两种参数名
sig = msg_signature or signature
if not sig:
logger.warning(
"wecom verify missing signature",
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
)
return PlainTextResponse("", status_code=400)
plain = verify_and_decrypt_echostr(sig, timestamp, nonce, echostr)
if plain is None:
logger.warning(
"wecom verify failed",
extra={"trace_id": trace_id, "timestamp": timestamp, "nonce": nonce},
)
return PlainTextResponse("", status_code=400)
logger.info(
"wecom verify success",
extra={"trace_id": trace_id, "echostr_length": len(echostr)},
)
return PlainTextResponse(plain)
@router.post("/callback")
async def wecom_callback(
request: Request,
signature: str = Query(None, alias="signature"),
msg_signature: str = Query(None, alias="msg_signature"),
timestamp: str = Query(..., alias="timestamp"),
nonce: str = Query(..., alias="nonce"),
db: AsyncSession = Depends(get_db),
):
"""POST 消息回调:验签、解密、会话与消息入库、文本 echo 回复。
兼容 signature 和 msg_signature 两种参数名。
"""
trace_id = get_trace_id()
# 兼容 signature 和 msg_signature 两种参数名
sig = msg_signature or signature
if not sig:
logger.warning(
"wecom post missing signature",
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
)
return PlainTextResponse("", status_code=400)
body = await request.body()
encrypt_raw, err = parse_encrypted_body(body)
if err:
logger.warning(
"wecom post parse error",
extra={"trace_id": trace_id, "error": err},
)
return PlainTextResponse("", status_code=400)
if not verify_signature(sig, timestamp, nonce, encrypt_raw):
logger.warning(
"wecom post verify failed",
extra={"trace_id": trace_id, "timestamp": timestamp},
)
return PlainTextResponse("", status_code=400)
try:
plain_xml = decrypt(encrypt_raw)
except Exception as e:
logger.warning(
"wecom decrypt error",
extra={"trace_id": trace_id, "error": str(e)},
)
return PlainTextResponse("", status_code=400)
msg = parse_decrypted_xml(plain_xml)
if not msg:
logger.warning(
"wecom xml parse failed",
extra={"trace_id": trace_id},
)
return PlainTextResponse("", status_code=400)
to_user = msg.get("ToUserName", "")
from_user = msg.get("FromUserName", "") # external_userid
msg_id = msg.get("MsgId", "")
msg_type = msg.get("MsgType", "")
content = (msg.get("Content") or "").strip()
content_summary = content[:50] + "..." if len(content) > 50 else content
# 记录日志trace_id + external_userid + msgid + 内容摘要
logger.info(
"wecom message received",
extra={
"trace_id": trace_id,
"external_userid": from_user,
"msgid": msg_id,
"msg_type": msg_type,
"content_summary": content_summary or "(empty)",
},
)
# 会话入库external_user_id = from_user客户
session = await get_or_create_session(db, from_user, msg.get("Contact"))
await add_message(db, session.id, "user", content or "(非文本消息)")
# Echo 文本:回复"已收到:{用户消息}"
if msg_type == "text" and content:
reply_content = f"已收到:{content}"
else:
reply_content = "已收到"
await add_message(db, session.id, "assistant", reply_content)
# 回复给客户(被动回复 XML
reply_xml = build_reply_xml(from_user, to_user, reply_content)
enc = encrypt(reply_xml)
ts = str(int(time.time()))
reply_nonce = "".join(random.choices(string.ascii_letters + string.digits, k=16))
sig = make_reply_signature(enc, ts, reply_nonce)
resp_xml = build_encrypted_response(enc, sig, ts, reply_nonce)
logger.info(
"wecom reply sent",
extra={
"trace_id": trace_id,
"external_userid": from_user,
"msgid": msg_id,
"reply_summary": reply_content[:50] + "..." if len(reply_content) > 50 else reply_content,
},
)
return PlainTextResponse(resp_xml, media_type="application/xml")