Initial commit: 浼佷笟寰俊 AI 鏈哄櫒浜哄姪鐞?MVP
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
11
backend/Dockerfile
Normal file
11
backend/Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY . .
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
EXPOSE 8000
|
||||
CMD ["sh", "-c", "alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000"]
|
||||
40
backend/alembic.ini
Normal file
40
backend/alembic.ini
Normal file
@@ -0,0 +1,40 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
45
backend/alembic/env.py
Normal file
45
backend/alembic/env.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy import pool
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Base
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url_sync)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with connectable.connect() as connection:
|
||||
do_run_migrations(connection)
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
backend/alembic/script.py.mako
Normal file
26
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
52
backend/alembic/versions/001_users_and_audit_logs.py
Normal file
52
backend/alembic/versions/001_users_and_audit_logs.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""users and audit_logs
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2025-02-05
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("username", sa.String(64), nullable=False),
|
||||
sa.Column("password_hash", sa.String(256), nullable=False),
|
||||
sa.Column("role", sa.String(32), nullable=False, server_default="admin"),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"audit_logs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("actor_user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("action", sa.String(128), nullable=False),
|
||||
sa.Column("meta_json", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["actor_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_audit_logs_actor_user_id"), "audit_logs", ["actor_user_id"], unique=False)
|
||||
op.create_index(op.f("ix_audit_logs_action"), "audit_logs", ["action"], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_audit_logs_action"), table_name="audit_logs")
|
||||
op.drop_index(op.f("ix_audit_logs_actor_user_id"), table_name="audit_logs")
|
||||
op.drop_table("audit_logs")
|
||||
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||
op.drop_table("users")
|
||||
26
backend/alembic/versions/002_stamp.py
Normal file
26
backend/alembic/versions/002_stamp.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""stamp 002 (empty migration to match DB state)
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2025-02-05
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 空迁移:仅用于对齐数据库中的版本号
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# 空迁移:仅用于对齐数据库中的版本号
|
||||
pass
|
||||
36
backend/alembic/versions/003_add_missing_columns.py
Normal file
36
backend/alembic/versions/003_add_missing_columns.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""add missing columns if users table exists without them
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2025-02-05
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 检查 users 表是否存在 role 列,若不存在则添加
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [c["name"] for c in inspector.get_columns("users")] if inspector.has_table("users") else []
|
||||
|
||||
if "users" in inspector.get_table_names():
|
||||
if "role" not in columns:
|
||||
op.add_column("users", sa.Column("role", sa.String(32), nullable=False, server_default="admin"))
|
||||
if "is_active" not in columns:
|
||||
op.add_column("users", sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"))
|
||||
if "created_at" not in columns:
|
||||
op.add_column("users", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# 可选:移除这些列(通常不需要)
|
||||
pass
|
||||
59
backend/alembic/versions/004_chat_sessions_and_messages.py
Normal file
59
backend/alembic/versions/004_chat_sessions_and_messages.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Create chat_sessions and messages tables.
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2025-02-05 15:30:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
from sqlalchemy import inspect
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if "chat_sessions" not in tables:
|
||||
op.create_table(
|
||||
"chat_sessions",
|
||||
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
|
||||
sa.Column("external_user_id", sa.String(128), nullable=False),
|
||||
sa.Column("external_name", sa.String(128), nullable=True),
|
||||
sa.Column("status", sa.String(32), nullable=False, server_default="open"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_chat_sessions_external_user_id"), "chat_sessions", ["external_user_id"], unique=False)
|
||||
|
||||
if "messages" not in tables:
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
|
||||
sa.Column("session_id", sa.Integer(), nullable=False),
|
||||
sa.Column("role", sa.String(16), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["session_id"], ["chat_sessions.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_messages_session_id"), "messages", ["session_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_messages_session_id"), table_name="messages")
|
||||
op.drop_table("messages")
|
||||
op.drop_index(op.f("ix_chat_sessions_external_user_id"), table_name="chat_sessions")
|
||||
op.drop_table("chat_sessions")
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# backend app
|
||||
28
backend/app/config.py
Normal file
28
backend/app/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
database_url: str = "postgresql+asyncpg://wecom:wecom_secret@localhost:5432/wecom_ai"
|
||||
database_url_sync: str = "postgresql://wecom:wecom_secret@localhost:5432/wecom_ai"
|
||||
|
||||
jwt_secret: str = "change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expire_minutes: int = 60
|
||||
|
||||
wecom_corp_id: str = ""
|
||||
wecom_agent_id: str = ""
|
||||
wecom_secret: str = ""
|
||||
wecom_token: str = ""
|
||||
wecom_encoding_aes_key: str = ""
|
||||
wecom_api_base: str = "https://qyapi.weixin.qq.com"
|
||||
wecom_api_timeout: int = 10
|
||||
wecom_api_retries: int = 2
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
|
||||
|
||||
settings = Settings()
|
||||
27
backend/app/database.py
Normal file
27
backend/app/database.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""异步数据库会话,DATABASE_URL 来自环境变量。"""
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False, autoflush=False
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
30
backend/app/deps.py
Normal file
30
backend/app/deps.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""依赖:get_db、JWT 校验。"""
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import User
|
||||
from app.services.auth_service import decode_access_token
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
) -> User:
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
|
||||
subject = decode_access_token(credentials.credentials)
|
||||
if not subject:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或已过期的 token")
|
||||
# subject 存 username
|
||||
r = await db.execute(select(User).where(User.username == subject))
|
||||
user = r.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="账号已禁用")
|
||||
return user
|
||||
49
backend/app/logging_config.py
Normal file
49
backend/app/logging_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""结构化 JSON 日志 + trace_id。"""
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
from pythonjsonlogger import jsonlogger
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar("trace_id", default="")
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
t = trace_id_var.get()
|
||||
if not t:
|
||||
t = str(uuid.uuid4())
|
||||
trace_id_var.set(t)
|
||||
return t
|
||||
|
||||
|
||||
def set_trace_id(tid: str) -> None:
|
||||
trace_id_var.set(tid)
|
||||
|
||||
|
||||
class TraceIdFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.trace_id = get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
class JsonFormatter(jsonlogger.JsonFormatter):
|
||||
def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None:
|
||||
super().add_fields(log_record, record, message_dict)
|
||||
log_record["trace_id"] = getattr(record, "trace_id", "")
|
||||
log_record["level"] = record.levelname
|
||||
if record.exc_info:
|
||||
log_record["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO", log_json: bool = True) -> None:
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
if log_json:
|
||||
handler.setFormatter(JsonFormatter("%(timestamp)s %(level)s %(message)s %(trace_id)s"))
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s [%(trace_id)s] %(message)s"))
|
||||
handler.addFilter(TraceIdFilter())
|
||||
root.addHandler(handler)
|
||||
root.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||
68
backend/app/main.py
Normal file
68
backend/app/main.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.routers import auth, wecom
|
||||
from app.routers.admin import sessions, tickets, kb, settings, users
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
app = FastAPI(title="企微AI助手", version="0.1.0")
|
||||
|
||||
# CORS 必须在最前面
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"code": exc.status_code, "message": str(exc.detail), "data": None, "trace_id": get_trace_id()},
|
||||
headers={"Access-Control-Allow-Origin": "*"},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={"code": 422, "message": "参数校验失败", "data": exc.errors(), "trace_id": get_trace_id()},
|
||||
headers={"Access-Control-Allow-Origin": "*"},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": 500, "message": str(exc), "data": None, "trace_id": get_trace_id()},
|
||||
headers={"Access-Control-Allow-Origin": "*"},
|
||||
)
|
||||
|
||||
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(wecom.router, prefix="/api/wecom", tags=["wecom"])
|
||||
app.include_router(sessions.router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(tickets.router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(kb.router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(settings.router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(users.router, prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {"status": "up", "service": "backend"}
|
||||
|
||||
|
||||
@app.get("/api/ready")
|
||||
def ready():
|
||||
return {"ready": True, "service": "backend"}
|
||||
7
backend/app/models/__init__.py
Normal file
7
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from app.models.base import Base
|
||||
from app.models.user import User
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.session import ChatSession
|
||||
from app.models.message import Message
|
||||
|
||||
__all__ = ["Base", "User", "AuditLog", "ChatSession", "Message"]
|
||||
23
backend/app/models/audit_log.py
Normal file
23
backend/app/models/audit_log.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""审计日志(最简):id、actor_user_id、action、meta_json、created_at。"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
actor_user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
|
||||
meta_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
5
backend/app/models/base.py
Normal file
5
backend/app/models/base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
18
backend/app/models/message.py
Normal file
18
backend/app/models/message.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""单条消息(仅存 public 可见内容,隔离内部信息)。"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
session_id: Mapped[int] = mapped_column(ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user / assistant / system
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
|
||||
session: Mapped["ChatSession"] = relationship("ChatSession", back_populates="messages")
|
||||
19
backend/app/models/session.py
Normal file
19
backend/app/models/session.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""外部客户会话(企微单聊/群聊维度)。"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_sessions"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
external_user_id: Mapped[str] = mapped_column(String(128), nullable=False, index=True) # 企微 external_userid
|
||||
external_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="open") # open / transferred / closed
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
messages: Mapped[list["Message"]] = relationship("Message", back_populates="session", order_by="Message.id")
|
||||
17
backend/app/models/ticket.py
Normal file
17
backend/app/models/ticket.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""转人工工单。"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class Ticket(Base):
|
||||
__tablename__ = "tickets"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
session_id: Mapped[int] = mapped_column(ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="open") # open / handling / closed
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
22
backend/app/models/user.py
Normal file
22
backend/app/models/user.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""后台用户:id(uuid)、username、password_hash、role、is_active、created_at。"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Boolean, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), nullable=False, default="admin")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
2
backend/app/routers/__init__.py
Normal file
2
backend/app/routers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# routers
|
||||
from app.routers import auth, wecom
|
||||
1
backend/app/routers/admin/__init__.py
Normal file
1
backend/app/routers/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# admin routers
|
||||
31
backend/app/routers/admin/kb.py
Normal file
31
backend/app/routers/admin/kb.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Admin API: GET/POST /api/admin/kb/docs/upload。"""
|
||||
from fastapi import APIRouter, Depends, UploadFile
|
||||
from app.deps import get_current_user
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/kb/docs")
|
||||
async def list_kb_docs(_user=Depends(get_current_user)):
|
||||
"""知识库文档列表(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": [
|
||||
{"id": "1", "filename": "faq.pdf", "size": 102400, "uploaded_at": "2025-02-05T10:00:00Z"},
|
||||
],
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/kb/docs/upload")
|
||||
async def upload_kb_doc(file: UploadFile, _user=Depends(get_current_user)):
|
||||
"""上传知识库文档(占位:先存本地/对象存储)。"""
|
||||
# 占位:实际应保存到对象存储或本地卷
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"id": "new_1", "filename": file.filename, "size": 0},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
40
backend/app/routers/admin/sessions.py
Normal file
40
backend/app/routers/admin/sessions.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Admin API: GET /api/admin/sessions, GET /api/admin/sessions/{id}。"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from app.deps import get_current_user
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(_user=Depends(get_current_user)):
|
||||
"""会话列表(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": [
|
||||
{"id": "1", "external_user_id": "ext_001", "external_name": "客户A", "status": "open", "created_at": "2025-02-05T10:00:00Z"},
|
||||
{"id": "2", "external_user_id": "ext_002", "external_name": "客户B", "status": "transferred", "created_at": "2025-02-05T11:00:00Z"},
|
||||
],
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{id}")
|
||||
async def get_session(id: str, _user=Depends(get_current_user)):
|
||||
"""会话详情:消息列表(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {
|
||||
"id": id,
|
||||
"external_user_id": "ext_001",
|
||||
"external_name": "客户A",
|
||||
"status": "open",
|
||||
"messages": [
|
||||
{"id": 1, "role": "user", "content": "你好", "created_at": "2025-02-05T10:00:00Z"},
|
||||
{"id": 2, "role": "assistant", "content": "您好,有什么可以帮您?", "created_at": "2025-02-05T10:00:01Z"},
|
||||
],
|
||||
},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
37
backend/app/routers/admin/settings.py
Normal file
37
backend/app/routers/admin/settings.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Admin API: GET/PATCH /api/admin/settings。"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from app.deps import get_current_user
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UpdateSettingsBody(BaseModel):
|
||||
model_name: str | None = None
|
||||
strategy: dict | None = None
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_settings(_user=Depends(get_current_user)):
|
||||
"""获取设置(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {
|
||||
"model_name": "gpt-4",
|
||||
"strategy": {"faq_priority": True, "rag_enabled": False},
|
||||
},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
async def update_settings(body: UpdateSettingsBody, _user=Depends(get_current_user)):
|
||||
"""更新设置(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"model_name": body.model_name or "gpt-4", "strategy": body.strategy or {}},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
52
backend/app/routers/admin/tickets.py
Normal file
52
backend/app/routers/admin/tickets.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Admin API: GET/POST/PATCH /api/admin/tickets。"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from app.deps import get_current_user
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateTicketBody(BaseModel):
|
||||
session_id: str
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class UpdateTicketBody(BaseModel):
|
||||
status: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
@router.get("/tickets")
|
||||
async def list_tickets(_user=Depends(get_current_user)):
|
||||
"""工单列表(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": [
|
||||
{"id": "1", "session_id": "1", "reason": "转人工", "status": "open", "created_at": "2025-02-05T10:00:00Z"},
|
||||
],
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tickets")
|
||||
async def create_ticket(body: CreateTicketBody, _user=Depends(get_current_user)):
|
||||
"""创建工单(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"id": "new_1", "session_id": body.session_id, "reason": body.reason, "status": "open"},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/tickets/{id}")
|
||||
async def update_ticket(id: str, body: UpdateTicketBody, _user=Depends(get_current_user)):
|
||||
"""更新工单(占位)。"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"id": id, "status": body.status or "open", "reason": body.reason},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
74
backend/app/routers/admin/users.py
Normal file
74
backend/app/routers/admin/users.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Admin API: GET/POST/PATCH/DELETE /api/admin/users(仅管理员可见)。"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from app.deps import get_current_user
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateUserBody(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
role: str = "admin"
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class UpdateUserBody(BaseModel):
|
||||
password: str | None = None
|
||||
role: str | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(current_user=Depends(get_current_user)):
|
||||
"""用户列表(仅管理员)。"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="仅管理员可访问")
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": [
|
||||
{"id": "1", "username": "admin", "role": "admin", "is_active": True, "created_at": "2025-02-05T10:00:00Z"},
|
||||
],
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def create_user(body: CreateUserBody, current_user=Depends(get_current_user)):
|
||||
"""创建用户(仅管理员)。"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="仅管理员可访问")
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"id": "new_1", "username": body.username, "role": body.role, "is_active": body.is_active},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/users/{id}")
|
||||
async def update_user(id: str, body: UpdateUserBody, current_user=Depends(get_current_user)):
|
||||
"""更新用户(仅管理员)。"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="仅管理员可访问")
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"id": id, "role": body.role, "is_active": body.is_active},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/users/{id}")
|
||||
async def delete_user(id: str, current_user=Depends(get_current_user)):
|
||||
"""删除用户(仅管理员)。"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="仅管理员可访问")
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": None,
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
47
backend/app/routers/auth.py
Normal file
47
backend/app/routers/auth.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Auth API:POST /api/auth/login、GET /api/auth/me。"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.deps import get_current_user
|
||||
from app.models import User
|
||||
from app.services.auth_service import (
|
||||
get_user_by_username,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LoginBody(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(body: LoginBody, db: AsyncSession = Depends(get_db)):
|
||||
user = await get_user_by_username(db, body.username)
|
||||
if not user or not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="账号已禁用")
|
||||
token = create_access_token(subject=user.username)
|
||||
return LoginResponse(access_token=token, token_type="bearer")
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def me(current_user: User = Depends(get_current_user)):
|
||||
return {
|
||||
"id": str(current_user.id),
|
||||
"username": current_user.username,
|
||||
"role": current_user.role,
|
||||
"is_active": current_user.is_active,
|
||||
"created_at": current_user.created_at.isoformat() if current_user.created_at else None,
|
||||
}
|
||||
15
backend/app/routers/health.py
Normal file
15
backend/app/routers/health.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""健康检查:供负载均衡与 CI 验证。"""
|
||||
from fastapi import APIRouter
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"status": "up"},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
12
backend/app/routers/knowledge.py
Normal file
12
backend/app/routers/knowledge.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""知识库上传占位:先落本地卷/对象存储占位。"""
|
||||
from fastapi import APIRouter, UploadFile
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_file(file: UploadFile):
|
||||
"""上传知识库文件,占位落盘。"""
|
||||
# 占位:保存到 backend/uploads 或配置的存储
|
||||
return {"code": 0, "message": "ok", "data": {"filename": file.filename}, "trace_id": get_trace_id()}
|
||||
53
backend/app/routers/sessions.py
Normal file
53
backend/app/routers/sessions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""会话列表与消息:从 DB 读取,需登录。"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.deps import get_current_user_id
|
||||
from app.logging_config import get_trace_id
|
||||
from app.models import ChatSession, Message
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _session_row(s: ChatSession) -> dict:
|
||||
return {
|
||||
"id": s.id,
|
||||
"external_user_id": s.external_user_id,
|
||||
"external_name": s.external_name,
|
||||
"status": s.status,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _message_row(m: Message) -> dict:
|
||||
return {
|
||||
"id": m.id,
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_sessions(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""会话列表。"""
|
||||
r = await db.execute(select(ChatSession).order_by(ChatSession.updated_at.desc()))
|
||||
rows = r.scalars().all()
|
||||
return {"code": 0, "message": "ok", "data": [_session_row(s) for s in rows], "trace_id": get_trace_id()}
|
||||
|
||||
|
||||
@router.get("/{session_id}/messages")
|
||||
async def list_messages(
|
||||
session_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""某会话消息列表。"""
|
||||
r = await db.execute(select(Message).where(Message.session_id == session_id).order_by(Message.id))
|
||||
rows = r.scalars().all()
|
||||
return {"code": 0, "message": "ok", "data": [_message_row(m) for m in rows], "trace_id": get_trace_id()}
|
||||
15
backend/app/routers/settings.py
Normal file
15
backend/app/routers/settings.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""设置页占位:仅占位接口。"""
|
||||
from fastapi import APIRouter
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_settings():
|
||||
return {"code": 0, "message": "ok", "data": {}, "trace_id": get_trace_id()}
|
||||
|
||||
|
||||
@router.put("")
|
||||
def update_settings():
|
||||
return {"code": 0, "message": "ok", "data": None, "trace_id": get_trace_id()}
|
||||
67
backend/app/routers/tickets.py
Normal file
67
backend/app/routers/tickets.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""工单转人工:创建工单入库、手动回复调企业微信 API。"""
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.deps import get_current_user_id
|
||||
from app.logging_config import get_trace_id
|
||||
from app.models import ChatSession, Ticket
|
||||
from app.services.wecom_api import send_text_to_external
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateTicketBody(BaseModel):
|
||||
session_id: int
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class SendReplyBody(BaseModel):
|
||||
session_id: int
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_ticket(
|
||||
body: CreateTicketBody,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""创建转人工工单并更新会话状态。"""
|
||||
r = await db.execute(select(ChatSession).where(ChatSession.id == body.session_id))
|
||||
session = r.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
ticket = Ticket(session_id=body.session_id, reason=body.reason or None)
|
||||
db.add(ticket)
|
||||
session.status = "transferred"
|
||||
await db.flush()
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {"ticket_id": str(ticket.id)},
|
||||
"trace_id": get_trace_id(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/reply")
|
||||
async def send_reply(
|
||||
body: SendReplyBody,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: str = Depends(get_current_user_id),
|
||||
):
|
||||
"""手动回复:通过企业微信 API 发给客户。"""
|
||||
r = await db.execute(select(ChatSession).where(ChatSession.id == body.session_id))
|
||||
session = r.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
try:
|
||||
await send_text_to_external(session.external_user_id, body.content)
|
||||
except Exception as e:
|
||||
logger.exception("wecom send reply failed")
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
return {"code": 0, "message": "ok", "data": None, "trace_id": get_trace_id()}
|
||||
163
backend/app/routers/wecom.py
Normal file
163
backend/app/routers/wecom.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""企业微信回调:GET 校验 + POST 消息回调(验签、解密、echo 回复、会话入库)。"""
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from fastapi import APIRouter, Request, Query, Depends
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.session_service import get_or_create_session, add_message
|
||||
from app.services.wecom_crypto import (
|
||||
verify_and_decrypt_echostr,
|
||||
verify_signature,
|
||||
parse_encrypted_body,
|
||||
decrypt,
|
||||
parse_decrypted_xml,
|
||||
build_reply_xml,
|
||||
encrypt,
|
||||
make_reply_signature,
|
||||
build_encrypted_response,
|
||||
)
|
||||
from app.logging_config import get_trace_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
def wecom_verify(
|
||||
request: Request,
|
||||
signature: str = Query(None, alias="signature"),
|
||||
msg_signature: str = Query(None, alias="msg_signature"),
|
||||
timestamp: str = Query(..., alias="timestamp"),
|
||||
nonce: str = Query(..., alias="nonce"),
|
||||
echostr: str = Query(..., alias="echostr"),
|
||||
):
|
||||
"""企业微信 GET 验签:校验签名并解密 echostr,原样返回明文。
|
||||
兼容 signature 和 msg_signature 两种参数名。
|
||||
"""
|
||||
trace_id = get_trace_id()
|
||||
# 兼容 signature 和 msg_signature 两种参数名
|
||||
sig = msg_signature or signature
|
||||
if not sig:
|
||||
logger.warning(
|
||||
"wecom verify missing signature",
|
||||
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
plain = verify_and_decrypt_echostr(sig, timestamp, nonce, echostr)
|
||||
if plain is None:
|
||||
logger.warning(
|
||||
"wecom verify failed",
|
||||
extra={"trace_id": trace_id, "timestamp": timestamp, "nonce": nonce},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
logger.info(
|
||||
"wecom verify success",
|
||||
extra={"trace_id": trace_id, "echostr_length": len(echostr)},
|
||||
)
|
||||
return PlainTextResponse(plain)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def wecom_callback(
|
||||
request: Request,
|
||||
signature: str = Query(None, alias="signature"),
|
||||
msg_signature: str = Query(None, alias="msg_signature"),
|
||||
timestamp: str = Query(..., alias="timestamp"),
|
||||
nonce: str = Query(..., alias="nonce"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""POST 消息回调:验签、解密、会话与消息入库、文本 echo 回复。
|
||||
兼容 signature 和 msg_signature 两种参数名。
|
||||
"""
|
||||
trace_id = get_trace_id()
|
||||
# 兼容 signature 和 msg_signature 两种参数名
|
||||
sig = msg_signature or signature
|
||||
if not sig:
|
||||
logger.warning(
|
||||
"wecom post missing signature",
|
||||
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
body = await request.body()
|
||||
encrypt_raw, err = parse_encrypted_body(body)
|
||||
if err:
|
||||
logger.warning(
|
||||
"wecom post parse error",
|
||||
extra={"trace_id": trace_id, "error": err},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
if not verify_signature(sig, timestamp, nonce, encrypt_raw):
|
||||
logger.warning(
|
||||
"wecom post verify failed",
|
||||
extra={"trace_id": trace_id, "timestamp": timestamp},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
try:
|
||||
plain_xml = decrypt(encrypt_raw)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"wecom decrypt error",
|
||||
extra={"trace_id": trace_id, "error": str(e)},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
msg = parse_decrypted_xml(plain_xml)
|
||||
if not msg:
|
||||
logger.warning(
|
||||
"wecom xml parse failed",
|
||||
extra={"trace_id": trace_id},
|
||||
)
|
||||
return PlainTextResponse("", status_code=400)
|
||||
to_user = msg.get("ToUserName", "")
|
||||
from_user = msg.get("FromUserName", "") # external_userid
|
||||
msg_id = msg.get("MsgId", "")
|
||||
msg_type = msg.get("MsgType", "")
|
||||
content = (msg.get("Content") or "").strip()
|
||||
content_summary = content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
# 记录日志:trace_id + external_userid + msgid + 内容摘要
|
||||
logger.info(
|
||||
"wecom message received",
|
||||
extra={
|
||||
"trace_id": trace_id,
|
||||
"external_userid": from_user,
|
||||
"msgid": msg_id,
|
||||
"msg_type": msg_type,
|
||||
"content_summary": content_summary or "(empty)",
|
||||
},
|
||||
)
|
||||
|
||||
# 会话入库:external_user_id = from_user(客户)
|
||||
session = await get_or_create_session(db, from_user, msg.get("Contact"))
|
||||
await add_message(db, session.id, "user", content or "(非文本消息)")
|
||||
|
||||
# Echo 文本:回复"已收到:{用户消息}"
|
||||
if msg_type == "text" and content:
|
||||
reply_content = f"已收到:{content}"
|
||||
else:
|
||||
reply_content = "已收到"
|
||||
|
||||
await add_message(db, session.id, "assistant", reply_content)
|
||||
|
||||
# 回复给客户(被动回复 XML)
|
||||
reply_xml = build_reply_xml(from_user, to_user, reply_content)
|
||||
enc = encrypt(reply_xml)
|
||||
ts = str(int(time.time()))
|
||||
reply_nonce = "".join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
sig = make_reply_signature(enc, ts, reply_nonce)
|
||||
resp_xml = build_encrypted_response(enc, sig, ts, reply_nonce)
|
||||
|
||||
logger.info(
|
||||
"wecom reply sent",
|
||||
extra={
|
||||
"trace_id": trace_id,
|
||||
"external_userid": from_user,
|
||||
"msgid": msg_id,
|
||||
"reply_summary": reply_content[:50] + "..." if len(reply_content) > 50 else reply_content,
|
||||
},
|
||||
)
|
||||
|
||||
return PlainTextResponse(resp_xml, media_type="application/xml")
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# services
|
||||
39
backend/app/services/auth_service.py
Normal file
39
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""密码 bcrypt hash;JWT 创建与解码,带过期时间。"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models import User
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
def create_access_token(subject: str) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expire_minutes)
|
||||
to_encode = {"sub": subject, "exp": expire}
|
||||
return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> str | None:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
|
||||
return payload.get("sub")
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||
r = await db.execute(select(User).where(User.username == username))
|
||||
return r.scalar_one_or_none()
|
||||
30
backend/app/services/session_service.py
Normal file
30
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""会话与消息入库;仅存 public 可见内容,隔离内部信息。"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ChatSession, Message
|
||||
|
||||
|
||||
async def get_or_create_session(
|
||||
db: AsyncSession,
|
||||
external_user_id: str,
|
||||
external_name: str | None = None,
|
||||
) -> ChatSession:
|
||||
r = await db.execute(select(ChatSession).where(ChatSession.external_user_id == external_user_id))
|
||||
row = r.scalar_one_or_none()
|
||||
if row:
|
||||
if external_name is not None and row.external_name != external_name:
|
||||
row.external_name = external_name
|
||||
await db.flush()
|
||||
return row
|
||||
session = ChatSession(external_user_id=external_user_id, external_name=external_name or None)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def add_message(db: AsyncSession, session_id: int, role: str, content: str) -> Message:
|
||||
msg = Message(session_id=session_id, role=role, content=content)
|
||||
db.add(msg)
|
||||
await db.flush()
|
||||
return msg
|
||||
54
backend/app/services/wecom_api.py
Normal file
54
backend/app/services/wecom_api.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""企业微信 API 调用:超时与重试,配置来自环境变量。"""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
TIMEOUT = settings.wecom_api_timeout
|
||||
RETRIES = settings.wecom_api_retries
|
||||
BASE = settings.wecom_api_base.rstrip("/")
|
||||
|
||||
|
||||
async def _request(method: str, path: str, **kwargs: Any) -> dict | None:
|
||||
url = f"{BASE}{path}"
|
||||
for attempt in range(RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
r = await client.request(method, url, **kwargs)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.warning("wecom api attempt %s failed: %s", attempt + 1, e)
|
||||
if attempt == RETRIES:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
async def get_access_token() -> str:
|
||||
"""获取 corpid + secret 的 access_token。"""
|
||||
r = await _request(
|
||||
"GET",
|
||||
"/cgi-bin/gettoken",
|
||||
params={"corpid": settings.wecom_corp_id, "corpsecret": settings.wecom_secret},
|
||||
)
|
||||
if not r or r.get("errcode") != 0:
|
||||
raise RuntimeError(r.get("errmsg", "get token failed"))
|
||||
return r["access_token"]
|
||||
|
||||
|
||||
async def send_text_to_external(external_user_id: str, content: str) -> None:
|
||||
"""发送文本消息给外部联系人(客户联系-发送消息到客户)。"""
|
||||
token = await get_access_token()
|
||||
body = {
|
||||
"touser": [external_user_id],
|
||||
"sender": settings.wecom_agent_id,
|
||||
"msgtype": "text",
|
||||
"text": {"content": content},
|
||||
}
|
||||
# 企业微信文档:发送消息到客户 send_message_to_user
|
||||
r = await _request("POST", f"/cgi-bin/externalcontact/send_message_to_user?access_token={token}", json=body)
|
||||
if not r or r.get("errcode") != 0:
|
||||
raise RuntimeError(r.get("errmsg", "send failed"))
|
||||
119
backend/app/services/wecom_crypto.py
Normal file
119
backend/app/services/wecom_crypto.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""企业微信回调加解密与验签(与企微文档一致)。"""
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Tuple
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _sha1(s: str) -> str:
|
||||
return hashlib.sha1(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def _check_signature(signature: str, timestamp: str, nonce: str, echostr_or_encrypt: str) -> bool:
|
||||
token = settings.wecom_token
|
||||
lst = [token, timestamp, nonce, echostr_or_encrypt]
|
||||
lst.sort()
|
||||
return _sha1("".join(lst)) == signature
|
||||
|
||||
|
||||
def _aes_key() -> bytes:
|
||||
key_b64 = settings.wecom_encoding_aes_key + "="
|
||||
return base64.b64decode(key_b64)[:32]
|
||||
|
||||
|
||||
def decrypt(encrypt: str) -> str:
|
||||
"""解密企微回调密文(echostr 或 Encrypt 节点内容)。"""
|
||||
key = _aes_key()
|
||||
iv = key[:16]
|
||||
raw = base64.b64decode(encrypt)
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
dec = cipher.decrypt(raw)
|
||||
# 16 随机字节 + 4 字节长度(big-endian) + 消息 + corpid;先按长度取消息,避免 padding 差异
|
||||
msg_len = struct.unpack(">I", dec[16:20])[0]
|
||||
return dec[20 : 20 + msg_len].decode("utf-8")
|
||||
|
||||
|
||||
def encrypt(plain: str) -> str:
|
||||
"""加密回复内容(明文为 XML 或文本)。"""
|
||||
import os
|
||||
key = _aes_key()
|
||||
iv = key[:16]
|
||||
corpid = settings.wecom_corp_id or "placeholder"
|
||||
msg = plain.encode("utf-8")
|
||||
msg_len = struct.pack(">I", len(msg))
|
||||
rand = os.urandom(16)
|
||||
to_enc = rand + msg_len + msg + corpid.encode("utf-8")
|
||||
from Crypto.Util.Padding import pad
|
||||
to_enc = pad(to_enc, 16)
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
enc = cipher.encrypt(to_enc)
|
||||
return base64.b64encode(enc).decode("ascii")
|
||||
|
||||
|
||||
def verify_signature(msg_signature: str, timestamp: str, nonce: str, encrypt: str) -> bool:
|
||||
"""校验签名(GET 或 POST 的 Encrypt)。"""
|
||||
return _check_signature(msg_signature, timestamp, nonce, encrypt)
|
||||
|
||||
|
||||
def verify_and_decrypt_echostr(msg_signature: str, timestamp: str, nonce: str, echostr: str) -> str | None:
|
||||
"""GET 校验:验签并解密 echostr,返回明文;失败返回 None。"""
|
||||
if not verify_signature(msg_signature, timestamp, nonce, echostr):
|
||||
return None
|
||||
return decrypt(echostr)
|
||||
|
||||
|
||||
def parse_encrypted_body(body: bytes) -> Tuple[str | None, str | None]:
|
||||
"""解析 POST 请求体 XML,取 Encrypt;验签用 msg_signature/timestamp/nonce 从 query 传。返回 (encrypt_raw, None) 或 (None, error)。"""
|
||||
try:
|
||||
root = ET.fromstring(body)
|
||||
encrypt_el = root.find("Encrypt")
|
||||
if encrypt_el is None or encrypt_el.text is None:
|
||||
return None, "missing Encrypt"
|
||||
return encrypt_el.text.strip(), None
|
||||
except Exception as e:
|
||||
return None, str(e)
|
||||
|
||||
|
||||
def parse_decrypted_xml(plain_xml: str) -> dict | None:
|
||||
"""解密后的 XML 解析为 dict(ToUserName, FromUserName, MsgType, Content 等)。"""
|
||||
try:
|
||||
root = ET.fromstring(plain_xml)
|
||||
d = {}
|
||||
for c in root:
|
||||
if c.text:
|
||||
d[c.tag] = c.text
|
||||
return d
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def build_reply_xml(to_user: str, from_user: str, content: str) -> str:
|
||||
"""构造文本回复 XML(明文)。"""
|
||||
return f"""<xml>
|
||||
<ToUserName><![CDATA[{to_user}]]></ToUserName>
|
||||
<FromUserName><![CDATA[{from_user}]]></FromUserName>
|
||||
<CreateTime>{int(__import__("time").time())}</CreateTime>
|
||||
<MsgType><![CDATA[text]]></MsgType>
|
||||
<Content><![CDATA[{content}]]></Content>
|
||||
</xml>"""
|
||||
|
||||
|
||||
def make_reply_signature(encrypt: str, timestamp: str, nonce: str) -> str:
|
||||
lst = [settings.wecom_token, timestamp, nonce, encrypt]
|
||||
lst.sort()
|
||||
return _sha1("".join(lst))
|
||||
|
||||
|
||||
def build_encrypted_response(encrypt: str, signature: str, timestamp: str, nonce: str) -> str:
|
||||
"""构造 POST 回复的加密 XML。"""
|
||||
return f"""<xml>
|
||||
<Encrypt><![CDATA[{encrypt}]]></Encrypt>
|
||||
<MsgSignature><![CDATA[{signature}]]></MsgSignature>
|
||||
<TimeStamp>{timestamp}</TimeStamp>
|
||||
<Nonce><![CDATA[{nonce}]]></Nonce>
|
||||
</xml>"""
|
||||
26
backend/config.py
Normal file
26
backend/config.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
database_url: str = "postgresql+asyncpg://wecom:wecom_secret@localhost:5432/wecom_ai"
|
||||
database_url_sync: str = "postgresql://wecom:wecom_secret@localhost:5432/wecom_ai"
|
||||
jwt_secret: str = "change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expire_minutes: int = 60
|
||||
wecom_corp_id: str = ""
|
||||
wecom_agent_id: str = ""
|
||||
wecom_secret: str = ""
|
||||
wecom_token: str = ""
|
||||
wecom_encoding_aes_key: str = ""
|
||||
wecom_api_base: str = "https://qyapi.weixin.qq.com"
|
||||
wecom_api_timeout: int = 10
|
||||
wecom_api_retries: int = 2
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
|
||||
|
||||
settings = Settings()
|
||||
4
backend/pyproject.toml
Normal file
4
backend/pyproject.toml
Normal file
@@ -0,0 +1,4 @@
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
22
backend/requirements.txt
Normal file
22
backend/requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.32.1
|
||||
pydantic-settings==2.6.1
|
||||
python-multipart==0.0.9
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# DB
|
||||
sqlalchemy[asyncio]==2.0.36
|
||||
asyncpg==0.30.0
|
||||
alembic==1.14.0
|
||||
psycopg2-binary==2.9.10
|
||||
|
||||
# Auth
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
bcrypt==4.1.2
|
||||
|
||||
# Logging
|
||||
python-json-logger==2.0.7
|
||||
|
||||
# WeCom Crypto
|
||||
pycryptodome==3.21.0
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# tests
|
||||
26
backend/tests/test_auth.py
Normal file
26
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""登录接口测试。"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_login_fail_wrong_password():
|
||||
r = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_login_fail_wrong_user():
|
||||
r = client.post("/api/auth/login", json={"username": "nobody", "password": "admin"})
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
def test_login_returns_json():
|
||||
"""无 DB 时可能 401;有 DB 且 admin 存在时 200。仅断言响应为 JSON 且含 code。"""
|
||||
r = client.post("/api/auth/login", json={"username": "admin", "password": "admin"})
|
||||
assert r.headers.get("content-type", "").startswith("application/json")
|
||||
data = r.json()
|
||||
assert "code" in data
|
||||
assert "trace_id" in data
|
||||
16
backend/tests/test_health.py
Normal file
16
backend/tests/test_health.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Health 接口测试。"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_health():
|
||||
r = client.get("/api/health")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data.get("code") == 0
|
||||
assert data.get("data", {}).get("status") == "up"
|
||||
assert "trace_id" in data
|
||||
26
backend/tests/test_wecom.py
Normal file
26
backend/tests/test_wecom.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""企业微信回调验签逻辑测试(不依赖真实 Token/Key)。"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.services.wecom_crypto import (
|
||||
verify_signature,
|
||||
verify_and_decrypt_echostr,
|
||||
_sha1,
|
||||
)
|
||||
|
||||
|
||||
def test_sha1():
|
||||
h = _sha1("abc")
|
||||
assert len(h) == 40
|
||||
assert h == "a9993e364706816aba3e25717850c26c9cd0d89d"
|
||||
|
||||
|
||||
def test_verify_signature():
|
||||
# 用固定 token 时,签名为 sha1(sort(token, ts, nonce, encrypt))
|
||||
with patch("app.services.wecom_crypto.settings") as s:
|
||||
s.wecom_token = "mytoken"
|
||||
lst = ["mytoken", "123", "456", "echostr"]
|
||||
lst.sort()
|
||||
expected = _sha1("".join(lst))
|
||||
assert verify_signature(expected, "123", "456", "echostr") is True
|
||||
assert verify_signature("wrong", "123", "456", "echostr") is False
|
||||
Reference in New Issue
Block a user