565 lines
19 KiB
Python
565 lines
19 KiB
Python
"""
|
||
API 轉接層 - 偽裝 OpenAI API,整合 Rocket.Chat 作為管理員回覆介面
|
||
"""
|
||
|
||
from fastapi import FastAPI, Request, Header
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel
|
||
import asyncpg
|
||
import asyncio
|
||
import time
|
||
import uuid
|
||
import os
|
||
import json
|
||
import hashlib
|
||
import httpx
|
||
from typing import Optional
|
||
from datetime import datetime
|
||
import hashlib
|
||
|
||
app = FastAPI()
|
||
|
||
# 資料庫連接池
|
||
db_pool = None
|
||
|
||
# 資料庫設定
|
||
DB_CONFIG = {
|
||
"host": os.getenv("DB_HOST", "postgres"),
|
||
"port": int(os.getenv("DB_PORT", 5432)),
|
||
"database": os.getenv("DB_NAME", "tobiichiGPT"),
|
||
"user": os.getenv("DB_USER", "tobiichi3227"),
|
||
"password": os.getenv("DB_PASSWORD", "tobiichi_password")
|
||
}
|
||
|
||
# Rocket.Chat 設定
|
||
ROCKETCHAT_URL = os.getenv("ROCKETCHAT_URL", "http://rocketchat:3000")
|
||
ROCKETCHAT_USER = os.getenv("ROCKETCHAT_USER", "admin")
|
||
ROCKETCHAT_PASSWORD = os.getenv("ROCKETCHAT_PASSWORD", "admin")
|
||
|
||
# 全域認證狀態
|
||
rocketchat_auth = None
|
||
|
||
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
model: str
|
||
messages: list[Message]
|
||
stream: Optional[bool] = False
|
||
|
||
|
||
async def rocketchat_login():
|
||
"""登入 Rocket.Chat 取得認證 Token"""
|
||
global rocketchat_auth
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.post(
|
||
f"{ROCKETCHAT_URL}/api/v1/login",
|
||
json={
|
||
"username": ROCKETCHAT_USER,
|
||
"password": ROCKETCHAT_PASSWORD
|
||
}
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
rocketchat_auth = {
|
||
"X-Auth-Token": data["data"]["authToken"],
|
||
"X-User-Id": data["data"]["userId"]
|
||
}
|
||
print(f"✅ Rocket.Chat 登入成功")
|
||
return rocketchat_auth
|
||
else:
|
||
print(f"⚠️ Rocket.Chat 登入失敗: {resp.status_code}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ Rocket.Chat 登入錯誤: {e}")
|
||
return None
|
||
|
||
|
||
async def get_or_create_chat_channel(chat_id: str, user_name: str = None):
|
||
"""為每個對話創建專屬頻道
|
||
|
||
新架構:每個 chat_id 對應一個 Channel,不使用 Thread
|
||
"""
|
||
if not rocketchat_auth:
|
||
await rocketchat_login()
|
||
|
||
if not rocketchat_auth:
|
||
return None
|
||
|
||
# 頻道名稱:用戶名-對話ID前8位
|
||
display_name = user_name if user_name else "user"
|
||
clean_name = display_name.replace(" ", "-").replace("@", "").lower()
|
||
channel_name = f"{clean_name}-{chat_id[:8]}"[:50] # 限制長度
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
# 嘗試創建頻道
|
||
resp = await client.post(
|
||
f"{ROCKETCHAT_URL}/api/v1/channels.create",
|
||
headers=rocketchat_auth,
|
||
json={
|
||
"name": channel_name,
|
||
"members": []
|
||
}
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
room_id = data["channel"]["_id"]
|
||
print(f"✅ 創建對話頻道: {channel_name}")
|
||
|
||
# 設置頻道描述
|
||
await client.post(
|
||
f"{ROCKETCHAT_URL}/api/v1/channels.setDescription",
|
||
headers=rocketchat_auth,
|
||
json={
|
||
"roomId": room_id,
|
||
"description": f"用戶: {display_name} | 對話: {chat_id[:8]}"
|
||
}
|
||
)
|
||
|
||
return room_id
|
||
|
||
# 如果頻道已存在,取得頻道資訊
|
||
elif resp.status_code == 400:
|
||
resp = await client.get(
|
||
f"{ROCKETCHAT_URL}/api/v1/channels.info",
|
||
headers=rocketchat_auth,
|
||
params={"roomName": channel_name}
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
room_id = data["channel"]["_id"]
|
||
print(f"✅ 使用現有頻道: {channel_name}")
|
||
return room_id
|
||
else:
|
||
print(f"⚠️ 取得頻道資訊失敗: {resp.status_code}")
|
||
return None
|
||
else:
|
||
print(f"⚠️ 創建頻道失敗: {resp.status_code}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ Rocket.Chat 頻道操作錯誤: {e}")
|
||
return None
|
||
|
||
|
||
async def send_user_message(room_id: str, user_message: str, user_name: str = None):
|
||
"""發送用戶訊息到頻道(不使用 Thread)"""
|
||
if not rocketchat_auth:
|
||
await rocketchat_login()
|
||
|
||
if not rocketchat_auth:
|
||
return None
|
||
|
||
display_name = user_name if user_name else "用戶"
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.post(
|
||
f"{ROCKETCHAT_URL}/api/v1/chat.postMessage",
|
||
headers=rocketchat_auth,
|
||
json={
|
||
"roomId": room_id,
|
||
"text": f"**💬 {display_name}:**\n{user_message}"
|
||
}
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
message_id = data["message"]["_id"]
|
||
message_ts = data["message"]["ts"]
|
||
print(f"✅ 發送用戶訊息: {message_id}")
|
||
return {"message_id": message_id, "ts": message_ts}
|
||
else:
|
||
print(f"⚠️ 發送訊息失敗: {resp.status_code}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ 發送訊息錯誤: {e}")
|
||
return None
|
||
|
||
|
||
async def wait_for_admin_reply(room_id: str, after_ts: str, exclude_msg_id: str = None, timeout: int = 600):
|
||
"""等待管理員在頻道中的回覆
|
||
|
||
Args:
|
||
room_id: 頻道 ID
|
||
after_ts: 用戶訊息的時間戳 (ISO 格式)
|
||
exclude_msg_id: 要排除的用戶訊息 ID (避免自己讀到自己)
|
||
timeout: 超時秒數
|
||
"""
|
||
if not rocketchat_auth:
|
||
await rocketchat_login()
|
||
|
||
if not rocketchat_auth:
|
||
return None
|
||
|
||
start_time = time.time()
|
||
check_interval = 2 # 每 2 秒檢查一次
|
||
|
||
# 取得機器人用戶 ID
|
||
bot_user_id = rocketchat_auth.get("X-User-Id")
|
||
|
||
print(f"🔄 等待回覆 (room: {room_id[:8]}..., after: {after_ts}, exclude: {exclude_msg_id})")
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
while time.time() - start_time < timeout:
|
||
# 取得頻道中的最新訊息
|
||
resp = await client.get(
|
||
f"{ROCKETCHAT_URL}/api/v1/channels.messages",
|
||
headers=rocketchat_auth,
|
||
params={
|
||
"roomId": room_id,
|
||
"count": 20 # 稍微增加數量以防萬一
|
||
}
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
messages = data.get("messages", [])
|
||
|
||
found_new_reply = False
|
||
|
||
for msg in messages:
|
||
msg_id = msg.get("_id")
|
||
msg_ts = msg.get("ts", "")
|
||
sender_id = msg.get("u", {}).get("_id", "")
|
||
sender_name = msg.get("u", {}).get("username", "?")
|
||
reply_text = msg.get("msg", "")
|
||
msg_type = msg.get("t")
|
||
|
||
# 1. 基本過濾
|
||
if msg_type: # 忽略系統訊息
|
||
continue
|
||
|
||
if msg_id == exclude_msg_id: # 忽略剛發送的用戶訊息
|
||
continue
|
||
|
||
# 2. 時間戳檢查 (放寬條件: 只要 >= 就考慮,因為可能有精度差)
|
||
# 但為了避免讀到舊訊息,我們仍需確保它"看起來"是新的
|
||
if msg_ts < after_ts:
|
||
continue
|
||
|
||
# 3. 機器人自我發送過濾
|
||
if sender_id == bot_user_id:
|
||
# 這是最關鍵的部分: 如何區分 "API轉發的用戶訊息" 與 "Admin用同一帳號的回覆"
|
||
|
||
# A. 如果這就是我們要排除的 ID (前面已處理)
|
||
|
||
# B. 檢查內容格式
|
||
# 格式為: "**💬 {display_name}:**\n{user_message}"
|
||
# 我們檢查幾個特徵: 開頭是 **💬, 包含 :**
|
||
is_forwarded_msg = False
|
||
if reply_text.strip().startswith("**💬") and ":**" in reply_text:
|
||
is_forwarded_msg = True
|
||
|
||
if is_forwarded_msg:
|
||
# 這是轉發的訊息,跳過
|
||
continue
|
||
|
||
# 如果不是轉發格式,那我們假設這是 Admin 的手動回覆
|
||
pass
|
||
|
||
# 4. 找到有效回覆
|
||
if reply_text:
|
||
print(f"✅ 收到管理員回覆 (from {sender_name}, id: {msg_id}): {reply_text[:50]}...")
|
||
return reply_text
|
||
|
||
# 沒找到
|
||
elapsed = int(time.time() - start_time)
|
||
if elapsed % 10 == 0:
|
||
print(f"🔄 等待中... ({elapsed}s)")
|
||
|
||
else:
|
||
print(f"⚠️ API 錯誤: {resp.status_code}")
|
||
|
||
await asyncio.sleep(check_interval)
|
||
|
||
print(f"⚠️ 等待回覆超時")
|
||
return "抱歉,目前客服繁忙,請稍後再試。"
|
||
|
||
except Exception as e:
|
||
print(f"❌ 檢查回覆錯誤: {e}")
|
||
return "系統錯誤,請稍後再試。"
|
||
|
||
|
||
async def get_user_name(user_id: str):
|
||
"""從資料庫獲取用戶真實姓名(保留作為工具函數)"""
|
||
if not user_id:
|
||
return None
|
||
|
||
try:
|
||
async with db_pool.acquire() as conn:
|
||
row = await conn.fetchrow(
|
||
'SELECT name, email FROM "user" WHERE id = $1',
|
||
user_id
|
||
)
|
||
if row:
|
||
name = row['name'] or row['email']
|
||
if name:
|
||
return name
|
||
except Exception as e:
|
||
print(f"⚠️ 獲取用戶名稱失敗: {e}")
|
||
|
||
return None
|
||
|
||
|
||
async def init_db():
|
||
"""初始化資料庫連接池和表格"""
|
||
global db_pool
|
||
db_pool = await asyncpg.create_pool(**DB_CONFIG, min_size=2, max_size=10)
|
||
|
||
# 建立對話隊列表格(用於記錄和追蹤)
|
||
async with db_pool.acquire() as conn:
|
||
await conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS reply_queue (
|
||
id SERIAL PRIMARY KEY,
|
||
conversation_id VARCHAR(50) UNIQUE NOT NULL,
|
||
user_id VARCHAR(255),
|
||
chat_id VARCHAR(255),
|
||
user_name VARCHAR(255),
|
||
user_message TEXT NOT NULL,
|
||
admin_reply TEXT,
|
||
status VARCHAR(20) DEFAULT 'pending',
|
||
created_at TIMESTAMP DEFAULT NOW(),
|
||
replied_at TIMESTAMP,
|
||
rocketchat_room_id VARCHAR(100),
|
||
rocketchat_thread_id VARCHAR(100)
|
||
)
|
||
""")
|
||
|
||
# 建立索引
|
||
await conn.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_status ON reply_queue(status);
|
||
CREATE INDEX IF NOT EXISTS idx_conversation_id ON reply_queue(conversation_id);
|
||
CREATE INDEX IF NOT EXISTS idx_chat_id ON reply_queue(chat_id);
|
||
CREATE INDEX IF NOT EXISTS idx_user_id ON reply_queue(user_id);
|
||
""")
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
await init_db()
|
||
print("✅ 資料庫連接成功")
|
||
|
||
# 嘗試登入 Rocket.Chat
|
||
await rocketchat_login()
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown():
|
||
if db_pool:
|
||
await db_pool.close()
|
||
print("👋 資料庫連接已關閉")
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""根路徑"""
|
||
return {
|
||
"status": "ok",
|
||
"service": "TobiichiGPT API",
|
||
"chat_backend": "Rocket.Chat"
|
||
}
|
||
|
||
|
||
@app.get("/v1/models")
|
||
async def list_models():
|
||
"""模擬 OpenAI 的 /v1/models 端點"""
|
||
return {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"id": "tobiichiGPT",
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "tobiichi",
|
||
"permission": [],
|
||
"root": "tobiichiGPT",
|
||
"parent": None
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def chat_completions(
|
||
request_data: ChatRequest,
|
||
http_request: Request,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
模擬 OpenAI Chat Completions API
|
||
將用戶訊息轉發到 Rocket.Chat,等待管理員回覆
|
||
"""
|
||
# 取得最後一則用戶訊息
|
||
user_message = None
|
||
for msg in reversed(request_data.messages):
|
||
if msg.role == "user":
|
||
user_message = msg.content
|
||
break
|
||
|
||
if not user_message:
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={"error": "No user message found"}
|
||
)
|
||
|
||
# 從 Open WebUI headers 提取用戶資訊
|
||
headers_dict = dict(http_request.headers)
|
||
|
||
# 調試:輸出所有 headers
|
||
print(f"🔍 收到的 Headers:")
|
||
for key, value in headers_dict.items():
|
||
if 'user' in key.lower() or 'chat' in key.lower() or 'name' in key.lower():
|
||
print(f" {key}: {value}")
|
||
|
||
user_id = headers_dict.get("x-openwebui-user-id")
|
||
chat_id = headers_dict.get("x-openwebui-chat-id")
|
||
user_name = headers_dict.get("x-openwebui-user-name")
|
||
user_email = headers_dict.get("x-openwebui-user-email")
|
||
|
||
# 如果沒有從 headers 取得,使用備用方案
|
||
if not user_id:
|
||
user_id = headers_dict.get("x-user-id") or headers_dict.get("user-id")
|
||
if not chat_id:
|
||
chat_id = headers_dict.get("x-chat-id") or headers_dict.get("chat-id")
|
||
|
||
# 生成 message ID
|
||
message_id = str(uuid.uuid4())
|
||
|
||
# 如果還是沒有,使用 fallback
|
||
if not user_id:
|
||
user_id = hashlib.md5(authorization.encode() if authorization else message_id.encode()).hexdigest()[:16]
|
||
|
||
if not chat_id:
|
||
chat_id = message_id
|
||
|
||
# 過濾 Open WebUI 的系統任務訊息(標題生成、標籤生成、後續問題生成等)
|
||
if user_message.strip().startswith("### Task:"):
|
||
print(f"⏭️ 跳過系統任務訊息: {user_message[:50]}...")
|
||
# 回傳空的 JSON 回應讓 Open WebUI 處理
|
||
return {
|
||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": request_data.model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": "{}"
|
||
},
|
||
"finish_reason": "stop"
|
||
}
|
||
],
|
||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||
}
|
||
|
||
print(f"📝 收到訊息")
|
||
print(f" 用戶: {user_name} ({user_id[:8]}...)")
|
||
print(f" 對話: {chat_id[:8]}")
|
||
print(f" 內容: {user_message[:50]}...")
|
||
|
||
# 1. 為此對話取得或創建專屬頻道(每個 chat_id = 一個 Channel)
|
||
room_id = await get_or_create_chat_channel(chat_id, user_name)
|
||
|
||
if not room_id:
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"error": "Failed to create Rocket.Chat channel"}
|
||
)
|
||
|
||
# 2. 發送用戶訊息到頻道
|
||
msg_result = await send_user_message(room_id, user_message, user_name)
|
||
|
||
if not msg_result:
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"error": "Failed to send message to Rocket.Chat"}
|
||
)
|
||
|
||
user_message_ts = msg_result["ts"] # 用時間戳比較,更可靠
|
||
|
||
# 3. 記錄到資料庫
|
||
async with db_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO reply_queue (
|
||
conversation_id, user_id, chat_id, user_name, user_message,
|
||
status, rocketchat_room_id, rocketchat_thread_id
|
||
)
|
||
VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7)
|
||
""",
|
||
message_id, user_id, chat_id, user_name, user_message, room_id, msg_result["message_id"]
|
||
)
|
||
|
||
# 4. 等待管理員在頻道中回覆(使用時間戳比較)
|
||
admin_reply = await wait_for_admin_reply(room_id, user_message_ts, exclude_msg_id=msg_result["message_id"])
|
||
|
||
# 5. 更新資料庫
|
||
async with db_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE reply_queue
|
||
SET admin_reply = $1, status = 'replied', replied_at = NOW()
|
||
WHERE conversation_id = $2
|
||
""",
|
||
admin_reply, message_id
|
||
)
|
||
|
||
print(f"✅ 完成回覆 [chat:{chat_id[:8]}]")
|
||
|
||
# 6. 回傳 OpenAI 格式的回應
|
||
return {
|
||
"id": f"chatcmpl-{message_id}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": request_data.model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": admin_reply
|
||
},
|
||
"finish_reason": "stop"
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": len(user_message),
|
||
"completion_tokens": len(admin_reply),
|
||
"total_tokens": len(user_message) + len(admin_reply)
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康檢查端點"""
|
||
db_status = "ok" if db_pool else "disconnected"
|
||
rc_status = "ok" if rocketchat_auth else "not_authenticated"
|
||
|
||
return {
|
||
"status": "healthy",
|
||
"database": db_status,
|
||
"rocketchat": rc_status
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||
|