Skip to content

Commit 1195fd0

Browse files
authored
Merge pull request #110 from will-lp1/persist-tool-state
Persist tool state
2 parents 03e88dd + 4f06a4f commit 1195fd0

File tree

5 files changed

+227
-73
lines changed

5 files changed

+227
-73
lines changed

apps/saru/app/api/messages/route.ts

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { NextResponse } from 'next/server';
22
import { auth } from "@/lib/auth";
33
import { headers } from 'next/headers';
4-
import { getMessagesByChatId, getChatById } from '@/lib/db/queries';
4+
import { getMessagesByChatId, getChatById, getMessageById, updateToolMetadata } from '@/lib/db/queries';
55

66
export async function GET(request: Request) {
77
try {
@@ -45,4 +45,54 @@ export async function GET(request: Request) {
4545
console.error('Error fetching messages:', error);
4646
return NextResponse.json({ error: 'Error fetching messages' }, { status: 500 });
4747
}
48-
}
48+
}
49+
50+
export async function PATCH(request: Request) {
51+
try {
52+
const readonlyHeaders = await headers();
53+
const requestHeaders = new Headers(readonlyHeaders);
54+
const session = await auth.api.getSession({ headers: requestHeaders });
55+
56+
if (!session?.user?.id) {
57+
console.error('Session error in PATCH /api/messages');
58+
return NextResponse.json({ error: 'Authentication error' }, { status: 401 });
59+
}
60+
const userId = session.user.id;
61+
62+
const body = await request.json();
63+
const { messageId, chatId, toolCallId, applied, rejected } = body;
64+
65+
if (!messageId || !chatId || !toolCallId) {
66+
return NextResponse.json(
67+
{ error: 'messageId, chatId, and toolCallId are required' },
68+
{ status: 400 }
69+
);
70+
}
71+
72+
// Verify the chat belongs to the user
73+
const chat = await getChatById({ id: chatId });
74+
if (!chat) {
75+
return NextResponse.json({ error: 'Chat not found' }, { status: 404 });
76+
}
77+
if (chat.userId !== userId) {
78+
return NextResponse.json({ error: 'Unauthorized' }, { status: 403 });
79+
}
80+
81+
// Verify the message exists and belongs to the chat
82+
const message = await getMessageById({ id: messageId });
83+
if (!message) {
84+
return NextResponse.json({ error: 'Message not found' }, { status: 404 });
85+
}
86+
if (message.chatId !== chatId) {
87+
return NextResponse.json({ error: 'Message does not belong to this chat' }, { status: 403 });
88+
}
89+
90+
// Update the tool metadata
91+
await updateToolMetadata({ messageId, toolCallId, applied, rejected });
92+
93+
return NextResponse.json({ success: true });
94+
} catch (error) {
95+
console.error('Error updating message:', error);
96+
return NextResponse.json({ error: 'Error updating message' }, { status: 500 });
97+
}
98+
}

apps/saru/components/chat/message.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ const PurePreviewMessage = ({
232232
type={actionType}
233233
result={{ ...result, toolCallId: p.toolCallId } as DocumentToolResultProps['result']}
234234
isReadonly={isReadonly}
235+
chatId={chatId}
236+
messageId={message.id}
235237
/>
236238
);
237239
}

apps/saru/components/document/document-tool.tsx

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,25 @@ export interface DocumentToolResultProps {
106106
content?: string;
107107
message?: string;
108108
toolCallId?: string;
109+
applied?: boolean;
110+
rejected?: boolean;
109111
};
110112
isReadonly: boolean;
113+
chatId?: string;
114+
messageId?: string;
111115
}
112116

113117
function PureDocumentToolResult({
114118
type,
115119
result,
116120
isReadonly,
121+
chatId,
122+
messageId,
117123
}: DocumentToolResultProps) {
118124
const { document, setDocument } = useDocument();
119125
const [isSaving, setIsSaving] = useState(false);
120-
const [isApplied, setIsApplied] = useState(() => {
121-
if (type === 'update' && result.id && document.documentId === result.id) {
122-
return document.content === result.newContent;
123-
}
124-
return false;
125-
});
126-
const [isRejected, setIsRejected] = useState(false);
126+
const [isApplied, setIsApplied] = useState(() => result.applied ?? false);
127+
const [isRejected, setIsRejected] = useState(() => result.rejected ?? false);
127128

128129
const isUpdateProposal =
129130
type === 'update' &&
@@ -132,7 +133,7 @@ function PureDocumentToolResult({
132133
result.originalContent !== result.newContent;
133134

134135
useEffect(() => {
135-
if (isUpdateProposal && result.id && result.newContent) {
136+
if (isUpdateProposal && result.id && result.newContent && !isApplied && !isRejected) {
136137
const event = new CustomEvent('preview-document-update', {
137138
detail: {
138139
documentId: result.id,
@@ -142,7 +143,7 @@ function PureDocumentToolResult({
142143
});
143144
window.dispatchEvent(event);
144145
}
145-
}, [isUpdateProposal, result.id, result.newContent, result.originalContent]);
146+
}, [isUpdateProposal, result.id, result.newContent, result.originalContent, isApplied, isRejected]);
146147

147148
const handleApplyUpdate = useCallback(() => {
148149
if (type !== 'update' || !result.newContent || !result.id || isSaving) return;
@@ -171,6 +172,24 @@ function PureDocumentToolResult({
171172
detail: { documentId: result.id, newContent: result.newContent, transient: false },
172173
}));
173174
setIsApplied(true);
175+
176+
// Persist applied state to chat metadata
177+
if (chatId && messageId && result.toolCallId) {
178+
fetch('/api/messages', {
179+
method: 'PATCH',
180+
headers: { 'Content-Type': 'application/json' },
181+
body: JSON.stringify({
182+
messageId,
183+
chatId,
184+
toolCallId: result.toolCallId,
185+
applied: true,
186+
rejected: false,
187+
}),
188+
}).catch(err => {
189+
console.error('[DocumentToolResult] Failed to persist applied state:', err);
190+
});
191+
}
192+
174193
fetch('/api/document', {
175194
method: 'POST',
176195
headers: { 'Content-Type': 'application/json' },
@@ -196,6 +215,23 @@ function PureDocumentToolResult({
196215

197216
setIsRejected(true);
198217

218+
// Persist rejected state to chat metadata
219+
if (chatId && messageId && result.toolCallId) {
220+
fetch('/api/messages', {
221+
method: 'PATCH',
222+
headers: { 'Content-Type': 'application/json' },
223+
body: JSON.stringify({
224+
messageId,
225+
chatId,
226+
toolCallId: result.toolCallId,
227+
applied: false,
228+
rejected: true,
229+
}),
230+
}).catch(err => {
231+
console.error('[DocumentToolResult] Failed to persist rejected state:', err);
232+
});
233+
}
234+
199235
window.dispatchEvent(new CustomEvent('tool-result', {
200236
detail: {
201237
toolCallId: result.toolCallId,
@@ -219,7 +255,7 @@ function PureDocumentToolResult({
219255
});
220256
window.dispatchEvent(event);
221257
toast.info('Update proposal rejected.');
222-
}, [result.id, result.title, result.originalContent, result.newContent, result.toolCallId, type]);
258+
}, [result.id, result.title, result.originalContent, result.newContent, result.toolCallId, type, chatId, messageId]);
223259

224260
if (result.error) {
225261
return (

apps/saru/components/document/version-rail.tsx

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ export function VersionRail({ versions, currentIndex, onIndexChange, baseDocumen
6969
});
7070
},[versions]);
7171
React.useEffect(() => {
72-
if (currentIndex === versions.length - 1) {
73-
setSelectedIndex(currentIndex);
74-
} else {
75-
setSelectedIndex(currentIndex);
76-
}
77-
}, [currentIndex, versions.length]);
78-
72+
setSelectedIndex(currentIndex);
73+
}, [currentIndex]);
74+
75+
React.useEffect(() => {
76+
return () => {
77+
window.dispatchEvent(
78+
new CustomEvent('cancel-document-update', {
79+
detail: { documentId: baseDocumentId },
80+
})
81+
);
82+
};
83+
}, [baseDocumentId]);
84+
7985
const isViewingHistory = selectedIndex !== null && selectedIndex < versions.length - 1;
8086
if (isLoading || versions.length <= 1) {
8187
return <div className="w-full border-b bg-background h-1 group-hover:h-12 transition-all duration-200" />;
@@ -191,24 +197,22 @@ export function VersionRail({ versions, currentIndex, onIndexChange, baseDocumen
191197

192198
const handlePointerLeave = () => {
193199
setPressStart(null);
194-
if (hoverIndex !== null) {
195-
if (isViewingHistory && selectedIndex !== null && selectedIndex < versions.length - 1) {
196-
const v = versions[selectedIndex];
197-
window.dispatchEvent(
198-
new CustomEvent('preview-document-update', {
199-
detail: { documentId: baseDocumentId, newContent: v.content },
200-
})
201-
);
202-
} else {
203-
window.dispatchEvent(
204-
new CustomEvent('cancel-document-update', {
205-
detail: { documentId: baseDocumentId },
206-
})
207-
);
208-
}
209-
setHoverIndex(null);
210-
}
200+
setHoverIndex(null);
211201

202+
if (selectedIndex !== null && selectedIndex < versions.length - 1) {
203+
const v = versions[selectedIndex];
204+
window.dispatchEvent(
205+
new CustomEvent('preview-document-update', {
206+
detail: { documentId: baseDocumentId, newContent: v.content },
207+
})
208+
);
209+
} else {
210+
window.dispatchEvent(
211+
new CustomEvent('cancel-document-update', {
212+
detail: { documentId: baseDocumentId },
213+
})
214+
);
215+
}
212216
};
213217

214218
const Tooltip = ({ active, payload }: { active?: boolean; payload?: Array<{ payload: { ts: string; additions: number; deletions: number } }> }) => {

0 commit comments

Comments
 (0)