Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix: fix the missing user id in chat compeletition and fix remove unstarred topic not working #2677

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/app/api/chat/[provider]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ describe('POST handler', () => {
accessCode: 'test-access-code',
apiKey: 'test-api-key',
azureApiVersion: 'v1',
userId: 'abc',
});

const mockParams = { provider: 'test-provider' };
Expand All @@ -176,7 +177,7 @@ describe('POST handler', () => {
const response = await POST(request as unknown as Request, { params: mockParams });

expect(response).toEqual(mockChatResponse);
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload);
expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload, { user: 'abc' });
});

it('should return an error response when chat completion fails', async () => {
Expand Down
15 changes: 7 additions & 8 deletions src/app/api/chat/[provider]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@

const tracePayload = getTracePayload(req);

let traceOptions = {};
// If user enable trace
if (tracePayload?.enabled) {
return await agentRuntime.chat(
data,
createTraceOptions(data, {
provider,
trace: tracePayload,
}),
);
traceOptions = createTraceOptions(data, {
provider,
trace: tracePayload,
});

Check warning on line 34 in src/app/api/chat/[provider]/route.ts

View check run for this annotation

Codecov / codecov/patch

src/app/api/chat/[provider]/route.ts#L31-L34

Added lines #L31 - L34 were not covered by tests
}
return await agentRuntime.chat(data);

return await agentRuntime.chat(data, { user: jwtPayload.userId, ...traceOptions });
} catch (e) {
const {
errorType = ChatErrorType.InternalServerError,
Expand Down
6 changes: 6 additions & 0 deletions src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,11 @@ export interface JWTPayload {
awsAccessKeyId?: string;
awsRegion?: string;
awsSecretAccessKey?: string;
/**
* user id
* in client db mode it's a uuid
* in server db mode it's a user id
*/
userId?: string;
}
/* eslint-enable */
4 changes: 4 additions & 0 deletions src/libs/agent-runtime/types/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ export interface ChatCompetitionOptions {
callback?: ChatStreamCallbacks;
headers?: Record<string, any>;
signal?: AbortSignal;
/**
* userId for the chat completion
*/
user?: string;
}

export interface ChatCompletionFunctions {
Expand Down
17 changes: 10 additions & 7 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,21 @@ describe('LobeOpenAICompatibleFactory', () => {
});

describe('handlePayload option', () => {
it('should modify request payload correctly', async () => {
it('should add user in payload correctly', async () => {
const mockCreateMethod = vi.spyOn(instance['client'].chat.completions, 'create');

await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
},
{ user: 'abc' },
);

expect(mockCreateMethod).toHaveBeenCalledWith(
expect.objectContaining({
// 根据实际的 handlePayload 函数,添加断言
user: 'abc',
}),
expect.anything(),
);
Expand Down
13 changes: 8 additions & 5 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,14 @@ export const LobeOpenAICompatibleFactory = ({
stream: payload.stream ?? true,
} as OpenAI.ChatCompletionCreateParamsStreaming);

const response = await this.client.chat.completions.create(postPayload, {
// https://github.com/lobehub/lobe-chat/pull/318
headers: { Accept: '*/*' },
signal: options?.signal,
});
const response = await this.client.chat.completions.create(
{ ...postPayload, user: options?.user },
{
// https://github.com/lobehub/lobe-chat/pull/318
headers: { Accept: '*/*' },
signal: options?.signal,
},
);

if (postPayload.stream) {
const [prod, useForDebug] = response.tee();
Expand Down
9 changes: 7 additions & 2 deletions src/services/_auth.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { JWTPayload, LOBE_CHAT_AUTH_HEADER } from '@/const/auth';
import { ModelProvider } from '@/libs/agent-runtime';
import { useUserStore } from '@/store/user';
import { keyVaultsConfigSelectors, settingsSelectors } from '@/store/user/selectors';
import {
keyVaultsConfigSelectors,
settingsSelectors,
userProfileSelectors,
} from '@/store/user/selectors';
import { GlobalLLMProviderKey } from '@/types/user/settings';
import { createJWT } from '@/utils/jwt';

Expand Down Expand Up @@ -48,8 +52,9 @@ export const getProviderAuthPayload = (provider: string) => {

const createAuthTokenWithPayload = async (payload = {}) => {
const accessCode = settingsSelectors.password(useUserStore.getState());
const userId = userProfileSelectors.userId(useUserStore.getState());

return await createJWT<JWTPayload>({ accessCode, ...payload });
return await createJWT<JWTPayload>({ accessCode, userId, ...payload });
};

interface AuthParams {
Expand Down
18 changes: 12 additions & 6 deletions src/store/chat/slices/topic/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,14 @@ describe('topic action', () => {
// Set up mock state with unstarred topics
await act(async () => {
useChatStore.setState({
topics: [
{ id: 'topic-1', favorite: false },
{ id: 'topic-2', favorite: true },
{ id: 'topic-3', favorite: false },
] as ChatTopic[],
activeId: 'abc',
topicMaps: {
abc: [
{ id: 'topic-1', favorite: false },
{ id: 'topic-2', favorite: true },
{ id: 'topic-3', favorite: false },
] as ChatTopic[],
},
});
});
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');
Expand Down Expand Up @@ -431,7 +434,10 @@ describe('topic action', () => {
});

// Mock the `updateTopicTitleInSummary` and `refreshTopic` for spying
const updateTopicTitleInSummarySpy = vi.spyOn(result.current, 'updateTopicTitleInSummary');
const updateTopicTitleInSummarySpy = vi.spyOn(
result.current,
'internal_updateTopicTitleInSummary',
);
const refreshTopicSpy = vi.spyOn(result.current, 'refreshTopic');

// Mock the `chatService.fetchPresetTaskResult` to simulate the AI response
Expand Down
35 changes: 17 additions & 18 deletions src/store/chat/slices/topic/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// DON'T REMOVE THE FIRST LINE
import isEqual from 'fast-deep-equal';
import { t } from 'i18next';
import { produce } from 'immer';
import useSWR, { SWRResponse, mutate } from 'swr';
import { StateCreator } from 'zustand/vanilla';

Expand Down Expand Up @@ -37,19 +36,19 @@
removeAllTopics: () => Promise<void>;
removeSessionTopics: () => Promise<void>;
removeTopic: (id: string) => Promise<void>;
removeUnstarredTopic: () => void;
removeUnstarredTopic: () => Promise<void>;
saveToTopic: () => Promise<string | undefined>;
createTopic: () => Promise<string | undefined>;

autoRenameTopicTitle: (id: string) => Promise<void>;
duplicateTopic: (id: string) => Promise<void>;
summaryTopicTitle: (topicId: string, messages: ChatMessage[]) => Promise<void>;
switchTopic: (id?: string, skipRefreshMessage?: boolean) => Promise<void>;
updateTopicTitleInSummary: (id: string, title: string) => void;
updateTopicTitle: (id: string, title: string) => Promise<void>;
useFetchTopics: (sessionId: string) => SWRResponse<ChatTopic[]>;
useSearchTopics: (keywords?: string, sessionId?: string) => SWRResponse<ChatTopic[]>;

internal_updateTopicTitleInSummary: (id: string, title: string) => void;
internal_updateTopicLoading: (id: string, loading: boolean) => void;
internal_createTopic: (params: CreateTopicParams) => Promise<string>;
internal_updateTopic: (id: string, data: Partial<ChatTopic>) => Promise<void>;
Expand Down Expand Up @@ -133,18 +132,18 @@
},
// update
summaryTopicTitle: async (topicId, messages) => {
const { updateTopicTitleInSummary, internal_updateTopicLoading } = get();
const { internal_updateTopicTitleInSummary, internal_updateTopicLoading } = get();
const topic = topicSelectors.getTopicById(topicId)(get());
if (!topic) return;

updateTopicTitleInSummary(topicId, LOADING_FLAT);
internal_updateTopicTitleInSummary(topicId, LOADING_FLAT);

let output = '';

// 自动总结话题标题
await chatService.fetchPresetTaskResult({
onError: () => {
updateTopicTitleInSummary(topicId, topic.title);
internal_updateTopicTitleInSummary(topicId, topic.title);

Check warning on line 146 in src/store/chat/slices/topic/action.ts

View check run for this annotation

Codecov / codecov/patch

src/store/chat/slices/topic/action.ts#L146

Added line #L146 was not covered by tests
},
onFinish: async (text) => {
await get().internal_updateTopic(topicId, { title: text });
Expand All @@ -159,7 +158,7 @@
}
}

updateTopicTitleInSummary(topicId, output);
internal_updateTopicTitleInSummary(topicId, output);

Check warning on line 161 in src/store/chat/slices/topic/action.ts

View check run for this annotation

Codecov / codecov/patch

src/store/chat/slices/topic/action.ts#L161

Added line #L161 was not covered by tests
},
params: await chainSummaryTitle(messages),
trace: get().getCurrentTracePayload({ traceName: TraceNameMap.SummaryTopicTitle, topicId }),
Expand Down Expand Up @@ -264,15 +263,11 @@
},

// Internal process method of the topics
updateTopicTitleInSummary: (id, title) => {
const topics = produce(get().topics, (draftState) => {
const topic = draftState.find((i) => i.id === id);

if (!topic) return;
topic.title = title;
});

set({ topics }, false, n(`updateTopicTitleInSummary`, { id, title }));
internal_updateTopicTitleInSummary: (id, title) => {
get().internal_dispatchTopic(
{ type: 'updateTopic', id, value: { title } },
'updateTopicTitleInSummary',
);
},
refreshTopic: async () => {
return mutate([SWR_USE_FETCH_TOPIC, get().activeId]);
Expand Down Expand Up @@ -317,8 +312,12 @@
},

internal_dispatchTopic: (payload, action) => {
const nextTopics = topicReducer(get().topics, payload);
const nextTopics = topicReducer(topicSelectors.currentTopics(get()), payload);
const nextMap = { ...get().topicMaps, [get().activeId]: nextTopics };

// no need to update map if is the same
if (isEqual(nextMap, get().topicMaps)) return;

set({ topics: nextTopics }, false, action ?? n(`dispatchTopic/${payload.type}`));
set({ topicMaps: nextMap }, false, action ?? n(`dispatchTopic/${payload.type}`));
},
});
2 changes: 0 additions & 2 deletions src/store/chat/slices/topic/initialState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export interface ChatTopicState {
topicMaps: Record<string, ChatTopic[]>;
topicRenamingId?: string;
topicSearchKeywords: string;
topics: ChatTopic[];
/**
* whether topics have fetched
*/
Expand All @@ -23,6 +22,5 @@ export const initialTopicState: ChatTopicState = {
topicLoadingIds: [],
topicMaps: {},
topicSearchKeywords: '',
topics: [],
topicsInit: false,
};
2 changes: 1 addition & 1 deletion src/store/chat/slices/topic/selectors.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ describe('topicSelectors', () => {

describe('currentUnFavTopics', () => {
it('should return all unfavorited topics', () => {
const state = merge(initialStore, { topics: topicMaps.test });
const state = merge(initialStore, { topicMaps, activeId: 'test' });
const topics = topicSelectors.currentUnFavTopics(state);
expect(topics).toEqual([topicMaps.test[1]]);
});
Expand Down
3 changes: 2 additions & 1 deletion src/store/chat/slices/topic/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ const searchTopics = (s: ChatStore): ChatTopic[] => s.searchTopics;
const displayTopics = (s: ChatStore): ChatTopic[] | undefined =>
s.isSearchingTopic ? searchTopics(s) : currentTopics(s);

const currentUnFavTopics = (s: ChatStore): ChatTopic[] => s.topics.filter((s) => !s.favorite);
const currentUnFavTopics = (s: ChatStore): ChatTopic[] =>
currentTopics(s)?.filter((s) => !s.favorite) || [];

const currentTopicLength = (s: ChatStore): number => currentTopics(s)?.length || 0;

Expand Down
Loading