diff --git a/app.json b/app.json index 63fd639..3aa0f2f 100644 --- a/app.json +++ b/app.json @@ -85,6 +85,7 @@ }, "extra": { "API_ENDPOINT": "http://192.168.31.107:8081/api", + "WEBSOCKET_ENDPOINT": "ws://192.168.31.107:8081/ws/chat", "router": {}, "eas": { "projectId": "04721dd4-6b15-495a-b9ec-98187c613172" diff --git a/app/(tabs)/_layout.tsx b/app/(tabs)/_layout.tsx index ff49b19..514e602 100644 --- a/app/(tabs)/_layout.tsx +++ b/app/(tabs)/_layout.tsx @@ -6,6 +6,7 @@ import { Colors } from '@/constants/Colors'; import { useColorScheme } from '@/hooks/useColorScheme'; import { prefetchChats } from '@/lib/prefetch'; import { fetchApi } from '@/lib/server-api-util'; +import { webSocketManager, WebSocketStatus } from '@/lib/websocket-util'; import * as Notifications from 'expo-notifications'; import { Tabs } from 'expo-router'; import * as SecureStore from 'expo-secure-store'; @@ -30,6 +31,7 @@ export default function TabLayout() { const tokenInterval = useRef(null); const isMounted = useRef(true); const [token, setToken] = useState(''); + const [wsStatus, setWsStatus] = useState('disconnected'); const sendNotification = async (item: PollingData) => { // 请求通知权限 const granted = await requestNotificationPermission(); @@ -67,6 +69,16 @@ export default function TabLayout() { }; }, []); + useEffect(() => { + const handleStatusChange = (status: WebSocketStatus) => { + setWsStatus(status); + }; + webSocketManager.subscribeStatus(handleStatusChange); + return () => { + webSocketManager.unsubscribeStatus(handleStatusChange); + }; + }, []); + // 轮询获取推送消息 const startPolling = useCallback(async (interval: number = 5000) => { @@ -365,7 +377,7 @@ export default function TabLayout() { }} /> - + ); } diff --git a/app/(tabs)/ask.tsx b/app/(tabs)/ask.tsx index 4de2030..1d8ed53 100644 --- a/app/(tabs)/ask.tsx +++ b/app/(tabs)/ask.tsx @@ -4,7 +4,8 @@ import AskHello from "@/components/ask/hello"; import SendMessage from "@/components/ask/send"; import { ThemedText } from "@/components/ThemedText"; import { fetchApi } from "@/lib/server-api-util"; -import { Message } from "@/types/ask"; +import { webSocketManager, WsMessage } from "@/lib/websocket-util"; +import { Assistant, Message } from "@/types/ask"; import { router, useFocusEffect, useLocalSearchParams } from "expo-router"; import { useCallback, useEffect, useRef, useState } from 'react'; import { @@ -78,6 +79,86 @@ export default function AskScreen() { }; }, [isHello]); + useFocusEffect( + useCallback(() => { + webSocketManager.connect(); + + const handleChatStream = (message: WsMessage) => { + if (message.type === 'ChatStream') { + setUserMessages(prevMessages => { + const newMessages = [...prevMessages]; + const lastMessage = newMessages[newMessages.length - 1]; + + if (lastMessage && lastMessage.role === Assistant) { + if (typeof lastMessage.content === 'string') { + if (lastMessage.content === 'keepSearchIng') { + // 第一次收到流式消息,替换占位符 + lastMessage.content = message.chunk; + } else { + // 持续追加流式消息 + lastMessage.content += message.chunk; + } + } else { + // 如果 content 是数组,则更新第一个 text 部分 + const textPart = lastMessage.content.find(p => p.type === 'text'); + if (textPart) { + textPart.text = (textPart.text || '') + message.chunk; + } + } + } + return newMessages; + }); + } + }; + + const handleChatStreamEnd = (message: WsMessage) => { + if (message.type === 'ChatStreamEnd') { + setUserMessages(prevMessages => { + const newMessages = [...prevMessages]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage && lastMessage.role === Assistant) { + // 使用最终消息替换流式消息,确保 message.message 存在 + if (message.message) { + newMessages[newMessages.length - 1] = message.message as Message; + } else { + // 如果最终消息为空,则移除 'keepSearchIng' 占位符 + return prevMessages.filter(m => !(typeof m.content === 'string' && m.content === 'keepSearchIng')); + } + } + return newMessages; + }); + } + }; + + const handleError = (message: WsMessage) => { + if (message.type === 'Error') { + console.error(`WebSocket Error: ${message.code} - ${message.message}`); + // 可以在这里添加错误提示,例如替换最后一条消息为错误信息 + setUserMessages(prev => { + const newMessages = [...prev]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage && typeof lastMessage.content === 'string' && lastMessage.content === 'keepSearchIng') { + lastMessage.content = `Error: ${message.message}`; + } + return newMessages; + }) + } + }; + + webSocketManager.subscribe('ChatStream', handleChatStream); + webSocketManager.subscribe('ChatStreamEnd', handleChatStreamEnd); + webSocketManager.subscribe('Error', handleError); + + return () => { + webSocketManager.unsubscribe('ChatStream', handleChatStream); + webSocketManager.unsubscribe('ChatStreamEnd', handleChatStreamEnd); + webSocketManager.unsubscribe('Error', handleError); + // 可以在这里选择断开连接,或者保持连接以加快下次进入页面的速度 + // webSocketManager.disconnect(); + }; + }, []) + ); + useEffect(() => { if (sessionId) { setConversationId(sessionId); diff --git a/components/ask/send.tsx b/components/ask/send.tsx index 756b258..352dc46 100644 --- a/components/ask/send.tsx +++ b/components/ask/send.tsx @@ -15,7 +15,8 @@ import { import { Message } from '@/types/ask'; import { useTranslation } from 'react-i18next'; import { ThemedText } from '../ThemedText'; -import { createNewConversation, getConversation } from './utils'; +import { createNewConversation } from './utils'; +import { webSocketManager } from '@/lib/websocket-util'; interface Props { setIsHello: Dispatch>, @@ -88,37 +89,28 @@ export default function SendMessage(props: Props) { timestamp: new Date().toISOString() } ])); - // 如果没有对话ID,创建新对话并获取消息,否则直接获取消息 - if (!conversationId) { - const data = await createNewConversation(text); - setConversationId(data); - const response = await getConversation({ session_id: data, user_text: text, material_ids: [] }); - setSelectedImages([]); - setUserMessages((prev: Message[]) => { - const newMessages = [...(prev || [])]; - if (response) { - newMessages.push(response); - } - return newMessages.filter((item: Message) => - item?.content?.text !== 'keepSearchIng' - ); + let currentSessionId = conversationId; + // 如果没有对话ID,先创建一个新对话 + if (!currentSessionId) { + currentSessionId = await createNewConversation(text); + setConversationId(currentSessionId); + } + + // 通过 WebSocket 发送消息 + if (currentSessionId) { + webSocketManager.send({ + type: 'Chat', + session_id: currentSessionId, + message: text, + image_material_ids: selectedImages.length > 0 ? selectedImages : undefined, }); + setSelectedImages([]); } else { - const response = await getConversation({ - session_id: conversationId, - user_text: text, - material_ids: selectedImages - }); - setSelectedImages([]); - setUserMessages((prev: Message[]) => { - const newMessages = [...(prev || [])]; - if (response) { - newMessages.push(response); - } - return newMessages.filter((item: Message) => - item?.content?.text !== 'keepSearchIng' - ); - }); + console.error("无法获取 session_id,消息发送失败。"); + // 可以在这里处理错误,例如显示一个提示 + setUserMessages(prev => prev.filter(item => + !(typeof item.content === 'string' && item.content === 'keepSearchIng') + )); } // 将输入框清空 setInputValue(''); @@ -127,7 +119,7 @@ export default function SendMessage(props: Props) { Keyboard.dismiss(); } } - }, [inputValue, conversationId, selectedImages, createNewConversation, getConversation]); + }, [inputValue, conversationId, selectedImages, createNewConversation]); const handleQuitly = (type: string) => { setIsHello(false) diff --git a/components/layout/ask.tsx b/components/layout/ask.tsx index e933f68..73a657d 100644 --- a/components/layout/ask.tsx +++ b/components/layout/ask.tsx @@ -2,6 +2,7 @@ import ChatInSvg from "@/assets/icons/svg/chatIn.svg"; import ChatNotInSvg from "@/assets/icons/svg/chatNotIn.svg"; import PersonInSvg from "@/assets/icons/svg/personIn.svg"; import PersonNotInSvg from "@/assets/icons/svg/personNotIn.svg"; +import { WebSocketStatus } from "@/lib/websocket-util"; import { router, usePathname } from "expo-router"; import React, { useCallback, useEffect, useMemo } from 'react'; import { Dimensions, Image, StyleSheet, TouchableOpacity, View } from 'react-native'; @@ -41,11 +42,28 @@ const CenterButtonSvg = React.memo(() => ( )); -const AskNavbar = () => { +interface AskNavbarProps { + wsStatus: WebSocketStatus; +} + +const AskNavbar = ({ wsStatus }: AskNavbarProps) => { // 获取设备尺寸 const { width } = useMemo(() => Dimensions.get('window'), []); const pathname = usePathname(); + const statusColor = useMemo(() => { + switch (wsStatus) { + case 'connected': + return '#4CAF50'; // Green + case 'connecting': + case 'reconnecting': + return '#FFC107'; // Amber + case 'disconnected': + default: + return '#F44336'; // Red + } + }, [wsStatus]); + // 预加载目标页面 useEffect(() => { const preloadPages = async () => { @@ -128,8 +146,20 @@ const AskNavbar = () => { borderRadius: 50, backgroundColor: 'transparent', zIndex: 10, + }, + statusIndicator: { + position: 'absolute', + top: 15, + right: 15, + width: 10, + height: 10, + borderRadius: 5, + borderWidth: 1, + borderColor: '#FFF', + backgroundColor: statusColor, + zIndex: 11, } - }), [width]); + }), [width, statusColor]); // 如果当前路径是ask页面,则不渲染导航栏 if (pathname != '/memo-list' && pathname != '/owner') { @@ -155,6 +185,7 @@ const AskNavbar = () => { onPress={() => navigateTo('/ask')} style={styles.centerButton} > + diff --git a/lib/websocket-util.ts b/lib/websocket-util.ts new file mode 100644 index 0000000..7fc7775 --- /dev/null +++ b/lib/websocket-util.ts @@ -0,0 +1,226 @@ +import Constants from 'expo-constants'; +import * as SecureStore from 'expo-secure-store'; +import { Platform } from 'react-native'; + +// 从环境变量或默认值中定义 WebSocket 端点 +export const WEBSOCKET_ENDPOINT = Constants.expoConfig?.extra?.WEBSOCKET_ENDPOINT || "wss://api.memorywake.com/ws"; + +export type WebSocketStatus = 'connecting' | 'connected' | 'disconnected' | 'reconnecting'; + +type StatusListener = (status: WebSocketStatus) => void; + +// 消息监听器类型 +type MessageListener = (data: any) => void; + +// 根据后端 Rust 定义的 WsMessage 枚举创建 TypeScript 类型 +export type WsMessage = + | { type: 'Chat', session_id: string, message: string, image_material_ids?: string[], video_material_ids?: string[] } + | { type: 'ChatResponse', session_id: string, message: any, message_id?: string } + | { type: 'ChatStream', session_id: string, chunk: string } + | { type: 'ChatStreamEnd', session_id: string, message: any } + | { type: 'Error', code: string, message: string } + | { type: 'Ping' } + | { type: 'Pong' } + | { type: 'Connected', user_id: number }; + +class WebSocketManager { + private ws: WebSocket | null = null; + private status: WebSocketStatus = 'disconnected'; + private messageListeners: Map void>> = new Map(); + private statusListeners: Set = new Set(); + private reconnectAttempts = 0; + private readonly maxReconnectAttempts = 1; + private readonly reconnectInterval = 1000; // 初始重连间隔为1秒 + private pingIntervalId: ReturnType | null = null; + private readonly pingInterval = 30000; // 30秒发送一次心跳 + + constructor() { + // 这是一个单例类,连接通过调用 connect() 方法来启动 + } + + /** + * 获取当前 WebSocket 连接状态。 + */ + public getStatus(): WebSocketStatus { + return this.status; + } + + /** + * 启动 WebSocket 连接。 + * 会自动获取并使用存储的认证 token。 + */ + public async connect() { + if (this.ws && (this.status === 'connected' || this.status === 'connecting')) { + if (this.status === 'connected' || this.status === 'connecting') { + return; + } + } + + this.setStatus('connecting'); + + let token = ""; + if (Platform.OS === 'web') { + token = localStorage.getItem('token') || ""; + } else { + token = await SecureStore.getItemAsync('token') || ""; + } + + if (!token) { + console.error('WebSocket: 未找到认证 token,无法连接。'); + this.setStatus('disconnected'); + return; + } else { + console.log('WebSocket: 认证 token:', token); + } + + const url = `${WEBSOCKET_ENDPOINT}?token=${token}`; + console.log('WebSocket: 连接 URL:', url); + this.ws = new WebSocket(url); + + this.ws.onopen = () => { + console.log('WebSocket connected'); + this.setStatus('connected'); + this.reconnectAttempts = 0; // 重置重连尝试次数 + this.startPing(); + }; + + this.ws.onmessage = (event) => { + try { + const message: WsMessage = JSON.parse(event.data); + // 根据消息类型分发 + const eventListeners = this.messageListeners.get(message.type); + if (eventListeners) { + eventListeners.forEach(callback => callback(message)); + } + // 可以在这里处理通用的消息,比如 Pong + if (message.type === 'Pong') { + // console.log('Received Pong'); + } + } catch (error) { + console.error('处理 WebSocket 消息失败:', error); + } + }; + + this.ws.onerror = (error) => { + console.error('WebSocket 发生错误:', error); + }; + + this.ws.onclose = () => { + console.log('WebSocket disconnected'); + this.ws = null; + this.stopPing(); + // 只有在不是手动断开连接时才重连 + if (this.status !== 'disconnected') { + this.setStatus('reconnecting'); + this.handleReconnect(); + } + }; + } + + /** + * 处理自动重连逻辑,使用指数退避策略。 + */ + private handleReconnect() { + if (this.reconnectAttempts < this.maxReconnectAttempts) { + this.reconnectAttempts++; + const delay = this.reconnectInterval * Math.pow(2, this.reconnectAttempts - 1); + console.log(`${delay / 1000}秒后尝试重新连接 (第 ${this.reconnectAttempts} 次)...`); + setTimeout(() => { + this.connect(); + }, delay); + } else { + console.error('WebSocket 重连失败,已达到最大尝试次数。'); + this.setStatus('disconnected'); + } + } + + /** + * 发送消息到 WebSocket 服务器。 + * @param message 要发送的消息对象,必须包含 type 字段。 + */ + public send(message: WsMessage) { + if (this.status !== 'connected' || !this.ws) { + console.error('WebSocket 未连接,无法发送消息。'); + return; + } + this.ws.send(JSON.stringify(message)); + } + + /** + * 订阅指定消息类型的消息。 + * @param type 消息类型,例如 'ChatResponse'。 + * @param callback 收到消息时的回调函数。 + */ + public subscribe(type: WsMessage['type'], callback: (message: WsMessage) => void) { + if (!this.messageListeners.has(type)) { + this.messageListeners.set(type, new Set()); + } + this.messageListeners.get(type)?.add(callback); + } + + /** + * 取消订阅指定消息类型的消息。 + * @param type 消息类型。 + * @param callback 要移除的回调函数。 + */ + public unsubscribe(type: WsMessage['type'], callback: (message: WsMessage) => void) { + const eventListeners = this.messageListeners.get(type); + if (eventListeners) { + eventListeners.delete(callback); + if (eventListeners.size === 0) { + this.messageListeners.delete(type); + } + } + } + + /** + * 手动断开 WebSocket 连接。 + */ + public disconnect() { + this.setStatus('disconnected'); + if (this.ws) { + this.ws.close(); + } + this.stopPing(); + } + + private setStatus(status: WebSocketStatus) { + if (this.status !== status) { + this.status = status; + this.statusListeners.forEach(listener => listener(status)); + } + } + + public subscribeStatus(listener: StatusListener) { + this.statusListeners.add(listener); + // Immediately invoke with current status + listener(this.status); + } + + public unsubscribeStatus(listener: StatusListener) { + this.statusListeners.delete(listener); + } + + /** + * 启动心跳机制。 + */ + private startPing() { + this.stopPing(); // 先停止任何可能正在运行的计时器 + this.pingIntervalId = setInterval(() => { + this.send({ type: 'Ping' }); + }, this.pingInterval); + } + + /** + * 停止心跳机制。 + */ + private stopPing() { + if (this.pingIntervalId) { + clearInterval(this.pingIntervalId); + this.pingIntervalId = null; + } + } +} + +// 导出一个单例,确保整个应用共享同一个 WebSocket 连接 +export const webSocketManager = new WebSocketManager();