274 lines
8.8 KiB
TypeScript
274 lines
8.8 KiB
TypeScript
import { useCallback, useEffect, useRef } from 'react';
|
||
import { useTranslation } from 'react-i18next';
|
||
import { getWebSocketManager, WsMessage, getWebSocketErrorMessage } from '@/lib/websocket-util';
|
||
import { Message, Assistant } from '@/types/ask';
|
||
import { Dispatch, SetStateAction } from 'react';
|
||
|
||
interface UseWebSocketStreamHandlerOptions {
|
||
setUserMessages: Dispatch<SetStateAction<Message[]>>;
|
||
isMounted: boolean;
|
||
enableBatching?: boolean; // 是否启用批量处理
|
||
renderInterval?: number; // 渲染间隔,默认50ms
|
||
}
|
||
|
||
export const useWebSocketStreamHandler = ({
|
||
setUserMessages,
|
||
isMounted,
|
||
enableBatching = false,
|
||
renderInterval = 50
|
||
}: UseWebSocketStreamHandlerOptions) => {
|
||
const { t } = useTranslation();
|
||
const isMountedRef = useRef(isMounted);
|
||
|
||
// 批量处理相关的refs
|
||
const chunkQueue = useRef<string[]>([]);
|
||
const renderIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||
|
||
// 更新挂载状态
|
||
useEffect(() => {
|
||
isMountedRef.current = isMounted;
|
||
}, [isMounted]);
|
||
|
||
// 清理函数
|
||
const cleanup = useCallback(() => {
|
||
if (renderIntervalRef.current) {
|
||
clearInterval(renderIntervalRef.current);
|
||
renderIntervalRef.current = null;
|
||
}
|
||
chunkQueue.current = [];
|
||
}, []);
|
||
|
||
// 批量处理流式消息的函数
|
||
const processBatchedChunks = useCallback(() => {
|
||
if (!isMountedRef.current) {
|
||
cleanup();
|
||
return;
|
||
}
|
||
|
||
if (chunkQueue.current.length > 0) {
|
||
const textToRender = chunkQueue.current.join('');
|
||
chunkQueue.current = [];
|
||
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
if (prevMessages.length === 0) return prevMessages;
|
||
|
||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||
if (lastMessage.role !== Assistant) return prevMessages;
|
||
|
||
const updatedContent = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + textToRender;
|
||
const updatedLastMessage = { ...lastMessage, content: updatedContent };
|
||
|
||
return [...prevMessages.slice(0, -1), updatedLastMessage];
|
||
} catch (error) {
|
||
console.error('处理批量流式消息时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
} else {
|
||
cleanup();
|
||
}
|
||
}, [setUserMessages, cleanup]);
|
||
|
||
const handleChatStream = useCallback((message: WsMessage) => {
|
||
if (!isMountedRef.current || message.type !== 'ChatStream' || !message.chunk) return;
|
||
|
||
if (enableBatching) {
|
||
// 批量处理模式
|
||
chunkQueue.current.push(message.chunk);
|
||
|
||
if (!renderIntervalRef.current) {
|
||
renderIntervalRef.current = setInterval(processBatchedChunks, renderInterval);
|
||
}
|
||
} else {
|
||
// 实时处理模式(原有逻辑)
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||
|
||
if (!lastMessage || lastMessage.role !== Assistant) {
|
||
return prevMessages;
|
||
}
|
||
|
||
const newMessages = [...prevMessages];
|
||
|
||
if (typeof lastMessage.content === 'string') {
|
||
if (lastMessage.content === 'keepSearchIng') {
|
||
newMessages[newMessages.length - 1] = {
|
||
...lastMessage,
|
||
content: message.chunk
|
||
};
|
||
} else {
|
||
newMessages[newMessages.length - 1] = {
|
||
...lastMessage,
|
||
content: lastMessage.content + message.chunk
|
||
};
|
||
}
|
||
} else if (Array.isArray(lastMessage.content)) {
|
||
const textPartIndex = lastMessage.content.findIndex(p => p.type === 'text');
|
||
if (textPartIndex !== -1) {
|
||
const updatedContent = [...lastMessage.content];
|
||
updatedContent[textPartIndex] = {
|
||
...updatedContent[textPartIndex],
|
||
text: (updatedContent[textPartIndex].text || '') + message.chunk
|
||
};
|
||
|
||
newMessages[newMessages.length - 1] = {
|
||
...lastMessage,
|
||
content: updatedContent
|
||
};
|
||
}
|
||
}
|
||
|
||
return newMessages;
|
||
} catch (error) {
|
||
console.error('处理 ChatStream 消息时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
}
|
||
}, [setUserMessages, enableBatching, processBatchedChunks, renderInterval]);
|
||
|
||
const handleChatStreamEnd = useCallback((message: WsMessage) => {
|
||
if (!isMountedRef.current || message.type !== 'ChatStreamEnd') return;
|
||
|
||
// 如果是批量模式,先处理剩余的chunks
|
||
if (enableBatching) {
|
||
cleanup();
|
||
|
||
const remainingText = chunkQueue.current.join('');
|
||
chunkQueue.current = [];
|
||
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
if (prevMessages.length === 0) return prevMessages;
|
||
|
||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||
if (lastMessage.role !== Assistant) return prevMessages;
|
||
|
||
const contentWithQueue = (lastMessage.content === 'keepSearchIng' ? '' : lastMessage.content as string) + remainingText;
|
||
|
||
const updatedLastMessage = {
|
||
...lastMessage,
|
||
content: message.message ? message.message.content : contentWithQueue,
|
||
timestamp: message.message ? message.message.timestamp : lastMessage.timestamp,
|
||
};
|
||
|
||
return [...prevMessages.slice(0, -1), updatedLastMessage];
|
||
} catch (error) {
|
||
console.error('处理ChatStreamEnd消息时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
} else {
|
||
// 实时模式的处理逻辑
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||
|
||
if (!lastMessage || lastMessage.role !== Assistant) {
|
||
return prevMessages;
|
||
}
|
||
|
||
if (message.message) {
|
||
const newMessages = [...prevMessages];
|
||
newMessages[newMessages.length - 1] = message.message as Message;
|
||
return newMessages;
|
||
} else {
|
||
return prevMessages.filter(m =>
|
||
!(typeof m.content === 'string' && m.content === 'keepSearchIng')
|
||
);
|
||
}
|
||
} catch (error) {
|
||
console.error('处理 ChatStreamEnd 消息时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
}
|
||
}, [setUserMessages, enableBatching, cleanup]);
|
||
|
||
const handleChatResponse = useCallback((message: WsMessage) => {
|
||
if (!isMountedRef.current || message.type !== 'ChatResponse') return;
|
||
|
||
if (message.message) {
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
const updatedMessages = [...prevMessages];
|
||
updatedMessages[updatedMessages.length - 1] = message.message as Message;
|
||
return updatedMessages;
|
||
} catch (error) {
|
||
console.error('处理聊天响应时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
}
|
||
}, [setUserMessages]);
|
||
|
||
const handleError = useCallback((message: WsMessage) => {
|
||
if (!isMountedRef.current || message.type !== 'Error') return;
|
||
|
||
console.log(`WebSocket Error: ${message.code} - ${message.message}`);
|
||
|
||
setUserMessages(prevMessages => {
|
||
try {
|
||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||
|
||
if (!lastMessage ||
|
||
lastMessage.role !== Assistant ||
|
||
typeof lastMessage.content !== 'string' ||
|
||
lastMessage.content !== 'keepSearchIng') {
|
||
return prevMessages;
|
||
}
|
||
|
||
const newMessages = [...prevMessages];
|
||
newMessages[newMessages.length - 1] = {
|
||
...lastMessage,
|
||
content: getWebSocketErrorMessage(message.code, t)
|
||
};
|
||
|
||
return newMessages;
|
||
} catch (error) {
|
||
console.error('处理 Error 消息时出错:', error);
|
||
return prevMessages;
|
||
}
|
||
});
|
||
}, [setUserMessages, t]);
|
||
|
||
const subscribeToWebSocket = useCallback(() => {
|
||
const webSocketManager = getWebSocketManager();
|
||
webSocketManager.connect();
|
||
|
||
webSocketManager.subscribe('ChatStream', handleChatStream);
|
||
webSocketManager.subscribe('ChatStreamEnd', handleChatStreamEnd);
|
||
webSocketManager.subscribe('ChatResponse', handleChatResponse);
|
||
webSocketManager.subscribe('Error', handleError);
|
||
|
||
return () => {
|
||
// 清理订阅
|
||
webSocketManager.unsubscribe('ChatStream', handleChatStream);
|
||
webSocketManager.unsubscribe('ChatStreamEnd', handleChatStreamEnd);
|
||
webSocketManager.unsubscribe('ChatResponse', handleChatResponse);
|
||
webSocketManager.unsubscribe('Error', handleError);
|
||
|
||
// 清理批量处理资源
|
||
cleanup();
|
||
};
|
||
}, [handleChatStream, handleChatStreamEnd, handleChatResponse, handleError, cleanup]);
|
||
|
||
// 组件卸载时的清理
|
||
useEffect(() => {
|
||
return () => {
|
||
cleanup();
|
||
};
|
||
}, [cleanup]);
|
||
|
||
return {
|
||
subscribeToWebSocket,
|
||
handleChatStream,
|
||
handleChatStreamEnd,
|
||
handleChatResponse,
|
||
handleError,
|
||
cleanup
|
||
};
|
||
};
|