Files
tobiichiGPT/api/server.py
2025-12-21 16:49:44 +08:00

205 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
API 轉接層 - 偽裝 OpenAI API將請求轉為人工回覆隊列
"""
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import asyncpg
import asyncio
import time
import uuid
import os
from typing import Optional
from datetime import datetime
app = FastAPI()
# 資料庫連接池
db_pool = None
# 資料庫設定
DB_CONFIG = {
"host": os.getenv("DB_HOST", "postgres"),
"port": int(os.getenv("DB_PORT", 5432)),
"database": os.getenv("DB_NAME", "tobiichiGPT"),
"user": os.getenv("DB_USER", "tobiichi3227"),
"password": os.getenv("DB_PASSWORD", "tobiichi_password")
}
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: list[Message]
stream: Optional[bool] = False
async def init_db():
"""初始化資料庫連接池和表格"""
global db_pool
db_pool = await asyncpg.create_pool(**DB_CONFIG, min_size=2, max_size=10)
# 建立對話隊列表格
async with db_pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS reply_queue (
id SERIAL PRIMARY KEY,
conversation_id VARCHAR(50) UNIQUE NOT NULL,
user_message TEXT NOT NULL,
admin_reply TEXT,
status VARCHAR(20) DEFAULT 'pending',
created_at TIMESTAMP DEFAULT NOW(),
replied_at TIMESTAMP
)
""")
# 建立索引
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_status ON reply_queue(status);
CREATE INDEX IF NOT EXISTS idx_conversation_id ON reply_queue(conversation_id);
""")
@app.on_event("startup")
async def startup():
await init_db()
print("✅ 資料庫連接成功")
@app.on_event("shutdown")
async def shutdown():
if db_pool:
await db_pool.close()
print("👋 資料庫連接已關閉")
@app.get("/")
async def root():
"""根路徑"""
return {"status": "ok", "service": "TobiichiGPT API"}
@app.get("/v1/models")
async def list_models():
"""模擬 OpenAI 的 /v1/models 端點"""
return {
"object": "list",
"data": [
{
"id": "human-admin",
"object": "model",
"created": int(time.time()),
"owned_by": "tobiichi",
"permission": [],
"root": "human-admin",
"parent": None
}
]
}
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
"""
模擬 OpenAI Chat Completions API
將用戶訊息寫入資料庫,等待管理員回覆
"""
# 取得最後一則用戶訊息
user_message = None
for msg in reversed(request.messages):
if msg.role == "user":
user_message = msg.content
break
if not user_message:
return JSONResponse(
status_code=400,
content={"error": "No user message found"}
)
# 生成對話 ID
conversation_id = str(uuid.uuid4())
# 寫入資料庫
async with db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO reply_queue (conversation_id, user_message, status)
VALUES ($1, $2, 'pending')
""",
conversation_id, user_message
)
print(f"📝 收到訊息 [{conversation_id}]: {user_message[:50]}...")
# 等待管理員回覆 (最多 15 分鐘)
max_wait = 900 # 15 分鐘
check_interval = 3 # 每 3 秒檢查一次
waited = 0
while waited < max_wait:
await asyncio.sleep(check_interval)
waited += check_interval
async with db_pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT admin_reply, status FROM reply_queue WHERE conversation_id = $1",
conversation_id
)
if row and row['status'] == 'replied' and row['admin_reply']:
admin_reply = row['admin_reply']
print(f"✅ 管理員已回覆 [{conversation_id}]")
# 回傳 OpenAI 格式的回應
return {
"id": f"chatcmpl-{conversation_id}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": admin_reply
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(user_message),
"completion_tokens": len(admin_reply),
"total_tokens": len(user_message) + len(admin_reply)
}
}
# 超時回應
print(f"⏰ 等待超時 [{conversation_id}]")
return {
"id": f"chatcmpl-{conversation_id}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "抱歉,管理員目前忙碌中,請稍後再試。"
},
"finish_reason": "stop"
}
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)