memowake-front/hooks/useWebSocketStreamHandler.ts
2025-08-07 01:16:18 +08:00

274 lines
8.8 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { useCallback, useEffect, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { getWebSocketManager, WsMessage, getWebSocketErrorMessage } from '@/lib/websocket-util';
import { Message, Assistant } from '@/types/ask';
import { Dispatch, SetStateAction } from 'react';
interface UseWebSocketStreamHandlerOptions {
setUserMessages: Dispatch<SetStateAction<Message[]>>;
isMounted: boolean;
enableBatching?: boolean; // 是否启用批量处理
renderInterval?: number; // 渲染间隔默认50ms
}
export const useWebSocketStreamHandler = ({
setUserMessages,
isMounted,
enableBatching = false,
renderInterval = 50
}: UseWebSocketStreamHandlerOptions) => {
const { t } = useTranslation();
const isMountedRef = useRef(isMounted);
// 批量处理相关的refs
const chunkQueue = useRef<string[]>([]);
const renderIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
// 更新挂载状态
useEffect(() => {
isMountedRef.current = isMounted;
}, [isMounted]);
// 清理函数
const cleanup = useCallback(() => {
if (renderIntervalRef.current) {
clearInterval(renderIntervalRef.current);
renderIntervalRef.current = null;
}
chunkQueue.current = [];
}, []);
// 批量处理流式消息的函数
const processBatchedChunks = useCallback(() => {
if (!isMountedRef.current) {
cleanup();
return;
}
if (chunkQueue.current.length > 0) {
const textToRender = chunkQueue.current.join('');
chunkQueue.current = [];
setUserMessages(prevMessages => {
try {
if (prevMessages.length === 0) return prevMessages;
const lastMessage = prevMessages[prevMessages.length - 1];
if (lastMessage.role !== Assistant) return prevMessages;
const updatedContent = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + textToRender;
const updatedLastMessage = { ...lastMessage, content: updatedContent };
return [...prevMessages.slice(0, -1), updatedLastMessage];
} catch (error) {
console.error('处理批量流式消息时出错:', error);
return prevMessages;
}
});
} else {
cleanup();
}
}, [setUserMessages, cleanup]);
const handleChatStream = useCallback((message: WsMessage) => {
if (!isMountedRef.current || message.type !== 'ChatStream' || !message.chunk) return;
if (enableBatching) {
// 批量处理模式
chunkQueue.current.push(message.chunk);
if (!renderIntervalRef.current) {
renderIntervalRef.current = setInterval(processBatchedChunks, renderInterval);
}
} else {
// 实时处理模式(原有逻辑)
setUserMessages(prevMessages => {
try {
const lastMessage = prevMessages[prevMessages.length - 1];
if (!lastMessage || lastMessage.role !== Assistant) {
return prevMessages;
}
const newMessages = [...prevMessages];
if (typeof lastMessage.content === 'string') {
if (lastMessage.content === 'keepSearchIng') {
newMessages[newMessages.length - 1] = {
...lastMessage,
content: message.chunk
};
} else {
newMessages[newMessages.length - 1] = {
...lastMessage,
content: lastMessage.content + message.chunk
};
}
} else if (Array.isArray(lastMessage.content)) {
const textPartIndex = lastMessage.content.findIndex(p => p.type === 'text');
if (textPartIndex !== -1) {
const updatedContent = [...lastMessage.content];
updatedContent[textPartIndex] = {
...updatedContent[textPartIndex],
text: (updatedContent[textPartIndex].text || '') + message.chunk
};
newMessages[newMessages.length - 1] = {
...lastMessage,
content: updatedContent
};
}
}
return newMessages;
} catch (error) {
console.error('处理 ChatStream 消息时出错:', error);
return prevMessages;
}
});
}
}, [setUserMessages, enableBatching, processBatchedChunks, renderInterval]);
const handleChatStreamEnd = useCallback((message: WsMessage) => {
if (!isMountedRef.current || message.type !== 'ChatStreamEnd') return;
// 如果是批量模式先处理剩余的chunks
if (enableBatching) {
cleanup();
const remainingText = chunkQueue.current.join('');
chunkQueue.current = [];
setUserMessages(prevMessages => {
try {
if (prevMessages.length === 0) return prevMessages;
const lastMessage = prevMessages[prevMessages.length - 1];
if (lastMessage.role !== Assistant) return prevMessages;
const contentWithQueue = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + remainingText;
const updatedLastMessage = {
...lastMessage,
content: message.message ? message.message.content : contentWithQueue,
timestamp: message.message ? message.message.timestamp : lastMessage.timestamp,
};
return [...prevMessages.slice(0, -1), updatedLastMessage];
} catch (error) {
console.error('处理ChatStreamEnd消息时出错:', error);
return prevMessages;
}
});
} else {
// 实时模式的处理逻辑
setUserMessages(prevMessages => {
try {
const lastMessage = prevMessages[prevMessages.length - 1];
if (!lastMessage || lastMessage.role !== Assistant) {
return prevMessages;
}
if (message.message) {
const newMessages = [...prevMessages];
newMessages[newMessages.length - 1] = message.message as Message;
return newMessages;
} else {
return prevMessages.filter(m =>
!(typeof m.content === 'string' && m.content === 'keepSearchIng')
);
}
} catch (error) {
console.error('处理 ChatStreamEnd 消息时出错:', error);
return prevMessages;
}
});
}
}, [setUserMessages, enableBatching, cleanup]);
const handleChatResponse = useCallback((message: WsMessage) => {
if (!isMountedRef.current || message.type !== 'ChatResponse') return;
if (message.message) {
setUserMessages(prevMessages => {
try {
const updatedMessages = [...prevMessages];
updatedMessages[updatedMessages.length - 1] = message.message as Message;
return updatedMessages;
} catch (error) {
console.error('处理聊天响应时出错:', error);
return prevMessages;
}
});
}
}, [setUserMessages]);
const handleError = useCallback((message: WsMessage) => {
if (!isMountedRef.current || message.type !== 'Error') return;
console.log(`WebSocket Error: ${message.code} - ${message.message}`);
setUserMessages(prevMessages => {
try {
const lastMessage = prevMessages[prevMessages.length - 1];
if (!lastMessage ||
lastMessage.role !== Assistant ||
typeof lastMessage.content !== 'string' ||
lastMessage.content !== 'keepSearchIng') {
return prevMessages;
}
const newMessages = [...prevMessages];
newMessages[newMessages.length - 1] = {
...lastMessage,
content: getWebSocketErrorMessage(message.code, t)
};
return newMessages;
} catch (error) {
console.error('处理 Error 消息时出错:', error);
return prevMessages;
}
});
}, [setUserMessages, t]);
const subscribeToWebSocket = useCallback(() => {
const webSocketManager = getWebSocketManager();
webSocketManager.connect();
webSocketManager.subscribe('ChatStream', handleChatStream);
webSocketManager.subscribe('ChatStreamEnd', handleChatStreamEnd);
webSocketManager.subscribe('ChatResponse', handleChatResponse);
webSocketManager.subscribe('Error', handleError);
return () => {
// 清理订阅
webSocketManager.unsubscribe('ChatStream', handleChatStream);
webSocketManager.unsubscribe('ChatStreamEnd', handleChatStreamEnd);
webSocketManager.unsubscribe('ChatResponse', handleChatResponse);
webSocketManager.unsubscribe('Error', handleError);
// 清理批量处理资源
cleanup();
};
}, [handleChatStream, handleChatStreamEnd, handleChatResponse, handleError, cleanup]);
// 组件卸载时的清理
useEffect(() => {
return () => {
cleanup();
};
}, [cleanup]);
return {
subscribeToWebSocket,
handleChatStream,
handleChatStreamEnd,
handleChatResponse,
handleError,
cleanup
};
};