From 46b05431bbd357e6712a5a2e6a70c466cbf4a008 Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Sat, 2 May 2026 16:45:26 -0700 Subject: [PATCH] fix(mcp): propagate abort signal to stop test execution Fixes https://github.com/microsoft/playwright/issues/38937 --- .../playwright/src/mcp/test/generatorTools.ts | 4 +- .../playwright/src/mcp/test/plannerTools.ts | 4 +- .../playwright/src/mcp/test/testBackend.ts | 6 +-- .../playwright/src/mcp/test/testContext.ts | 35 +++++++++++++++-- packages/playwright/src/mcp/test/testTool.ts | 2 +- packages/playwright/src/mcp/test/testTools.ts | 8 ++-- tests/mcp/test-run.spec.ts | 38 +++++++++++++++++++ 7 files changed, 81 insertions(+), 16 deletions(-) diff --git a/packages/playwright/src/mcp/test/generatorTools.ts b/packages/playwright/src/mcp/test/generatorTools.ts index 4552ae888f8ee..1bd75f3b31610 100644 --- a/packages/playwright/src/mcp/test/generatorTools.ts +++ b/packages/playwright/src/mcp/test/generatorTools.ts @@ -36,10 +36,10 @@ export const setupPage = defineTestTool({ type: 'readOnly', }, - handle: async (context, params) => { + handle: async (context, params, signal) => { const seed = await context.getOrCreateSeedFile(params.seedFile, params.project); context.generatorJournal = new GeneratorJournal(context.rootPath, params.plan, seed); - const { output, status } = await context.runSeedTest(seed.file, seed.projectName); + const { output, status } = await context.runSeedTest(seed.file, seed.projectName, signal); return { content: [{ type: 'text', text: output }], isError: status !== 'paused' }; }, }); diff --git a/packages/playwright/src/mcp/test/plannerTools.ts b/packages/playwright/src/mcp/test/plannerTools.ts index 800afd229e670..a0360c8f11c5b 100644 --- a/packages/playwright/src/mcp/test/plannerTools.ts +++ b/packages/playwright/src/mcp/test/plannerTools.ts @@ -34,9 +34,9 @@ export const setupPage = defineTestTool({ type: 'readOnly', }, - handle: async (context, params) => { + handle: async (context, params, signal) => { const seed = await context.getOrCreateSeedFile(params.seedFile, params.project); - const { output, status } = await context.runSeedTest(seed.file, seed.projectName); + const { output, status } = await context.runSeedTest(seed.file, seed.projectName, signal); return { content: [{ type: 'text', text: output }], isError: status !== 'paused' }; }, }); diff --git a/packages/playwright/src/mcp/test/testBackend.ts b/packages/playwright/src/mcp/test/testBackend.ts index 8b40a509047dc..a480860e9c3d5 100644 --- a/packages/playwright/src/mcp/test/testBackend.ts +++ b/packages/playwright/src/mcp/test/testBackend.ts @@ -58,12 +58,12 @@ export class TestServerBackend extends EventEmitter implements tools.ServerBacke this._context = new TestContext(clientInfo, this._configPath, this._options); } - async callTool(name: string, args: tools.CallToolRequest['params']['arguments']): Promise { + async callTool(name: string, args: tools.CallToolRequest['params']['arguments'], signal: AbortSignal): Promise { const tool = testServerBackendTools.find(tool => tool.schema.name === name); if (!tool) throw new Error(`Tool not found: ${name}. Available tools: ${testServerBackendTools.map(tool => tool.schema.name).join(', ')}`); try { - return await tool.handle(this._context!, tool.schema.inputSchema.parse(args || {})); + return await tool.handle(this._context!, tool.schema.inputSchema.parse(args || {}), signal); } catch (e) { return { content: [{ type: 'text', text: String(e) }], isError: true }; } @@ -83,7 +83,7 @@ function wrapBrowserTool(tool: tools.Tool): TestTool { ...tool.schema, inputSchema, }, - handle: async (context: TestContext, params: any) => { + handle: async (context: TestContext, params: any, _signal?: AbortSignal) => { const response = await context.sendMessageToPausedTest({ callTool: { name: tool.schema.name, arguments: params } }); return response.callTool!; }, diff --git a/packages/playwright/src/mcp/test/testContext.ts b/packages/playwright/src/mcp/test/testContext.ts index 5057e358226e3..957ac3144cf70 100644 --- a/packages/playwright/src/mcp/test/testContext.ts +++ b/packages/playwright/src/mcp/test/testContext.ts @@ -88,6 +88,7 @@ type TestRunnerAndScreen = { export class TestContext { private _clientInfo: tools.ClientInfo; private _testRunnerAndScreen: TestRunnerAndScreen | undefined; + private _testOpQueue: Promise = Promise.resolve(); readonly computedHeaded: boolean; private readonly _configLocation: ConfigLocation; readonly rootPath: string; @@ -105,6 +106,12 @@ export class TestContext { this.computedHeaded = !process.env.CI && !(os.platform() === 'linux' && !process.env.DISPLAY); } + private _enqueue(fn: () => Promise): Promise { + const result = this._testOpQueue.then(fn); + this._testOpQueue = result.then(() => {}, () => {}); + return result; + } + existingTestRunner(): testRunner.TestRunner | undefined { return this._testRunnerAndScreen?.testRunner; } @@ -176,7 +183,7 @@ export class TestContext { }; } - async runSeedTest(seedFile: string, projectName: string): Promise<{ output: string, status: testRunner.FullResultStatus | 'paused' }> { + async runSeedTest(seedFile: string, projectName: string, signal?: AbortSignal): Promise<{ output: string, status: testRunner.FullResultStatus | 'paused' }> { const result = await this.runTestsWithGlobalSetupAndPossiblePause({ headed: this.computedHeaded, locations: ['/' + escapeRegExp(seedFile) + '/'], @@ -186,7 +193,7 @@ export class TestContext { pauseAtEnd: true, disableConfigReporters: true, failOnLoadErrors: true, - }); + }, signal); if (result.status === 'passed') result.output += '\nError: seed test not found.'; else if (result.status !== 'paused') @@ -194,7 +201,11 @@ export class TestContext { return result; } - async runTestsWithGlobalSetupAndPossiblePause(params: testRunner.RunTestsParams): Promise<{ output: string, status: testRunner.FullResultStatus | 'paused' }> { + async runTestsWithGlobalSetupAndPossiblePause(params: testRunner.RunTestsParams, signal?: AbortSignal): Promise<{ output: string, status: testRunner.FullResultStatus | 'paused' }> { + return this._enqueue(() => this._runTestsImpl(params, signal)); + } + + private async _runTestsImpl(params: testRunner.RunTestsParams, signal?: AbortSignal): Promise<{ output: string, status: testRunner.FullResultStatus | 'paused' }> { const configDir = this._configLocation.configDir; const testRunnerAndScreen = await this.createTestRunner(); const { testRunner: runner, screen, claimStdio, releaseStdio } = testRunnerAndScreen; @@ -222,13 +233,29 @@ export class TestContext { } }; + const abortPromise: Promise<'interrupted'> = signal + ? new Promise(resolve => { + if (signal.aborted) + resolve('interrupted'); + else + signal.addEventListener('abort', () => resolve('interrupted'), { once: true }); + }) + : new Promise(() => {}); + try { const reporter = new MCPListReporter({ configDir, screen, includeTestId: true }); status = await Promise.race([ runner.runTests(reporter, params).then(result => result.status), testRunnerAndScreen.waitForTestPaused().then(() => 'paused' as const), + abortPromise, ]); + if (status === 'interrupted') { + await runner.stopTests(); + await cleanup(); + return { output: testRunnerAndScreen.output.join('\n'), status }; + } + if (status === 'paused') { const response = await testRunnerAndScreen.sendMessageToPausedTest!({ request: { initialize: { clientInfo: this._clientInfo } } }); if (response.error) @@ -248,7 +275,7 @@ export class TestContext { } async close() { - await this._cleanupTestRunner().catch(e => debug('pw:mcp:error')(e)); + await this._enqueue(() => this._cleanupTestRunner()).catch(e => debug('pw:mcp:error')(e)); } async sendMessageToPausedTest(request: BrowserMCPRequest): Promise { diff --git a/packages/playwright/src/mcp/test/testTool.ts b/packages/playwright/src/mcp/test/testTool.ts index a1d80b47b0d90..5d9240f12e94a 100644 --- a/packages/playwright/src/mcp/test/testTool.ts +++ b/packages/playwright/src/mcp/test/testTool.ts @@ -21,7 +21,7 @@ import type { tools } from 'playwright-core/lib/coreBundle'; export type TestTool = { schema: tools.ToolSchema; - handle: (context: TestContext, params: z.output) => Promise; + handle: (context: TestContext, params: z.output, signal?: AbortSignal) => Promise; }; export function defineTestTool(tool: TestTool): TestTool { diff --git a/packages/playwright/src/mcp/test/testTools.ts b/packages/playwright/src/mcp/test/testTools.ts index 7109c601f4c57..b6990495a5e32 100644 --- a/packages/playwright/src/mcp/test/testTools.ts +++ b/packages/playwright/src/mcp/test/testTools.ts @@ -47,12 +47,12 @@ export const runTests = defineTestTool({ type: 'readOnly', }, - handle: async (context, params) => { + handle: async (context, params, signal) => { const { output } = await context.runTestsWithGlobalSetupAndPossiblePause({ locations: params.locations ?? [], projects: params.projects, disableConfigReporters: true, - }); + }, signal); return { content: [{ type: 'text', text: output }] }; }, }); @@ -71,7 +71,7 @@ export const debugTest = defineTestTool({ type: 'readOnly', }, - handle: async (context, params) => { + handle: async (context, params, signal) => { const { output, status } = await context.runTestsWithGlobalSetupAndPossiblePause({ headed: context.computedHeaded, locations: [], // we can make this faster by passing the test's location, so we don't need to scan all tests to find the ID @@ -82,7 +82,7 @@ export const debugTest = defineTestTool({ pauseOnError: true, disableConfigReporters: true, actionTimeout: 5000, - }); + }, signal); return { content: [{ type: 'text', text: output }], isError: status !== 'paused' && status !== 'passed' }; }, }); diff --git a/tests/mcp/test-run.spec.ts b/tests/mcp/test-run.spec.ts index 291548066f06c..6aaf21a24e21c 100644 --- a/tests/mcp/test-run.spec.ts +++ b/tests/mcp/test-run.spec.ts @@ -123,6 +123,44 @@ Running 2 tests using 1 worker 2 passed (XXms)`); }); +test('test_run should stop when aborted', async ({ startClient }) => { + await writeFiles({ + 'slow.test.ts': ` + import { test } from '@playwright/test'; + test('slow', async () => { + await new Promise(resolve => setTimeout(resolve, 60_000)); + }); + `, + 'fast.test.ts': ` + import { test } from '@playwright/test'; + test('fast', async () => {}); + `, + }); + + const { client } = await startClient(); + const controller = new AbortController(); + + const runPromise = client.callTool( + { name: 'test_run', arguments: { locations: ['slow.test.ts'] } }, + undefined, + { signal: controller.signal }, + ); + + await new Promise(resolve => setTimeout(resolve, 2000)); + controller.abort(); + + // Per MCP spec, client can initiate a new call without waiting for the + // aborted one. Start the next run immediately to verify serialization. + const [, response] = await Promise.all([ + runPromise.catch(() => {}), + client.callTool({ + name: 'test_run', + arguments: { locations: ['fast.test.ts'] }, + }), + ]); + expect(response.content[0].text).toContain('1 passed'); +}); + test('test_run should include dependencies', async ({ startClient }) => { await writeFiles({ 'playwright.config.ts': `