Initial commit: 浼佷笟寰俊 AI 鏈哄櫒浜哄姪鐞?MVP

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
bujie9527
2026-02-05 16:36:32 +08:00
commit 59275ed4dc
126 changed files with 9120 additions and 0 deletions

11
backend/Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM python:3.12-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENV PYTHONPATH=/app
EXPOSE 8000
CMD ["sh", "-c", "alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000"]

40
backend/alembic.ini Normal file
View File

@@ -0,0 +1,40 @@
[alembic]
script_location = alembic
prepend_sys_path = .
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

45
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,45 @@
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config
from sqlalchemy.engine import Connection
from sqlalchemy import pool
from app.config import settings
from app.models import Base
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
config.set_main_option("sqlalchemy.url", settings.database_url_sync)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
do_run_migrations(connection)
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,52 @@
"""users and audit_logs
Revision ID: 001
Revises:
Create Date: 2025-02-05
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = "001"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("username", sa.String(64), nullable=False),
sa.Column("password_hash", sa.String(256), nullable=False),
sa.Column("role", sa.String(32), nullable=False, server_default="admin"),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
op.create_table(
"audit_logs",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("actor_user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("action", sa.String(128), nullable=False),
sa.Column("meta_json", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True),
sa.ForeignKeyConstraint(["actor_user_id"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_audit_logs_actor_user_id"), "audit_logs", ["actor_user_id"], unique=False)
op.create_index(op.f("ix_audit_logs_action"), "audit_logs", ["action"], unique=False)
def downgrade() -> None:
op.drop_index(op.f("ix_audit_logs_action"), table_name="audit_logs")
op.drop_index(op.f("ix_audit_logs_actor_user_id"), table_name="audit_logs")
op.drop_table("audit_logs")
op.drop_index(op.f("ix_users_username"), table_name="users")
op.drop_table("users")

View File

@@ -0,0 +1,26 @@
"""stamp 002 (empty migration to match DB state)
Revision ID: 002
Revises: 001
Create Date: 2025-02-05
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 空迁移:仅用于对齐数据库中的版本号
pass
def downgrade() -> None:
# 空迁移:仅用于对齐数据库中的版本号
pass

View File

@@ -0,0 +1,36 @@
"""add missing columns if users table exists without them
Revision ID: 003
Revises: 002
Create Date: 2025-02-05
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "003"
down_revision: Union[str, None] = "002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 检查 users 表是否存在 role 列,若不存在则添加
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [c["name"] for c in inspector.get_columns("users")] if inspector.has_table("users") else []
if "users" in inspector.get_table_names():
if "role" not in columns:
op.add_column("users", sa.Column("role", sa.String(32), nullable=False, server_default="admin"))
if "is_active" not in columns:
op.add_column("users", sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"))
if "created_at" not in columns:
op.add_column("users", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True))
def downgrade() -> None:
# 可选:移除这些列(通常不需要)
pass

View File

@@ -0,0 +1,59 @@
"""Create chat_sessions and messages tables.
Revision ID: 004
Revises: 003
Create Date: 2025-02-05 15:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "004"
down_revision: Union[str, None] = "003"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
tables = inspector.get_table_names()
if "chat_sessions" not in tables:
op.create_table(
"chat_sessions",
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("external_user_id", sa.String(128), nullable=False),
sa.Column("external_name", sa.String(128), nullable=True),
sa.Column("status", sa.String(32), nullable=False, server_default="open"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_chat_sessions_external_user_id"), "chat_sessions", ["external_user_id"], unique=False)
if "messages" not in tables:
op.create_table(
"messages",
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("session_id", sa.Integer(), nullable=False),
sa.Column("role", sa.String(16), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["session_id"], ["chat_sessions.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_messages_session_id"), "messages", ["session_id"], unique=False)
def downgrade() -> None:
op.drop_index(op.f("ix_messages_session_id"), table_name="messages")
op.drop_table("messages")
op.drop_index(op.f("ix_chat_sessions_external_user_id"), table_name="chat_sessions")
op.drop_table("chat_sessions")

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
# backend app

28
backend/app/config.py Normal file
View File

@@ -0,0 +1,28 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
api_host: str = "0.0.0.0"
api_port: int = 8000
database_url: str = "postgresql+asyncpg://wecom:wecom_secret@localhost:5432/wecom_ai"
database_url_sync: str = "postgresql://wecom:wecom_secret@localhost:5432/wecom_ai"
jwt_secret: str = "change-me"
jwt_algorithm: str = "HS256"
jwt_expire_minutes: int = 60
wecom_corp_id: str = ""
wecom_agent_id: str = ""
wecom_secret: str = ""
wecom_token: str = ""
wecom_encoding_aes_key: str = ""
wecom_api_base: str = "https://qyapi.weixin.qq.com"
wecom_api_timeout: int = 10
wecom_api_retries: int = 2
log_level: str = "INFO"
log_json: bool = True
settings = Settings()

27
backend/app/database.py Normal file
View File

@@ -0,0 +1,27 @@
"""异步数据库会话DATABASE_URL 来自环境变量。"""
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
engine = create_async_engine(
settings.database_url,
echo=False,
pool_pre_ping=True,
)
async_session_factory = async_sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False, autoflush=False
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

30
backend/app/deps.py Normal file
View File

@@ -0,0 +1,30 @@
"""依赖get_db、JWT 校验。"""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models import User
from app.services.auth_service import decode_access_token
security = HTTPBearer(auto_error=False)
async def get_current_user(
db: AsyncSession = Depends(get_db),
credentials: HTTPAuthorizationCredentials | None = Depends(security),
) -> User:
if not credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
subject = decode_access_token(credentials.credentials)
if not subject:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或已过期的 token")
# subject 存 username
r = await db.execute(select(User).where(User.username == subject))
user = r.scalar_one_or_none()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="账号已禁用")
return user

View File

@@ -0,0 +1,49 @@
"""结构化 JSON 日志 + trace_id。"""
import logging
import sys
import uuid
from contextvars import ContextVar
from pythonjsonlogger import jsonlogger
trace_id_var: ContextVar[str] = ContextVar("trace_id", default="")
def get_trace_id() -> str:
t = trace_id_var.get()
if not t:
t = str(uuid.uuid4())
trace_id_var.set(t)
return t
def set_trace_id(tid: str) -> None:
trace_id_var.set(tid)
class TraceIdFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.trace_id = get_trace_id()
return True
class JsonFormatter(jsonlogger.JsonFormatter):
def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None:
super().add_fields(log_record, record, message_dict)
log_record["trace_id"] = getattr(record, "trace_id", "")
log_record["level"] = record.levelname
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
def setup_logging(log_level: str = "INFO", log_json: bool = True) -> None:
root = logging.getLogger()
root.handlers.clear()
handler = logging.StreamHandler(sys.stdout)
if log_json:
handler.setFormatter(JsonFormatter("%(timestamp)s %(level)s %(message)s %(trace_id)s"))
else:
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s [%(trace_id)s] %(message)s"))
handler.addFilter(TraceIdFilter())
root.addHandler(handler)
root.setLevel(getattr(logging, log_level.upper(), logging.INFO))

68
backend/app/main.py Normal file
View File

@@ -0,0 +1,68 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.routers import auth, wecom
from app.routers.admin import sessions, tickets, kb, settings, users
from app.logging_config import get_trace_id
app = FastAPI(title="企微AI助手", version="0.1.0")
# CORS 必须在最前面
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": str(exc.detail), "data": None, "trace_id": get_trace_id()},
headers={"Access-Control-Allow-Origin": "*"},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={"code": 422, "message": "参数校验失败", "data": exc.errors(), "trace_id": get_trace_id()},
headers={"Access-Control-Allow-Origin": "*"},
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
import traceback
traceback.print_exc()
return JSONResponse(
status_code=500,
content={"code": 500, "message": str(exc), "data": None, "trace_id": get_trace_id()},
headers={"Access-Control-Allow-Origin": "*"},
)
app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
app.include_router(wecom.router, prefix="/api/wecom", tags=["wecom"])
app.include_router(sessions.router, prefix="/api/admin", tags=["admin"])
app.include_router(tickets.router, prefix="/api/admin", tags=["admin"])
app.include_router(kb.router, prefix="/api/admin", tags=["admin"])
app.include_router(settings.router, prefix="/api/admin", tags=["admin"])
app.include_router(users.router, prefix="/api/admin", tags=["admin"])
@app.get("/api/health")
def health():
return {"status": "up", "service": "backend"}
@app.get("/api/ready")
def ready():
return {"ready": True, "service": "backend"}

View File

@@ -0,0 +1,7 @@
from app.models.base import Base
from app.models.user import User
from app.models.audit_log import AuditLog
from app.models.session import ChatSession
from app.models.message import Message
__all__ = ["Base", "User", "AuditLog", "ChatSession", "Message"]

View File

@@ -0,0 +1,23 @@
"""审计日志最简id、actor_user_id、action、meta_json、created_at。"""
import uuid
from datetime import datetime
from sqlalchemy import String, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import Base
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
actor_user_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True
)
action: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
meta_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)

View File

@@ -0,0 +1,5 @@
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass

View File

@@ -0,0 +1,18 @@
"""单条消息(仅存 public 可见内容,隔离内部信息)。"""
from datetime import datetime
from sqlalchemy import String, Text, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
class Message(Base):
__tablename__ = "messages"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
session_id: Mapped[int] = mapped_column(ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
role: Mapped[str] = mapped_column(String(16), nullable=False) # user / assistant / system
content: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
session: Mapped["ChatSession"] = relationship("ChatSession", back_populates="messages")

View File

@@ -0,0 +1,19 @@
"""外部客户会话(企微单聊/群聊维度)。"""
from datetime import datetime
from sqlalchemy import String, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
class ChatSession(Base):
__tablename__ = "chat_sessions"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
external_user_id: Mapped[str] = mapped_column(String(128), nullable=False, index=True) # 企微 external_userid
external_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
status: Mapped[str] = mapped_column(String(32), default="open") # open / transferred / closed
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)
messages: Mapped[list["Message"]] = relationship("Message", back_populates="session", order_by="Message.id")

View File

@@ -0,0 +1,17 @@
"""转人工工单。"""
from datetime import datetime
from sqlalchemy import String, Text, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
class Ticket(Base):
__tablename__ = "tickets"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
session_id: Mapped[int] = mapped_column(ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(32), default="open") # open / handling / closed
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow)

View File

@@ -0,0 +1,22 @@
"""后台用户id(uuid)、username、password_hash、role、is_active、created_at。"""
import uuid
from datetime import datetime
from sqlalchemy import String, Boolean, DateTime
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import Base
class User(Base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(32), nullable=False, default="admin")
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)

View File

@@ -0,0 +1,2 @@
# routers
from app.routers import auth, wecom

View File

@@ -0,0 +1 @@
# admin routers

View File

@@ -0,0 +1,31 @@
"""Admin API: GET/POST /api/admin/kb/docs/upload。"""
from fastapi import APIRouter, Depends, UploadFile
from app.deps import get_current_user
from app.logging_config import get_trace_id
router = APIRouter()
@router.get("/kb/docs")
async def list_kb_docs(_user=Depends(get_current_user)):
"""知识库文档列表(占位)。"""
return {
"code": 0,
"message": "ok",
"data": [
{"id": "1", "filename": "faq.pdf", "size": 102400, "uploaded_at": "2025-02-05T10:00:00Z"},
],
"trace_id": get_trace_id(),
}
@router.post("/kb/docs/upload")
async def upload_kb_doc(file: UploadFile, _user=Depends(get_current_user)):
"""上传知识库文档(占位:先存本地/对象存储)。"""
# 占位:实际应保存到对象存储或本地卷
return {
"code": 0,
"message": "ok",
"data": {"id": "new_1", "filename": file.filename, "size": 0},
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,40 @@
"""Admin API: GET /api/admin/sessions, GET /api/admin/sessions/{id}"""
from fastapi import APIRouter, Depends
from app.deps import get_current_user
from app.logging_config import get_trace_id
router = APIRouter()
@router.get("/sessions")
async def list_sessions(_user=Depends(get_current_user)):
"""会话列表(占位)。"""
return {
"code": 0,
"message": "ok",
"data": [
{"id": "1", "external_user_id": "ext_001", "external_name": "客户A", "status": "open", "created_at": "2025-02-05T10:00:00Z"},
{"id": "2", "external_user_id": "ext_002", "external_name": "客户B", "status": "transferred", "created_at": "2025-02-05T11:00:00Z"},
],
"trace_id": get_trace_id(),
}
@router.get("/sessions/{id}")
async def get_session(id: str, _user=Depends(get_current_user)):
"""会话详情:消息列表(占位)。"""
return {
"code": 0,
"message": "ok",
"data": {
"id": id,
"external_user_id": "ext_001",
"external_name": "客户A",
"status": "open",
"messages": [
{"id": 1, "role": "user", "content": "你好", "created_at": "2025-02-05T10:00:00Z"},
{"id": 2, "role": "assistant", "content": "您好,有什么可以帮您?", "created_at": "2025-02-05T10:00:01Z"},
],
},
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,37 @@
"""Admin API: GET/PATCH /api/admin/settings。"""
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.deps import get_current_user
from app.logging_config import get_trace_id
router = APIRouter()
class UpdateSettingsBody(BaseModel):
model_name: str | None = None
strategy: dict | None = None
@router.get("/settings")
async def get_settings(_user=Depends(get_current_user)):
"""获取设置(占位)。"""
return {
"code": 0,
"message": "ok",
"data": {
"model_name": "gpt-4",
"strategy": {"faq_priority": True, "rag_enabled": False},
},
"trace_id": get_trace_id(),
}
@router.patch("/settings")
async def update_settings(body: UpdateSettingsBody, _user=Depends(get_current_user)):
"""更新设置(占位)。"""
return {
"code": 0,
"message": "ok",
"data": {"model_name": body.model_name or "gpt-4", "strategy": body.strategy or {}},
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,52 @@
"""Admin API: GET/POST/PATCH /api/admin/tickets。"""
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.deps import get_current_user
from app.logging_config import get_trace_id
router = APIRouter()
class CreateTicketBody(BaseModel):
session_id: str
reason: str = ""
class UpdateTicketBody(BaseModel):
status: str | None = None
reason: str | None = None
@router.get("/tickets")
async def list_tickets(_user=Depends(get_current_user)):
"""工单列表(占位)。"""
return {
"code": 0,
"message": "ok",
"data": [
{"id": "1", "session_id": "1", "reason": "转人工", "status": "open", "created_at": "2025-02-05T10:00:00Z"},
],
"trace_id": get_trace_id(),
}
@router.post("/tickets")
async def create_ticket(body: CreateTicketBody, _user=Depends(get_current_user)):
"""创建工单(占位)。"""
return {
"code": 0,
"message": "ok",
"data": {"id": "new_1", "session_id": body.session_id, "reason": body.reason, "status": "open"},
"trace_id": get_trace_id(),
}
@router.patch("/tickets/{id}")
async def update_ticket(id: str, body: UpdateTicketBody, _user=Depends(get_current_user)):
"""更新工单(占位)。"""
return {
"code": 0,
"message": "ok",
"data": {"id": id, "status": body.status or "open", "reason": body.reason},
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,74 @@
"""Admin API: GET/POST/PATCH/DELETE /api/admin/users仅管理员可见"""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from app.deps import get_current_user
from app.logging_config import get_trace_id
router = APIRouter()
class CreateUserBody(BaseModel):
username: str
password: str
role: str = "admin"
is_active: bool = True
class UpdateUserBody(BaseModel):
password: str | None = None
role: str | None = None
is_active: bool | None = None
@router.get("/users")
async def list_users(current_user=Depends(get_current_user)):
"""用户列表(仅管理员)。"""
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="仅管理员可访问")
return {
"code": 0,
"message": "ok",
"data": [
{"id": "1", "username": "admin", "role": "admin", "is_active": True, "created_at": "2025-02-05T10:00:00Z"},
],
"trace_id": get_trace_id(),
}
@router.post("/users")
async def create_user(body: CreateUserBody, current_user=Depends(get_current_user)):
"""创建用户(仅管理员)。"""
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="仅管理员可访问")
return {
"code": 0,
"message": "ok",
"data": {"id": "new_1", "username": body.username, "role": body.role, "is_active": body.is_active},
"trace_id": get_trace_id(),
}
@router.patch("/users/{id}")
async def update_user(id: str, body: UpdateUserBody, current_user=Depends(get_current_user)):
"""更新用户(仅管理员)。"""
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="仅管理员可访问")
return {
"code": 0,
"message": "ok",
"data": {"id": id, "role": body.role, "is_active": body.is_active},
"trace_id": get_trace_id(),
}
@router.delete("/users/{id}")
async def delete_user(id: str, current_user=Depends(get_current_user)):
"""删除用户(仅管理员)。"""
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="仅管理员可访问")
return {
"code": 0,
"message": "ok",
"data": None,
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,47 @@
"""Auth APIPOST /api/auth/login、GET /api/auth/me。"""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.deps import get_current_user
from app.models import User
from app.services.auth_service import (
get_user_by_username,
verify_password,
create_access_token,
)
router = APIRouter()
class LoginBody(BaseModel):
username: str
password: str
class LoginResponse(BaseModel):
access_token: str
token_type: str = "bearer"
@router.post("/login", response_model=LoginResponse)
async def login(body: LoginBody, db: AsyncSession = Depends(get_db)):
user = await get_user_by_username(db, body.username)
if not user or not verify_password(body.password, user.password_hash):
raise HTTPException(status_code=401, detail="用户名或密码错误")
if not user.is_active:
raise HTTPException(status_code=403, detail="账号已禁用")
token = create_access_token(subject=user.username)
return LoginResponse(access_token=token, token_type="bearer")
@router.get("/me")
async def me(current_user: User = Depends(get_current_user)):
return {
"id": str(current_user.id),
"username": current_user.username,
"role": current_user.role,
"is_active": current_user.is_active,
"created_at": current_user.created_at.isoformat() if current_user.created_at else None,
}

View File

@@ -0,0 +1,15 @@
"""健康检查:供负载均衡与 CI 验证。"""
from fastapi import APIRouter
from app.logging_config import get_trace_id
router = APIRouter()
@router.get("/health")
def health():
return {
"code": 0,
"message": "ok",
"data": {"status": "up"},
"trace_id": get_trace_id(),
}

View File

@@ -0,0 +1,12 @@
"""知识库上传占位:先落本地卷/对象存储占位。"""
from fastapi import APIRouter, UploadFile
from app.logging_config import get_trace_id
router = APIRouter()
@router.post("/upload")
async def upload_file(file: UploadFile):
"""上传知识库文件,占位落盘。"""
# 占位:保存到 backend/uploads 或配置的存储
return {"code": 0, "message": "ok", "data": {"filename": file.filename}, "trace_id": get_trace_id()}

View File

@@ -0,0 +1,53 @@
"""会话列表与消息:从 DB 读取,需登录。"""
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.deps import get_current_user_id
from app.logging_config import get_trace_id
from app.models import ChatSession, Message
router = APIRouter()
def _session_row(s: ChatSession) -> dict:
return {
"id": s.id,
"external_user_id": s.external_user_id,
"external_name": s.external_name,
"status": s.status,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
def _message_row(m: Message) -> dict:
return {
"id": m.id,
"role": m.role,
"content": m.content,
"created_at": m.created_at.isoformat() if m.created_at else None,
}
@router.get("")
async def list_sessions(
db: AsyncSession = Depends(get_db),
_user: str = Depends(get_current_user_id),
):
"""会话列表。"""
r = await db.execute(select(ChatSession).order_by(ChatSession.updated_at.desc()))
rows = r.scalars().all()
return {"code": 0, "message": "ok", "data": [_session_row(s) for s in rows], "trace_id": get_trace_id()}
@router.get("/{session_id}/messages")
async def list_messages(
session_id: int,
db: AsyncSession = Depends(get_db),
_user: str = Depends(get_current_user_id),
):
"""某会话消息列表。"""
r = await db.execute(select(Message).where(Message.session_id == session_id).order_by(Message.id))
rows = r.scalars().all()
return {"code": 0, "message": "ok", "data": [_message_row(m) for m in rows], "trace_id": get_trace_id()}

View File

@@ -0,0 +1,15 @@
"""设置页占位:仅占位接口。"""
from fastapi import APIRouter
from app.logging_config import get_trace_id
router = APIRouter()
@router.get("")
def get_settings():
return {"code": 0, "message": "ok", "data": {}, "trace_id": get_trace_id()}
@router.put("")
def update_settings():
return {"code": 0, "message": "ok", "data": None, "trace_id": get_trace_id()}

View File

@@ -0,0 +1,67 @@
"""工单转人工:创建工单入库、手动回复调企业微信 API。"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.deps import get_current_user_id
from app.logging_config import get_trace_id
from app.models import ChatSession, Ticket
from app.services.wecom_api import send_text_to_external
logger = logging.getLogger(__name__)
router = APIRouter()
class CreateTicketBody(BaseModel):
session_id: int
reason: str = ""
class SendReplyBody(BaseModel):
session_id: int
content: str
@router.post("")
async def create_ticket(
body: CreateTicketBody,
db: AsyncSession = Depends(get_db),
_user: str = Depends(get_current_user_id),
):
"""创建转人工工单并更新会话状态。"""
r = await db.execute(select(ChatSession).where(ChatSession.id == body.session_id))
session = r.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
ticket = Ticket(session_id=body.session_id, reason=body.reason or None)
db.add(ticket)
session.status = "transferred"
await db.flush()
return {
"code": 0,
"message": "ok",
"data": {"ticket_id": str(ticket.id)},
"trace_id": get_trace_id(),
}
@router.post("/reply")
async def send_reply(
body: SendReplyBody,
db: AsyncSession = Depends(get_db),
_user: str = Depends(get_current_user_id),
):
"""手动回复:通过企业微信 API 发给客户。"""
r = await db.execute(select(ChatSession).where(ChatSession.id == body.session_id))
session = r.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
try:
await send_text_to_external(session.external_user_id, body.content)
except Exception as e:
logger.exception("wecom send reply failed")
raise HTTPException(status_code=502, detail=str(e))
return {"code": 0, "message": "ok", "data": None, "trace_id": get_trace_id()}

View File

@@ -0,0 +1,163 @@
"""企业微信回调GET 校验 + POST 消息回调验签、解密、echo 回复、会话入库)。"""
import time
import logging
import random
import string
from fastapi import APIRouter, Request, Query, Depends
from fastapi.responses import PlainTextResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.services.session_service import get_or_create_session, add_message
from app.services.wecom_crypto import (
verify_and_decrypt_echostr,
verify_signature,
parse_encrypted_body,
decrypt,
parse_decrypted_xml,
build_reply_xml,
encrypt,
make_reply_signature,
build_encrypted_response,
)
from app.logging_config import get_trace_id
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/callback")
def wecom_verify(
request: Request,
signature: str = Query(None, alias="signature"),
msg_signature: str = Query(None, alias="msg_signature"),
timestamp: str = Query(..., alias="timestamp"),
nonce: str = Query(..., alias="nonce"),
echostr: str = Query(..., alias="echostr"),
):
"""企业微信 GET 验签:校验签名并解密 echostr原样返回明文。
兼容 signature 和 msg_signature 两种参数名。
"""
trace_id = get_trace_id()
# 兼容 signature 和 msg_signature 两种参数名
sig = msg_signature or signature
if not sig:
logger.warning(
"wecom verify missing signature",
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
)
return PlainTextResponse("", status_code=400)
plain = verify_and_decrypt_echostr(sig, timestamp, nonce, echostr)
if plain is None:
logger.warning(
"wecom verify failed",
extra={"trace_id": trace_id, "timestamp": timestamp, "nonce": nonce},
)
return PlainTextResponse("", status_code=400)
logger.info(
"wecom verify success",
extra={"trace_id": trace_id, "echostr_length": len(echostr)},
)
return PlainTextResponse(plain)
@router.post("/callback")
async def wecom_callback(
request: Request,
signature: str = Query(None, alias="signature"),
msg_signature: str = Query(None, alias="msg_signature"),
timestamp: str = Query(..., alias="timestamp"),
nonce: str = Query(..., alias="nonce"),
db: AsyncSession = Depends(get_db),
):
"""POST 消息回调:验签、解密、会话与消息入库、文本 echo 回复。
兼容 signature 和 msg_signature 两种参数名。
"""
trace_id = get_trace_id()
# 兼容 signature 和 msg_signature 两种参数名
sig = msg_signature or signature
if not sig:
logger.warning(
"wecom post missing signature",
extra={"trace_id": trace_id, "query_params": dict(request.query_params)},
)
return PlainTextResponse("", status_code=400)
body = await request.body()
encrypt_raw, err = parse_encrypted_body(body)
if err:
logger.warning(
"wecom post parse error",
extra={"trace_id": trace_id, "error": err},
)
return PlainTextResponse("", status_code=400)
if not verify_signature(sig, timestamp, nonce, encrypt_raw):
logger.warning(
"wecom post verify failed",
extra={"trace_id": trace_id, "timestamp": timestamp},
)
return PlainTextResponse("", status_code=400)
try:
plain_xml = decrypt(encrypt_raw)
except Exception as e:
logger.warning(
"wecom decrypt error",
extra={"trace_id": trace_id, "error": str(e)},
)
return PlainTextResponse("", status_code=400)
msg = parse_decrypted_xml(plain_xml)
if not msg:
logger.warning(
"wecom xml parse failed",
extra={"trace_id": trace_id},
)
return PlainTextResponse("", status_code=400)
to_user = msg.get("ToUserName", "")
from_user = msg.get("FromUserName", "") # external_userid
msg_id = msg.get("MsgId", "")
msg_type = msg.get("MsgType", "")
content = (msg.get("Content") or "").strip()
content_summary = content[:50] + "..." if len(content) > 50 else content
# 记录日志trace_id + external_userid + msgid + 内容摘要
logger.info(
"wecom message received",
extra={
"trace_id": trace_id,
"external_userid": from_user,
"msgid": msg_id,
"msg_type": msg_type,
"content_summary": content_summary or "(empty)",
},
)
# 会话入库external_user_id = from_user客户
session = await get_or_create_session(db, from_user, msg.get("Contact"))
await add_message(db, session.id, "user", content or "(非文本消息)")
# Echo 文本:回复"已收到:{用户消息}"
if msg_type == "text" and content:
reply_content = f"已收到:{content}"
else:
reply_content = "已收到"
await add_message(db, session.id, "assistant", reply_content)
# 回复给客户(被动回复 XML
reply_xml = build_reply_xml(from_user, to_user, reply_content)
enc = encrypt(reply_xml)
ts = str(int(time.time()))
reply_nonce = "".join(random.choices(string.ascii_letters + string.digits, k=16))
sig = make_reply_signature(enc, ts, reply_nonce)
resp_xml = build_encrypted_response(enc, sig, ts, reply_nonce)
logger.info(
"wecom reply sent",
extra={
"trace_id": trace_id,
"external_userid": from_user,
"msgid": msg_id,
"reply_summary": reply_content[:50] + "..." if len(reply_content) > 50 else reply_content,
},
)
return PlainTextResponse(resp_xml, media_type="application/xml")

View File

@@ -0,0 +1 @@
# services

View File

@@ -0,0 +1,39 @@
"""密码 bcrypt hashJWT 创建与解码,带过期时间。"""
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models import User
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
def create_access_token(subject: str) -> str:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expire_minutes)
to_encode = {"sub": subject, "exp": expire}
return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def decode_access_token(token: str) -> str | None:
try:
payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
return payload.get("sub")
except JWTError:
return None
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
r = await db.execute(select(User).where(User.username == username))
return r.scalar_one_or_none()

View File

@@ -0,0 +1,30 @@
"""会话与消息入库;仅存 public 可见内容,隔离内部信息。"""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ChatSession, Message
async def get_or_create_session(
db: AsyncSession,
external_user_id: str,
external_name: str | None = None,
) -> ChatSession:
r = await db.execute(select(ChatSession).where(ChatSession.external_user_id == external_user_id))
row = r.scalar_one_or_none()
if row:
if external_name is not None and row.external_name != external_name:
row.external_name = external_name
await db.flush()
return row
session = ChatSession(external_user_id=external_user_id, external_name=external_name or None)
db.add(session)
await db.flush()
return session
async def add_message(db: AsyncSession, session_id: int, role: str, content: str) -> Message:
msg = Message(session_id=session_id, role=role, content=content)
db.add(msg)
await db.flush()
return msg

View File

@@ -0,0 +1,54 @@
"""企业微信 API 调用:超时与重试,配置来自环境变量。"""
import logging
from typing import Any
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
TIMEOUT = settings.wecom_api_timeout
RETRIES = settings.wecom_api_retries
BASE = settings.wecom_api_base.rstrip("/")
async def _request(method: str, path: str, **kwargs: Any) -> dict | None:
url = f"{BASE}{path}"
for attempt in range(RETRIES + 1):
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
r = await client.request(method, url, **kwargs)
r.raise_for_status()
return r.json()
except Exception as e:
logger.warning("wecom api attempt %s failed: %s", attempt + 1, e)
if attempt == RETRIES:
raise
return None
async def get_access_token() -> str:
"""获取 corpid + secret 的 access_token。"""
r = await _request(
"GET",
"/cgi-bin/gettoken",
params={"corpid": settings.wecom_corp_id, "corpsecret": settings.wecom_secret},
)
if not r or r.get("errcode") != 0:
raise RuntimeError(r.get("errmsg", "get token failed"))
return r["access_token"]
async def send_text_to_external(external_user_id: str, content: str) -> None:
"""发送文本消息给外部联系人(客户联系-发送消息到客户)。"""
token = await get_access_token()
body = {
"touser": [external_user_id],
"sender": settings.wecom_agent_id,
"msgtype": "text",
"text": {"content": content},
}
# 企业微信文档:发送消息到客户 send_message_to_user
r = await _request("POST", f"/cgi-bin/externalcontact/send_message_to_user?access_token={token}", json=body)
if not r or r.get("errcode") != 0:
raise RuntimeError(r.get("errmsg", "send failed"))

View File

@@ -0,0 +1,119 @@
"""企业微信回调加解密与验签(与企微文档一致)。"""
import base64
import hashlib
import struct
import xml.etree.ElementTree as ET
from typing import Tuple
from Crypto.Cipher import AES
from app.config import settings
def _sha1(s: str) -> str:
return hashlib.sha1(s.encode()).hexdigest()
def _check_signature(signature: str, timestamp: str, nonce: str, echostr_or_encrypt: str) -> bool:
token = settings.wecom_token
lst = [token, timestamp, nonce, echostr_or_encrypt]
lst.sort()
return _sha1("".join(lst)) == signature
def _aes_key() -> bytes:
key_b64 = settings.wecom_encoding_aes_key + "="
return base64.b64decode(key_b64)[:32]
def decrypt(encrypt: str) -> str:
"""解密企微回调密文echostr 或 Encrypt 节点内容)。"""
key = _aes_key()
iv = key[:16]
raw = base64.b64decode(encrypt)
cipher = AES.new(key, AES.MODE_CBC, iv)
dec = cipher.decrypt(raw)
# 16 随机字节 + 4 字节长度(big-endian) + 消息 + corpid先按长度取消息避免 padding 差异
msg_len = struct.unpack(">I", dec[16:20])[0]
return dec[20 : 20 + msg_len].decode("utf-8")
def encrypt(plain: str) -> str:
"""加密回复内容(明文为 XML 或文本)。"""
import os
key = _aes_key()
iv = key[:16]
corpid = settings.wecom_corp_id or "placeholder"
msg = plain.encode("utf-8")
msg_len = struct.pack(">I", len(msg))
rand = os.urandom(16)
to_enc = rand + msg_len + msg + corpid.encode("utf-8")
from Crypto.Util.Padding import pad
to_enc = pad(to_enc, 16)
cipher = AES.new(key, AES.MODE_CBC, iv)
enc = cipher.encrypt(to_enc)
return base64.b64encode(enc).decode("ascii")
def verify_signature(msg_signature: str, timestamp: str, nonce: str, encrypt: str) -> bool:
"""校验签名GET 或 POST 的 Encrypt"""
return _check_signature(msg_signature, timestamp, nonce, encrypt)
def verify_and_decrypt_echostr(msg_signature: str, timestamp: str, nonce: str, echostr: str) -> str | None:
"""GET 校验:验签并解密 echostr返回明文失败返回 None。"""
if not verify_signature(msg_signature, timestamp, nonce, echostr):
return None
return decrypt(echostr)
def parse_encrypted_body(body: bytes) -> Tuple[str | None, str | None]:
"""解析 POST 请求体 XML取 Encrypt验签用 msg_signature/timestamp/nonce 从 query 传。返回 (encrypt_raw, None) 或 (None, error)。"""
try:
root = ET.fromstring(body)
encrypt_el = root.find("Encrypt")
if encrypt_el is None or encrypt_el.text is None:
return None, "missing Encrypt"
return encrypt_el.text.strip(), None
except Exception as e:
return None, str(e)
def parse_decrypted_xml(plain_xml: str) -> dict | None:
"""解密后的 XML 解析为 dictToUserName, FromUserName, MsgType, Content 等)。"""
try:
root = ET.fromstring(plain_xml)
d = {}
for c in root:
if c.text:
d[c.tag] = c.text
return d
except Exception:
return None
def build_reply_xml(to_user: str, from_user: str, content: str) -> str:
"""构造文本回复 XML明文"""
return f"""<xml>
<ToUserName><![CDATA[{to_user}]]></ToUserName>
<FromUserName><![CDATA[{from_user}]]></FromUserName>
<CreateTime>{int(__import__("time").time())}</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[{content}]]></Content>
</xml>"""
def make_reply_signature(encrypt: str, timestamp: str, nonce: str) -> str:
lst = [settings.wecom_token, timestamp, nonce, encrypt]
lst.sort()
return _sha1("".join(lst))
def build_encrypted_response(encrypt: str, signature: str, timestamp: str, nonce: str) -> str:
"""构造 POST 回复的加密 XML。"""
return f"""<xml>
<Encrypt><![CDATA[{encrypt}]]></Encrypt>
<MsgSignature><![CDATA[{signature}]]></MsgSignature>
<TimeStamp>{timestamp}</TimeStamp>
<Nonce><![CDATA[{nonce}]]></Nonce>
</xml>"""

26
backend/config.py Normal file
View File

@@ -0,0 +1,26 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
api_host: str = "0.0.0.0"
api_port: int = 8000
database_url: str = "postgresql+asyncpg://wecom:wecom_secret@localhost:5432/wecom_ai"
database_url_sync: str = "postgresql://wecom:wecom_secret@localhost:5432/wecom_ai"
jwt_secret: str = "change-me"
jwt_algorithm: str = "HS256"
jwt_expire_minutes: int = 60
wecom_corp_id: str = ""
wecom_agent_id: str = ""
wecom_secret: str = ""
wecom_token: str = ""
wecom_encoding_aes_key: str = ""
wecom_api_base: str = "https://qyapi.weixin.qq.com"
wecom_api_timeout: int = 10
wecom_api_retries: int = 2
log_level: str = "INFO"
log_json: bool = True
settings = Settings()

4
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,4 @@
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
pythonpath = ["."]

22
backend/requirements.txt Normal file
View File

@@ -0,0 +1,22 @@
fastapi==0.115.5
uvicorn[standard]==0.32.1
pydantic-settings==2.6.1
python-multipart==0.0.9
python-dotenv==1.0.1
# DB
sqlalchemy[asyncio]==2.0.36
asyncpg==0.30.0
alembic==1.14.0
psycopg2-binary==2.9.10
# Auth
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.1.2
# Logging
python-json-logger==2.0.7
# WeCom Crypto
pycryptodome==3.21.0

View File

@@ -0,0 +1 @@
# tests

View File

@@ -0,0 +1,26 @@
"""登录接口测试。"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_login_fail_wrong_password():
r = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
assert r.status_code == 401
def test_login_fail_wrong_user():
r = client.post("/api/auth/login", json={"username": "nobody", "password": "admin"})
assert r.status_code == 401
def test_login_returns_json():
"""无 DB 时可能 401有 DB 且 admin 存在时 200。仅断言响应为 JSON 且含 code。"""
r = client.post("/api/auth/login", json={"username": "admin", "password": "admin"})
assert r.headers.get("content-type", "").startswith("application/json")
data = r.json()
assert "code" in data
assert "trace_id" in data

View File

@@ -0,0 +1,16 @@
"""Health 接口测试。"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_health():
r = client.get("/api/health")
assert r.status_code == 200
data = r.json()
assert data.get("code") == 0
assert data.get("data", {}).get("status") == "up"
assert "trace_id" in data

View File

@@ -0,0 +1,26 @@
"""企业微信回调验签逻辑测试(不依赖真实 Token/Key"""
import pytest
from unittest.mock import patch
from app.services.wecom_crypto import (
verify_signature,
verify_and_decrypt_echostr,
_sha1,
)
def test_sha1():
h = _sha1("abc")
assert len(h) == 40
assert h == "a9993e364706816aba3e25717850c26c9cd0d89d"
def test_verify_signature():
# 用固定 token 时,签名为 sha1(sort(token, ts, nonce, encrypt))
with patch("app.services.wecom_crypto.settings") as s:
s.wecom_token = "mytoken"
lst = ["mytoken", "123", "456", "echostr"]
lst.sort()
expected = _sha1("".join(lst))
assert verify_signature(expected, "123", "456", "echostr") is True
assert verify_signature("wrong", "123", "456", "echostr") is False