Compare commits

...

1 Commits

Author SHA1 Message Date
Innei 24da7363c3 🐛 fix: avoid premature keep-mounted chat items 2026-04-24 16:21:43 +08:00
3 changed files with 265 additions and 32 deletions
@@ -2,7 +2,7 @@
import isEqual from 'fast-deep-equal';
import { type ReactElement, type ReactNode } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef } from 'react';
import { memo, useCallback, useEffect, useRef } from 'react';
import { type VListHandle } from 'virtua';
import { VList } from 'virtua';
import { useShallow } from 'zustand/react/shallow';
@@ -18,6 +18,7 @@ import {
CONVERSATION_SPACER_TRANSITION_MS,
useConversationSpacer,
} from '../hooks/useConversationSpacer';
import { useKeepMountedIndices } from '../hooks/useKeepMountedIndices';
import { useScrollToUserMessage } from '../hooks/useScrollToUserMessage';
import { useSelectionMessageIds } from '../hooks/useSelectionMessageIds';
import AutoScroll from './AutoScroll';
@@ -60,6 +61,31 @@ const VirtualizedList = memo<VirtualizedListProps>(({ dataSource, itemContent })
const setActiveIndex = useConversationStore((s) => s.setActiveIndex);
const activeIndex = useConversationStore(virtuaListSelectors.activeIndex);
// Keep a streaming item mounted only after it has entered the viewport once.
// Pre-mounting a new offscreen assistant item via keepMounted changes virtua's
// measurement/scroll scheduling and can block the send-time user-message pin.
const streamingMessageIds = useConversationStore(
useShallow((s) => {
const ids: string[] = [];
for (const id of dataSource) {
if (messageStateSelectors.isMessageGenerating(id)(s)) ids.push(id);
}
return ids;
}),
);
// Also keep items that host the active text selection — unmounting a node
// containing a Selection endpoint would silently drop the user's highlight.
const selectionMessageIds = useSelectionMessageIds();
const getVirtua = useCallback(() => virtuaRef.current, []);
const { keepMountedIndices, trackVisibleStreamingItems } = useKeepMountedIndices({
dataSource,
getVirtua,
selectionMessageIds,
streamingMessageIds,
});
// Check if at bottom based on scroll position
const checkAtBottom = useCallback(() => {
const ref = virtuaRef.current;
@@ -83,6 +109,7 @@ const VirtualizedList = memo<VirtualizedListProps>(({ dataSource, itemContent })
typeof activeFromFindRaw === 'number' && activeFromFindRaw >= 0 ? activeFromFindRaw : null;
if (activeFromFind !== activeIndex) setActiveIndex(activeFromFind);
trackVisibleStreamingItems();
setScrollState({ isScrolling: true });
@@ -105,11 +132,19 @@ const VirtualizedList = memo<VirtualizedListProps>(({ dataSource, itemContent })
scrollEndTimerRef.current = setTimeout(() => {
setScrollState({ isScrolling: false });
}, 150);
}, [activeIndex, checkAtBottom, handleScrollOffset, setActiveIndex, setScrollState]);
}, [
activeIndex,
checkAtBottom,
handleScrollOffset,
setActiveIndex,
setScrollState,
trackVisibleStreamingItems,
]);
const handleScrollEnd = useCallback(() => {
trackVisibleStreamingItems();
setScrollState({ isScrolling: false });
}, [setScrollState]);
}, [setScrollState, trackVisibleStreamingItems]);
// Register scroll methods to store on mount
useEffect(() => {
@@ -153,35 +188,6 @@ const VirtualizedList = memo<VirtualizedListProps>(({ dataSource, itemContent })
const secondLastMessage = displayMessages.at(-2);
const isSecondLastMessageFromUser = secondLastMessage?.role === 'user';
// Keep currently-streaming items mounted so vlist recycling never triggers
// Markdown animation replay when the user scrolls them back into view.
const streamingIndices = useConversationStore(
useShallow((s) => {
const indices: number[] = [];
for (let i = 0; i < dataSource.length; i++) {
const id = dataSource[i];
if (!id) continue;
if (messageStateSelectors.isMessageGenerating(id)(s)) indices.push(i);
}
return indices;
}),
);
// Also keep items that host the active text selection — unmounting a node
// containing a Selection endpoint would silently drop the user's highlight.
const selectionMessageIds = useSelectionMessageIds();
const keepMountedIndices = useMemo(() => {
if (selectionMessageIds.size === 0) return streamingIndices;
const merged = new Set<number>(streamingIndices);
for (let i = 0; i < dataSource.length; i++) {
const id = dataSource[i];
if (id && selectionMessageIds.has(id)) merged.add(i);
}
if (merged.size === streamingIndices.length) return streamingIndices;
return [...merged].sort((a, b) => a - b);
}, [dataSource, streamingIndices, selectionMessageIds]);
// Auto scroll to user message when user sends a new message
// Only scroll when 2 new messages are added and second-to-last is from user
useScrollToUserMessage({
@@ -0,0 +1,60 @@
import { describe, expect, it } from 'vitest';
import {
getKeepMountedIndices,
getVisibleStreamingMessageIds,
prunePinnedStreamingMessageIds,
} from './useKeepMountedIndices';
describe('useKeepMountedIndices helpers', () => {
it('should not pass keepMounted when no message needs pinning', () => {
expect(
getKeepMountedIndices({
dataSource: ['user-1', 'assistant-1'],
pinnedStreamingMessageIds: new Set(),
selectionMessageIds: new Set(),
}),
).toBeUndefined();
});
it('should keep only streaming messages that have already been pinned', () => {
expect(
getKeepMountedIndices({
dataSource: ['user-1', 'assistant-1', 'user-2', 'assistant-2'],
pinnedStreamingMessageIds: new Set(['assistant-1']),
selectionMessageIds: new Set(),
}),
).toEqual([1]);
});
it('should keep selected messages independently from streaming pins', () => {
expect(
getKeepMountedIndices({
dataSource: ['user-1', 'assistant-1', 'user-2'],
pinnedStreamingMessageIds: new Set(['assistant-1']),
selectionMessageIds: new Set(['user-2']),
}),
).toEqual([1, 2]);
});
it('should only mark streaming messages inside the current viewport range as visible', () => {
expect(
getVisibleStreamingMessageIds({
dataSource: ['user-1', 'assistant-1', 'user-2', 'assistant-2'],
endIndex: 1,
startIndex: 0,
streamingMessageIds: new Set(['assistant-1', 'assistant-2']),
}),
).toEqual(new Set(['assistant-1']));
});
it('should prune pins once messages stop streaming or leave the data source', () => {
expect(
prunePinnedStreamingMessageIds({
dataSource: ['user-1', 'assistant-1', 'user-2'],
pinnedStreamingMessageIds: new Set(['assistant-1', 'assistant-2']),
streamingMessageIds: new Set(['assistant-2']),
}),
).toEqual(new Set());
});
});
@@ -0,0 +1,167 @@
import { useCallback, useEffect, useMemo, useState } from 'react';
interface VirtuaViewport {
findItemIndex: (offset: number) => number;
scrollOffset: number;
viewportSize: number;
}
interface GetKeepMountedIndicesOptions {
dataSource: readonly string[];
pinnedStreamingMessageIds: ReadonlySet<string>;
selectionMessageIds: ReadonlySet<string>;
}
interface GetVisibleStreamingMessageIdsOptions {
dataSource: readonly string[];
endIndex: number | null;
startIndex: number | null;
streamingMessageIds: ReadonlySet<string>;
}
interface PrunePinnedStreamingMessageIdsOptions {
dataSource: readonly string[];
pinnedStreamingMessageIds: ReadonlySet<string>;
streamingMessageIds: ReadonlySet<string>;
}
interface UseKeepMountedIndicesOptions {
dataSource: readonly string[];
getVirtua: () => VirtuaViewport | null;
selectionMessageIds: ReadonlySet<string>;
streamingMessageIds: readonly string[];
}
const setsEqual = (a: ReadonlySet<string>, b: ReadonlySet<string>) => {
if (a.size !== b.size) return false;
for (const value of a) {
if (!b.has(value)) return false;
}
return true;
};
export const getKeepMountedIndices = ({
dataSource,
pinnedStreamingMessageIds,
selectionMessageIds,
}: GetKeepMountedIndicesOptions) => {
const keepMountedIndices: number[] = [];
for (const [index, id] of dataSource.entries()) {
if (pinnedStreamingMessageIds.has(id) || selectionMessageIds.has(id)) {
keepMountedIndices.push(index);
}
}
return keepMountedIndices.length === 0 ? undefined : keepMountedIndices;
};
export const getVisibleStreamingMessageIds = ({
dataSource,
endIndex,
startIndex,
streamingMessageIds,
}: GetVisibleStreamingMessageIdsOptions) => {
const visibleStreamingMessageIds = new Set<string>();
if (streamingMessageIds.size === 0 || startIndex === null || endIndex === null) {
return visibleStreamingMessageIds;
}
const start = Math.max(0, Math.min(startIndex, endIndex));
const end = Math.min(dataSource.length - 1, Math.max(startIndex, endIndex));
if (end < start) return visibleStreamingMessageIds;
for (const [index, id] of dataSource.entries()) {
if (index < start) continue;
if (index > end) break;
if (streamingMessageIds.has(id)) visibleStreamingMessageIds.add(id);
}
return visibleStreamingMessageIds;
};
export const prunePinnedStreamingMessageIds = ({
dataSource,
pinnedStreamingMessageIds,
streamingMessageIds,
}: PrunePinnedStreamingMessageIdsOptions) => {
if (pinnedStreamingMessageIds.size === 0) return pinnedStreamingMessageIds;
const dataSourceIds = new Set(dataSource);
const nextPinnedIds = new Set<string>();
for (const id of pinnedStreamingMessageIds) {
if (dataSourceIds.has(id) && streamingMessageIds.has(id)) nextPinnedIds.add(id);
}
return setsEqual(pinnedStreamingMessageIds, nextPinnedIds)
? pinnedStreamingMessageIds
: nextPinnedIds;
};
export const useKeepMountedIndices = ({
dataSource,
getVirtua,
selectionMessageIds,
streamingMessageIds,
}: UseKeepMountedIndicesOptions) => {
const streamingMessageIdSet = useMemo(() => new Set(streamingMessageIds), [streamingMessageIds]);
const [pinnedStreamingMessageIds, setPinnedStreamingMessageIds] = useState<ReadonlySet<string>>(
() => new Set(),
);
useEffect(() => {
setPinnedStreamingMessageIds((prev) =>
prunePinnedStreamingMessageIds({
dataSource,
pinnedStreamingMessageIds: prev,
streamingMessageIds: streamingMessageIdSet,
}),
);
}, [dataSource, streamingMessageIdSet]);
const trackVisibleStreamingItems = useCallback(() => {
const virtua = getVirtua();
if (!virtua || streamingMessageIdSet.size === 0) return;
const visibleStreamingMessageIds = getVisibleStreamingMessageIds({
dataSource,
endIndex: virtua.findItemIndex(virtua.scrollOffset + virtua.viewportSize),
startIndex: virtua.findItemIndex(virtua.scrollOffset),
streamingMessageIds: streamingMessageIdSet,
});
if (visibleStreamingMessageIds.size === 0) return;
setPinnedStreamingMessageIds((prev) => {
const nextPinnedIds = new Set(prev);
for (const id of visibleStreamingMessageIds) {
nextPinnedIds.add(id);
}
return setsEqual(prev, nextPinnedIds) ? prev : nextPinnedIds;
});
}, [dataSource, getVirtua, streamingMessageIdSet]);
useEffect(() => {
trackVisibleStreamingItems();
}, [trackVisibleStreamingItems]);
const keepMountedIndices = useMemo(
() =>
getKeepMountedIndices({
dataSource,
pinnedStreamingMessageIds,
selectionMessageIds,
}),
[dataSource, pinnedStreamingMessageIds, selectionMessageIds],
);
return {
keepMountedIndices,
trackVisibleStreamingItems,
};
};