Skip to content

Commit ebc2ffa

Browse files
authored
fix(agent): always fetch latest custom tool from DB when customToolId is present (#3208)
* fix(agent): always fetch latest custom tool from DB when customToolId is present * test(agent): use generic test data for customToolId resolution tests * fix(agent): mock buildAuthHeaders in tests for CI compatibility * remove inline mocks in favor of sim/testing ones
1 parent c380e59 commit ebc2ffa

File tree

13 files changed

+452
-158
lines changed

13 files changed

+452
-158
lines changed

apps/sim/app/api/auth/oauth/utils.test.ts

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,10 @@
44
* @vitest-environment node
55
*/
66

7-
import { loggerMock } from '@sim/testing'
7+
import { databaseMock, loggerMock } from '@sim/testing'
88
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
99

10-
vi.mock('@sim/db', () => ({
11-
db: {
12-
select: vi.fn().mockReturnThis(),
13-
from: vi.fn().mockReturnThis(),
14-
where: vi.fn().mockReturnThis(),
15-
limit: vi.fn().mockReturnValue([]),
16-
update: vi.fn().mockReturnThis(),
17-
set: vi.fn().mockReturnThis(),
18-
orderBy: vi.fn().mockReturnThis(),
19-
},
20-
}))
10+
vi.mock('@sim/db', () => databaseMock)
2111

2212
vi.mock('@/lib/oauth/oauth', () => ({
2313
refreshOAuthToken: vi.fn(),
@@ -34,13 +24,36 @@ import {
3424
refreshTokenIfNeeded,
3525
} from '@/app/api/auth/oauth/utils'
3626

37-
const mockDbTyped = db as any
27+
const mockDb = db as any
3828
const mockRefreshOAuthToken = refreshOAuthToken as any
3929

30+
/**
31+
* Creates a chainable mock for db.select() calls.
32+
* Returns a nested chain: select() -> from() -> where() -> limit() / orderBy()
33+
*/
34+
function mockSelectChain(limitResult: unknown[]) {
35+
const mockLimit = vi.fn().mockReturnValue(limitResult)
36+
const mockOrderBy = vi.fn().mockReturnValue(limitResult)
37+
const mockWhere = vi.fn().mockReturnValue({ limit: mockLimit, orderBy: mockOrderBy })
38+
const mockFrom = vi.fn().mockReturnValue({ where: mockWhere })
39+
mockDb.select.mockReturnValueOnce({ from: mockFrom })
40+
return { mockFrom, mockWhere, mockLimit }
41+
}
42+
43+
/**
44+
* Creates a chainable mock for db.update() calls.
45+
* Returns a nested chain: update() -> set() -> where()
46+
*/
47+
function mockUpdateChain() {
48+
const mockWhere = vi.fn().mockResolvedValue({})
49+
const mockSet = vi.fn().mockReturnValue({ where: mockWhere })
50+
mockDb.update.mockReturnValueOnce({ set: mockSet })
51+
return { mockSet, mockWhere }
52+
}
53+
4054
describe('OAuth Utils', () => {
4155
beforeEach(() => {
4256
vi.clearAllMocks()
43-
mockDbTyped.limit.mockReturnValue([])
4457
})
4558

4659
afterEach(() => {
@@ -50,20 +63,20 @@ describe('OAuth Utils', () => {
5063
describe('getCredential', () => {
5164
it('should return credential when found', async () => {
5265
const mockCredential = { id: 'credential-id', userId: 'test-user-id' }
53-
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
66+
const { mockFrom, mockWhere, mockLimit } = mockSelectChain([mockCredential])
5467

5568
const credential = await getCredential('request-id', 'credential-id', 'test-user-id')
5669

57-
expect(mockDbTyped.select).toHaveBeenCalled()
58-
expect(mockDbTyped.from).toHaveBeenCalled()
59-
expect(mockDbTyped.where).toHaveBeenCalled()
60-
expect(mockDbTyped.limit).toHaveBeenCalledWith(1)
70+
expect(mockDb.select).toHaveBeenCalled()
71+
expect(mockFrom).toHaveBeenCalled()
72+
expect(mockWhere).toHaveBeenCalled()
73+
expect(mockLimit).toHaveBeenCalledWith(1)
6174

6275
expect(credential).toEqual(mockCredential)
6376
})
6477

6578
it('should return undefined when credential is not found', async () => {
66-
mockDbTyped.limit.mockReturnValueOnce([])
79+
mockSelectChain([])
6780

6881
const credential = await getCredential('request-id', 'nonexistent-id', 'test-user-id')
6982

@@ -102,11 +115,12 @@ describe('OAuth Utils', () => {
102115
refreshToken: 'new-refresh-token',
103116
})
104117

118+
mockUpdateChain()
119+
105120
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
106121

107122
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
108-
expect(mockDbTyped.update).toHaveBeenCalled()
109-
expect(mockDbTyped.set).toHaveBeenCalled()
123+
expect(mockDb.update).toHaveBeenCalled()
110124
expect(result).toEqual({ accessToken: 'new-token', refreshed: true })
111125
})
112126

@@ -152,7 +166,7 @@ describe('OAuth Utils', () => {
152166
providerId: 'google',
153167
userId: 'test-user-id',
154168
}
155-
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
169+
mockSelectChain([mockCredential])
156170

157171
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
158172

@@ -169,7 +183,8 @@ describe('OAuth Utils', () => {
169183
providerId: 'google',
170184
userId: 'test-user-id',
171185
}
172-
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
186+
mockSelectChain([mockCredential])
187+
mockUpdateChain()
173188

174189
mockRefreshOAuthToken.mockResolvedValueOnce({
175190
accessToken: 'new-token',
@@ -180,13 +195,12 @@ describe('OAuth Utils', () => {
180195
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
181196

182197
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
183-
expect(mockDbTyped.update).toHaveBeenCalled()
184-
expect(mockDbTyped.set).toHaveBeenCalled()
198+
expect(mockDb.update).toHaveBeenCalled()
185199
expect(token).toBe('new-token')
186200
})
187201

188202
it('should return null if credential not found', async () => {
189-
mockDbTyped.limit.mockReturnValueOnce([])
203+
mockSelectChain([])
190204

191205
const token = await refreshAccessTokenIfNeeded('nonexistent-id', 'test-user-id', 'request-id')
192206

@@ -202,7 +216,7 @@ describe('OAuth Utils', () => {
202216
providerId: 'google',
203217
userId: 'test-user-id',
204218
}
205-
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
219+
mockSelectChain([mockCredential])
206220

207221
mockRefreshOAuthToken.mockResolvedValueOnce(null)
208222

apps/sim/app/api/knowledge/search/utils.test.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44
*
55
* @vitest-environment node
66
*/
7-
import { createEnvMock, createMockLogger } from '@sim/testing'
7+
import { createEnvMock, databaseMock, loggerMock } from '@sim/testing'
88
import { beforeEach, describe, expect, it, vi } from 'vitest'
99

10-
const loggerMock = vi.hoisted(() => ({
11-
createLogger: () => createMockLogger(),
12-
}))
13-
1410
vi.mock('drizzle-orm')
1511
vi.mock('@sim/logger', () => loggerMock)
16-
vi.mock('@sim/db')
12+
vi.mock('@sim/db', () => databaseMock)
1713
vi.mock('@/lib/knowledge/documents/utils', () => ({
1814
retryWithExponentialBackoff: (fn: any) => fn(),
1915
}))

apps/sim/app/api/schedules/[id]/route.test.ts

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
*
44
* @vitest-environment node
55
*/
6-
import { loggerMock } from '@sim/testing'
6+
import { databaseMock, loggerMock } from '@sim/testing'
77
import { NextRequest } from 'next/server'
88
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
99

10-
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission, mockDbSelect, mockDbUpdate } =
11-
vi.hoisted(() => ({
12-
mockGetSession: vi.fn(),
13-
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
14-
mockDbSelect: vi.fn(),
15-
mockDbUpdate: vi.fn(),
16-
}))
10+
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission } = vi.hoisted(() => ({
11+
mockGetSession: vi.fn(),
12+
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
13+
}))
1714

1815
vi.mock('@/lib/auth', () => ({
1916
getSession: mockGetSession,
@@ -23,12 +20,7 @@ vi.mock('@/lib/workflows/utils', () => ({
2320
authorizeWorkflowByWorkspacePermission: mockAuthorizeWorkflowByWorkspacePermission,
2421
}))
2522

26-
vi.mock('@sim/db', () => ({
27-
db: {
28-
select: mockDbSelect,
29-
update: mockDbUpdate,
30-
},
31-
}))
23+
vi.mock('@sim/db', () => databaseMock)
3224

3325
vi.mock('@sim/db/schema', () => ({
3426
workflow: { id: 'id', userId: 'userId', workspaceId: 'workspaceId' },
@@ -59,6 +51,9 @@ function createParams(id: string): { params: Promise<{ id: string }> } {
5951
return { params: Promise.resolve({ id }) }
6052
}
6153

54+
const mockDbSelect = databaseMock.db.select as ReturnType<typeof vi.fn>
55+
const mockDbUpdate = databaseMock.db.update as ReturnType<typeof vi.fn>
56+
6257
function mockDbChain(selectResults: unknown[][]) {
6358
let selectCallIndex = 0
6459
mockDbSelect.mockImplementation(() => ({

apps/sim/app/api/schedules/route.test.ts

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
*
44
* @vitest-environment node
55
*/
6-
import { loggerMock } from '@sim/testing'
6+
import { databaseMock, loggerMock } from '@sim/testing'
77
import { NextRequest } from 'next/server'
88
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
99

10-
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission, mockDbSelect } = vi.hoisted(
11-
() => ({
12-
mockGetSession: vi.fn(),
13-
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
14-
mockDbSelect: vi.fn(),
15-
})
16-
)
10+
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission } = vi.hoisted(() => ({
11+
mockGetSession: vi.fn(),
12+
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
13+
}))
1714

1815
vi.mock('@/lib/auth', () => ({
1916
getSession: mockGetSession,
@@ -23,11 +20,7 @@ vi.mock('@/lib/workflows/utils', () => ({
2320
authorizeWorkflowByWorkspacePermission: mockAuthorizeWorkflowByWorkspacePermission,
2421
}))
2522

26-
vi.mock('@sim/db', () => ({
27-
db: {
28-
select: mockDbSelect,
29-
},
30-
}))
23+
vi.mock('@sim/db', () => databaseMock)
3124

3225
vi.mock('@sim/db/schema', () => ({
3326
workflow: { id: 'id', userId: 'userId', workspaceId: 'workspaceId' },
@@ -62,6 +55,8 @@ function createRequest(url: string): NextRequest {
6255
return new NextRequest(new URL(url), { method: 'GET' })
6356
}
6457

58+
const mockDbSelect = databaseMock.db.select as ReturnType<typeof vi.fn>
59+
6560
function mockDbChain(results: any[]) {
6661
let callIndex = 0
6762
mockDbSelect.mockImplementation(() => ({

apps/sim/app/api/workflows/[id]/route.test.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* @vitest-environment node
66
*/
77

8-
import { loggerMock } from '@sim/testing'
8+
import { loggerMock, setupGlobalFetchMock } from '@sim/testing'
99
import { NextRequest } from 'next/server'
1010
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
1111

@@ -284,9 +284,7 @@ describe('Workflow By ID API Route', () => {
284284
where: vi.fn().mockResolvedValue([{ id: 'workflow-123' }]),
285285
})
286286

287-
global.fetch = vi.fn().mockResolvedValue({
288-
ok: true,
289-
})
287+
setupGlobalFetchMock({ ok: true })
290288

291289
const req = new NextRequest('http://localhost:3000/api/workflows/workflow-123', {
292290
method: 'DELETE',
@@ -331,9 +329,7 @@ describe('Workflow By ID API Route', () => {
331329
where: vi.fn().mockResolvedValue([{ id: 'workflow-123' }]),
332330
})
333331

334-
global.fetch = vi.fn().mockResolvedValue({
335-
ok: true,
336-
})
332+
setupGlobalFetchMock({ ok: true })
337333

338334
const req = new NextRequest('http://localhost:3000/api/workflows/workflow-123', {
339335
method: 'DELETE',

0 commit comments

Comments
 (0)