import { useCallback, useEffect, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { getWebSocketManager, WsMessage, getWebSocketErrorMessage } from '@/lib/websocket-util'; import { Message, Assistant } from '@/types/ask'; import { Dispatch, SetStateAction } from 'react'; interface UseWebSocketStreamHandlerOptions { setUserMessages: Dispatch>; isMounted: boolean; enableBatching?: boolean; // 是否启用批量处理 renderInterval?: number; // 渲染间隔,默认50ms } export const useWebSocketStreamHandler = ({ setUserMessages, isMounted, enableBatching = false, renderInterval = 50 }: UseWebSocketStreamHandlerOptions) => { const { t } = useTranslation(); const isMountedRef = useRef(isMounted); // 批量处理相关的refs const chunkQueue = useRef([]); const renderIntervalRef = useRef | null>(null); // 更新挂载状态 useEffect(() => { isMountedRef.current = isMounted; }, [isMounted]); // 清理函数 const cleanup = useCallback(() => { if (renderIntervalRef.current) { clearInterval(renderIntervalRef.current); renderIntervalRef.current = null; } chunkQueue.current = []; }, []); // 批量处理流式消息的函数 const processBatchedChunks = useCallback(() => { if (!isMountedRef.current) { cleanup(); return; } if (chunkQueue.current.length > 0) { const textToRender = chunkQueue.current.join(''); chunkQueue.current = []; setUserMessages(prevMessages => { try { if (prevMessages.length === 0) return prevMessages; const lastMessage = prevMessages[prevMessages.length - 1]; if (lastMessage.role !== Assistant) return prevMessages; const updatedContent = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + textToRender; const updatedLastMessage = { ...lastMessage, content: updatedContent }; return [...prevMessages.slice(0, -1), updatedLastMessage]; } catch (error) { console.error('处理批量流式消息时出错:', error); return prevMessages; } }); } else { cleanup(); } }, [setUserMessages, cleanup]); const handleChatStream = useCallback((message: WsMessage) => { if (!isMountedRef.current || message.type !== 'ChatStream' || !message.chunk) return; if (enableBatching) { // 批量处理模式 chunkQueue.current.push(message.chunk); if (!renderIntervalRef.current) { renderIntervalRef.current = setInterval(processBatchedChunks, renderInterval); } } else { // 实时处理模式(原有逻辑) setUserMessages(prevMessages => { try { const lastMessage = prevMessages[prevMessages.length - 1]; if (!lastMessage || lastMessage.role !== Assistant) { return prevMessages; } const newMessages = [...prevMessages]; if (typeof lastMessage.content === 'string') { if (lastMessage.content === 'keepSearchIng') { newMessages[newMessages.length - 1] = { ...lastMessage, content: message.chunk }; } else { newMessages[newMessages.length - 1] = { ...lastMessage, content: lastMessage.content + message.chunk }; } } else if (Array.isArray(lastMessage.content)) { const textPartIndex = lastMessage.content.findIndex(p => p.type === 'text'); if (textPartIndex !== -1) { const updatedContent = [...lastMessage.content]; updatedContent[textPartIndex] = { ...updatedContent[textPartIndex], text: (updatedContent[textPartIndex].text || '') + message.chunk }; newMessages[newMessages.length - 1] = { ...lastMessage, content: updatedContent }; } } return newMessages; } catch (error) { console.error('处理 ChatStream 消息时出错:', error); return prevMessages; } }); } }, [setUserMessages, enableBatching, processBatchedChunks, renderInterval]); const handleChatStreamEnd = useCallback((message: WsMessage) => { if (!isMountedRef.current || message.type !== 'ChatStreamEnd') return; // 如果是批量模式,先处理剩余的chunks if (enableBatching) { cleanup(); const remainingText = chunkQueue.current.join(''); chunkQueue.current = []; setUserMessages(prevMessages => { try { if (prevMessages.length === 0) return prevMessages; const lastMessage = prevMessages[prevMessages.length - 1]; if (lastMessage.role !== Assistant) return prevMessages; const contentWithQueue = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + remainingText; const updatedLastMessage = { ...lastMessage, content: message.message ? message.message.content : contentWithQueue, timestamp: message.message ? message.message.timestamp : lastMessage.timestamp, }; return [...prevMessages.slice(0, -1), updatedLastMessage]; } catch (error) { console.error('处理ChatStreamEnd消息时出错:', error); return prevMessages; } }); } else { // 实时模式的处理逻辑 setUserMessages(prevMessages => { try { const lastMessage = prevMessages[prevMessages.length - 1]; if (!lastMessage || lastMessage.role !== Assistant) { return prevMessages; } if (message.message) { const newMessages = [...prevMessages]; newMessages[newMessages.length - 1] = message.message as Message; return newMessages; } else { return prevMessages.filter(m => !(typeof m.content === 'string' && m.content === 'keepSearchIng') ); } } catch (error) { console.error('处理 ChatStreamEnd 消息时出错:', error); return prevMessages; } }); } }, [setUserMessages, enableBatching, cleanup]); const handleChatResponse = useCallback((message: WsMessage) => { if (!isMountedRef.current || message.type !== 'ChatResponse') return; if (message.message) { setUserMessages(prevMessages => { try { const updatedMessages = [...prevMessages]; updatedMessages[updatedMessages.length - 1] = message.message as Message; return updatedMessages; } catch (error) { console.error('处理聊天响应时出错:', error); return prevMessages; } }); } }, [setUserMessages]); const handleError = useCallback((message: WsMessage) => { if (!isMountedRef.current || message.type !== 'Error') return; console.log(`WebSocket Error: ${message.code} - ${message.message}`); setUserMessages(prevMessages => { try { const lastMessage = prevMessages[prevMessages.length - 1]; if (!lastMessage || lastMessage.role !== Assistant || typeof lastMessage.content !== 'string' || lastMessage.content !== 'keepSearchIng') { return prevMessages; } const newMessages = [...prevMessages]; newMessages[newMessages.length - 1] = { ...lastMessage, content: getWebSocketErrorMessage(message.code, t) }; return newMessages; } catch (error) { console.error('处理 Error 消息时出错:', error); return prevMessages; } }); }, [setUserMessages, t]); const subscribeToWebSocket = useCallback(() => { const webSocketManager = getWebSocketManager(); webSocketManager.connect(); webSocketManager.subscribe('ChatStream', handleChatStream); webSocketManager.subscribe('ChatStreamEnd', handleChatStreamEnd); webSocketManager.subscribe('ChatResponse', handleChatResponse); webSocketManager.subscribe('Error', handleError); return () => { // 清理订阅 webSocketManager.unsubscribe('ChatStream', handleChatStream); webSocketManager.unsubscribe('ChatStreamEnd', handleChatStreamEnd); webSocketManager.unsubscribe('ChatResponse', handleChatResponse); webSocketManager.unsubscribe('Error', handleError); // 清理批量处理资源 cleanup(); }; }, [handleChatStream, handleChatStreamEnd, handleChatResponse, handleError, cleanup]); // 组件卸载时的清理 useEffect(() => { return () => { cleanup(); }; }, [cleanup]); return { subscribeToWebSocket, handleChatStream, handleChatStreamEnd, handleChatResponse, handleError, cleanup }; };