feat: websocket util

This commit is contained in:
Junhui Chen 2025-08-03 21:15:01 +08:00
parent 94340feabb
commit 9acdec2347
6 changed files with 378 additions and 35 deletions

View File

@ -85,6 +85,7 @@
}, },
"extra": { "extra": {
"API_ENDPOINT": "http://192.168.31.107:8081/api", "API_ENDPOINT": "http://192.168.31.107:8081/api",
"WEBSOCKET_ENDPOINT": "ws://192.168.31.107:8081/ws/chat",
"router": {}, "router": {},
"eas": { "eas": {
"projectId": "04721dd4-6b15-495a-b9ec-98187c613172" "projectId": "04721dd4-6b15-495a-b9ec-98187c613172"

View File

@ -6,6 +6,7 @@ import { Colors } from '@/constants/Colors';
import { useColorScheme } from '@/hooks/useColorScheme'; import { useColorScheme } from '@/hooks/useColorScheme';
import { prefetchChats } from '@/lib/prefetch'; import { prefetchChats } from '@/lib/prefetch';
import { fetchApi } from '@/lib/server-api-util'; import { fetchApi } from '@/lib/server-api-util';
import { webSocketManager, WebSocketStatus } from '@/lib/websocket-util';
import * as Notifications from 'expo-notifications'; import * as Notifications from 'expo-notifications';
import { Tabs } from 'expo-router'; import { Tabs } from 'expo-router';
import * as SecureStore from 'expo-secure-store'; import * as SecureStore from 'expo-secure-store';
@ -30,6 +31,7 @@ export default function TabLayout() {
const tokenInterval = useRef<NodeJS.Timeout | number>(null); const tokenInterval = useRef<NodeJS.Timeout | number>(null);
const isMounted = useRef(true); const isMounted = useRef(true);
const [token, setToken] = useState(''); const [token, setToken] = useState('');
const [wsStatus, setWsStatus] = useState<WebSocketStatus>('disconnected');
const sendNotification = async (item: PollingData) => { const sendNotification = async (item: PollingData) => {
// 请求通知权限 // 请求通知权限
const granted = await requestNotificationPermission(); const granted = await requestNotificationPermission();
@ -67,6 +69,16 @@ export default function TabLayout() {
}; };
}, []); }, []);
useEffect(() => {
const handleStatusChange = (status: WebSocketStatus) => {
setWsStatus(status);
};
webSocketManager.subscribeStatus(handleStatusChange);
return () => {
webSocketManager.unsubscribeStatus(handleStatusChange);
};
}, []);
// 轮询获取推送消息 // 轮询获取推送消息
const startPolling = useCallback(async (interval: number = 5000) => { const startPolling = useCallback(async (interval: number = 5000) => {
@ -365,7 +377,7 @@ export default function TabLayout() {
}} }}
/> />
</Tabs > </Tabs >
<AskNavbar /> <AskNavbar wsStatus={wsStatus} />
</> </>
); );
} }

View File

@ -4,7 +4,8 @@ import AskHello from "@/components/ask/hello";
import SendMessage from "@/components/ask/send"; import SendMessage from "@/components/ask/send";
import { ThemedText } from "@/components/ThemedText"; import { ThemedText } from "@/components/ThemedText";
import { fetchApi } from "@/lib/server-api-util"; import { fetchApi } from "@/lib/server-api-util";
import { Message } from "@/types/ask"; import { webSocketManager, WsMessage } from "@/lib/websocket-util";
import { Assistant, Message } from "@/types/ask";
import { router, useFocusEffect, useLocalSearchParams } from "expo-router"; import { router, useFocusEffect, useLocalSearchParams } from "expo-router";
import { useCallback, useEffect, useRef, useState } from 'react'; import { useCallback, useEffect, useRef, useState } from 'react';
import { import {
@ -78,6 +79,86 @@ export default function AskScreen() {
}; };
}, [isHello]); }, [isHello]);
useFocusEffect(
useCallback(() => {
webSocketManager.connect();
const handleChatStream = (message: WsMessage) => {
if (message.type === 'ChatStream') {
setUserMessages(prevMessages => {
const newMessages = [...prevMessages];
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage && lastMessage.role === Assistant) {
if (typeof lastMessage.content === 'string') {
if (lastMessage.content === 'keepSearchIng') {
// 第一次收到流式消息,替换占位符
lastMessage.content = message.chunk;
} else {
// 持续追加流式消息
lastMessage.content += message.chunk;
}
} else {
// 如果 content 是数组,则更新第一个 text 部分
const textPart = lastMessage.content.find(p => p.type === 'text');
if (textPart) {
textPart.text = (textPart.text || '') + message.chunk;
}
}
}
return newMessages;
});
}
};
const handleChatStreamEnd = (message: WsMessage) => {
if (message.type === 'ChatStreamEnd') {
setUserMessages(prevMessages => {
const newMessages = [...prevMessages];
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage && lastMessage.role === Assistant) {
// 使用最终消息替换流式消息,确保 message.message 存在
if (message.message) {
newMessages[newMessages.length - 1] = message.message as Message;
} else {
// 如果最终消息为空,则移除 'keepSearchIng' 占位符
return prevMessages.filter(m => !(typeof m.content === 'string' && m.content === 'keepSearchIng'));
}
}
return newMessages;
});
}
};
const handleError = (message: WsMessage) => {
if (message.type === 'Error') {
console.error(`WebSocket Error: ${message.code} - ${message.message}`);
// 可以在这里添加错误提示,例如替换最后一条消息为错误信息
setUserMessages(prev => {
const newMessages = [...prev];
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage && typeof lastMessage.content === 'string' && lastMessage.content === 'keepSearchIng') {
lastMessage.content = `Error: ${message.message}`;
}
return newMessages;
})
}
};
webSocketManager.subscribe('ChatStream', handleChatStream);
webSocketManager.subscribe('ChatStreamEnd', handleChatStreamEnd);
webSocketManager.subscribe('Error', handleError);
return () => {
webSocketManager.unsubscribe('ChatStream', handleChatStream);
webSocketManager.unsubscribe('ChatStreamEnd', handleChatStreamEnd);
webSocketManager.unsubscribe('Error', handleError);
// 可以在这里选择断开连接,或者保持连接以加快下次进入页面的速度
// webSocketManager.disconnect();
};
}, [])
);
useEffect(() => { useEffect(() => {
if (sessionId) { if (sessionId) {
setConversationId(sessionId); setConversationId(sessionId);

View File

@ -15,7 +15,8 @@ import {
import { Message } from '@/types/ask'; import { Message } from '@/types/ask';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ThemedText } from '../ThemedText'; import { ThemedText } from '../ThemedText';
import { createNewConversation, getConversation } from './utils'; import { createNewConversation } from './utils';
import { webSocketManager } from '@/lib/websocket-util';
interface Props { interface Props {
setIsHello: Dispatch<SetStateAction<boolean>>, setIsHello: Dispatch<SetStateAction<boolean>>,
@ -88,37 +89,28 @@ export default function SendMessage(props: Props) {
timestamp: new Date().toISOString() timestamp: new Date().toISOString()
} }
])); ]));
// 如果没有对话ID创建新对话并获取消息否则直接获取消息 let currentSessionId = conversationId;
if (!conversationId) { // 如果没有对话ID先创建一个新对话
const data = await createNewConversation(text); if (!currentSessionId) {
setConversationId(data); currentSessionId = await createNewConversation(text);
const response = await getConversation({ session_id: data, user_text: text, material_ids: [] }); setConversationId(currentSessionId);
setSelectedImages([]); }
setUserMessages((prev: Message[]) => {
const newMessages = [...(prev || [])]; // 通过 WebSocket 发送消息
if (response) { if (currentSessionId) {
newMessages.push(response); webSocketManager.send({
} type: 'Chat',
return newMessages.filter((item: Message) => session_id: currentSessionId,
item?.content?.text !== 'keepSearchIng' message: text,
); image_material_ids: selectedImages.length > 0 ? selectedImages : undefined,
}); });
setSelectedImages([]);
} else { } else {
const response = await getConversation({ console.error("无法获取 session_id消息发送失败。");
session_id: conversationId, // 可以在这里处理错误,例如显示一个提示
user_text: text, setUserMessages(prev => prev.filter(item =>
material_ids: selectedImages !(typeof item.content === 'string' && item.content === 'keepSearchIng')
}); ));
setSelectedImages([]);
setUserMessages((prev: Message[]) => {
const newMessages = [...(prev || [])];
if (response) {
newMessages.push(response);
}
return newMessages.filter((item: Message) =>
item?.content?.text !== 'keepSearchIng'
);
});
} }
// 将输入框清空 // 将输入框清空
setInputValue(''); setInputValue('');
@ -127,7 +119,7 @@ export default function SendMessage(props: Props) {
Keyboard.dismiss(); Keyboard.dismiss();
} }
} }
}, [inputValue, conversationId, selectedImages, createNewConversation, getConversation]); }, [inputValue, conversationId, selectedImages, createNewConversation]);
const handleQuitly = (type: string) => { const handleQuitly = (type: string) => {
setIsHello(false) setIsHello(false)

View File

@ -2,6 +2,7 @@ import ChatInSvg from "@/assets/icons/svg/chatIn.svg";
import ChatNotInSvg from "@/assets/icons/svg/chatNotIn.svg"; import ChatNotInSvg from "@/assets/icons/svg/chatNotIn.svg";
import PersonInSvg from "@/assets/icons/svg/personIn.svg"; import PersonInSvg from "@/assets/icons/svg/personIn.svg";
import PersonNotInSvg from "@/assets/icons/svg/personNotIn.svg"; import PersonNotInSvg from "@/assets/icons/svg/personNotIn.svg";
import { WebSocketStatus } from "@/lib/websocket-util";
import { router, usePathname } from "expo-router"; import { router, usePathname } from "expo-router";
import React, { useCallback, useEffect, useMemo } from 'react'; import React, { useCallback, useEffect, useMemo } from 'react';
import { Dimensions, Image, StyleSheet, TouchableOpacity, View } from 'react-native'; import { Dimensions, Image, StyleSheet, TouchableOpacity, View } from 'react-native';
@ -41,11 +42,28 @@ const CenterButtonSvg = React.memo(() => (
</Svg> </Svg>
)); ));
const AskNavbar = () => { interface AskNavbarProps {
wsStatus: WebSocketStatus;
}
const AskNavbar = ({ wsStatus }: AskNavbarProps) => {
// 获取设备尺寸 // 获取设备尺寸
const { width } = useMemo(() => Dimensions.get('window'), []); const { width } = useMemo(() => Dimensions.get('window'), []);
const pathname = usePathname(); const pathname = usePathname();
const statusColor = useMemo(() => {
switch (wsStatus) {
case 'connected':
return '#4CAF50'; // Green
case 'connecting':
case 'reconnecting':
return '#FFC107'; // Amber
case 'disconnected':
default:
return '#F44336'; // Red
}
}, [wsStatus]);
// 预加载目标页面 // 预加载目标页面
useEffect(() => { useEffect(() => {
const preloadPages = async () => { const preloadPages = async () => {
@ -128,8 +146,20 @@ const AskNavbar = () => {
borderRadius: 50, borderRadius: 50,
backgroundColor: 'transparent', backgroundColor: 'transparent',
zIndex: 10, zIndex: 10,
},
statusIndicator: {
position: 'absolute',
top: 15,
right: 15,
width: 10,
height: 10,
borderRadius: 5,
borderWidth: 1,
borderColor: '#FFF',
backgroundColor: statusColor,
zIndex: 11,
} }
}), [width]); }), [width, statusColor]);
// 如果当前路径是ask页面则不渲染导航栏 // 如果当前路径是ask页面则不渲染导航栏
if (pathname != '/memo-list' && pathname != '/owner') { if (pathname != '/memo-list' && pathname != '/owner') {
@ -155,6 +185,7 @@ const AskNavbar = () => {
onPress={() => navigateTo('/ask')} onPress={() => navigateTo('/ask')}
style={styles.centerButton} style={styles.centerButton}
> >
<View style={styles.statusIndicator} />
<CenterButtonSvg /> <CenterButtonSvg />
</TouchableOpacity> </TouchableOpacity>

226
lib/websocket-util.ts Normal file
View File

@ -0,0 +1,226 @@
import Constants from 'expo-constants';
import * as SecureStore from 'expo-secure-store';
import { Platform } from 'react-native';
// 从环境变量或默认值中定义 WebSocket 端点
export const WEBSOCKET_ENDPOINT = Constants.expoConfig?.extra?.WEBSOCKET_ENDPOINT || "wss://api.memorywake.com/ws";
export type WebSocketStatus = 'connecting' | 'connected' | 'disconnected' | 'reconnecting';
type StatusListener = (status: WebSocketStatus) => void;
// 消息监听器类型
type MessageListener = (data: any) => void;
// 根据后端 Rust 定义的 WsMessage 枚举创建 TypeScript 类型
export type WsMessage =
| { type: 'Chat', session_id: string, message: string, image_material_ids?: string[], video_material_ids?: string[] }
| { type: 'ChatResponse', session_id: string, message: any, message_id?: string }
| { type: 'ChatStream', session_id: string, chunk: string }
| { type: 'ChatStreamEnd', session_id: string, message: any }
| { type: 'Error', code: string, message: string }
| { type: 'Ping' }
| { type: 'Pong' }
| { type: 'Connected', user_id: number };
class WebSocketManager {
private ws: WebSocket | null = null;
private status: WebSocketStatus = 'disconnected';
private messageListeners: Map<string, Set<(message: WsMessage) => void>> = new Map();
private statusListeners: Set<StatusListener> = new Set();
private reconnectAttempts = 0;
private readonly maxReconnectAttempts = 1;
private readonly reconnectInterval = 1000; // 初始重连间隔为1秒
private pingIntervalId: ReturnType<typeof setInterval> | null = null;
private readonly pingInterval = 30000; // 30秒发送一次心跳
constructor() {
// 这是一个单例类,连接通过调用 connect() 方法来启动
}
/**
* WebSocket
*/
public getStatus(): WebSocketStatus {
return this.status;
}
/**
* WebSocket
* 使 token
*/
public async connect() {
if (this.ws && (this.status === 'connected' || this.status === 'connecting')) {
if (this.status === 'connected' || this.status === 'connecting') {
return;
}
}
this.setStatus('connecting');
let token = "";
if (Platform.OS === 'web') {
token = localStorage.getItem('token') || "";
} else {
token = await SecureStore.getItemAsync('token') || "";
}
if (!token) {
console.error('WebSocket: 未找到认证 token无法连接。');
this.setStatus('disconnected');
return;
} else {
console.log('WebSocket: 认证 token:', token);
}
const url = `${WEBSOCKET_ENDPOINT}?token=${token}`;
console.log('WebSocket: 连接 URL:', url);
this.ws = new WebSocket(url);
this.ws.onopen = () => {
console.log('WebSocket connected');
this.setStatus('connected');
this.reconnectAttempts = 0; // 重置重连尝试次数
this.startPing();
};
this.ws.onmessage = (event) => {
try {
const message: WsMessage = JSON.parse(event.data);
// 根据消息类型分发
const eventListeners = this.messageListeners.get(message.type);
if (eventListeners) {
eventListeners.forEach(callback => callback(message));
}
// 可以在这里处理通用的消息,比如 Pong
if (message.type === 'Pong') {
// console.log('Received Pong');
}
} catch (error) {
console.error('处理 WebSocket 消息失败:', error);
}
};
this.ws.onerror = (error) => {
console.error('WebSocket 发生错误:', error);
};
this.ws.onclose = () => {
console.log('WebSocket disconnected');
this.ws = null;
this.stopPing();
// 只有在不是手动断开连接时才重连
if (this.status !== 'disconnected') {
this.setStatus('reconnecting');
this.handleReconnect();
}
};
}
/**
* 使退
*/
private handleReconnect() {
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
const delay = this.reconnectInterval * Math.pow(2, this.reconnectAttempts - 1);
console.log(`${delay / 1000}秒后尝试重新连接 (第 ${this.reconnectAttempts} 次)...`);
setTimeout(() => {
this.connect();
}, delay);
} else {
console.error('WebSocket 重连失败,已达到最大尝试次数。');
this.setStatus('disconnected');
}
}
/**
* WebSocket
* @param message type
*/
public send(message: WsMessage) {
if (this.status !== 'connected' || !this.ws) {
console.error('WebSocket 未连接,无法发送消息。');
return;
}
this.ws.send(JSON.stringify(message));
}
/**
*
* @param type 'ChatResponse'
* @param callback
*/
public subscribe(type: WsMessage['type'], callback: (message: WsMessage) => void) {
if (!this.messageListeners.has(type)) {
this.messageListeners.set(type, new Set());
}
this.messageListeners.get(type)?.add(callback);
}
/**
*
* @param type
* @param callback
*/
public unsubscribe(type: WsMessage['type'], callback: (message: WsMessage) => void) {
const eventListeners = this.messageListeners.get(type);
if (eventListeners) {
eventListeners.delete(callback);
if (eventListeners.size === 0) {
this.messageListeners.delete(type);
}
}
}
/**
* WebSocket
*/
public disconnect() {
this.setStatus('disconnected');
if (this.ws) {
this.ws.close();
}
this.stopPing();
}
private setStatus(status: WebSocketStatus) {
if (this.status !== status) {
this.status = status;
this.statusListeners.forEach(listener => listener(status));
}
}
public subscribeStatus(listener: StatusListener) {
this.statusListeners.add(listener);
// Immediately invoke with current status
listener(this.status);
}
public unsubscribeStatus(listener: StatusListener) {
this.statusListeners.delete(listener);
}
/**
*
*/
private startPing() {
this.stopPing(); // 先停止任何可能正在运行的计时器
this.pingIntervalId = setInterval(() => {
this.send({ type: 'Ping' });
}, this.pingInterval);
}
/**
*
*/
private stopPing() {
if (this.pingIntervalId) {
clearInterval(this.pingIntervalId);
this.pingIntervalId = null;
}
}
}
// 导出一个单例,确保整个应用共享同一个 WebSocket 连接
export const webSocketManager = new WebSocketManager();