diff --git a/CHANGELOG.md b/CHANGELOG.md index e36f5cb7b..6e54f7574 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Data grid cell focus ring redraws when the user toggles Light or Dark mode mid-session, picking up the system's appearance-aware focus indicator color - Data grid keeps sortedIDs and cachedRowCount paired by calling updateCache() immediately after the SwiftUI bridge writes new sortedIDs to the coordinator, removing a window where the cached count and the sort permutation could disagree - Display formats memoized per tab on MainContentCoordinator keyed by schema version, smart-detection setting, and format-overrides version, so ValueDisplayDetector.detect runs once per result schema instead of on every SwiftUI body evaluation +- MCP HTTP router replaced with a route registry. `MCPRouter` now matches paths and methods against a list of `MCPRouteHandler` values; `/mcp` traffic and `/v1/integrations/exchange` traffic each live in their own handler file under `Core/MCP/Routes/`. OPTIONS preflight is handled once at the router level for every path +- `MCPAuthGuard` and `MCPConnectionBridge` route concurrent dedup through a shared `OnceTask` actor (`Core/Concurrency/OnceTask.swift`). Cleanup of in-flight slots happens in `defer` inside the actor, so a cancelled or thrown caller no longer leaves a stale entry behind. ### Removed (BREAKING) @@ -91,6 +93,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Persist group deletions before firing the sync notification, fixing a race that could re-upload deleted groups via iCloud. - Persist connection deletions before firing the sync notification, fixing the same race for deleted connections. - Refuse to generate SQL when the database dialect cannot be resolved, instead of silently emitting unquoted identifiers. +- MCP `execute_query`: strip trailing semicolons before appending `LIMIT/OFFSET`, fixing `syntax error at or near LIMIT` for queries like `select * from t;`. ## [0.36.0] - 2026-04-27 diff --git a/TablePro.xcodeproj/project.pbxproj b/TablePro.xcodeproj/project.pbxproj index a5749a98c..9c65d7b14 100644 --- a/TablePro.xcodeproj/project.pbxproj +++ b/TablePro.xcodeproj/project.pbxproj @@ -459,9 +459,9 @@ 5AF312BE2F36FF7500E86682 /* Exceptions for "TablePro" folder in "TablePro" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; membershipExceptions = ( - Info.plist, CLI/main.swift, CLI/MCPBridgeProxy.swift, + Info.plist, ); target = 5A1091C62EF17EDC0055EA7C /* TablePro */; }; diff --git a/TablePro.xcodeproj/xcshareddata/xcschemes/TablePro.xcscheme b/TablePro.xcodeproj/xcshareddata/xcschemes/TablePro.xcscheme index d2a10cf3d..f99c67cbd 100644 --- a/TablePro.xcodeproj/xcshareddata/xcschemes/TablePro.xcscheme +++ b/TablePro.xcodeproj/xcshareddata/xcschemes/TablePro.xcscheme @@ -25,7 +25,6 @@ 300 - ? String(condition.prefix(300)) + "…" : condition - filterDescription = preview - } else { - filterDescription = [parsed.filterColumn, parsed.filterOperation, parsed.filterValue] - .compactMap { $0 }.joined(separator: " ") - } - if !filterDescription.isEmpty { - let confirmed = await AlertHelper.confirmDestructive( - title: String(localized: "Apply Filter from Link"), - message: String( - format: String(localized: "An external link wants to apply a filter:\n\n%@"), - filterDescription - ), - confirmButton: String(localized: "Apply Filter"), - cancelButton: String(localized: "Cancel"), - window: NSApp.keyWindow - ) - guard confirmed else { return } - } - - NotificationCenter.default.post( - name: .applyURLFilter, - object: nil, - userInfo: [ - "connectionId": connectionId, - "column": parsed.filterColumn as Any, - "operation": parsed.filterOperation as Any, - "value": parsed.filterValue as Any, - "condition": parsed.filterCondition as Any - ] - ) - } - } - } - } - - private func waitForConnection(timeout: Duration) async { - await withCheckedContinuation { (continuation: CheckedContinuation) in - var didResume = false - var observer: NSObjectProtocol? - - func resumeOnce() { - guard !didResume else { return } - didResume = true - if let obs = observer { - NotificationCenter.default.removeObserver(obs) - } - continuation.resume() - } - - let timeoutTask = Task { - try? await Task.sleep(for: timeout) - resumeOnce() - } - observer = NotificationCenter.default.addObserver( - forName: .databaseDidConnect, - object: nil, - queue: .main - ) { _ in - timeoutTask.cancel() - resumeOnce() - } - } - } - - private func waitForNotification(_ name: Notification.Name, timeout: Duration) async { - await withCheckedContinuation { (continuation: CheckedContinuation) in - var didResume = false - var observer: NSObjectProtocol? - - func resumeOnce() { - guard !didResume else { return } - didResume = true - if let obs = observer { - NotificationCenter.default.removeObserver(obs) - } - continuation.resume() - } - - let timeoutTask = Task { - try? await Task.sleep(for: timeout) - resumeOnce() - } - observer = NotificationCenter.default.addObserver( - forName: name, object: nil, queue: .main - ) { _ in - timeoutTask.cancel() - resumeOnce() - } - } - } - - // MARK: - Session Lookup - - /// Finds any session (connected or still connecting) matching the parsed URL params. - private func findSessionByParams(_ parsed: ParsedConnectionURL) -> UUID? { - for (id, session) in DatabaseManager.shared.activeSessions { - let conn = session.connection - if conn.type == parsed.type - && conn.host == parsed.host - && conn.database == parsed.database - && (parsed.port == nil || conn.port == parsed.port || conn.port == parsed.type.defaultPort) - && (parsed.username.isEmpty || conn.username == parsed.username) - && (parsed.redisDatabase == nil || conn.redisDatabase == parsed.redisDatabase) { - return id - } - } - return nil - } - - /// Normalized key for deduplicating connection attempts by URL params. - static func paramKey(for parsed: ParsedConnectionURL) -> String { - let rdb = parsed.redisDatabase.map { "/redis:\($0)" } ?? "" - return "\(parsed.type.rawValue):\(parsed.username)@\(parsed.host):\(parsed.port ?? 0)/\(parsed.database)\(rdb)" - } - - func bringConnectionWindowToFront(_ connectionId: UUID) { - let windows = WindowLifecycleMonitor.shared.windows(for: connectionId) - if let window = windows.first { - window.makeKeyAndOrderFront(nil) - } else { - NSApp.windows.first { isMainWindow($0) && $0.isVisible }?.makeKeyAndOrderFront(nil) - } - } - - // MARK: - Connection Failure - - func handleConnectionFailure(_ error: Error) async { - closeOrphanedMainWindows() - - // User cancelled password prompt — no error dialog needed - if error is CancellationError { return } - - await Task.yield() - AlertHelper.showErrorSheet( - title: String(localized: "Connection Failed"), - message: error.localizedDescription, - window: NSApp.keyWindow - ) - } - - /// Closes main windows that have no active database session, then opens the welcome window if none remain. - private func closeOrphanedMainWindows() { - for window in NSApp.windows where isMainWindow(window) { - let hasActiveSession = DatabaseManager.shared.activeSessions.values.contains { - window.subtitle == $0.connection.name - || window.subtitle == "\($0.connection.name) — Preview" - } - if !hasActiveSession { window.close() } - } - if !NSApp.windows.contains(where: { isMainWindow($0) && $0.isVisible }) { - openWelcomeWindow() - } - } - - // MARK: - Transient Connection Builder - - private func buildTransientConnection(from parsed: ParsedConnectionURL) -> DatabaseConnection { - var sshConfig = SSHConfiguration() - if let sshHost = parsed.sshHost { - sshConfig.enabled = true - sshConfig.host = sshHost - sshConfig.port = parsed.sshPort ?? 22 - sshConfig.username = parsed.sshUsername ?? "" - if parsed.usePrivateKey == true { - sshConfig.authMethod = .privateKey - } - if parsed.useSSHAgent == true { - sshConfig.authMethod = .sshAgent - sshConfig.agentSocketPath = parsed.agentSocket ?? "" - } - } - - var sslConfig = SSLConfiguration() - if let sslMode = parsed.sslMode { - sslConfig.mode = sslMode - } - - var color: ConnectionColor = .none - if let hex = parsed.statusColor { - color = ConnectionURLParser.connectionColor(fromHex: hex) - } - - var tagId: UUID? - if let envName = parsed.envTag { - tagId = ConnectionURLParser.tagId(fromEnvName: envName) - } - - let resolvedSafeMode = parsed.safeModeLevel.flatMap(SafeModeLevel.from(urlInteger:)) ?? .silent - - var connection = DatabaseConnection( - name: parsed.connectionName ?? parsed.suggestedName, - host: parsed.host, - port: parsed.port ?? parsed.type.defaultPort, - database: parsed.database, - username: parsed.username, - type: parsed.type, - sshConfig: sshConfig, - sslConfig: sslConfig, - color: color, - tagId: tagId, - safeModeLevel: resolvedSafeMode, - mongoAuthSource: parsed.authSource, - mongoUseSrv: parsed.useSrv, - mongoAuthMechanism: parsed.mongoQueryParams["authMechanism"], - mongoReplicaSet: parsed.mongoQueryParams["replicaSet"], - redisDatabase: parsed.redisDatabase, - oracleServiceName: parsed.oracleServiceName - ) - - for (key, value) in parsed.mongoQueryParams where !value.isEmpty { - if key != "authMechanism" && key != "replicaSet" { - connection.additionalFields["mongoParam_\(key)"] = value - } - } - - return connection - } -} diff --git a/TablePro/AppDelegate+FileOpen.swift b/TablePro/AppDelegate+FileOpen.swift deleted file mode 100644 index 938cdeba1..000000000 --- a/TablePro/AppDelegate+FileOpen.swift +++ /dev/null @@ -1,354 +0,0 @@ -// -// AppDelegate+FileOpen.swift -// TablePro -// - -import AppKit -import os -import SwiftUI - -private let fileOpenLogger = Logger(subsystem: "com.TablePro", category: "FileOpen") - -extension AppDelegate { - // MARK: - Handoff - - func application(_ application: NSApplication, continue userActivity: NSUserActivity, - restorationHandler: @escaping ([any NSUserActivityRestoring]) -> Void) -> Bool { - handleHandoffActivity(userActivity) - return true - } - - private func handleHandoffActivity(_ activity: NSUserActivity) { - guard let connectionIdString = activity.userInfo?["connectionId"] as? String, - let connectionId = UUID(uuidString: connectionIdString) else { return } - - let connections = ConnectionStorage.shared.loadConnections() - guard let connection = connections.first(where: { $0.id == connectionId }) else { - fileOpenLogger.error("Handoff: no connection with ID '\(connectionIdString, privacy: .public)'") - return - } - - let tableName = activity.userInfo?["tableName"] as? String - - if DatabaseManager.shared.activeSessions[connectionId]?.driver != nil { - if let tableName { - let payload = EditorTabPayload(connectionId: connectionId, tabType: .table, tableName: tableName) - WindowManager.shared.openTab(payload: payload) - } else { - for window in NSApp.windows where isMainWindow(window) { - window.makeKeyAndOrderFront(nil) - return - } - } - return - } - - let initialPayload = EditorTabPayload(connectionId: connectionId) - WindowManager.shared.openTab(payload: initialPayload) - - Task { - do { - try await DatabaseManager.shared.connectToSession(connection) - for window in NSApp.windows where self.isWelcomeWindow(window) { - window.close() - } - if let tableName { - let payload = EditorTabPayload(connectionId: connectionId, tabType: .table, tableName: tableName) - WindowManager.shared.openTab(payload: payload) - } - } catch { - fileOpenLogger.error("Handoff connect failed: \(error.localizedDescription)") - } - } - } - - // MARK: - URL Classification - - private func isDatabaseURL(_ url: URL) -> Bool { - guard let scheme = url.scheme?.lowercased() else { return false } - let base = scheme - .replacingOccurrences(of: "+ssh", with: "") - .replacingOccurrences(of: "+srv", with: "") - let registeredSchemes = PluginManager.shared.allRegisteredURLSchemes - return registeredSchemes.contains(base) || registeredSchemes.contains(scheme) - } - - private func isDatabaseFile(_ url: URL) -> Bool { - PluginManager.shared.allRegisteredFileExtensions[url.pathExtension.lowercased()] != nil - } - - private func databaseTypeForFile(_ url: URL) -> DatabaseType? { - PluginManager.shared.allRegisteredFileExtensions[url.pathExtension.lowercased()] - } - - // MARK: - Main Dispatch - - func handleOpenURLs(_ urls: [URL]) { - let deeplinks = urls.filter { $0.scheme == "tablepro" } - if !deeplinks.isEmpty { - suppressWelcomeWindow() - Task { - for url in deeplinks { await self.handleDeeplink(url) } - self.endFileOpenSuppression() - } - } - - let plugins = urls.filter { $0.pathExtension == "tableplugin" } - if !plugins.isEmpty { - Task { - for url in plugins { await self.handlePluginInstall(url) } - } - } - - let databaseURLs = urls.filter { isDatabaseURL($0) } - if !databaseURLs.isEmpty { - suppressWelcomeWindow() - Task { - for url in databaseURLs { self.handleDatabaseURL(url) } - // endFileOpenSuppression is called here to match suppressWelcomeWindow above. - // Individual handlers no longer manage this flag. - self.endFileOpenSuppression() - } - } - - let databaseFiles = urls.filter { isDatabaseFile($0) } - if !databaseFiles.isEmpty { - suppressWelcomeWindow() - Task { - for url in databaseFiles { - guard let dbType = self.databaseTypeForFile(url) else { continue } - switch dbType { - case .sqlite: - self.handleSQLiteFile(url) - case .duckdb: - self.handleDuckDBFile(url) - default: - self.handleGenericDatabaseFile(url, type: dbType) - } - } - self.endFileOpenSuppression() - } - } - - // Connection share files - let connectionShareFiles = urls.filter { $0.pathExtension.lowercased() == "tablepro" } - for url in connectionShareFiles { - handleConnectionShareFile(url) - } - - let sqlFiles = urls.filter { $0.pathExtension.lowercased() == "sql" } - if !sqlFiles.isEmpty { - if DatabaseManager.shared.currentSession != nil { - suppressWelcomeWindow() - for window in NSApp.windows where isMainWindow(window) { - window.makeKeyAndOrderFront(nil) - } - for window in NSApp.windows where isWelcomeWindow(window) { - window.close() - } - NotificationCenter.default.post(name: .openSQLFiles, object: sqlFiles) - endFileOpenSuppression() - } else { - queuedFileURLs.append(contentsOf: sqlFiles) - openWelcomeWindow() - } - } - } - - // MARK: - Welcome Window Suppression - - func suppressWelcomeWindow() { - isHandlingFileOpen = true - fileOpenSuppressionCount += 1 - for window in NSApp.windows where isWelcomeWindow(window) { - window.orderOut(nil) - } - } - - // MARK: - Deeplink Handling - - private func handleDeeplink(_ url: URL) async { - guard let action = DeeplinkHandler.parse(url) else { return } - - switch action { - case .connect(let connectionId): - connectViaDeeplink(connectionId: connectionId) - - case .openTable(let connectionId, let table, let database, let schema): - connectViaDeeplink(connectionId: connectionId) { resolvedId in - EditorTabPayload(connectionId: resolvedId, tabType: .table, - tableName: table, databaseName: database, schemaName: schema) - } - - case .openQuery(let connectionId, let sql): - let maxDeeplinkSQLLength = 51_200 - let sqlLength = (sql as NSString).length - guard sqlLength <= maxDeeplinkSQLLength else { return } - guard let connection = DeeplinkHandler.resolveConnection(byId: connectionId) else { - showConnectionNotFoundAlert(connectionId: connectionId) - return - } - let preview: String - if sqlLength > 300 { - let hiddenCount = sqlLength - 300 - preview = String(sql.prefix(300)) - + String(format: String(localized: "\n\n… (%d more characters not shown)"), hiddenCount) - } else { - preview = sql - } - let confirmed = await AlertHelper.confirmDestructive( - title: String(localized: "Open Query from Link"), - message: String(format: String(localized: "An external link wants to open a query on connection \"%@\":\n\n%@"), connection.name, preview), - confirmButton: String(localized: "Open Query"), - cancelButton: String(localized: "Cancel"), - window: NSApp.keyWindow - ) - guard confirmed else { return } - connectViaDeeplink(connectionId: connectionId) { resolvedId in - EditorTabPayload(connectionId: resolvedId, tabType: .query, - initialQuery: sql) - } - - case .importConnection(let exportable): - openWelcomeWindow() - PendingActionStore.shared.deeplinkImport = exportable - NotificationCenter.default.post(name: .deeplinkImportRequested, object: exportable) - - case .pairIntegration(let request): - do { - try await MCPPairingService.shared.startPairing(request) - } catch let error as MCPError where error.isUserCancelled { - fileOpenLogger.info("Pairing cancelled by user") - } catch { - fileOpenLogger.error("Pairing failed: \(error.localizedDescription)") - AlertHelper.showErrorSheet( - title: String(localized: "Pairing Failed"), - message: error.localizedDescription, - window: NSApp.keyWindow - ) - } - - case .exchangePairing: - fileOpenLogger.warning("Exchange pairing received via URL scheme; ignored (use HTTP endpoint)") - - case .startMCP: - await MCPServerManager.shared.lazyStart() - } - } - - private func showConnectionNotFoundAlert(connectionId: UUID) { - fileOpenLogger.error("Deep link: no connection with ID '\(connectionId.uuidString, privacy: .public)'") - AlertHelper.showErrorSheet( - title: String(localized: "Connection Not Found"), - message: String(format: String(localized: "No saved connection with ID \"%@\"."), connectionId.uuidString), - window: NSApp.keyWindow - ) - } - - private func connectViaDeeplink( - connectionId: UUID, - makePayload: (@Sendable (UUID) -> EditorTabPayload)? = nil - ) { - guard let connection = DeeplinkHandler.resolveConnection(byId: connectionId) else { - showConnectionNotFoundAlert(connectionId: connectionId) - return - } - - if DatabaseManager.shared.activeSessions[connection.id]?.driver != nil { - if let payload = makePayload?(connection.id) { - if payload.tabType == .table, - let tableName = payload.tableName, - let coordinator = MainContentCoordinator.allActiveCoordinators() - .first(where: { $0.connectionId == connection.id }) { - if let window = coordinator.contentWindow { - window.makeKeyAndOrderFront(nil) - NSApp.activate(ignoringOtherApps: true) - } - Task { @MainActor in - if let dbName = payload.databaseName, !dbName.isEmpty, - let session = DatabaseManager.shared.session(for: connection.id), - dbName != session.activeDatabase { - await coordinator.switchDatabase(to: dbName) - } - coordinator.openTableTab(tableName) - } - } else { - WindowManager.shared.openTab(payload: payload) - } - } else { - for window in NSApp.windows where isMainWindow(window) { - window.makeKeyAndOrderFront(nil) - return - } - } - return - } - - let hadExistingMain = NSApp.windows.contains { isMainWindow($0) && $0.isVisible } - let savedTabbing = NSWindow.allowsAutomaticWindowTabbing - if hadExistingMain && !AppSettingsManager.shared.tabs.groupAllConnectionTabs { - NSWindow.allowsAutomaticWindowTabbing = false - } - - if makePayload == nil { - let deeplinkPayload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) - WindowManager.shared.openTab(payload: deeplinkPayload) - } - NSWindow.allowsAutomaticWindowTabbing = savedTabbing - - Task { - do { - if let script = connection.preConnectScript, - !script.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty - { - let confirmed = await AlertHelper.confirmDestructive( - title: String(localized: "Pre-Connect Script"), - message: String(format: String(localized: "Connection \"%@\" has a script that will run before connecting:\n\n%@"), connection.name, script), - confirmButton: String(localized: "Run Script"), - cancelButton: String(localized: "Cancel"), - window: NSApp.keyWindow - ) - guard confirmed else { return } - } - - try await DatabaseManager.shared.connectToSession(connection) - for window in NSApp.windows where self.isWelcomeWindow(window) { - window.close() - } - if let payload = makePayload?(connection.id) { - WindowManager.shared.openTab(payload: payload) - } - } catch { - fileOpenLogger.error("Deep link connect failed: \(error.localizedDescription)") - await self.handleConnectionFailure(error) - } - } - } - - // MARK: - Connection Share Import - - private func handleConnectionShareFile(_ url: URL) { - openWelcomeWindow() - PendingActionStore.shared.connectionShareURL = url - NotificationCenter.default.post(name: .connectionShareFileOpened, object: url) - } - - // MARK: - Plugin Install - - private func handlePluginInstall(_ url: URL) async { - do { - let entry = try await PluginManager.shared.installPlugin(from: url) - fileOpenLogger.info("Installed plugin '\(entry.name)' from Finder") - - UserDefaults.standard.set(SettingsTab.plugins.rawValue, forKey: "selectedSettingsTab") - NotificationCenter.default.post(name: .openSettingsWindow, object: nil) - } catch { - fileOpenLogger.error("Plugin install failed: \(error.localizedDescription)") - AlertHelper.showErrorSheet( - title: String(localized: "Plugin Installation Failed"), - message: error.localizedDescription, - window: NSApp.keyWindow - ) - } - } -} diff --git a/TablePro/AppDelegate+WindowConfig.swift b/TablePro/AppDelegate+WindowConfig.swift deleted file mode 100644 index 3e9f8dde2..000000000 --- a/TablePro/AppDelegate+WindowConfig.swift +++ /dev/null @@ -1,406 +0,0 @@ -// -// AppDelegate+WindowConfig.swift -// TablePro -// - -import AppKit -import os -import SwiftUI - -private let windowLogger = Logger(subsystem: "com.TablePro", category: "WindowConfig") - -extension AppDelegate { - // MARK: - Dock Menu - - func applicationDockMenu(_ sender: NSApplication) -> NSMenu? { - let menu = NSMenu() - - let welcomeItem = NSMenuItem( - title: String(localized: "Show Welcome Window"), - action: #selector(showWelcomeFromDock), - keyEquivalent: "" - ) - welcomeItem.target = self - menu.addItem(welcomeItem) - - let connections = ConnectionStorage.shared.loadConnections() - if !connections.isEmpty { - let connectionsItem = NSMenuItem(title: String(localized: "Open Connection"), action: nil, keyEquivalent: "") - let submenu = NSMenu() - - for connection in connections { - let item = NSMenuItem( - title: connection.name, - action: #selector(connectFromDock(_:)), - keyEquivalent: "" - ) - item.target = self - item.representedObject = connection.id - let iconName = connection.type.iconName - let original = NSImage(systemSymbolName: iconName, accessibilityDescription: nil) - ?? NSImage(named: iconName) - if let original { - let resized = NSImage(size: NSSize(width: 16, height: 16), flipped: false) { rect in - original.draw(in: rect) - return true - } - item.image = resized - } - submenu.addItem(item) - } - - connectionsItem.submenu = submenu - menu.addItem(connectionsItem) - } - - return menu - } - - @objc func showWelcomeFromDock() { - openWelcomeWindow() - } - - @objc func newWindowForTab(_ sender: Any?) { - guard let keyWindow = NSApp.keyWindow, - let connectionId = MainActor.assumeIsolated({ - WindowLifecycleMonitor.shared.connectionId(forWindow: keyWindow) - }) - else { return } - - let payload = EditorTabPayload( - connectionId: connectionId, - intent: .newEmptyTab - ) - MainActor.assumeIsolated { - WindowManager.shared.openTab(payload: payload) - } - } - - @objc func connectFromDock(_ sender: NSMenuItem) { - guard let connectionId = sender.representedObject as? UUID else { return } - let connections = ConnectionStorage.shared.loadConnections() - guard let connection = connections.first(where: { $0.id == connectionId }) else { return } - - let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) - WindowManager.shared.openTab(payload: payload) - - Task { - do { - try await DatabaseManager.shared.connectToSession(connection) - - for window in NSApp.windows where self.isWelcomeWindow(window) { - window.close() - } - } catch { - windowLogger.error("Dock connection failed for '\(connection.name)': \(error.localizedDescription)") - - for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { - window.close() - } - if !NSApp.windows.contains(where: { self.isMainWindow($0) && $0.isVisible }) { - self.openWelcomeWindow() - } - } - } - } - - // MARK: - Reopen Handling - - func applicationShouldHandleReopen(_ sender: NSApplication, hasVisibleWindows flag: Bool) -> Bool { - if flag { - return true - } - - openWelcomeWindow() - return false - } - - // MARK: - Window Identification - - private enum WindowId { - static let main = "main" - static let welcome = "welcome" - static let connectionForm = "connection-form" - } - - func isMainWindow(_ window: NSWindow) -> Bool { - guard let rawValue = window.identifier?.rawValue else { return false } - return rawValue == WindowId.main || rawValue.hasPrefix("\(WindowId.main)-") - } - - func isWelcomeWindow(_ window: NSWindow) -> Bool { - guard let rawValue = window.identifier?.rawValue else { return false } - return rawValue == WindowId.welcome || rawValue.hasPrefix("\(WindowId.welcome)-") - } - - private func isConnectionFormWindow(_ window: NSWindow) -> Bool { - guard let rawValue = window.identifier?.rawValue else { return false } - return rawValue == WindowId.connectionForm || rawValue.hasPrefix("\(WindowId.connectionForm)-") - } - - @objc func handleFocusConnectionForm() { - if let window = NSApp.windows.first(where: { isConnectionFormWindow($0) }) { - window.makeKeyAndOrderFront(nil) - } - } - - // MARK: - Welcome Window - - /// Hide the Welcome window immediately when we know we're going to - /// auto-reconnect. Prevents a visible flash of the Welcome screen - /// before the main editor window appears. - func closeWelcomeWindowEagerly() { - for window in NSApp.windows where isWelcomeWindow(window) { - window.orderOut(nil) - } - } - - func openWelcomeWindow() { - for window in NSApp.windows where isWelcomeWindow(window) { - window.makeKeyAndOrderFront(nil) - return - } - - NotificationCenter.default.post(name: .openWelcomeWindow, object: nil) - } - - private func configureWelcomeWindowStyle(_ window: NSWindow) { - window.standardWindowButton(.miniaturizeButton)?.isHidden = true - window.standardWindowButton(.zoomButton)?.isHidden = true - window.styleMask.remove(.miniaturizable) - - window.collectionBehavior.remove(.fullScreenPrimary) - window.collectionBehavior.insert(.fullScreenNone) - - if window.styleMask.contains(.resizable) { - window.styleMask.remove(.resizable) - } - - let welcomeSize = NSSize(width: 700, height: 450) - if window.frame.size != welcomeSize { - window.setContentSize(welcomeSize) - window.center() - } - - window.isOpaque = false - window.backgroundColor = .clear - window.titlebarAppearsTransparent = true - - window.makeKeyAndOrderFront(nil) - - if let textField = window.contentView?.firstEditableTextField() { - window.makeFirstResponder(textField) - } - } - - private func configureConnectionFormWindowStyle(_ window: NSWindow) { - window.standardWindowButton(.miniaturizeButton)?.isEnabled = false - window.standardWindowButton(.zoomButton)?.isEnabled = false - window.styleMask.remove(.miniaturizable) - - window.collectionBehavior.remove(.fullScreenPrimary) - window.collectionBehavior.insert(.fullScreenNone) - } - - // MARK: - Welcome Window Suppression - - /// Called by connection handlers when the file-open connection attempt finishes - /// (success or failure). Decrements the suppression counter and resets the flag - /// when all outstanding file opens have completed. - func endFileOpenSuppression() { - fileOpenSuppressionCount = max(0, fileOpenSuppressionCount - 1) - if fileOpenSuppressionCount == 0 { - isHandlingFileOpen = false - } - } - - @discardableResult - private func closeWelcomeWindowIfMainExists() -> Bool { - let hasMainWindow = NSApp.windows.contains { isMainWindow($0) && $0.isVisible } - guard hasMainWindow else { return false } - for window in NSApp.windows where isWelcomeWindow(window) { - window.close() - } - return true - } - - // MARK: - Window Notifications - - @objc func windowDidBecomeKey(_ notification: Notification) { - guard let window = notification.object as? NSWindow else { return } - let windowId = ObjectIdentifier(window) - - if isWelcomeWindow(window) && isHandlingFileOpen { - // Only close welcome if a main window exists to take its place; - // otherwise just hide it so the user doesn't see a flash. - if let mainWin = NSApp.windows.first(where: { isMainWindow($0) }) { - window.close() - mainWin.makeKeyAndOrderFront(nil) - } else { - window.orderOut(nil) - } - return - } - - if isWelcomeWindow(window) && !configuredWindows.contains(windowId) { - configureWelcomeWindowStyle(window) - configuredWindows.insert(windowId) - } - - if isConnectionFormWindow(window) && !configuredWindows.contains(windowId) { - configureConnectionFormWindowStyle(window) - configuredWindows.insert(windowId) - } - - if isMainWindow(window) && isHandlingFileOpen { - closeWelcomeWindowIfMainExists() - } - - // Phase 5: removed legacy main-window tabbing block. `WindowManager.openTab` - // now performs the tab-group merge at creation time with the correct - // ordering, and pre-marks `configuredWindows` so this method is a no-op - // for main windows. The old block consumed `WindowOpener.pendingPayloads` - // and called `addTabbedWindow` mid-`windowDidBecomeKey`, which produced - // the 200–7000 ms grace-period delay we removed in Phase 2. - } - - @objc func windowWillClose(_ notification: Notification) { - let seq = MainContentCoordinator.nextSwitchSeq() - let t0 = Date() - guard let window = notification.object as? NSWindow else { return } - let isMain = isMainWindow(window) - - configuredWindows.remove(ObjectIdentifier(window)) - - if isMain { - let remainingMainWindows = NSApp.windows.filter { - $0 !== window && isMainWindow($0) && $0.isVisible - }.count - windowLogger.info("[close] AppDelegate.windowWillClose seq=\(seq) isMain=true remaining=\(remainingMainWindows)") - - if remainingMainWindows == 0 { - NotificationCenter.default.post(name: .mainWindowWillClose, object: nil) - openWelcomeWindow() - } - } - windowLogger.info("[close] AppDelegate.windowWillClose seq=\(seq) total ms=\(Int(Date().timeIntervalSince(t0) * 1_000))") - } - - @objc func windowDidChangeOcclusionState(_ notification: Notification) { - guard let window = notification.object as? NSWindow, - isHandlingFileOpen else { return } - - if isWelcomeWindow(window), - window.occlusionState.contains(.visible), - NSApp.windows.contains(where: { isMainWindow($0) && $0.isVisible }), - window.isVisible { - window.close() - } - } - - // MARK: - Auto-Reconnect - - func attemptAutoReconnectAll(connectionIds: [UUID]) { - let connections = ConnectionStorage.shared.loadConnections() - let validConnections = connectionIds.compactMap { id in - connections.first { $0.id == id } - } - - guard !validConnections.isEmpty else { - AppSettingsStorage.shared.saveLastOpenConnectionIds([]) - AppSettingsStorage.shared.saveLastConnectionId(nil) - closeRestoredMainWindows() - openWelcomeWindow() - return - } - - isAutoReconnecting = true - - Task { @MainActor [weak self] in - guard let self else { return } - defer { self.isAutoReconnecting = false } - - for connection in validConnections { - let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) - WindowManager.shared.openTab(payload: payload) - - do { - try await DatabaseManager.shared.connectToSession(connection) - } catch is CancellationError { - for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { - window.close() - } - continue - } catch { - windowLogger.error( - "Auto-reconnect failed for '\(connection.name)': \(error.localizedDescription)" - ) - for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { - window.close() - } - continue - } - } - - for window in NSApp.windows where self.isWelcomeWindow(window) { - window.close() - } - - // If all connections failed, show the welcome window - if !NSApp.windows.contains(where: { self.isMainWindow($0) && $0.isVisible }) { - self.openWelcomeWindow() - } - } - } - - func attemptAutoReconnect(connectionId: UUID) { - let connections = ConnectionStorage.shared.loadConnections() - guard let connection = connections.first(where: { $0.id == connectionId }) else { - AppSettingsStorage.shared.saveLastConnectionId(nil) - closeRestoredMainWindows() - openWelcomeWindow() - return - } - - isAutoReconnecting = true - - Task { @MainActor [weak self] in - guard let self else { return } - let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) - WindowManager.shared.openTab(payload: payload) - - defer { self.isAutoReconnecting = false } - do { - try await DatabaseManager.shared.connectToSession(connection) - - for window in NSApp.windows where self.isWelcomeWindow(window) { - window.close() - } - } catch is CancellationError { - for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { - window.close() - } - if !NSApp.windows.contains(where: { self.isMainWindow($0) && $0.isVisible }) { - self.openWelcomeWindow() - } - } catch { - windowLogger.error("Auto-reconnect failed for '\(connection.name)': \(error.localizedDescription)") - - for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { - window.close() - } - if !NSApp.windows.contains(where: { self.isMainWindow($0) && $0.isVisible }) { - self.openWelcomeWindow() - } - } - } - } - - func closeRestoredMainWindows() { - Task { @MainActor [weak self] in - for window in NSApp.windows where self?.isMainWindow(window) == true { - window.close() - } - } - } -} diff --git a/TablePro/AppDelegate.swift b/TablePro/AppDelegate.swift index fdc6ebc1d..bf3f1c360 100644 --- a/TablePro/AppDelegate.swift +++ b/TablePro/AppDelegate.swift @@ -7,44 +7,30 @@ import AppKit import os import SwiftUI -internal extension URL { - /// Returns the URL string with the password component replaced by `***` for safe logging. - var sanitizedForLogging: String { - guard var components = URLComponents(url: self, resolvingAgainstBaseURL: false), - components.password != nil else { - return absoluteString - } - components.password = "***" - return components.string ?? absoluteString - } -} - @MainActor class AppDelegate: NSObject, NSApplicationDelegate { private static let logger = Logger(subsystem: "com.TablePro", category: "AppDelegate") static let lifecycleLogger = Logger(subsystem: "com.TablePro", category: "NativeTabLifecycle") - var configuredWindows = Set() - var queuedFileURLs: [URL] = [] - var queuedURLEntries: [QueuedURLEntry] = [] - var isHandlingFileOpen = false - var fileOpenSuppressionCount = 0 - var isProcessingQueuedURLs = false - var isAutoReconnecting = false - var connectingURLConnectionIds = Set() - var connectingURLParamKeys = Set() - var connectingFilePaths = Set() - - // MARK: - NSApplicationDelegate + // MARK: - URL & File Open func application(_ application: NSApplication, open urls: [URL]) { - handleOpenURLs(urls) + AppLaunchCoordinator.shared.handleOpenURLs(urls) + } + + func application(_ application: NSApplication, continue userActivity: NSUserActivity, + restorationHandler: @escaping ([any NSUserActivityRestoring]) -> Void) -> Bool { + AppLaunchCoordinator.shared.handleHandoff(userActivity) + return true } + func applicationShouldHandleReopen(_ sender: NSApplication, hasVisibleWindows flag: Bool) -> Bool { + AppLaunchCoordinator.shared.handleReopen(hasVisibleWindows: flag) + } + + // MARK: - Lifecycle + func applicationDidFinishLaunching(_ notification: Notification) { - // Re-apply appearance now that NSApp exists. - // AppSettingsManager.shared may already be initialized (by @State in TableProApp), - // but NSApp was nil at that point so NSApp?.appearance was a no-op. let appearanceSettings = AppSettingsManager.shared.appearance ThemeEngine.shared.updateAppearanceAndTheme( mode: appearanceSettings.appearanceMode, @@ -90,54 +76,12 @@ class AppDelegate: NSObject, NSApplicationDelegate { _ = QueryHistoryStorage.shared } - if !isHandlingFileOpen { - let settings = AppSettingsStorage.shared.loadGeneral() - if settings.startupBehavior == .reopenLast { - let connectionIds = AppSettingsStorage.shared.loadLastOpenConnectionIds() - if !connectionIds.isEmpty { - closeWelcomeWindowEagerly() - attemptAutoReconnectAll(connectionIds: connectionIds) - } else if let lastConnectionId = AppSettingsStorage.shared.loadLastConnectionId() { - closeWelcomeWindowEagerly() - attemptAutoReconnect(connectionId: lastConnectionId) - } else { - Task { @MainActor [weak self] in - guard let self, !self.isHandlingFileOpen else { return } - let diskIds = await TabDiskActor.shared.connectionIdsWithSavedState() - guard !self.isHandlingFileOpen else { return } - if !diskIds.isEmpty { - self.closeWelcomeWindowEagerly() - self.attemptAutoReconnectAll(connectionIds: diskIds) - } else { - self.closeRestoredMainWindows() - } - } - } - } else { - closeRestoredMainWindows() - } - } - - // NOTE: These observers are not explicitly removed because AppDelegate - // lives for the entire app lifetime. NotificationCenter uses weak - // references for selector-based observers on macOS 10.11+. + AppLaunchCoordinator.shared.didFinishLaunching() - NotificationCenter.default.addObserver( - self, selector: #selector(windowDidBecomeKey(_:)), - name: NSWindow.didBecomeKeyNotification, object: nil - ) NotificationCenter.default.addObserver( self, selector: #selector(windowWillClose(_:)), name: NSWindow.willCloseNotification, object: nil ) - NotificationCenter.default.addObserver( - self, selector: #selector(windowDidChangeOcclusionState(_:)), - name: NSWindow.didChangeOcclusionStateNotification, object: nil - ) - NotificationCenter.default.addObserver( - self, selector: #selector(handleDatabaseDidConnect), - name: .databaseDidConnect, object: nil - ) NotificationCenter.default.addObserver( self, selector: #selector(handlePluginsRejected(_:)), name: .pluginsRejected, object: nil @@ -148,6 +92,45 @@ class AppDelegate: NSObject, NSApplicationDelegate { ) } + func applicationDidBecomeActive(_ notification: Notification) { + SyncCoordinator.shared.syncIfNeeded() + } + + func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply { + let hasUnsaved = MainContentCoordinator.hasAnyUnsavedChanges() + if hasUnsaved { + let alert = NSAlert() + alert.messageText = String(localized: "You have unsaved changes") + alert.informativeText = String(localized: "Some tabs have unsaved edits. Quitting will discard these changes.") + alert.alertStyle = .warning + alert.addButton(withTitle: String(localized: "Cancel")) + alert.addButton(withTitle: String(localized: "Quit Anyway")) + alert.buttons[1].hasDestructiveAction = true + let response = alert.runModal() + guard response == .alertSecondButtonReturn else { return .terminateCancel } + } + + Task { + await MCPServerManager.shared.stop() + NSApp.reply(toApplicationShouldTerminate: true) + } + return .terminateLater + } + + func applicationWillTerminate(_ notification: Notification) { + LinkedFolderWatcher.shared.stop() + TerminalProcessManager.registry.terminateAllSync() + SSHTunnelManager.shared.terminateAllProcessesSync() + } + + @objc func showHelp(_ sender: Any?) { + if let url = URL(string: "https://docs.tablepro.app") { + NSWorkspace.shared.open(url) + } + } + + // MARK: - Plugin Rejection Alert + @objc private func handlePluginsRejected(_ notification: Notification) { guard let rejected = notification.object as? [RejectedPlugin], !rejected.isEmpty else { return } @@ -184,40 +167,101 @@ class AppDelegate: NSObject, NSApplicationDelegate { } } - func applicationDidBecomeActive(_ notification: Notification) { - SyncCoordinator.shared.syncIfNeeded() + // MARK: - Window Notifications + + @objc func windowWillClose(_ notification: Notification) { + guard let window = notification.object as? NSWindow else { return } + + if AppLaunchCoordinator.isMainWindow(window) { + let remaining = NSApp.windows.filter { + $0 !== window && AppLaunchCoordinator.isMainWindow($0) && $0.isVisible + }.count + if remaining == 0 { + NotificationCenter.default.post(name: .mainWindowWillClose, object: nil) + WelcomeWindowFactory.openOrFront() + } + } } - func applicationShouldTerminate(_ sender: NSApplication) -> NSApplication.TerminateReply { - let hasUnsaved = MainContentCoordinator.hasAnyUnsavedChanges() - if hasUnsaved { - let alert = NSAlert() - alert.messageText = String(localized: "You have unsaved changes") - alert.informativeText = String(localized: "Some tabs have unsaved edits. Quitting will discard these changes.") - alert.alertStyle = .warning - alert.addButton(withTitle: String(localized: "Cancel")) - alert.addButton(withTitle: String(localized: "Quit Anyway")) - alert.buttons[1].hasDestructiveAction = true - let response = alert.runModal() - guard response == .alertSecondButtonReturn else { return .terminateCancel } + @objc func handleFocusConnectionForm() { + if let window = NSApp.windows.first(where: { AppLaunchCoordinator.isConnectionFormWindow($0) }) { + window.makeKeyAndOrderFront(nil) } + } - Task { - await MCPServerManager.shared.stop() - NSApp.reply(toApplicationShouldTerminate: true) + // MARK: - Dock Menu + + func applicationDockMenu(_ sender: NSApplication) -> NSMenu? { + let menu = NSMenu() + + let welcomeItem = NSMenuItem( + title: String(localized: "Show Welcome Window"), + action: #selector(showWelcomeFromDock), + keyEquivalent: "" + ) + welcomeItem.target = self + menu.addItem(welcomeItem) + + let connections = ConnectionStorage.shared.loadConnections() + if !connections.isEmpty { + let connectionsItem = NSMenuItem(title: String(localized: "Open Connection"), action: nil, keyEquivalent: "") + let submenu = NSMenu() + + for connection in connections { + let item = NSMenuItem( + title: connection.name, + action: #selector(connectFromDock(_:)), + keyEquivalent: "" + ) + item.target = self + item.representedObject = connection.id + let iconName = connection.type.iconName + let original = NSImage(systemSymbolName: iconName, accessibilityDescription: nil) + ?? NSImage(named: iconName) + if let original { + let resized = NSImage(size: NSSize(width: 16, height: 16), flipped: false) { rect in + original.draw(in: rect) + return true + } + item.image = resized + } + submenu.addItem(item) + } + + connectionsItem.submenu = submenu + menu.addItem(connectionsItem) } - return .terminateLater + + return menu } - func applicationWillTerminate(_ notification: Notification) { - LinkedFolderWatcher.shared.stop() - TerminalProcessManager.registry.terminateAllSync() - SSHTunnelManager.shared.terminateAllProcessesSync() + @objc func showWelcomeFromDock() { + WelcomeWindowFactory.openOrFront() } - @objc func showHelp(_ sender: Any?) { - if let url = URL(string: "https://docs.tablepro.app") { - NSWorkspace.shared.open(url) + @objc func newWindowForTab(_ sender: Any?) { + guard let keyWindow = NSApp.keyWindow, + let connectionId = MainActor.assumeIsolated({ + WindowLifecycleMonitor.shared.connectionId(forWindow: keyWindow) + }) + else { return } + + MainActor.assumeIsolated { + if let actions = MainContentCoordinator.allActiveCoordinators() + .first(where: { $0.connectionId == connectionId })?.commandActions { + actions.newTab() + } else { + WindowManager.shared.openTab( + payload: EditorTabPayload(connectionId: connectionId, intent: .newEmptyTab) + ) + } + } + } + + @objc func connectFromDock(_ sender: NSMenuItem) { + guard let connectionId = sender.representedObject as? UUID else { return } + Task { + await LaunchIntentRouter.shared.route(.openConnection(connectionId)) } } diff --git a/TablePro/Core/Concurrency/OnceTask.swift b/TablePro/Core/Concurrency/OnceTask.swift new file mode 100644 index 000000000..69b0c2ed9 --- /dev/null +++ b/TablePro/Core/Concurrency/OnceTask.swift @@ -0,0 +1,52 @@ +// +// OnceTask.swift +// TablePro +// + +import Foundation + +actor OnceTask { + private struct Entry { + let task: Task + let generation: Int + } + + private var inFlight: [Key: Entry] = [:] + private var nextGeneration: Int = 0 + + init() {} + + func execute( + key: Key, + work: @Sendable @escaping () async throws -> Value + ) async throws -> Value { + if let existing = inFlight[key] { + return try await existing.task.value + } + + nextGeneration += 1 + let generation = nextGeneration + let task = Task { + try await work() + } + inFlight[key] = Entry(task: task, generation: generation) + defer { + if inFlight[key]?.generation == generation { + inFlight.removeValue(forKey: key) + } + } + return try await task.value + } + + func cancel(key: Key) { + inFlight[key]?.task.cancel() + inFlight.removeValue(forKey: key) + } + + func cancelAll() { + for entry in inFlight.values { + entry.task.cancel() + } + inFlight.removeAll() + } +} diff --git a/TablePro/Core/Database/DatabaseManager+ConnectionState.swift b/TablePro/Core/Database/DatabaseManager+ConnectionState.swift new file mode 100644 index 000000000..765a3490f --- /dev/null +++ b/TablePro/Core/Database/DatabaseManager+ConnectionState.swift @@ -0,0 +1,20 @@ +import Foundation + +enum ConnectionState { + case live(DatabaseDriver, ConnectionSession) + case stored(DatabaseConnection) + case unknown +} + +extension DatabaseManager { + @MainActor + func connectionState(_ id: UUID) -> ConnectionState { + if let session = activeSessions[id], let driver = session.driver { + return .live(driver, session) + } + if let connection = ConnectionStorage.shared.loadConnections().first(where: { $0.id == id }) { + return .stored(connection) + } + return .unknown + } +} diff --git a/TablePro/Core/Database/DatabaseManager+EnsureConnected.swift b/TablePro/Core/Database/DatabaseManager+EnsureConnected.swift new file mode 100644 index 000000000..4abe669e2 --- /dev/null +++ b/TablePro/Core/Database/DatabaseManager+EnsureConnected.swift @@ -0,0 +1,15 @@ +// +// DatabaseManager+EnsureConnected.swift +// TablePro +// + +import Foundation + +extension DatabaseManager { + func ensureConnected(_ connection: DatabaseConnection) async throws { + if activeSessions[connection.id]?.driver != nil { return } + try await ensureConnectedDedup.execute(key: connection.id) { + try await self.connectToSession(connection) + } + } +} diff --git a/TablePro/Core/Database/DatabaseManager+Health.swift b/TablePro/Core/Database/DatabaseManager+Health.swift index 1cb7b128c..a6c6e7c8e 100644 --- a/TablePro/Core/Database/DatabaseManager+Health.swift +++ b/TablePro/Core/Database/DatabaseManager+Health.swift @@ -51,6 +51,7 @@ extension DatabaseManager { reconnectHandler: { [weak self] in guard let self else { return false } guard let session = await self.activeSessions[connectionId] else { return false } + await SchemaService.shared.invalidate(connectionId: connectionId) do { let result = try await self.trackOperation(sessionId: connectionId) { try await self.reconnectDriver(for: session) @@ -207,6 +208,8 @@ extension DatabaseManager { session.status = .connecting } + await SchemaService.shared.invalidate(connectionId: sessionId) + // Stop existing health monitor await stopHealthMonitor(for: sessionId) diff --git a/TablePro/Core/Database/DatabaseManager+Sessions.swift b/TablePro/Core/Database/DatabaseManager+Sessions.swift index 59f99a4ec..ecb77f56e 100644 --- a/TablePro/Core/Database/DatabaseManager+Sessions.swift +++ b/TablePro/Core/Database/DatabaseManager+Sessions.swift @@ -13,17 +13,12 @@ import TableProPluginKit // MARK: - Session Management extension DatabaseManager { - /// Connect to a database and create/switch to its session - /// If connection already has a session, switches to it instead func connectToSession(_ connection: DatabaseConnection) async throws { - // Check if session already exists and is connected if let existing = activeSessions[connection.id], existing.driver != nil { - // Session is fully connected, just switch to it switchToSession(connection.id) return } - // Resolve environment variable references in connection fields (Pro feature) let resolvedConnection: DatabaseConnection if LicenseManager.shared.isFeatureAvailable(.envVarReferences) { resolvedConnection = EnvVarResolver.resolveConnection(connection) @@ -31,7 +26,6 @@ extension DatabaseManager { resolvedConnection = connection } - // Create new session (or reuse a prepared one) if activeSessions[connection.id] == nil { var session = ConnectionSession(connection: connection) session.status = .connecting @@ -39,18 +33,15 @@ extension DatabaseManager { } currentSessionId = connection.id - // Create SSH tunnel if needed and build effective connection let effectiveConnection: DatabaseConnection do { effectiveConnection = try await buildEffectiveConnection(for: resolvedConnection) } catch { - // Remove failed session removeSessionEntry(for: connection.id) currentSessionId = nil throw error } - // Run pre-connect hook if configured (only on explicit connect, not auto-reconnect) if let script = resolvedConnection.preConnectScript, !script.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { @@ -63,7 +54,6 @@ extension DatabaseManager { } } - // Resolve password override for prompt-for-password connections var passwordOverride: String? if connection.promptForPassword { if let cached = activeSessions[connection.id]?.cachedPassword { @@ -83,7 +73,6 @@ extension DatabaseManager { } } - // Create appropriate driver with effective connection let driver: DatabaseDriver do { driver = try await DatabaseDriverFactory.createDriver( @@ -92,7 +81,6 @@ extension DatabaseManager { awaitPlugins: true ) } catch { - // Close tunnel if SSH was established if connection.resolvedSSHConfig.enabled { Task { do { @@ -110,35 +98,32 @@ extension DatabaseManager { do { try await driver.connect() - // Apply query timeout from settings (best-effort — some PostgreSQL-compatible - // databases like Aurora DSQL don't support SET statement_timeout) let timeoutSeconds = AppSettingsManager.shared.general.queryTimeoutSeconds if timeoutSeconds > 0 { do { try await driver.applyQueryTimeout(timeoutSeconds) } catch { + // Best-effort: some PostgreSQL-compatible databases like Aurora DSQL + // don't support SET statement_timeout. Self.logger.warning( "Query timeout not supported for \(connection.name): \(error.localizedDescription)" ) } } - // Run startup commands before schema init await executeStartupCommands( resolvedConnection.startupCommands, on: driver, connectionName: connection.name ) - // Initialize schema for drivers that support schema switching if let schemaDriver = driver as? SchemaSwitchable { activeSessions[connection.id]?.currentSchema = schemaDriver.currentSchema } - // Run post-connect actions declared by the plugin await executePostConnectActions( for: connection, resolvedConnection: resolvedConnection, driver: driver ) - // Batch all session mutations into a single write to fire objectWillChange once + // Batch all session mutations into a single write to fire objectWillChange once. if var session = activeSessions[connection.id] { session.driver = driver session.status = driver.status @@ -149,13 +134,10 @@ extension DatabaseManager { setSession(session, for: connection.id) } - // Save as last connection for "Reopen Last Session" feature appSettingsStorage.saveLastConnectionId(connection.id) - // Post notification for reliable delivery NotificationCenter.default.post(name: .databaseDidConnect, object: nil) - // Start health monitoring if the plugin supports it let supportsHealth = PluginMetadataRegistry.shared.snapshot( forTypeId: connection.type.pluginTypeId )?.supportsHealthMonitor ?? true @@ -164,7 +146,6 @@ extension DatabaseManager { await startHealthMonitor(for: connection.id) } } catch { - // Close tunnel if connection failed if connection.resolvedSSHConfig.enabled { Task { do { @@ -175,12 +156,10 @@ extension DatabaseManager { } } - // Remove failed session completely so UI returns to Welcome window + // Remove failed session completely so UI returns to Welcome window. removeSessionEntry(for: connection.id) - // Clear current session if this was it if currentSessionId == connection.id { - // Switch to another session if available, otherwise clear if let nextSessionId = activeSessions.keys.first { currentSessionId = nextSessionId } else { @@ -268,6 +247,7 @@ extension DatabaseManager { session.currentSchema = nil } appSettingsStorage.saveLastSchema(nil, for: connectionId) + await SchemaService.shared.invalidate(connectionId: connectionId) await reconnectSession(connectionId) } else if pm?.capabilities.supportsSchemaSwitching == true, let schemaDriver = driver as? SchemaSwitchable { @@ -304,7 +284,6 @@ extension DatabaseManager { appSettingsStorage.saveLastSchema(schema, for: connectionId) } - /// Switch to an existing session func switchToSession(_ sessionId: UUID) { guard activeSessions[sessionId] != nil else { return } currentSessionId = sessionId @@ -313,7 +292,6 @@ extension DatabaseManager { } } - /// Disconnect a specific session func disconnectSession(_ sessionId: UUID) async { let lifecycleLogger = Logger(subsystem: "com.TablePro", category: "NativeTabLifecycle") guard let session = activeSessions[sessionId] else { @@ -327,7 +305,6 @@ extension DatabaseManager { "[close] disconnectSession start connId=\(sessionId, privacy: .public) name=\(session.connection.name, privacy: .public) hasSSH=\(session.connection.resolvedSSHConfig.enabled)" ) - // Close SSH tunnel if exists if session.connection.resolvedSSHConfig.enabled { let sshStart = Date() do { @@ -340,7 +317,6 @@ extension DatabaseManager { ) } - // Stop health monitoring let hmStart = Date() await stopHealthMonitor(for: sessionId) lifecycleLogger.info( @@ -354,18 +330,16 @@ extension DatabaseManager { ) removeSessionEntry(for: sessionId) - // Clean up shared schema cache for this connection + await SchemaService.shared.invalidate(connectionId: sessionId) + SchemaProviderRegistry.shared.clear(for: sessionId) - // Clean up shared sidebar state for this connection SharedSidebarState.removeConnection(sessionId) - // If this was the current session, switch to another or clear if currentSessionId == sessionId { if let nextSessionId = activeSessions.keys.first { switchToSession(nextSessionId) } else { - // No more sessions - clear current session and last connection ID currentSessionId = nil appSettingsStorage.saveLastConnectionId(nil) } @@ -375,7 +349,6 @@ extension DatabaseManager { ) } - /// Disconnect all sessions func disconnectAll() async { let monitorIds = Array(healthMonitors.keys) for sessionId in monitorIds { @@ -388,8 +361,7 @@ extension DatabaseManager { } } - /// Update session state (for preserving UI state). - /// Skips the write-back when no observable fields changed, avoiding spurious connectionStatusVersion bumps. + // Skips the write-back when no observable fields changed, avoiding spurious connectionStatusVersion bumps. func updateSession(_ sessionId: UUID, update: (inout ConnectionSession) -> Void) { guard var session = activeSessions[sessionId] else { return } let before = session @@ -400,14 +372,12 @@ extension DatabaseManager { setSession(session, for: sessionId) } - /// Write a session and bump its per-connection version counter. internal func setSession(_ session: ConnectionSession, for connectionId: UUID) { activeSessions[connectionId] = session connectionStatusVersions[connectionId, default: 0] &+= 1 NotificationCenter.default.post(name: .connectionStatusDidChange, object: connectionId) } - /// Remove a session and clean up its per-connection version counter. internal func removeSessionEntry(for connectionId: UUID) { activeSessions.removeValue(forKey: connectionId) connectionStatusVersions.removeValue(forKey: connectionId) @@ -415,12 +385,10 @@ extension DatabaseManager { } #if DEBUG - /// Test-only: inject a session for unit testing without real database connections internal func injectSession(_ session: ConnectionSession, for connectionId: UUID) { setSession(session, for: connectionId) } - /// Test-only: remove an injected session internal func removeSession(for connectionId: UUID) { removeSessionEntry(for: connectionId) } diff --git a/TablePro/Core/Database/DatabaseManager.swift b/TablePro/Core/Database/DatabaseManager.swift index 83bded437..d70f35448 100644 --- a/TablePro/Core/Database/DatabaseManager.swift +++ b/TablePro/Core/Database/DatabaseManager.swift @@ -60,6 +60,8 @@ final class DatabaseManager { /// and the wake-from-sleep handler fire for the same connection. @ObservationIgnored internal var recoveringConnectionIds = Set() + @ObservationIgnored internal let ensureConnectedDedup = OnceTask() + /// Current session (computed from currentSessionId) var currentSession: ConnectionSession? { guard let sessionId = currentSessionId else { return nil } diff --git a/TablePro/Core/MCP/MCPAuthGuard.swift b/TablePro/Core/MCP/MCPAuthGuard.swift deleted file mode 100644 index c4689f3d8..000000000 --- a/TablePro/Core/MCP/MCPAuthGuard.swift +++ /dev/null @@ -1,274 +0,0 @@ -// -// MCPAuthGuard.swift -// TablePro -// -// Enforces AIConnectionPolicy and SafeModeLevel for MCP requests. -// - -import AppKit -import Foundation -import os - -actor MCPAuthGuard { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPAuthGuard") - - /// Per-session approved connections (for askEachTime policy) - private var sessionApprovals: [String: Set] = [:] - - /// In-flight approval prompts keyed by (sessionId, connectionId) to dedupe concurrent requests. - private var inFlightApprovals: [ApprovalKey: Task] = [:] - - private struct ApprovalKey: Hashable { - let sessionId: String - let connectionId: UUID - } - - // MARK: - Connection Access Check - - func checkConnectionAccess(connectionId: UUID, sessionId: String) async throws { - let snapshot: ConnectionAccessSnapshot? = await MainActor.run { - let conns = ConnectionStorage.shared.loadConnections() - guard let conn = conns.first(where: { $0.id == connectionId }) else { - return nil - } - return ConnectionAccessSnapshot( - policy: conn.aiPolicy ?? AppSettingsManager.shared.ai.defaultConnectionPolicy, - externalAccess: conn.externalAccess, - name: conn.name, - databaseType: conn.type.rawValue - ) - } - - guard let snapshot else { - throw MCPError.forbidden( - String(localized: "Connection not found") - ) - } - - switch snapshot.policy { - case .alwaysAllow: - break - - case .never: - throw MCPError.forbidden( - String(localized: "AI access is disabled for this connection") - ) - - case .askEachTime: - if let approved = sessionApprovals[sessionId], approved.contains(connectionId) { - break - } - - let key = ApprovalKey(sessionId: sessionId, connectionId: connectionId) - let approvalTask: Task - if let existing = inFlightApprovals[key] { - approvalTask = existing - } else { - let connectionName = snapshot.name - let databaseType = snapshot.databaseType - approvalTask = Task { - try await self.promptUserApproval( - connectionName: connectionName, - databaseType: databaseType - ) - } - inFlightApprovals[key] = approvalTask - } - - let userApproved: Bool - do { - userApproved = try await approvalTask.value - inFlightApprovals.removeValue(forKey: key) - } catch { - inFlightApprovals.removeValue(forKey: key) - throw error - } - - if userApproved { - sessionApprovals[sessionId, default: []].insert(connectionId) - } else { - throw MCPError.forbidden( - String(localized: "User denied MCP access to this connection") - ) - } - } - - if snapshot.externalAccess == .blocked { - throw MCPError.forbidden( - String(localized: "External access is disabled for this connection") - ) - } - } - - func checkExternalAccessLevel( - connectionId: UUID, - requires required: ExternalAccessLevel - ) async throws { - let externalAccess: ExternalAccessLevel? = await MainActor.run { - ConnectionStorage.shared.loadConnections().first { $0.id == connectionId }?.externalAccess - } - - guard let externalAccess else { - throw MCPError.forbidden( - String(localized: "Connection not found") - ) - } - - guard externalAccess.satisfies(required) else { - throw MCPError.forbidden( - String(localized: "Connection is read-only for external clients") - ) - } - } - - func checkExternalWritePermission( - connectionId: UUID, - sql: String, - databaseType: DatabaseType - ) async throws { - guard QueryClassifier.isWriteQuery(sql, databaseType: databaseType) else { return } - - let externalAccess: ExternalAccessLevel? = await MainActor.run { - ConnectionStorage.shared.loadConnections().first { $0.id == connectionId }?.externalAccess - } - - guard let externalAccess else { - throw MCPError.forbidden( - String(localized: "Connection not found") - ) - } - - if externalAccess != .readWrite { - throw MCPError.forbidden( - String(localized: "Connection is read-only for external clients") - ) - } - } - - private struct ConnectionAccessSnapshot: Sendable { - let policy: AIConnectionPolicy - let externalAccess: ExternalAccessLevel - let name: String - let databaseType: String - } - - // MARK: - Query Permission Check - - func checkQueryPermission( - sql: String, - connectionId: UUID, - databaseType: DatabaseType, - safeModeLevel: SafeModeLevel - ) async throws { - try await checkExternalWritePermission( - connectionId: connectionId, - sql: sql, - databaseType: databaseType - ) - - let isWrite = QueryClassifier.isWriteQuery(sql, databaseType: databaseType) - let needsDialog = safeModeLevel != .silent && (isWrite || safeModeLevel == .alertFull || safeModeLevel == .safeModeFull) - - var window: NSWindow? - if needsDialog { - window = await MainActor.run { - NSApp.activate(ignoringOtherApps: true) - return NSApp.keyWindow ?? NSApp.mainWindow - } - } - - let permission = await SafeModeGuard.checkPermission( - level: safeModeLevel, - isWriteOperation: isWrite, - sql: sql, - operationDescription: String(localized: "MCP query execution"), - window: window, - databaseType: databaseType - ) - - if case .blocked(let reason) = permission { - throw MCPError.forbidden(reason) - } - } - - // MARK: - Query Logging - - func logQuery( - sql: String, - connectionId: UUID, - databaseName: String, - executionTime: TimeInterval, - rowCount: Int, - wasSuccessful: Bool, - errorMessage: String? - ) async { - let shouldLog = await MainActor.run { - AppSettingsManager.shared.mcp.logQueriesInHistory - } - guard shouldLog else { return } - - let entry = QueryHistoryEntry( - query: sql, - connectionId: connectionId, - databaseName: databaseName, - executionTime: executionTime, - rowCount: rowCount, - wasSuccessful: wasSuccessful, - errorMessage: errorMessage - ) - - _ = await QueryHistoryStorage.shared.addHistory(entry) - } - - // MARK: - User Approval (askEachTime) - - private func promptUserApproval(connectionName: String, databaseType: String) async throws -> Bool { - let approvalTask = Task { @MainActor in - NSApp.requestUserAttention(.criticalRequest) - NSApp.activate(ignoringOtherApps: true) - return await AlertHelper.confirmDestructive( - title: String(localized: "MCP Access Request"), - message: String( - format: String(localized: "An MCP client wants to access '%@' (%@). Allow?"), - connectionName, - databaseType - ), - confirmButton: String(localized: "Allow"), - cancelButton: String(localized: "Deny"), - window: nil - ) - } - - let approved = try await withThrowingTaskGroup(of: Bool.self) { group in - group.addTask { - await approvalTask.value - } - group.addTask { - try await Task.sleep(for: .seconds(30)) - approvalTask.cancel() - throw MCPError.timeout( - String(localized: "User approval timed out after 30 seconds") - ) - } - guard let result = try await group.next() else { - throw MCPError.internalError("No result from approval prompt") - } - approvalTask.cancel() - group.cancelAll() - return result - } - - if approved { - return true - } - throw MCPError.forbidden( - String(localized: "User denied MCP access to this connection") - ) - } - - // MARK: - Session Cleanup - - func clearSession(_ sessionId: String) { - sessionApprovals.removeValue(forKey: sessionId) - } -} diff --git a/TablePro/Core/MCP/MCPAuthPolicy.swift b/TablePro/Core/MCP/MCPAuthPolicy.swift new file mode 100644 index 000000000..2fa2eee60 --- /dev/null +++ b/TablePro/Core/MCP/MCPAuthPolicy.swift @@ -0,0 +1,314 @@ +import AppKit +import Foundation +import os + +typealias MCPToolName = String + +extension MCPToolName { + static let stateMutating: Set = [ + "execute_query", "confirm_destructive_operation", + "switch_database", "switch_schema", "export_data" + ] + static let requiresFullAccess: Set = ["confirm_destructive_operation"] + static let requiresReadWrite: Set = ["switch_database", "switch_schema", "export_data"] + static let writeQueryTools: Set = ["execute_query"] +} + +enum AuthDecision: Sendable { + case allowed + case requiresUserApproval(reason: String) + case denied(reason: String) +} + +actor MCPAuthPolicy { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCPAuthPolicy") + + private var sessionApprovals: [String: Set] = [:] + private let approvalDedup = OnceTask() + + private struct ApprovalKey: Hashable, Sendable { + let sessionId: String + let connectionId: UUID + } + + private struct ConnectionSnapshot: Sendable { + let policy: AIConnectionPolicy + let externalAccess: ExternalAccessLevel + let name: String + let databaseType: String + let safeModeLevel: SafeModeLevel + } + + func authorize( + token: MCPAuthToken, + tool: MCPToolName, + connectionId: UUID?, + sql: String? = nil, + sessionId: String + ) async throws -> AuthDecision { + guard let connectionId else { + return decideTokenTier(token: token, tool: tool) + } + + guard let snapshot = await loadConnection(connectionId) else { + return .denied(reason: String(localized: "Connection not found")) + } + + if snapshot.policy == .never { + return .denied(reason: String(localized: "AI access is disabled for this connection")) + } + + if snapshot.externalAccess == .blocked { + return .denied(reason: String(localized: "External access is disabled for this connection")) + } + + if !token.connectionAccess.allows(connectionId) { + return .denied(reason: String(localized: "Token does not have access to this connection")) + } + + if case .denied(let reason) = decideTokenTier(token: token, tool: tool) { + return .denied(reason: reason) + } + + if let writeReason = denialForWriteIntent( + tool: tool, + sql: sql, + externalAccess: snapshot.externalAccess, + databaseType: snapshot.databaseType + ) { + return .denied(reason: writeReason) + } + + if snapshot.policy == .askEachTime, + !(sessionApprovals[sessionId]?.contains(connectionId) ?? false) + { + return .requiresUserApproval( + reason: String( + format: String(localized: "An MCP client wants to access '%@' (%@). Allow?"), + snapshot.name, + snapshot.databaseType + ) + ) + } + + return .allowed + } + + func resolveAndAuthorize( + token: MCPAuthToken, + tool: MCPToolName, + connectionId: UUID?, + sql: String? = nil, + sessionId: String + ) async throws { + let decision = try await authorize( + token: token, + tool: tool, + connectionId: connectionId, + sql: sql, + sessionId: sessionId + ) + + switch decision { + case .allowed: + return + + case .denied(let reason): + throw MCPError.forbidden(reason) + + case .requiresUserApproval(let reason): + guard let connectionId else { + throw MCPError.forbidden(reason) + } + let approved = try await runApprovalDedup( + sessionId: sessionId, + connectionId: connectionId, + reason: reason + ) + if approved { + recordApproval(sessionId: sessionId, connectionId: connectionId) + } else { + throw MCPError.forbidden( + String(localized: "User denied MCP access to this connection") + ) + } + } + } + + func recordApproval(sessionId: String, connectionId: UUID) { + sessionApprovals[sessionId, default: []].insert(connectionId) + } + + func clearSession(_ sessionId: String) { + sessionApprovals.removeValue(forKey: sessionId) + } + + func checkSafeModeDialog( + sql: String, + connectionId: UUID, + databaseType: DatabaseType, + safeModeLevel: SafeModeLevel + ) async throws { + let isWrite = QueryClassifier.isWriteQuery(sql, databaseType: databaseType) + let needsDialog = safeModeLevel != .silent + && (isWrite || safeModeLevel == .alertFull || safeModeLevel == .safeModeFull) + + let window: NSWindow? = needsDialog + ? await MainActor.run { + NSApp.activate(ignoringOtherApps: true) + return WindowLifecycleMonitor.shared.findWindow(for: connectionId) + ?? NSApp.mainWindow + } + : nil + + let permission = await SafeModeGuard.checkPermission( + level: safeModeLevel, + isWriteOperation: isWrite, + sql: sql, + operationDescription: String(localized: "MCP query execution"), + window: window, + databaseType: databaseType + ) + + if case .blocked(let reason) = permission { + throw MCPError.forbidden(reason) + } + } + + func logQuery( + sql: String, + connectionId: UUID, + databaseName: String, + executionTime: TimeInterval, + rowCount: Int, + wasSuccessful: Bool, + errorMessage: String? + ) async { + let shouldLog = await MainActor.run { + AppSettingsManager.shared.mcp.logQueriesInHistory + } + guard shouldLog else { return } + + let entry = QueryHistoryEntry( + query: sql, + connectionId: connectionId, + databaseName: databaseName, + executionTime: executionTime, + rowCount: rowCount, + wasSuccessful: wasSuccessful, + errorMessage: errorMessage + ) + + _ = await QueryHistoryStorage.shared.addHistory(entry) + } + + private func runApprovalDedup( + sessionId: String, + connectionId: UUID, + reason: String + ) async throws -> Bool { + let key = ApprovalKey(sessionId: sessionId, connectionId: connectionId) + return try await approvalDedup.execute(key: key) { + try await Self.promptApproval(reason: reason) + } + } + + private static func promptApproval(reason: String) async throws -> Bool { + try await withThrowingTaskGroup(of: Bool.self) { group in + defer { group.cancelAll() } + group.addTask { + await AlertHelper.runApprovalModal( + title: String(localized: "MCP Access Request"), + message: reason, + confirm: String(localized: "Allow"), + cancel: String(localized: "Deny") + ) + } + group.addTask { + try await Task.sleep(for: .seconds(30)) + throw MCPError.timeout( + String(localized: "User approval timed out after 30 seconds") + ) + } + guard let result = try await group.next() else { + throw MCPError.internalError("No result from approval prompt") + } + return result + } + } + + private func decideTokenTier(token: MCPAuthToken, tool: MCPToolName) -> AuthDecision { + let required = requiredPermission(for: tool) + if token.permissions.satisfies(required) { + return .allowed + } + return .denied( + reason: String( + format: String(localized: "Token '%@' with permission '%@' cannot access '%@'"), + token.name, + token.permissions.displayName, + tool + ) + ) + } + + private func requiredPermission(for tool: MCPToolName) -> TokenPermissions { + if MCPToolName.requiresFullAccess.contains(tool) { return .fullAccess } + if MCPToolName.requiresReadWrite.contains(tool) { return .readWrite } + return .readOnly + } + + private func denialForWriteIntent( + tool: MCPToolName, + sql: String?, + externalAccess: ExternalAccessLevel, + databaseType: String + ) -> String? { + if MCPToolName.requiresReadWrite.contains(tool) || MCPToolName.requiresFullAccess.contains(tool) { + if externalAccess != .readWrite { + return String(localized: "Connection is read-only for external clients") + } + return nil + } + + guard MCPToolName.writeQueryTools.contains(tool), let sql else { + return nil + } + + let dbType = DatabaseType(rawValue: databaseType) + guard QueryClassifier.isWriteQuery(sql, databaseType: dbType) else { + return nil + } + if externalAccess != .readWrite { + return String(localized: "Connection is read-only for external clients") + } + return nil + } + + private func loadConnection(_ connectionId: UUID) async -> ConnectionSnapshot? { + await MainActor.run { + let state = DatabaseManager.shared.connectionState(connectionId) + switch state { + case .live(_, let session): + let conn = session.connection + return ConnectionSnapshot( + policy: conn.aiPolicy ?? AppSettingsManager.shared.ai.defaultConnectionPolicy, + externalAccess: conn.externalAccess, + name: conn.name, + databaseType: conn.type.rawValue, + safeModeLevel: conn.safeModeLevel + ) + case .stored(let conn): + return ConnectionSnapshot( + policy: conn.aiPolicy ?? AppSettingsManager.shared.ai.defaultConnectionPolicy, + externalAccess: conn.externalAccess, + name: conn.name, + databaseType: conn.type.rawValue, + safeModeLevel: conn.safeModeLevel + ) + case .unknown: + return nil + } + } + } +} diff --git a/TablePro/Core/MCP/MCPConnectionBridge.swift b/TablePro/Core/MCP/MCPConnectionBridge.swift index 770daeb0c..97e647c56 100644 --- a/TablePro/Core/MCP/MCPConnectionBridge.swift +++ b/TablePro/Core/MCP/MCPConnectionBridge.swift @@ -11,8 +11,6 @@ import os actor MCPConnectionBridge { private static let logger = Logger(subsystem: "com.TablePro", category: "MCPConnectionBridge") - private var inFlightConnects: [UUID: Task] = [:] - // MARK: - Connection Management func listConnections() async -> JSONValue { @@ -72,9 +70,7 @@ actor MCPConnectionBridge { return .object(result) } - // Not connected yet -- create a new session via DatabaseManager. - // connectToSession is @MainActor; Swift hops automatically for async calls. - try await DatabaseManager.shared.connectToSession(connection) + try await DatabaseManager.shared.ensureConnected(connection) let (serverVersion, currentDatabase, currentSchema) = await MainActor.run { let session = DatabaseManager.shared.activeSessions[connectionId] @@ -173,8 +169,9 @@ actor MCPConnectionBridge { timeoutSeconds: Int ) async throws -> JSONValue { let (driver, databaseType) = try await resolveDriver(connectionId) - let isWrite = QueryClassifier.isWriteQuery(query, databaseType: databaseType) - let hasReturning = query.range(of: #"\bRETURNING\b"#, options: [.regularExpression, .caseInsensitive]) != nil + let normalizedQuery = Self.stripTrailingSemicolons(query) + let isWrite = QueryClassifier.isWriteQuery(normalizedQuery, databaseType: databaseType) + let hasReturning = normalizedQuery.range(of: #"\bRETURNING\b"#, options: [.regularExpression, .caseInsensitive]) != nil let shouldUseFetchRows = !isWrite || hasReturning let effectiveLimit = maxRows + 1 @@ -186,9 +183,9 @@ actor MCPConnectionBridge { try await withThrowingTaskGroup(of: QueryResult.self) { group in group.addTask { if shouldUseFetchRows { - try await driver.fetchRows(query: query, offset: 0, limit: effectiveLimit) + try await driver.fetchRows(query: normalizedQuery, offset: 0, limit: effectiveLimit) } else { - try await driver.execute(query: query) + try await driver.execute(query: normalizedQuery) } } group.addTask { @@ -237,15 +234,8 @@ actor MCPConnectionBridge { // MARK: - Schema Operations func listTables(connectionId: UUID, includeRowCounts: Bool) async throws -> JSONValue { - let provider = await MainActor.run { - SchemaProviderRegistry.shared.provider(for: connectionId) - } - var cachedTables: [TableInfo] = [] - if let provider { - let cached = await provider.getTables() - if !cached.isEmpty { - cachedTables = cached - } + let cachedTables = await MainActor.run { + SchemaService.shared.tables(for: connectionId) } let tables: [TableInfo] @@ -385,16 +375,8 @@ actor MCPConnectionBridge { // MARK: - Schema Resource (for resources/read) func fetchSchemaResource(connectionId: UUID) async throws -> JSONValue { - // Check SchemaProviderRegistry cache first - let provider = await MainActor.run { - SchemaProviderRegistry.shared.provider(for: connectionId) - } - var cachedTables: [TableInfo] = [] - if let provider { - let cached = await provider.getTables() - if !cached.isEmpty { - cachedTables = cached - } + let cachedTables = await MainActor.run { + SchemaService.shared.tables(for: connectionId) } let (driver, _) = try await resolveDriver(connectionId) @@ -486,33 +468,28 @@ actor MCPConnectionBridge { // MARK: - Private Helpers private func resolveDriver(_ connectionId: UUID) async throws -> (DatabaseDriver, DatabaseType) { - let connection: DatabaseConnection? = await MainActor.run { - if DatabaseManager.shared.activeSessions[connectionId]?.driver != nil { return nil } - return ConnectionStorage.shared.loadConnections().first { $0.id == connectionId } + let pending: DatabaseConnection? = await MainActor.run { + switch DatabaseManager.shared.connectionState(connectionId) { + case .live: return nil + case .stored(let connection): return connection + case .unknown: return nil + } } - if let connection { - try await connectIfNeeded(connection) + if let pending { + try await connectIfNeeded(pending) } return try await MainActor.run { - guard let session = DatabaseManager.shared.activeSessions[connectionId], - let driver = session.driver else { + switch DatabaseManager.shared.connectionState(connectionId) { + case .live(let driver, let session): + return (driver, session.connection.type) + case .stored, .unknown: throw MCPError.notConnected(connectionId) } - return (driver, session.connection.type) } } private func connectIfNeeded(_ connection: DatabaseConnection) async throws { - if let existing = inFlightConnects[connection.id] { - try await existing.value - return - } - let task = Task { [connection] in - try await DatabaseManager.shared.connectToSession(connection) - } - inFlightConnects[connection.id] = task - defer { inFlightConnects.removeValue(forKey: connection.id) } - try await task.value + try await DatabaseManager.shared.ensureConnected(connection) } private func resolveSession(_ connectionId: UUID) async throws -> ConnectionSession { @@ -533,4 +510,13 @@ actor MCPConnectionBridge { return connection } } + + static func stripTrailingSemicolons(_ query: String) -> String { + var result = query.trimmingCharacters(in: .whitespacesAndNewlines) + while result.hasSuffix(";") { + result = String(result.dropLast()) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + return result + } } diff --git a/TablePro/Core/MCP/MCPHTTPParser.swift b/TablePro/Core/MCP/MCPHTTPParser.swift index 6fdc329ae..5662aa73e 100644 --- a/TablePro/Core/MCP/MCPHTTPParser.swift +++ b/TablePro/Core/MCP/MCPHTTPParser.swift @@ -13,6 +13,19 @@ struct HTTPRequest: Sendable { let path: String let headers: [String: String] let body: Data? + var remoteIP: String? + + init(method: Method, path: String, headers: [String: String], body: Data?, remoteIP: String? = nil) { + self.method = method + self.path = path + self.headers = headers + self.body = body + self.remoteIP = remoteIP + } + + func withRemoteIP(_ remoteIP: String?) -> HTTPRequest { + HTTPRequest(method: method, path: path, headers: headers, body: body, remoteIP: remoteIP) + } } enum HTTPParseError: Error, Sendable { diff --git a/TablePro/Core/MCP/MCPPairingService.swift b/TablePro/Core/MCP/MCPPairingService.swift index ae1652fbb..c486180f2 100644 --- a/TablePro/Core/MCP/MCPPairingService.swift +++ b/TablePro/Core/MCP/MCPPairingService.swift @@ -120,12 +120,13 @@ final class MCPPairingService { throw MCPError.internalError("Token store unavailable") } - let approval = try await PairingApprovalPresenter.present(request: request) + let approval = try await AlertHelper.runPairingApproval(request: request) + let connectionAccess: ConnectionAccess = approval.allowedConnectionIds.map { .limited($0) } ?? .all let result = await tokenStore.generate( name: request.clientName, permissions: approval.grantedPermissions, - allowedConnectionIds: approval.allowedConnectionIds, + connectionAccess: connectionAccess, expiresAt: approval.expiresAt ) diff --git a/TablePro/Core/MCP/MCPResourceHandler.swift b/TablePro/Core/MCP/MCPResourceHandler.swift index d2fec5f24..aa4b36226 100644 --- a/TablePro/Core/MCP/MCPResourceHandler.swift +++ b/TablePro/Core/MCP/MCPResourceHandler.swift @@ -5,11 +5,11 @@ final class MCPResourceHandler: Sendable { private static let logger = Logger(subsystem: "com.TablePro", category: "MCPResourceHandler") private let bridge: MCPConnectionBridge - private let authGuard: MCPAuthGuard + private let authPolicy: MCPAuthPolicy - init(bridge: MCPConnectionBridge, authGuard: MCPAuthGuard) { + init(bridge: MCPConnectionBridge, authPolicy: MCPAuthPolicy) { self.bridge = bridge - self.authGuard = authGuard + self.authPolicy = authPolicy } func handleResourceRead(uri: String, sessionId: String) async throws -> MCPResourceReadResult { @@ -65,7 +65,12 @@ final class MCPResourceHandler: Sendable { } private func handleSchemaResource(uri: String, connectionId: UUID, sessionId: String) async throws -> MCPResourceReadResult { - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authPolicy.resolveAndAuthorize( + token: MCPToolHandler.anonymousFullAccessToken, + tool: "describe_table", + connectionId: connectionId, + sessionId: sessionId + ) let result = try await bridge.fetchSchemaResource(connectionId: connectionId) let jsonString = encodeJSON(result) return MCPResourceReadResult(contents: [ @@ -79,7 +84,12 @@ final class MCPResourceHandler: Sendable { queryItems: [URLQueryItem], sessionId: String ) async throws -> MCPResourceReadResult { - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authPolicy.resolveAndAuthorize( + token: MCPToolHandler.anonymousFullAccessToken, + tool: "search_query_history", + connectionId: connectionId, + sessionId: sessionId + ) let limit = queryItems.first(where: { $0.name == "limit" }) .flatMap { $0.value } .flatMap { Int($0) } diff --git a/TablePro/Core/MCP/MCPRouteHandler.swift b/TablePro/Core/MCP/MCPRouteHandler.swift new file mode 100644 index 000000000..3b7a39c12 --- /dev/null +++ b/TablePro/Core/MCP/MCPRouteHandler.swift @@ -0,0 +1,7 @@ +import Foundation + +protocol MCPRouteHandler: Sendable { + var methods: [HTTPRequest.Method] { get } + var path: String { get } + func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult +} diff --git a/TablePro/Core/MCP/MCPRouter.swift b/TablePro/Core/MCP/MCPRouter.swift index a47305a02..1561e27dd 100644 --- a/TablePro/Core/MCP/MCPRouter.swift +++ b/TablePro/Core/MCP/MCPRouter.swift @@ -1,12 +1,6 @@ import Foundation -import os final class MCPRouter: Sendable { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPRouter") - - private let encoder: JSONEncoder - private let decoder: JSONDecoder - enum RouteResult: Sendable { case json(Data, sessionId: String?) case sseStream(sessionId: String) @@ -16,542 +10,40 @@ final class MCPRouter: Sendable { case httpErrorWithHeaders(status: Int, message: String, extraHeaders: [(String, String)]) } - init() { - let enc = JSONEncoder() - enc.outputFormatting = [.sortedKeys] - self.encoder = enc - self.decoder = JSONDecoder() + private let routes: [any MCPRouteHandler] + + init(routes: [any MCPRouteHandler]) { + self.routes = routes } - func route( - _ request: HTTPRequest, - server: MCPServer, - remoteIP: String?, - tokenStore: MCPTokenStore?, - rateLimiter: MCPRateLimiter? - ) async -> RouteResult { + func handle(_ request: HTTPRequest) async -> RouteResult { if request.path.hasPrefix("/.well-known/") { return .httpError(status: 404, message: "Not found") } - if request.path == "/v1/integrations/exchange" - || request.path.hasPrefix("/v1/integrations/exchange?") - { - return await handleIntegrationsExchange(request) + if request.method == .options { + return .noContent } - guard request.path == "/mcp" || request.path.hasPrefix("/mcp?") else { + guard let route = match(request) else { return .httpError(status: 404, message: "Not found") } - if let rateLimiter, let ip = remoteIP { - let lockoutCheck = await rateLimiter.isLockedOut(ip: ip) - if case .rateLimited(let retryAfter) = lockoutCheck { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: ip, retryAfterSeconds: seconds) - return .httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - ) - } - } - - let authResult = await authenticateRequest( - request, - remoteIP: remoteIP, - tokenStore: tokenStore, - rateLimiter: rateLimiter - ) - - switch authResult { - case .failure(let result): - return result - case .success(let token): - if token == nil { - if let origin = request.headers["origin"], !isAllowedOrigin(origin) { - return .httpError(status: 403, message: "Forbidden origin") - } - } - - switch request.method { - case .options: - return handleOptions() - case .post: - return await handlePost(request, server: server, authenticatedToken: token) - case .get: - return await handleGet(request, server: server) - case .delete: - return await handleDelete(request, server: server) - } - } - } - - private enum AuthResult { - case success(MCPAuthToken?) - case failure(RouteResult) - } - - private func authenticateRequest( - _ request: HTTPRequest, - remoteIP: String?, - tokenStore: MCPTokenStore?, - rateLimiter: MCPRateLimiter? - ) async -> AuthResult { - let authRequired = await MainActor.run { AppSettingsManager.shared.mcp.requireAuthentication } - - guard let authHeader = request.headers["authorization"] else { - guard !authRequired else { - MCPAuditLogger.logAuthFailure(reason: "Missing authorization header", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Authentication required", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - return .success(nil) - } - - guard authHeader.lowercased().hasPrefix("bearer "), let tokenStore else { - let rateLimitResult = await recordAuthFailure(ip: remoteIP, rateLimiter: rateLimiter) - if case .rateLimited(let retryAfter) = rateLimitResult { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) - return .failure(.httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - )) - } - MCPAuditLogger.logAuthFailure(reason: "Invalid authorization header format", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Invalid authorization header", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - - let bearerToken = String(authHeader.dropFirst(7)) - - guard let token = await tokenStore.validate(bearerToken: bearerToken) else { - let rateLimitResult = await recordAuthFailure(ip: remoteIP, rateLimiter: rateLimiter) - if case .rateLimited(let retryAfter) = rateLimitResult { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) - return .failure(.httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - )) - } - MCPAuditLogger.logAuthFailure(reason: "Invalid token", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Invalid or expired token", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - - if let rateLimiter, let ip = remoteIP { - _ = await rateLimiter.checkAndRecord(ip: ip, success: true) - } - MCPAuditLogger.logAuthSuccess(tokenName: token.name, ip: remoteIP ?? "localhost") - return .success(token) - } - - @discardableResult - private func recordAuthFailure( - ip: String?, - rateLimiter: MCPRateLimiter? - ) async -> MCPRateLimiter.AuthRateResult? { - guard let rateLimiter, let ip else { return nil } - return await rateLimiter.checkAndRecord(ip: ip, success: false) - } - - private func isAllowedOrigin(_ origin: String) -> Bool { - guard let components = URLComponents(string: origin), - let host = components.host - else { - return false - } - let allowedHosts: Set = ["localhost", "127.0.0.1", "::1"] - return allowedHosts.contains(host) - } - - private func handleOptions() -> RouteResult { - .noContent - } - - private func handleGet(_ request: HTTPRequest, server: MCPServer) async -> RouteResult { - guard let sessionId = request.headers["mcp-session-id"] else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - - guard let session = await server.session(for: sessionId) else { - return .httpError(status: 404, message: "Session not found") - } - - await session.markActive() - return .sseStream(sessionId: session.id) - } - - private func handleDelete(_ request: HTTPRequest, server: MCPServer) async -> RouteResult { - guard let sessionId = request.headers["mcp-session-id"] else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - - guard await server.session(for: sessionId) != nil else { - return .httpError(status: 404, message: "Session not found") - } - - await server.removeSession(sessionId) - Self.logger.info("Session terminated via DELETE: \(sessionId)") - return .noContent - } - - private func handlePost( - _ request: HTTPRequest, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> RouteResult { - if let accept = request.headers["accept"], !accept.contains("application/json") && !accept.contains("*/*") { - return .httpError(status: 406, message: "Accept header must include application/json") - } - - guard let body = request.body else { - return encodeError(MCPError.parseError, id: nil) - } - - let rpcRequest: JSONRPCRequest - do { - rpcRequest = try decoder.decode(JSONRPCRequest.self, from: body) - } catch { - return encodeError(MCPError.parseError, id: nil) - } - - guard rpcRequest.jsonrpc == "2.0" else { - return encodeError(MCPError.invalidRequest("jsonrpc must be \"2.0\""), id: rpcRequest.id) - } - - if let protocolVersion = request.headers["mcp-protocol-version"], - protocolVersion != "2025-03-26" - { - Self.logger.warning("Client mcp-protocol-version mismatch: \(protocolVersion)") - } - - let headerSessionId = request.headers["mcp-session-id"] - return await dispatchMethod( - rpcRequest, - headerSessionId: headerSessionId, - server: server, - authenticatedToken: authenticatedToken - ) - } - - private func dispatchMethod( - _ request: JSONRPCRequest, - headerSessionId: String?, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> RouteResult { - if request.method == "initialize" { - return await handleInitialize(request, server: server, authenticatedToken: authenticatedToken) - } - - if request.method == "ping" { - return handlePing(request) - } - - guard let sessionId = headerSessionId else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - guard let session = await server.session(for: sessionId) else { - return .httpError(status: 404, message: "Session not found") - } - - await session.markActive() - - if request.method == "notifications/initialized" { - await session.setInitialized(true) - return .accepted - } - - if request.method == "notifications/cancelled" { - return await handleCancellation(request, session: session) - } - - guard await session.isInitialized else { - return encodeError( - MCPError.invalidRequest("Session not initialized. Send notifications/initialized first."), - id: request.id - ) - } - - switch request.method { - case "tools/list": - return handleToolsList(request, sessionId: sessionId) - - case "tools/call": - return await handleToolsCall( - request, - sessionId: sessionId, - server: server, - authenticatedToken: authenticatedToken - ) - - case "resources/list": - return handleResourcesList(request, sessionId: sessionId) - - case "resources/read": - return await handleResourcesRead(request, sessionId: sessionId, server: server) - - default: - return encodeError(MCPError.methodNotFound(request.method), id: request.id) - } - } - - private func handleInitialize( - _ request: JSONRPCRequest, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> RouteResult { - guard let session = await server.createSession() else { - return encodeError(MCPError.internalError("Maximum sessions reached"), id: request.id) - } - - if let params = request.params, - let clientInfo = params["clientInfo"], - let name = clientInfo["name"]?.stringValue - { - let version = clientInfo["version"]?.stringValue - await session.setClientInfo(MCPClientInfo(name: name, version: version)) - } - - if let token = authenticatedToken { - await session.setAuthenticatedTokenId(token.id) - await session.setTokenName(token.name) - } - - let result = MCPInitializeResult( - protocolVersion: "2025-03-26", - capabilities: MCPServerCapabilities( - tools: .init(listChanged: false), - resources: .init(subscribe: false, listChanged: false) - ), - serverInfo: MCPServerInfo(name: "tablepro", version: "1.0.0") - ) - - return encodeResult(result, id: request.id, sessionId: session.id) - } - - private func handlePing(_ request: JSONRPCRequest) -> RouteResult { - guard let id = request.id else { - return .accepted - } - return encodeRawResult(.object([:]), id: id, sessionId: nil) - } - - private func handleCancellation( - _ request: JSONRPCRequest, - session: MCPSession - ) async -> RouteResult { - guard let params = request.params, - let requestIdValue = params["requestId"] - else { - return .accepted - } - - let cancelId: JSONRPCId? - switch requestIdValue { - case .string(let s): - cancelId = .string(s) - case .int(let i): - cancelId = .int(i) - default: - cancelId = nil - } - - if let cancelId, let task = await session.removeRunningTask(cancelId) { - task.cancel() - Self.logger.info("Cancelled request \(String(describing: cancelId)) in session \(session.id)") - } - - return .accepted - } - - private func handleToolsList(_ request: JSONRPCRequest, sessionId: String) -> RouteResult { - guard let id = request.id else { - return .accepted - } - - let tools = Self.toolDefinitions() - let result: JSONValue = .object(["tools": encodeToolDefinitions(tools)]) - return encodeRawResult(result, id: id, sessionId: sessionId) - } - - private func handleToolsCall( - _ request: JSONRPCRequest, - sessionId: String, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> RouteResult { - guard let id = request.id else { - return encodeError(MCPError.invalidRequest("tools/call requires an id"), id: nil) - } - - guard let params = request.params, - let name = params["name"]?.stringValue - else { - return encodeError(MCPError.invalidParams("Missing tool name"), id: id) - } - - let arguments = params["arguments"] - - guard let handler = await server.toolCallHandler else { - return encodeError(MCPError.internalError("Server not fully initialized"), id: id) - } - - let session = await server.session(for: sessionId) - let toolTask = Task { - try await handler(name, arguments, sessionId, authenticatedToken) - } - if let session { - let cancelForwardingTask = Task { - await withTaskCancellationHandler { - _ = try? await toolTask.value - } onCancel: { - toolTask.cancel() - } - } - await session.addRunningTask(id, task: cancelForwardingTask) - } - - do { - let toolResult = try await toolTask.value - if let session { _ = await session.removeRunningTask(id) } - let resultData = try encoder.encode(toolResult) - guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { - return encodeError(MCPError.internalError("Failed to encode tool result"), id: id) - } - return encodeRawResult(resultValue, id: id, sessionId: sessionId) - } catch is CancellationError { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(MCPError.timeout("Request was cancelled"), id: id) - } catch let mcpError as MCPError { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(mcpError, id: id) - } catch { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(MCPError.internalError(error.localizedDescription), id: id) - } - } - - private func handleResourcesList(_ request: JSONRPCRequest, sessionId: String) -> RouteResult { - guard let id = request.id else { - return .accepted - } - - let resources = Self.resourceDefinitions() - let result: JSONValue = .object(["resources": encodeResourceDefinitions(resources)]) - return encodeRawResult(result, id: id, sessionId: sessionId) - } - - private func handleResourcesRead( - _ request: JSONRPCRequest, - sessionId: String, - server: MCPServer - ) async -> RouteResult { - guard let id = request.id else { - return encodeError(MCPError.invalidRequest("resources/read requires an id"), id: nil) - } - - guard let params = request.params, - let uri = params["uri"]?.stringValue - else { - return encodeError(MCPError.invalidParams("Missing resource uri"), id: id) - } - - guard let handler = await server.resourceReadHandler else { - return encodeError(MCPError.internalError("Server not fully initialized"), id: id) - } - - do { - let readResult = try await handler(uri, sessionId) - let resultData = try encoder.encode(readResult) - guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { - return encodeError(MCPError.internalError("Failed to encode resource result"), id: id) - } - return encodeRawResult(resultValue, id: id, sessionId: sessionId) - } catch let mcpError as MCPError { - return encodeError(mcpError, id: id) - } catch { - return encodeError(MCPError.internalError(error.localizedDescription), id: id) - } + return await route.handle(request) } - private func encodeResult(_ result: T, id: JSONRPCId?, sessionId: String?) -> RouteResult { - guard let id else { - return .accepted - } - - do { - let resultData = try encoder.encode(result) - let resultValue = try decoder.decode(JSONValue.self, from: resultData) - let response = JSONRPCResponse(id: id, result: resultValue) - let data = try encoder.encode(response) - return .json(data, sessionId: sessionId) - } catch { - Self.logger.error("Failed to encode response: \(error.localizedDescription)") - return encodeError(MCPError.internalError("Encoding failed"), id: id) + private func match(_ request: HTTPRequest) -> (any MCPRouteHandler)? { + let normalizedPath = Self.canonicalPath(request.path) + return routes.first { route in + route.path == normalizedPath && route.methods.contains(request.method) } } - private func encodeRawResult(_ result: JSONValue, id: JSONRPCId, sessionId: String?) -> RouteResult { - do { - let response = JSONRPCResponse(id: id, result: result) - let data = try encoder.encode(response) - return .json(data, sessionId: sessionId) - } catch { - Self.logger.error("Failed to encode response: \(error.localizedDescription)") - return encodeError(MCPError.internalError("Encoding failed"), id: id) + private static func canonicalPath(_ path: String) -> String { + if let queryIndex = path.firstIndex(of: "?") { + return String(path[.. RouteResult { - let errorResponse = error.toJsonRpcError(id: id) - do { - let data = try encoder.encode(errorResponse) - return .json(data, sessionId: nil) - } catch { - Self.logger.error("Failed to encode error response") - return .httpError(status: 500, message: "Internal encoding error") - } - } - - private func encodeToolDefinitions(_ tools: [MCPToolDefinition]) -> JSONValue { - .array(tools.map { tool in - .object([ - "name": .string(tool.name), - "description": .string(tool.description), - "inputSchema": tool.inputSchema - ]) - }) - } - - private func encodeResourceDefinitions(_ resources: [MCPResourceDefinition]) -> JSONValue { - .array(resources.map { resource in - var dict: [String: JSONValue] = [ - "uri": .string(resource.uri), - "name": .string(resource.name) - ] - if let description = resource.description { - dict["description"] = .string(description) - } - if let mimeType = resource.mimeType { - dict["mimeType"] = .string(mimeType) - } - return .object(dict) - }) + return path } } @@ -967,81 +459,6 @@ extension MCPRouter { } } -extension MCPRouter { - private struct ExchangeRequestBody: Decodable { - let code: String - let codeVerifier: String - - enum CodingKeys: String, CodingKey { - case code - case codeVerifier = "code_verifier" - } - } - - private struct ExchangeResponseBody: Encodable { - let token: String - } - - func handleIntegrationsExchange(_ request: HTTPRequest) async -> RouteResult { - if request.method == .options { - return .noContent - } - guard request.method == .post else { - return .httpError(status: 405, message: "Method not allowed") - } - - guard let body = request.body else { - return .httpError(status: 400, message: "Missing request body") - } - - let parsed: ExchangeRequestBody - do { - parsed = try decoder.decode(ExchangeRequestBody.self, from: body) - } catch { - return .httpError(status: 400, message: "Invalid JSON body") - } - - guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { - return .httpError(status: 400, message: "Missing code or code_verifier") - } - - let token: String - do { - token = try await MainActor.run { - try MCPPairingService.shared.exchange( - PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) - ) - } - } catch let mcpError as MCPError { - return mapExchangeError(mcpError) - } catch { - Self.logger.error("Pairing exchange failed: \(error.localizedDescription)") - return .httpError(status: 500, message: "Internal error") - } - - do { - let data = try encoder.encode(ExchangeResponseBody(token: token)) - return .json(data, sessionId: nil) - } catch { - Self.logger.error("Failed to encode exchange response: \(error.localizedDescription)") - return .httpError(status: 500, message: "Internal error") - } - } - - private func mapExchangeError(_ error: MCPError) -> RouteResult { - switch error { - case .notFound: - return .httpError(status: 404, message: "Pairing code not found") - case .expired: - return .httpError(status: 410, message: "Pairing code expired") - case .forbidden: - return .httpError(status: 403, message: "Challenge mismatch") - default: - return .httpError(status: 500, message: "Internal error") - } - } -} - extension MCPRouter { static func resourceDefinitions() -> [MCPResourceDefinition] { [ diff --git a/TablePro/Core/MCP/MCPServer.swift b/TablePro/Core/MCP/MCPServer.swift index 5d2a3b1c3..a499d6581 100644 --- a/TablePro/Core/MCP/MCPServer.swift +++ b/TablePro/Core/MCP/MCPServer.swift @@ -27,7 +27,7 @@ actor MCPServer { private var sessions: [String: MCPSession] = [:] private var cleanupTask: Task? private let stateCallback: @Sendable (MCPServerState) -> Void - private var router: MCPRouter! + private var router: MCPRouter? private(set) var tokenStore: MCPTokenStore? private(set) var rateLimiter: MCPRateLimiter? @@ -38,7 +38,10 @@ actor MCPServer { init(stateCallback: @escaping @Sendable (MCPServerState) -> Void) { self.stateCallback = stateCallback - self.router = MCPRouter() + } + + func setRouter(_ router: MCPRouter) { + self.router = router } func setTokenStore(_ store: MCPTokenStore) { @@ -129,10 +132,18 @@ actor MCPServer { cleanupTask?.cancel() cleanupTask = nil + let sessionIds = Array(sessions.keys) for (_, session) in sessions { await session.cancelAllTasks() await session.cancelSSEConnection() } + + if let cleanupHandler = sessionCleanupHandler { + for id in sessionIds { + await cleanupHandler(id) + } + } + sessions.removeAll() if let currentListener = listener { @@ -280,7 +291,7 @@ actor MCPServer { } } - private static let corsHeaders: [(String, String)] = [ + static let corsHeaders: [(String, String)] = [ ("Access-Control-Allow-Origin", "http://localhost"), ("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS"), ("Access-Control-Allow-Headers", "Content-Type, Mcp-Session-Id, mcp-protocol-version, Authorization"), @@ -295,7 +306,13 @@ actor MCPServer { return "\(host)" }() - let result = await router.route(request, server: self, remoteIP: remoteIP, tokenStore: tokenStore, rateLimiter: rateLimiter) + guard let router else { + sendHTTPError(connection: connection, status: 503, message: "Server not configured") + return + } + + let routedRequest = request.withRemoteIP(remoteIP) + let result = await router.handle(routedRequest) switch result { case .json(let data, let sessionId): @@ -342,6 +359,7 @@ actor MCPServer { guard let session = sessions.removeValue(forKey: sessionId) else { return } await session.cancelAllTasks() await session.cancelSSEConnection() + try? await session.transition(to: .terminated(reason: .removed)) if let cleanupHandler = sessionCleanupHandler { await cleanupHandler(sessionId) @@ -371,6 +389,7 @@ actor MCPServer { if idle > .seconds(Self.idleTimeout) { await session.cancelAllTasks() await session.cancelSSEConnection() + try? await session.transition(to: .terminated(reason: .idleTimeout)) sessions.removeValue(forKey: id) if let cleanupHandler = sessionCleanupHandler { diff --git a/TablePro/Core/MCP/MCPServerManager.swift b/TablePro/Core/MCP/MCPServerManager.swift index f028ac932..a51f0da54 100644 --- a/TablePro/Core/MCP/MCPServerManager.swift +++ b/TablePro/Core/MCP/MCPServerManager.swift @@ -61,9 +61,9 @@ final class MCPServerManager { let rateLimiter = MCPRateLimiter() let bridge = MCPConnectionBridge() - let authGuard = MCPAuthGuard() - let toolHandler = MCPToolHandler(bridge: bridge, authGuard: authGuard) - let resourceHandler = MCPResourceHandler(bridge: bridge, authGuard: authGuard) + let authPolicy = MCPAuthPolicy() + let toolHandler = MCPToolHandler(bridge: bridge, authPolicy: authPolicy) + let resourceHandler = MCPResourceHandler(bridge: bridge, authPolicy: authPolicy) await newServer.setTokenStore(newTokenStore) await newServer.setRateLimiter(rateLimiter) @@ -75,9 +75,18 @@ final class MCPServerManager { try await resourceHandler.handleResourceRead(uri: uri, sessionId: sessionId) } await newServer.setSessionCleanupHandler { sessionId in - await authGuard.clearSession(sessionId) + await authPolicy.clearSession(sessionId) } + let protocolHandler = MCPProtocolHandler( + server: newServer, + tokenStore: newTokenStore, + rateLimiter: rateLimiter + ) + let exchangeHandler = IntegrationsExchangeHandler.live() + let router = MCPRouter(routes: [protocolHandler, exchangeHandler]) + await newServer.setRouter(router) + let bridgeResult = await newTokenStore.generate( name: "__stdio_bridge__", permissions: .fullAccess diff --git a/TablePro/Core/MCP/MCPSession.swift b/TablePro/Core/MCP/MCPSession.swift index 66933e08c..a52295cad 100644 --- a/TablePro/Core/MCP/MCPSession.swift +++ b/TablePro/Core/MCP/MCPSession.swift @@ -6,15 +6,23 @@ actor MCPSession { let createdAt: ContinuousClock.Instant var lastActivityAt: ContinuousClock.Instant - var isInitialized: Bool = false + private(set) var phase: MCPSessionPhase = .created var clientInfo: MCPClientInfo? var sseConnection: NWConnection? var runningTasks: [JSONRPCId: Task] = [:] private(set) var eventCounter: Int = 0 - private(set) var authenticatedTokenId: UUID? - private(set) var tokenName: String? private(set) var remoteAddress: String? + var authenticatedTokenId: UUID? { + if case .active(let tokenId, _) = phase { return tokenId } + return nil + } + + var tokenName: String? { + if case .active(_, let tokenName) = phase { return tokenName } + return nil + } + init() { self.id = UUID().uuidString let now = ContinuousClock.now @@ -33,20 +41,31 @@ actor MCPSession { runningTasks.removeAll() } - func setInitialized(_ value: Bool) { - isInitialized = value - } - - func setClientInfo(_ info: MCPClientInfo?) { - clientInfo = info + func transition(to next: MCPSessionPhase) throws { + guard isValidTransition(from: phase, to: next) else { + throw MCPError.invalidRequest( + "Invalid session phase transition from \(phase) to \(next)" + ) + } + phase = next } - func setAuthenticatedTokenId(_ id: UUID?) { - authenticatedTokenId = id + private func isValidTransition(from current: MCPSessionPhase, to next: MCPSessionPhase) -> Bool { + switch (current, next) { + case (.created, .initializing), + (.created, .active), + (.created, .terminated), + (.initializing, .active), + (.initializing, .terminated), + (.active, .terminated): + return true + default: + return false + } } - func setTokenName(_ name: String?) { - tokenName = name + func setClientInfo(_ info: MCPClientInfo?) { + clientInfo = info } func setRemoteAddress(_ address: String?) { diff --git a/TablePro/Core/MCP/MCPSessionPhase.swift b/TablePro/Core/MCP/MCPSessionPhase.swift new file mode 100644 index 000000000..fa502c9f5 --- /dev/null +++ b/TablePro/Core/MCP/MCPSessionPhase.swift @@ -0,0 +1,20 @@ +import Foundation + +enum MCPSessionTerminationReason: Sendable, Equatable { + case removed + case idleTimeout + case serverStopped + case clientDisconnected +} + +enum MCPSessionPhase: Sendable, Equatable { + case created + case initializing + case active(tokenId: UUID?, tokenName: String?) + case terminated(reason: MCPSessionTerminationReason) + + var isActive: Bool { + if case .active = self { return true } + return false + } +} diff --git a/TablePro/Core/MCP/MCPTokenStore.swift b/TablePro/Core/MCP/MCPTokenStore.swift index b786764a3..cff7e3d89 100644 --- a/TablePro/Core/MCP/MCPTokenStore.swift +++ b/TablePro/Core/MCP/MCPTokenStore.swift @@ -3,6 +3,25 @@ import Foundation import os import Security +enum ConnectionAccess: Sendable, Codable, Equatable { + case all + case limited(Set) + + var allowedIds: Set? { + switch self { + case .all: return nil + case .limited(let ids): return ids + } + } + + func allows(_ connectionId: UUID) -> Bool { + switch self { + case .all: return true + case .limited(let ids): return ids.contains(connectionId) + } + } +} + struct MCPAuthToken: Codable, Identifiable, Sendable { let id: UUID let name: String @@ -10,7 +29,7 @@ struct MCPAuthToken: Codable, Identifiable, Sendable { let tokenHash: String let salt: String let permissions: TokenPermissions - let allowedConnectionIds: Set? + let connectionAccess: ConnectionAccess let createdAt: Date var lastUsedAt: Date? let expiresAt: Date? @@ -22,6 +41,84 @@ struct MCPAuthToken: Codable, Identifiable, Sendable { } var isEffectivelyActive: Bool { isActive && !isExpired } + + init( + id: UUID, + name: String, + prefix: String, + tokenHash: String, + salt: String, + permissions: TokenPermissions, + connectionAccess: ConnectionAccess, + createdAt: Date, + lastUsedAt: Date?, + expiresAt: Date?, + isActive: Bool + ) { + self.id = id + self.name = name + self.prefix = prefix + self.tokenHash = tokenHash + self.salt = salt + self.permissions = permissions + self.connectionAccess = connectionAccess + self.createdAt = createdAt + self.lastUsedAt = lastUsedAt + self.expiresAt = expiresAt + self.isActive = isActive + } + + private enum CodingKeys: String, CodingKey { + case id + case name + case prefix + case tokenHash + case salt + case permissions + case connectionAccess + case allowedConnectionIds + case createdAt + case lastUsedAt + case expiresAt + case isActive + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.id = try container.decode(UUID.self, forKey: .id) + self.name = try container.decode(String.self, forKey: .name) + self.prefix = try container.decode(String.self, forKey: .prefix) + self.tokenHash = try container.decode(String.self, forKey: .tokenHash) + self.salt = try container.decode(String.self, forKey: .salt) + self.permissions = try container.decode(TokenPermissions.self, forKey: .permissions) + self.createdAt = try container.decode(Date.self, forKey: .createdAt) + self.lastUsedAt = try container.decodeIfPresent(Date.self, forKey: .lastUsedAt) + self.expiresAt = try container.decodeIfPresent(Date.self, forKey: .expiresAt) + self.isActive = try container.decode(Bool.self, forKey: .isActive) + + if let access = try container.decodeIfPresent(ConnectionAccess.self, forKey: .connectionAccess) { + self.connectionAccess = access + } else if let legacyIds = try container.decodeIfPresent(Set.self, forKey: .allowedConnectionIds) { + self.connectionAccess = .limited(legacyIds) + } else { + self.connectionAccess = .all + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(id, forKey: .id) + try container.encode(name, forKey: .name) + try container.encode(prefix, forKey: .prefix) + try container.encode(tokenHash, forKey: .tokenHash) + try container.encode(salt, forKey: .salt) + try container.encode(permissions, forKey: .permissions) + try container.encode(connectionAccess, forKey: .connectionAccess) + try container.encode(createdAt, forKey: .createdAt) + try container.encodeIfPresent(lastUsedAt, forKey: .lastUsedAt) + try container.encodeIfPresent(expiresAt, forKey: .expiresAt) + try container.encode(isActive, forKey: .isActive) + } } enum TokenPermissions: String, Codable, Sendable, CaseIterable, Identifiable { @@ -72,7 +169,7 @@ actor MCPTokenStore { func generate( name: String, permissions: TokenPermissions, - allowedConnectionIds: Set? = nil, + connectionAccess: ConnectionAccess = .all, expiresAt: Date? = nil ) -> (token: MCPAuthToken, plaintext: String) { let key = SymmetricKey(size: .bits256) @@ -93,7 +190,7 @@ actor MCPTokenStore { tokenHash: hash, salt: saltBase64, permissions: permissions, - allowedConnectionIds: allowedConnectionIds, + connectionAccess: connectionAccess, createdAt: Date.now, lastUsedAt: nil, expiresAt: expiresAt, diff --git a/TablePro/Core/MCP/MCPToolHandler+Integrations.swift b/TablePro/Core/MCP/MCPToolHandler+Integrations.swift index 4e240334d..4724b646d 100644 --- a/TablePro/Core/MCP/MCPToolHandler+Integrations.swift +++ b/TablePro/Core/MCP/MCPToolHandler+Integrations.swift @@ -9,16 +9,21 @@ import Foundation extension MCPToolHandler { // MARK: - list_recent_tabs - func handleListRecentTabs(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + func handleListRecentTabs(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let limit = optionalInt(args, key: "limit", default: 20, clamp: 1...500) + if let token, !token.permissions.satisfies(.readOnly) { + throw MCPError.forbidden( + "Token '\(token.name)' with permission '\(token.permissions.displayName)' cannot access 'list_recent_tabs'" + ) + } + let snapshots = await MainActor.run { Self.collectTabSnapshots() } let blockedConnectionIds = await MainActor.run { Self.blockedExternalConnectionIds() } - let allowed = token?.allowedConnectionIds + let access = token?.connectionAccess ?? .all let filtered = snapshots.filter { snapshot in guard !blockedConnectionIds.contains(snapshot.connectionId) else { return false } - if let allowed, !allowed.contains(snapshot.connectionId) { return false } - return true + return access.allows(snapshot.connectionId) } let trimmed = Array(filtered.prefix(limit)) @@ -51,7 +56,7 @@ extension MCPToolHandler { // MARK: - search_query_history - func handleSearchQueryHistory(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + func handleSearchQueryHistory(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let query = try requireString(args, key: "query") let connectionIdString = optionalString(args, key: "connection_id") let limit = optionalInt(args, key: "limit", default: 50, clamp: 1...500) @@ -62,6 +67,12 @@ extension MCPToolHandler { throw MCPError.invalidParams("'since' must be less than or equal to 'until'") } + if let token, !token.permissions.satisfies(.readOnly) { + throw MCPError.forbidden( + "Token '\(token.name)' with permission '\(token.permissions.displayName)' cannot access 'search_query_history'" + ) + } + let blockedConnectionIds = await MainActor.run { Self.blockedExternalConnectionIds() } let connectionId: UUID? @@ -69,7 +80,9 @@ extension MCPToolHandler { guard let parsed = UUID(uuidString: connectionIdString) else { throw MCPError.invalidParams("Invalid UUID for parameter: connection_id") } - if let token { try checkTokenConnectionAccess(token, connectionId: parsed) } + if let token, !token.connectionAccess.allows(parsed) { + throw MCPError.forbidden("Token does not have access to this connection") + } if blockedConnectionIds.contains(parsed) { throw MCPError.forbidden( String(localized: "External access is disabled for this connection") @@ -119,11 +132,15 @@ extension MCPToolHandler { // MARK: - open_connection_window - func handleOpenConnectionWindow(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + func handleOpenConnectionWindow(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } try await ensureConnectionExists(connectionId) - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readOnly) + try await authPolicy.resolveAndAuthorize( + token: token ?? Self.anonymousFullAccessToken, + tool: "open_connection_window", + connectionId: connectionId, + sessionId: sessionId + ) let windowId = await MainActor.run { () -> UUID in let payload = EditorTabPayload( @@ -146,15 +163,19 @@ extension MCPToolHandler { // MARK: - open_table_tab - func handleOpenTableTab(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + func handleOpenTableTab(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") let tableName = try requireString(args, key: "table_name") let databaseName = optionalString(args, key: "database_name") let schemaName = optionalString(args, key: "schema_name") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } try await ensureConnectionExists(connectionId) - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readOnly) + try await authPolicy.resolveAndAuthorize( + token: token ?? Self.anonymousFullAccessToken, + tool: "open_table_tab", + connectionId: connectionId, + sessionId: sessionId + ) let windowId = await MainActor.run { () -> UUID in let payload = EditorTabPayload( @@ -181,7 +202,7 @@ extension MCPToolHandler { // MARK: - focus_query_tab - func handleFocusQueryTab(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + func handleFocusQueryTab(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let tabId = try requireUUID(args, key: "tab_id") let resolved = await MainActor.run { () -> (hasWindow: Bool, windowId: UUID?, connectionId: UUID?)? in @@ -198,8 +219,12 @@ extension MCPToolHandler { guard let connectionId = resolved.connectionId else { throw MCPError.notFound("connection") } - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readOnly) + try await authPolicy.resolveAndAuthorize( + token: token ?? Self.anonymousFullAccessToken, + tool: "focus_query_tab", + connectionId: connectionId, + sessionId: sessionId + ) let raised = await MainActor.run { () -> Bool in for snapshot in Self.collectTabSnapshots() where snapshot.tabId == tabId { @@ -237,8 +262,8 @@ extension MCPToolHandler { if scopedConnectionId != nil { return nil } - if let tokenAllowed = token?.allowedConnectionIds { - return tokenAllowed.subtracting(blockedConnectionIds) + if let access = token?.connectionAccess, case .limited(let allowed) = access { + return allowed.subtracting(blockedConnectionIds) } guard !blockedConnectionIds.isEmpty else { return nil } let allConnectionIds = await MainActor.run { @@ -290,6 +315,7 @@ extension MCPToolHandler { let connections = ConnectionStorage.shared.loadConnections() return Set(connections.filter { $0.externalAccess == .blocked }.map(\.id)) } + } struct TabSnapshot { diff --git a/TablePro/Core/MCP/MCPToolHandler.swift b/TablePro/Core/MCP/MCPToolHandler.swift index cdf135fb7..3e1c9a77d 100644 --- a/TablePro/Core/MCP/MCPToolHandler.swift +++ b/TablePro/Core/MCP/MCPToolHandler.swift @@ -4,12 +4,12 @@ import os final class MCPToolHandler: Sendable { private static let logger = Logger(subsystem: "com.TablePro", category: "MCPToolHandler") - private let bridge: MCPConnectionBridge - let authGuard: MCPAuthGuard + let bridge: MCPConnectionBridge + let authPolicy: MCPAuthPolicy - init(bridge: MCPConnectionBridge, authGuard: MCPAuthGuard) { + init(bridge: MCPConnectionBridge, authPolicy: MCPAuthPolicy) { self.bridge = bridge - self.authGuard = authGuard + self.authPolicy = authPolicy } func handleToolCall( @@ -18,10 +18,6 @@ final class MCPToolHandler: Sendable { sessionId: String, token: MCPAuthToken? = nil ) async throws -> MCPToolResult { - if let token { - try checkTokenToolPermission(token, toolName: name) - } - do { let result = try await dispatchTool( name: name, @@ -58,9 +54,9 @@ final class MCPToolHandler: Sendable { case "connect": return try await handleConnect(arguments, sessionId: sessionId, token: token) case "disconnect": - return try await handleDisconnect(arguments, token: token) + return try await handleDisconnect(arguments, sessionId: sessionId, token: token) case "get_connection_status": - return try await handleGetConnectionStatus(arguments, token: token) + return try await handleGetConnectionStatus(arguments, sessionId: sessionId, token: token) case "execute_query": return try await handleExecuteQuery(arguments, sessionId: sessionId, token: token) case "list_tables": @@ -82,15 +78,15 @@ final class MCPToolHandler: Sendable { case "switch_schema": return try await handleSwitchSchema(arguments, sessionId: sessionId, token: token) case "list_recent_tabs": - return try await handleListRecentTabs(arguments, token: token) + return try await handleListRecentTabs(arguments, sessionId: sessionId, token: token) case "search_query_history": - return try await handleSearchQueryHistory(arguments, token: token) + return try await handleSearchQueryHistory(arguments, sessionId: sessionId, token: token) case "open_connection_window": - return try await handleOpenConnectionWindow(arguments, token: token) + return try await handleOpenConnectionWindow(arguments, sessionId: sessionId, token: token) case "open_table_tab": - return try await handleOpenTableTab(arguments, token: token) + return try await handleOpenTableTab(arguments, sessionId: sessionId, token: token) case "focus_query_tab": - return try await handleFocusQueryTab(arguments, token: token) + return try await handleFocusQueryTab(arguments, sessionId: sessionId, token: token) default: throw MCPError.methodNotFound(name) } @@ -114,33 +110,35 @@ final class MCPToolHandler: Sendable { ) } - private func checkTokenToolPermission(_ token: MCPAuthToken, toolName: String) throws { - let required = minimumPermission(for: toolName) - guard token.permissions.satisfies(required) else { - throw MCPError.forbidden( - "Token '\(token.name)' with permission '\(token.permissions.displayName)' " - + "cannot access '\(toolName)'" - ) - } - } - - private func minimumPermission(for toolName: String) -> TokenPermissions { - switch toolName { - case "confirm_destructive_operation": - return .fullAccess - case "switch_database", "switch_schema", "export_data": - return .readWrite - default: - return .readOnly - } + private func authorize( + token: MCPAuthToken?, + tool: String, + connectionId: UUID?, + sql: String? = nil, + sessionId: String + ) async throws { + try await authPolicy.resolveAndAuthorize( + token: token ?? Self.anonymousFullAccessToken, + tool: tool, + connectionId: connectionId, + sql: sql, + sessionId: sessionId + ) } - func checkTokenConnectionAccess(_ token: MCPAuthToken, connectionId: UUID) throws { - guard let allowed = token.allowedConnectionIds else { return } - guard allowed.contains(connectionId) else { - throw MCPError.forbidden("Token does not have access to this connection") - } - } + static let anonymousFullAccessToken: MCPAuthToken = MCPAuthToken( + id: UUID(), + name: "__anonymous__", + prefix: "tp_anon", + tokenHash: "", + salt: "", + permissions: .fullAccess, + connectionAccess: .all, + createdAt: Date.now, + lastUsedAt: nil, + expiresAt: nil, + isActive: true + ) private func handleListConnections(token: MCPAuthToken?) async throws -> MCPToolResult { let result = await bridge.listConnections() @@ -149,7 +147,9 @@ final class MCPToolHandler: Sendable { } private func filterConnectionsByToken(_ value: JSONValue, token: MCPAuthToken?) -> JSONValue { - guard let allowed = token?.allowedConnectionIds else { return value } + guard let access = token?.connectionAccess, case .limited(let allowed) = access else { + return value + } guard case .object(var dict) = value, let entries = dict["connections"]?.arrayValue else { @@ -169,23 +169,22 @@ final class MCPToolHandler: Sendable { private func handleConnect(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "connect", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.connect(connectionId: connectionId) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) } - private func handleDisconnect(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + private func handleDisconnect(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } + try await authorize(token: token, tool: "disconnect", connectionId: connectionId, sessionId: sessionId) try await bridge.disconnect(connectionId: connectionId) let result: JSONValue = .object(["status": "disconnected"]) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) } - private func handleGetConnectionStatus(_ args: JSONValue?, token: MCPAuthToken?) async throws -> MCPToolResult { + private func handleGetConnectionStatus(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } + try await authorize(token: token, tool: "get_connection_status", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.getConnectionStatus(connectionId: connectionId) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) } @@ -207,8 +206,13 @@ final class MCPToolHandler: Sendable { throw MCPError.invalidParams("Multi-statement queries are not supported. Send one statement at a time.") } - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize( + token: token, + tool: "execute_query", + connectionId: connectionId, + sql: query, + sessionId: sessionId + ) let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId) @@ -228,11 +232,21 @@ final class MCPToolHandler: Sendable { + "Use the confirm_destructive_operation tool instead." ) - case .write, .safe: - if let token { - try checkTokenQueryTierPermission(token, tier: tier) + case .write: + if let token, !token.permissions.satisfies(.readWrite) { + throw MCPError.forbidden( + "Token '\(token.name)' with '\(token.permissions.displayName)' permission cannot execute write queries" + ) } - try await authGuard.checkQueryPermission( + try await authPolicy.checkSafeModeDialog( + sql: query, + connectionId: connectionId, + databaseType: databaseType, + safeModeLevel: safeModeLevel + ) + + case .safe: + try await authPolicy.checkSafeModeDialog( sql: query, connectionId: connectionId, databaseType: databaseType, @@ -252,25 +266,6 @@ final class MCPToolHandler: Sendable { return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) } - private func checkTokenQueryTierPermission(_ token: MCPAuthToken, tier: QueryTier) throws { - switch tier { - case .safe: - return - case .write: - guard token.permissions.satisfies(.readWrite) else { - throw MCPError.forbidden( - "Token '\(token.name)' with '\(token.permissions.displayName)' permission cannot execute write queries" - ) - } - case .destructive: - guard token.permissions == .fullAccess else { - throw MCPError.forbidden( - "Token '\(token.name)' with '\(token.permissions.displayName)' permission cannot execute destructive queries" - ) - } - } - } - private func handleConfirmDestructiveOperation( _ args: JSONValue?, sessionId: String, @@ -292,8 +287,13 @@ final class MCPToolHandler: Sendable { ) } - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize( + token: token, + tool: "confirm_destructive_operation", + connectionId: connectionId, + sql: query, + sessionId: sessionId + ) let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId) @@ -305,7 +305,7 @@ final class MCPToolHandler: Sendable { ) } - try await authGuard.checkQueryPermission( + try await authPolicy.checkSafeModeDialog( sql: query, connectionId: connectionId, databaseType: databaseType, @@ -333,8 +333,7 @@ final class MCPToolHandler: Sendable { let database = optionalString(args, key: "database") let schema = optionalString(args, key: "schema") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "list_tables", connectionId: connectionId, sessionId: sessionId) if let database { _ = try await bridge.switchDatabase(connectionId: connectionId, database: database) @@ -352,8 +351,7 @@ final class MCPToolHandler: Sendable { let table = try requireString(args, key: "table") let schema = optionalString(args, key: "schema") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "describe_table", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.describeTable(connectionId: connectionId, table: table, schema: schema) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) @@ -361,8 +359,7 @@ final class MCPToolHandler: Sendable { private func handleListDatabases(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { let connectionId = try requireUUID(args, key: "connection_id") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "list_databases", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.listDatabases(connectionId: connectionId) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) } @@ -371,8 +368,7 @@ final class MCPToolHandler: Sendable { let connectionId = try requireUUID(args, key: "connection_id") let database = optionalString(args, key: "database") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "list_schemas", connectionId: connectionId, sessionId: sessionId) if let database { _ = try await bridge.switchDatabase(connectionId: connectionId, database: database) @@ -387,8 +383,7 @@ final class MCPToolHandler: Sendable { let table = try requireString(args, key: "table") let schema = optionalString(args, key: "schema") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) + try await authorize(token: token, tool: "get_table_ddl", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.getTableDDL(connectionId: connectionId, table: table, schema: schema) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) @@ -420,15 +415,19 @@ final class MCPToolHandler: Sendable { _ = try Self.sandboxedDownloadsURL(for: outputPath) } - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readWrite) + try await authorize( + token: token, + tool: "export_data", + connectionId: connectionId, + sql: query, + sessionId: sessionId + ) let (databaseType, safeModeLevel, _) = try await resolveConnectionMeta(connectionId) var queries: [(label: String, sql: String)] = [] if let query { - try await authGuard.checkQueryPermission( + try await authPolicy.checkSafeModeDialog( sql: query, connectionId: connectionId, databaseType: databaseType, @@ -440,7 +439,7 @@ final class MCPToolHandler: Sendable { for table in tables { let quoted = Self.quoteQualifiedIdentifier(table, quoter: quoteIdentifier) let sql = "SELECT * FROM \(quoted) LIMIT \(maxRows)" - try await authGuard.checkQueryPermission( + try await authPolicy.checkSafeModeDialog( sql: sql, connectionId: connectionId, databaseType: databaseType, @@ -526,9 +525,7 @@ final class MCPToolHandler: Sendable { let connectionId = try requireUUID(args, key: "connection_id") let database = try requireString(args, key: "database") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readWrite) + try await authorize(token: token, tool: "switch_database", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.switchDatabase(connectionId: connectionId, database: database) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) @@ -538,9 +535,7 @@ final class MCPToolHandler: Sendable { let connectionId = try requireUUID(args, key: "connection_id") let schema = try requireString(args, key: "schema") - if let token { try checkTokenConnectionAccess(token, connectionId: connectionId) } - try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId) - try await authGuard.checkExternalAccessLevel(connectionId: connectionId, requires: .readWrite) + try await authorize(token: token, tool: "switch_schema", connectionId: connectionId, sessionId: sessionId) let result = try await bridge.switchSchema(connectionId: connectionId, schema: schema) return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) @@ -564,7 +559,7 @@ final class MCPToolHandler: Sendable { ) let elapsed = Date().timeIntervalSince(startTime) let rowCount = result["row_count"]?.intValue ?? 0 - await authGuard.logQuery( + await authPolicy.logQuery( sql: query, connectionId: connectionId, databaseName: databaseName, @@ -585,7 +580,7 @@ final class MCPToolHandler: Sendable { return result } catch { let elapsed = Date().timeIntervalSince(startTime) - await authGuard.logQuery( + await authPolicy.logQuery( sql: query, connectionId: connectionId, databaseName: databaseName, @@ -646,13 +641,14 @@ final class MCPToolHandler: Sendable { private func resolveConnectionMeta(_ connectionId: UUID) async throws -> (DatabaseType, SafeModeLevel, String) { try await MainActor.run { - if let session = DatabaseManager.shared.activeSessions[connectionId] { + switch DatabaseManager.shared.connectionState(connectionId) { + case .live(_, let session): return (session.connection.type, session.connection.safeModeLevel, session.activeDatabase) - } - if let conn = ConnectionStorage.shared.loadConnections().first(where: { $0.id == connectionId }) { + case .stored(let conn): return (conn.type, conn.safeModeLevel, conn.database) + case .unknown: + throw MCPError.notConnected(connectionId) } - throw MCPError.notConnected(connectionId) } } diff --git a/TablePro/Core/MCP/PairingTypes.swift b/TablePro/Core/MCP/PairingTypes.swift new file mode 100644 index 000000000..c54345d24 --- /dev/null +++ b/TablePro/Core/MCP/PairingTypes.swift @@ -0,0 +1,14 @@ +import Foundation + +struct PairingRequest: Sendable, Equatable { + let clientName: String + let challenge: String + let redirectURL: URL + let requestedScopes: String? + let requestedConnectionIds: Set? +} + +struct PairingExchange: Sendable, Equatable { + let code: String + let verifier: String +} diff --git a/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift b/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift new file mode 100644 index 000000000..794494a02 --- /dev/null +++ b/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift @@ -0,0 +1,94 @@ +import Foundation +import os + +struct IntegrationsExchangeHandler: MCPRouteHandler { + private static let logger = Logger(subsystem: "com.TablePro", category: "IntegrationsExchangeHandler") + + private let exchange: @Sendable (PairingExchange) async throws -> String + + private let encoder: JSONEncoder + private let decoder: JSONDecoder + + var methods: [HTTPRequest.Method] { [.post] } + var path: String { "/v1/integrations/exchange" } + + init(exchange: @escaping @Sendable (PairingExchange) async throws -> String) { + self.exchange = exchange + let enc = JSONEncoder() + enc.outputFormatting = [.sortedKeys] + self.encoder = enc + self.decoder = JSONDecoder() + } + + static func live() -> IntegrationsExchangeHandler { + IntegrationsExchangeHandler { request in + try await MainActor.run { + try MCPPairingService.shared.exchange(request) + } + } + } + + func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { + guard let body = request.body else { + return .httpError(status: 400, message: "Missing request body") + } + + let parsed: ExchangeRequestBody + do { + parsed = try decoder.decode(ExchangeRequestBody.self, from: body) + } catch { + return .httpError(status: 400, message: "Invalid JSON body") + } + + guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { + return .httpError(status: 400, message: "Missing code or code_verifier") + } + + let token: String + do { + token = try await exchange( + PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) + ) + } catch let mcpError as MCPError { + return Self.mapExchangeError(mcpError) + } catch { + Self.logger.error("Pairing exchange failed: \(error.localizedDescription)") + return .httpError(status: 500, message: "Internal error") + } + + do { + let data = try encoder.encode(ExchangeResponseBody(token: token)) + return .json(data, sessionId: nil) + } catch { + Self.logger.error("Failed to encode exchange response: \(error.localizedDescription)") + return .httpError(status: 500, message: "Internal error") + } + } + + private static func mapExchangeError(_ error: MCPError) -> MCPRouter.RouteResult { + switch error { + case .notFound: + return .httpError(status: 404, message: "Pairing code not found") + case .expired: + return .httpError(status: 410, message: "Pairing code expired") + case .forbidden: + return .httpError(status: 403, message: "Challenge mismatch") + default: + return .httpError(status: 500, message: "Internal error") + } + } + + private struct ExchangeRequestBody: Decodable { + let code: String + let codeVerifier: String + + enum CodingKeys: String, CodingKey { + case code + case codeVerifier = "code_verifier" + } + } + + private struct ExchangeResponseBody: Encodable { + let token: String + } +} diff --git a/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift b/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift new file mode 100644 index 000000000..dac72ede6 --- /dev/null +++ b/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift @@ -0,0 +1,533 @@ +import Foundation +import os + +final class MCPProtocolHandler: MCPRouteHandler, @unchecked Sendable { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCPProtocolHandler") + + private weak var server: MCPServer? + private let tokenStore: MCPTokenStore? + private let rateLimiter: MCPRateLimiter? + + private let encoder: JSONEncoder + private let decoder: JSONDecoder + + var methods: [HTTPRequest.Method] { [.get, .post, .delete] } + var path: String { "/mcp" } + + init(server: MCPServer, tokenStore: MCPTokenStore?, rateLimiter: MCPRateLimiter?) { + self.server = server + self.tokenStore = tokenStore + self.rateLimiter = rateLimiter + let enc = JSONEncoder() + enc.outputFormatting = [.sortedKeys] + self.encoder = enc + self.decoder = JSONDecoder() + } + + func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { + guard let server else { + return .httpError(status: 503, message: "Server unavailable") + } + + if let rateLimiter, let ip = request.remoteIP { + let lockoutCheck = await rateLimiter.isLockedOut(ip: ip) + if case .rateLimited(let retryAfter) = lockoutCheck { + let seconds = Int(retryAfter.components.seconds) + MCPAuditLogger.logRateLimited(ip: ip, retryAfterSeconds: seconds) + return .httpErrorWithHeaders( + status: 429, + message: "Too many failed attempts", + extraHeaders: [("Retry-After", "\(seconds)")] + ) + } + } + + let authResult = await authenticateRequest(request) + + switch authResult { + case .failure(let result): + return result + case .success(let token): + if token == nil { + if let origin = request.headers["origin"], !isAllowedOrigin(origin) { + return .httpError(status: 403, message: "Forbidden origin") + } + } + + switch request.method { + case .post: + return await handlePost(request, server: server, authenticatedToken: token) + case .get: + return await handleGet(request, server: server) + case .delete: + return await handleDelete(request, server: server) + case .options: + return .noContent + } + } + } + + private enum AuthResult { + case success(MCPAuthToken?) + case failure(MCPRouter.RouteResult) + } + + private func authenticateRequest(_ request: HTTPRequest) async -> AuthResult { + let remoteIP = request.remoteIP + let authRequired = await MainActor.run { AppSettingsManager.shared.mcp.requireAuthentication } + + guard let authHeader = request.headers["authorization"] else { + guard !authRequired else { + MCPAuditLogger.logAuthFailure(reason: "Missing authorization header", ip: remoteIP ?? "localhost") + return .failure(.httpErrorWithHeaders( + status: 401, + message: "Authentication required", + extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] + )) + } + return .success(nil) + } + + guard authHeader.lowercased().hasPrefix("bearer "), let tokenStore else { + let rateLimitResult = await recordAuthFailure(ip: remoteIP) + if case .rateLimited(let retryAfter) = rateLimitResult { + let seconds = Int(retryAfter.components.seconds) + MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) + return .failure(.httpErrorWithHeaders( + status: 429, + message: "Too many failed attempts", + extraHeaders: [("Retry-After", "\(seconds)")] + )) + } + MCPAuditLogger.logAuthFailure(reason: "Invalid authorization header format", ip: remoteIP ?? "localhost") + return .failure(.httpErrorWithHeaders( + status: 401, + message: "Invalid authorization header", + extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] + )) + } + + let bearerToken = String(authHeader.dropFirst(7)) + + guard let token = await tokenStore.validate(bearerToken: bearerToken) else { + let rateLimitResult = await recordAuthFailure(ip: remoteIP) + if case .rateLimited(let retryAfter) = rateLimitResult { + let seconds = Int(retryAfter.components.seconds) + MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) + return .failure(.httpErrorWithHeaders( + status: 429, + message: "Too many failed attempts", + extraHeaders: [("Retry-After", "\(seconds)")] + )) + } + MCPAuditLogger.logAuthFailure(reason: "Invalid token", ip: remoteIP ?? "localhost") + return .failure(.httpErrorWithHeaders( + status: 401, + message: "Invalid or expired token", + extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] + )) + } + + if let rateLimiter, let ip = remoteIP { + _ = await rateLimiter.checkAndRecord(ip: ip, success: true) + } + MCPAuditLogger.logAuthSuccess(tokenName: token.name, ip: remoteIP ?? "localhost") + return .success(token) + } + + @discardableResult + private func recordAuthFailure(ip: String?) async -> MCPRateLimiter.AuthRateResult? { + guard let rateLimiter, let ip else { return nil } + return await rateLimiter.checkAndRecord(ip: ip, success: false) + } + + private func isAllowedOrigin(_ origin: String) -> Bool { + guard let components = URLComponents(string: origin), + let host = components.host + else { + return false + } + let allowedHosts: Set = ["localhost", "127.0.0.1", "::1"] + return allowedHosts.contains(host) + } + + private func handleGet(_ request: HTTPRequest, server: MCPServer) async -> MCPRouter.RouteResult { + guard let sessionId = request.headers["mcp-session-id"] else { + return .httpError(status: 400, message: "Missing Mcp-Session-Id header") + } + + guard let session = await server.session(for: sessionId) else { + return .httpError(status: 404, message: "Session not found") + } + + await session.markActive() + return .sseStream(sessionId: session.id) + } + + private func handleDelete(_ request: HTTPRequest, server: MCPServer) async -> MCPRouter.RouteResult { + guard let sessionId = request.headers["mcp-session-id"] else { + return .httpError(status: 400, message: "Missing Mcp-Session-Id header") + } + + guard await server.session(for: sessionId) != nil else { + return .httpError(status: 404, message: "Session not found") + } + + await server.removeSession(sessionId) + Self.logger.info("Session terminated via DELETE: \(sessionId)") + return .noContent + } + + private func handlePost( + _ request: HTTPRequest, + server: MCPServer, + authenticatedToken: MCPAuthToken? + ) async -> MCPRouter.RouteResult { + if let accept = request.headers["accept"], !accept.contains("application/json") && !accept.contains("*/*") { + return .httpError(status: 406, message: "Accept header must include application/json") + } + + guard let body = request.body else { + return encodeError(MCPError.parseError, id: nil) + } + + let rpcRequest: JSONRPCRequest + do { + rpcRequest = try decoder.decode(JSONRPCRequest.self, from: body) + } catch { + return encodeError(MCPError.parseError, id: nil) + } + + guard rpcRequest.jsonrpc == "2.0" else { + return encodeError(MCPError.invalidRequest("jsonrpc must be \"2.0\""), id: rpcRequest.id) + } + + if let protocolVersion = request.headers["mcp-protocol-version"], + protocolVersion != "2025-03-26" + { + Self.logger.warning("Client mcp-protocol-version mismatch: \(protocolVersion)") + } + + let headerSessionId = request.headers["mcp-session-id"] + return await dispatchMethod( + rpcRequest, + headerSessionId: headerSessionId, + server: server, + authenticatedToken: authenticatedToken + ) + } + + private func dispatchMethod( + _ request: JSONRPCRequest, + headerSessionId: String?, + server: MCPServer, + authenticatedToken: MCPAuthToken? + ) async -> MCPRouter.RouteResult { + if request.method == "initialize" { + return await handleInitialize(request, server: server) + } + + if request.method == "ping" { + return handlePing(request) + } + + guard let sessionId = headerSessionId else { + return .httpError(status: 400, message: "Missing Mcp-Session-Id header") + } + guard let session = await server.session(for: sessionId) else { + return .httpError(status: 404, message: "Session not found") + } + + await session.markActive() + + if request.method == "notifications/initialized" { + do { + try await session.transition(to: .active( + tokenId: authenticatedToken?.id, + tokenName: authenticatedToken?.name + )) + } catch { + return encodeError(MCPError.invalidRequest("Cannot initialize session in current phase"), id: request.id) + } + return .accepted + } + + if request.method == "notifications/cancelled" { + return await handleCancellation(request, session: session) + } + + guard await session.phase.isActive else { + return encodeError( + MCPError.invalidRequest("Session not initialized. Send notifications/initialized first."), + id: request.id + ) + } + + switch request.method { + case "tools/list": + return handleToolsList(request, sessionId: sessionId) + + case "tools/call": + return await handleToolsCall( + request, + sessionId: sessionId, + server: server, + authenticatedToken: authenticatedToken + ) + + case "resources/list": + return handleResourcesList(request, sessionId: sessionId) + + case "resources/read": + return await handleResourcesRead(request, sessionId: sessionId, server: server) + + default: + return encodeError(MCPError.methodNotFound(request.method), id: request.id) + } + } + + private func handleInitialize( + _ request: JSONRPCRequest, + server: MCPServer + ) async -> MCPRouter.RouteResult { + guard let session = await server.createSession() else { + return encodeError(MCPError.internalError("Maximum sessions reached"), id: request.id) + } + + if let params = request.params, + let clientInfo = params["clientInfo"], + let name = clientInfo["name"]?.stringValue + { + let version = clientInfo["version"]?.stringValue + await session.setClientInfo(MCPClientInfo(name: name, version: version)) + } + + do { + try await session.transition(to: .initializing) + } catch { + await server.removeSession(session.id) + return encodeError(MCPError.invalidRequest("Cannot initialize session"), id: request.id) + } + + let result = MCPInitializeResult( + protocolVersion: "2025-03-26", + capabilities: MCPServerCapabilities( + tools: .init(listChanged: false), + resources: .init(subscribe: false, listChanged: false) + ), + serverInfo: MCPServerInfo(name: "tablepro", version: "1.0.0") + ) + + return encodeResult(result, id: request.id, sessionId: session.id) + } + + private func handlePing(_ request: JSONRPCRequest) -> MCPRouter.RouteResult { + guard let id = request.id else { + return .accepted + } + return encodeRawResult(.object([:]), id: id, sessionId: nil) + } + + private func handleCancellation( + _ request: JSONRPCRequest, + session: MCPSession + ) async -> MCPRouter.RouteResult { + guard let params = request.params, + let requestIdValue = params["requestId"] + else { + return .accepted + } + + let cancelId: JSONRPCId? + switch requestIdValue { + case .string(let s): + cancelId = .string(s) + case .int(let i): + cancelId = .int(i) + default: + cancelId = nil + } + + if let cancelId, let task = await session.removeRunningTask(cancelId) { + task.cancel() + Self.logger.info("Cancelled request \(String(describing: cancelId)) in session \(session.id)") + } + + return .accepted + } + + private func handleToolsList(_ request: JSONRPCRequest, sessionId: String) -> MCPRouter.RouteResult { + guard let id = request.id else { + return .accepted + } + + let tools = MCPRouter.toolDefinitions() + let result: JSONValue = .object(["tools": encodeToolDefinitions(tools)]) + return encodeRawResult(result, id: id, sessionId: sessionId) + } + + private func handleToolsCall( + _ request: JSONRPCRequest, + sessionId: String, + server: MCPServer, + authenticatedToken: MCPAuthToken? + ) async -> MCPRouter.RouteResult { + guard let id = request.id else { + return encodeError(MCPError.invalidRequest("tools/call requires an id"), id: nil) + } + + guard let params = request.params, + let name = params["name"]?.stringValue + else { + return encodeError(MCPError.invalidParams("Missing tool name"), id: id) + } + + let arguments = params["arguments"] + + guard let handler = await server.toolCallHandler else { + return encodeError(MCPError.internalError("Server not fully initialized"), id: id) + } + + let session = await server.session(for: sessionId) + let toolTask = Task { + try await handler(name, arguments, sessionId, authenticatedToken) + } + if let session { + let cancelForwardingTask = Task { + await withTaskCancellationHandler { + _ = try? await toolTask.value + } onCancel: { + toolTask.cancel() + } + } + await session.addRunningTask(id, task: cancelForwardingTask) + } + + do { + let toolResult = try await toolTask.value + if let session { _ = await session.removeRunningTask(id) } + let resultData = try encoder.encode(toolResult) + guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { + return encodeError(MCPError.internalError("Failed to encode tool result"), id: id) + } + return encodeRawResult(resultValue, id: id, sessionId: sessionId) + } catch is CancellationError { + if let session { _ = await session.removeRunningTask(id) } + return encodeError(MCPError.timeout("Request was cancelled"), id: id) + } catch let mcpError as MCPError { + if let session { _ = await session.removeRunningTask(id) } + return encodeError(mcpError, id: id) + } catch { + if let session { _ = await session.removeRunningTask(id) } + return encodeError(MCPError.internalError(error.localizedDescription), id: id) + } + } + + private func handleResourcesList(_ request: JSONRPCRequest, sessionId: String) -> MCPRouter.RouteResult { + guard let id = request.id else { + return .accepted + } + + let resources = MCPRouter.resourceDefinitions() + let result: JSONValue = .object(["resources": encodeResourceDefinitions(resources)]) + return encodeRawResult(result, id: id, sessionId: sessionId) + } + + private func handleResourcesRead( + _ request: JSONRPCRequest, + sessionId: String, + server: MCPServer + ) async -> MCPRouter.RouteResult { + guard let id = request.id else { + return encodeError(MCPError.invalidRequest("resources/read requires an id"), id: nil) + } + + guard let params = request.params, + let uri = params["uri"]?.stringValue + else { + return encodeError(MCPError.invalidParams("Missing resource uri"), id: id) + } + + guard let handler = await server.resourceReadHandler else { + return encodeError(MCPError.internalError("Server not fully initialized"), id: id) + } + + do { + let readResult = try await handler(uri, sessionId) + let resultData = try encoder.encode(readResult) + guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { + return encodeError(MCPError.internalError("Failed to encode resource result"), id: id) + } + return encodeRawResult(resultValue, id: id, sessionId: sessionId) + } catch let mcpError as MCPError { + return encodeError(mcpError, id: id) + } catch { + return encodeError(MCPError.internalError(error.localizedDescription), id: id) + } + } + + private func encodeResult(_ result: T, id: JSONRPCId?, sessionId: String?) -> MCPRouter.RouteResult { + guard let id else { + return .accepted + } + + do { + let resultData = try encoder.encode(result) + let resultValue = try decoder.decode(JSONValue.self, from: resultData) + let response = JSONRPCResponse(id: id, result: resultValue) + let data = try encoder.encode(response) + return .json(data, sessionId: sessionId) + } catch { + Self.logger.error("Failed to encode response: \(error.localizedDescription)") + return encodeError(MCPError.internalError("Encoding failed"), id: id) + } + } + + private func encodeRawResult(_ result: JSONValue, id: JSONRPCId, sessionId: String?) -> MCPRouter.RouteResult { + do { + let response = JSONRPCResponse(id: id, result: result) + let data = try encoder.encode(response) + return .json(data, sessionId: sessionId) + } catch { + Self.logger.error("Failed to encode response: \(error.localizedDescription)") + return encodeError(MCPError.internalError("Encoding failed"), id: id) + } + } + + private func encodeError(_ error: MCPError, id: JSONRPCId?) -> MCPRouter.RouteResult { + let errorResponse = error.toJsonRpcError(id: id) + do { + let data = try encoder.encode(errorResponse) + return .json(data, sessionId: nil) + } catch { + Self.logger.error("Failed to encode error response") + return .httpError(status: 500, message: "Internal encoding error") + } + } + + private func encodeToolDefinitions(_ tools: [MCPToolDefinition]) -> JSONValue { + .array(tools.map { tool in + .object([ + "name": .string(tool.name), + "description": .string(tool.description), + "inputSchema": tool.inputSchema + ]) + }) + } + + private func encodeResourceDefinitions(_ resources: [MCPResourceDefinition]) -> JSONValue { + .array(resources.map { resource in + var dict: [String: JSONValue] = [ + "uri": .string(resource.uri), + "name": .string(resource.name) + ] + if let description = resource.description { + dict["description"] = .string(description) + } + if let mimeType = resource.mimeType { + dict["mimeType"] = .string(mimeType) + } + return .object(dict) + }) + } +} diff --git a/TablePro/Core/MCP/TokenPermissionFilter.swift b/TablePro/Core/MCP/TokenPermissionFilter.swift new file mode 100644 index 000000000..8c42600e6 --- /dev/null +++ b/TablePro/Core/MCP/TokenPermissionFilter.swift @@ -0,0 +1,47 @@ +import Foundation + +protocol ConnectionIdentifiable { + var connectionId: UUID { get } +} + +enum TokenPermissionFilter { + static let overfetchMultiplier = 3 + private static let maxRoundTrips = 2 + + static func filter(_ items: [T], by access: ConnectionAccess) -> [T] { + switch access { + case .all: + return items + case .limited(let ids): + return items.filter { ids.contains($0.connectionId) } + } + } + + static func fetchFiltered( + access: ConnectionAccess, + limit: Int, + fetch: (Int, Int) async throws -> [T] + ) async throws -> [T] { + if case .all = access { + let items = try await fetch(limit, 0) + return Array(items.prefix(limit)) + } + + guard limit > 0 else { return [] } + + let fetchLimit = limit * overfetchMultiplier + var collected: [T] = [] + var offset = 0 + + for _ in 0..= limit { break } + if raw.count < fetchLimit { break } + offset += fetchLimit + } + + return Array(collected.prefix(limit)) + } +} diff --git a/TablePro/Core/Services/Infrastructure/AppLaunchCoordinator.swift b/TablePro/Core/Services/Infrastructure/AppLaunchCoordinator.swift new file mode 100644 index 000000000..b49745142 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/AppLaunchCoordinator.swift @@ -0,0 +1,226 @@ +// +// AppLaunchCoordinator.swift +// TablePro +// + +import AppKit +import Foundation +import Observation +import os + +@MainActor +@Observable +internal final class AppLaunchCoordinator { + internal static let shared = AppLaunchCoordinator() + + private static let logger = Logger(subsystem: "com.TablePro", category: "AppLaunchCoordinator") + internal static let collectionWindow: Duration = .milliseconds(150) + + private(set) var phase: LaunchPhase = .launching + + private var pendingIntents: [LaunchIntent] = [] + private var deadlineTask: Task? + private var hasFinishedLaunching = false + + private init() {} + + // MARK: - App Lifecycle Hooks + + internal func didFinishLaunching() { + hasFinishedLaunching = true + let deadline = Date().addingTimeInterval(0.150) + phase = .collectingIntents(deadline: deadline) + deadlineTask = Task { [weak self] in + try? await Task.sleep(for: Self.collectionWindow) + await MainActor.run { + self?.transitionToRouting() + } + } + } + + internal func handleOpenURLs(_ urls: [URL]) { + let intents: [LaunchIntent] = urls.compactMap { url in + switch URLClassifier.classify(url) { + case .none: + Self.logger.warning("Unrecognized URL: \(url.sanitizedForLogging, privacy: .public)") + return nil + case .some(.failure(let error)): + Self.logger.error("URL parse failed: \(error.localizedDescription, privacy: .public) for \(url.sanitizedForLogging, privacy: .public)") + return nil + case .some(.success(let intent)): + return intent + } + } + deliver(intents) + } + + internal func handleHandoff(_ activity: NSUserActivity) { + guard let connectionIdString = activity.userInfo?["connectionId"] as? String, + let connectionId = UUID(uuidString: connectionIdString) else { return } + let table = activity.userInfo?["tableName"] as? String + + if let table { + deliver([.openTable( + connectionId: connectionId, + database: nil, + schema: nil, + table: table, + isView: false + )]) + } else { + deliver([.openConnection(connectionId)]) + } + } + + internal func handleReopen(hasVisibleWindows: Bool) -> Bool { + if hasVisibleWindows { return true } + showWelcomeWindow() + return false + } + + // MARK: - Phase Transitions + + private func deliver(_ intents: [LaunchIntent]) { + guard !intents.isEmpty else { return } + if phase.isAcceptingIntents { + pendingIntents.append(contentsOf: intents) + for window in NSApp.windows where Self.isWelcomeWindow(window) { + window.orderOut(nil) + } + } else { + Task { [weak self] in + guard let self else { return } + for intent in intents { + await LaunchIntentRouter.shared.route(intent) + } + } + } + } + + private func transitionToRouting() { + guard hasFinishedLaunching else { return } + phase = .routing + let intents = pendingIntents + pendingIntents.removeAll() + + Task { [weak self] in + guard let self else { return } + for intent in intents { + await LaunchIntentRouter.shared.route(intent) + } + self.runStartupBehaviorIfNeeded(skipping: intents) + self.phase = .ready + self.finalizeWindowsIfNoVisibleMain(intents: intents) + } + } + + private func runStartupBehaviorIfNeeded(skipping intents: [LaunchIntent]) { + guard intents.isEmpty else { + closeRestoredMainWindowsExcept(intents: intents) + return + } + let general = AppSettingsStorage.shared.loadGeneral() + guard general.startupBehavior == .reopenLast else { + closeRestoredMainWindowsExcept(intents: intents) + return + } + let openIds = AppSettingsStorage.shared.loadLastOpenConnectionIds() + if !openIds.isEmpty { + attemptAutoReconnect(connectionIds: openIds) + return + } + if let lastId = AppSettingsStorage.shared.loadLastConnectionId() { + attemptAutoReconnect(connectionIds: [lastId]) + return + } + Task { [weak self] in + let diskIds = await TabDiskActor.shared.connectionIdsWithSavedState() + if !diskIds.isEmpty { + self?.attemptAutoReconnect(connectionIds: diskIds) + } else { + self?.closeRestoredMainWindowsExcept(intents: []) + } + } + } + + private func finalizeWindowsIfNoVisibleMain(intents: [LaunchIntent]) { + guard intents.isEmpty else { return } + guard !NSApp.windows.contains(where: { Self.isMainWindow($0) && $0.isVisible }) else { return } + showWelcomeWindow() + } + + private func closeRestoredMainWindowsExcept(intents: [LaunchIntent]) { + let preserved = Set(intents.compactMap { $0.targetConnectionId }) + for window in NSApp.windows where Self.isMainWindow(window) { + if let id = WindowLifecycleMonitor.shared.connectionId(forWindow: window), + preserved.contains(id) { + continue + } + window.close() + } + } + + private func attemptAutoReconnect(connectionIds: [UUID]) { + let saved = ConnectionStorage.shared.loadConnections() + let valid = connectionIds.compactMap { id in + saved.first(where: { $0.id == id }) + } + guard !valid.isEmpty else { + AppSettingsStorage.shared.saveLastOpenConnectionIds([]) + AppSettingsStorage.shared.saveLastConnectionId(nil) + closeRestoredMainWindowsExcept(intents: []) + showWelcomeWindow() + return + } + for window in NSApp.windows where Self.isWelcomeWindow(window) { + window.orderOut(nil) + } + Task { [weak self] in + for connection in valid { + let payload = EditorTabPayload( + connectionId: connection.id, intent: .restoreOrDefault + ) + WindowManager.shared.openTab(payload: payload) + do { + try await DatabaseManager.shared.ensureConnected(connection) + } catch is CancellationError { + for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { + window.close() + } + } catch { + Self.logger.error("Auto-reconnect failed for '\(connection.name, privacy: .public)': \(error.localizedDescription, privacy: .public)") + for window in WindowLifecycleMonitor.shared.windows(for: connection.id) { + window.close() + } + } + } + for window in NSApp.windows where Self.isWelcomeWindow(window) { + window.close() + } + if !NSApp.windows.contains(where: { Self.isMainWindow($0) && $0.isVisible }) { + self?.showWelcomeWindow() + } + } + } + + // MARK: - Window Identification + + internal static func isMainWindow(_ window: NSWindow) -> Bool { + guard let raw = window.identifier?.rawValue else { return false } + return raw == "main" || raw.hasPrefix("main-") + } + + internal static func isWelcomeWindow(_ window: NSWindow) -> Bool { + guard let raw = window.identifier?.rawValue else { return false } + return raw == "welcome" || raw.hasPrefix("welcome-") + } + + internal static func isConnectionFormWindow(_ window: NSWindow) -> Bool { + guard let raw = window.identifier?.rawValue else { return false } + return raw == "connection-form" || raw.hasPrefix("connection-form-") + } + + private func showWelcomeWindow() { + WelcomeWindowFactory.openOrFront() + } +} diff --git a/TablePro/Core/Services/Infrastructure/AppNotifications.swift b/TablePro/Core/Services/Infrastructure/AppNotifications.swift index ca04e402a..fa88c5b6a 100644 --- a/TablePro/Core/Services/Infrastructure/AppNotifications.swift +++ b/TablePro/Core/Services/Infrastructure/AppNotifications.swift @@ -19,8 +19,6 @@ extension Notification.Name { static let connectionUpdated = Notification.Name("connectionUpdated") static let connectionStatusDidChange = Notification.Name("connectionStatusDidChange") static let databaseDidConnect = Notification.Name("databaseDidConnect") - static let connectionShareFileOpened = Notification.Name("connectionShareFileOpened") - static let deeplinkImportRequested = Notification.Name("deeplinkImportRequested") static let exportConnections = Notification.Name("exportConnections") static let importConnections = Notification.Name("importConnections") static let importConnectionsFromApp = Notification.Name("importConnectionsFromApp") diff --git a/TablePro/Core/Services/Infrastructure/ConnectionFormWindowFactory.swift b/TablePro/Core/Services/Infrastructure/ConnectionFormWindowFactory.swift new file mode 100644 index 000000000..4c55f775e --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/ConnectionFormWindowFactory.swift @@ -0,0 +1,60 @@ +// +// ConnectionFormWindowFactory.swift +// TablePro +// + +import AppKit +import SwiftUI + +@MainActor +internal enum ConnectionFormWindowFactory { + private static let baseIdentifier = "connection-form" + + internal static func openOrFront(connectionId: UUID? = nil) { + if let existing = existingWindow(for: connectionId) { + existing.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + return + } + let window = makeWindow(connectionId: connectionId) + window.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + } + + internal static func close(connectionId: UUID? = nil) { + existingWindow(for: connectionId)?.close() + } + + internal static func closeAll() { + for window in NSApp.windows where AppLaunchCoordinator.isConnectionFormWindow(window) { + window.close() + } + } + + private static func existingWindow(for connectionId: UUID?) -> NSWindow? { + let target = identifier(for: connectionId) + return NSApp.windows.first { $0.identifier?.rawValue == target } + } + + private static func identifier(for connectionId: UUID?) -> String { + if let connectionId { + return "\(baseIdentifier)-\(connectionId.uuidString)" + } + return baseIdentifier + } + + private static func makeWindow(connectionId: UUID?) -> NSWindow { + let hostingController = NSHostingController(rootView: ConnectionFormView(connectionId: connectionId)) + let window = NSWindow(contentViewController: hostingController) + window.identifier = NSUserInterfaceItemIdentifier(identifier(for: connectionId)) + window.title = String(localized: "New Connection") + window.styleMask = [.titled, .closable, .resizable] + window.standardWindowButton(.miniaturizeButton)?.isEnabled = false + window.standardWindowButton(.zoomButton)?.isEnabled = false + window.styleMask.remove(.miniaturizable) + window.collectionBehavior.insert(.fullScreenNone) + window.center() + window.isReleasedWhenClosed = false + return window + } +} diff --git a/TablePro/Core/Services/Infrastructure/DeeplinkHandler.swift b/TablePro/Core/Services/Infrastructure/DeeplinkHandler.swift deleted file mode 100644 index 8223851ca..000000000 --- a/TablePro/Core/Services/Infrastructure/DeeplinkHandler.swift +++ /dev/null @@ -1,300 +0,0 @@ -// -// DeeplinkHandler.swift -// TablePro -// - -import Foundation -import os - -struct PairingRequest: Sendable, Equatable { - let clientName: String - let challenge: String - let redirectURL: URL - let requestedScopes: String? - let requestedConnectionIds: Set? -} - -struct PairingExchange: Sendable, Equatable { - let code: String - let verifier: String -} - -enum DeeplinkAction { - case connect(connectionId: UUID) - case openTable(connectionId: UUID, tableName: String, databaseName: String?, schemaName: String?) - case openQuery(connectionId: UUID, sql: String) - case importConnection(ExportableConnection) - case pairIntegration(PairingRequest) - case exchangePairing(PairingExchange) - case startMCP -} - -@MainActor -enum DeeplinkHandler { - private static let logger = Logger(subsystem: "com.TablePro", category: "DeeplinkHandler") - - static func parse(_ url: URL) -> DeeplinkAction? { - guard url.scheme == "tablepro" else { return nil } - - let host = url.host(percentEncoded: false) - switch host { - case "connect": - return parseConnect(url) - case "import": - return parseImport(url) - case "integrations": - return parseIntegrations(url) - default: - logger.warning("Unknown deep link host: \(host ?? "nil", privacy: .public)") - return nil - } - } - - // MARK: - Connect parsing - - private static func parseConnect(_ url: URL) -> DeeplinkAction? { - let components = url.pathComponents.filter { $0 != "/" } - guard let firstRaw = components.first?.removingPercentEncoding, - !firstRaw.isEmpty else { return nil } - - guard let connectionId = UUID(uuidString: firstRaw) else { - logger.warning("Connect deep link missing valid UUID: \(firstRaw, privacy: .public)") - return nil - } - - if components.count >= 2, components[1] == "query" { - let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems - guard let sql = queryItems?.first(where: { $0.name == "sql" })?.value, - !sql.isEmpty else { return nil } - return .openQuery(connectionId: connectionId, sql: sql) - } - - if components.count == 7, - components[1] == "database", - components[3] == "schema", - components[5] == "table", - let dbName = components[2].removingPercentEncoding, - let schemaName = components[4].removingPercentEncoding, - let tableName = components[6].removingPercentEncoding { - return .openTable(connectionId: connectionId, tableName: tableName, databaseName: dbName, schemaName: schemaName) - } - - if components.count == 5, - components[1] == "database", - components[3] == "table", - let dbName = components[2].removingPercentEncoding, - let tableName = components[4].removingPercentEncoding { - return .openTable(connectionId: connectionId, tableName: tableName, databaseName: dbName, schemaName: nil) - } - - if components.count >= 3, components[1] == "table", - let tableName = components[2].removingPercentEncoding { - return .openTable(connectionId: connectionId, tableName: tableName, databaseName: nil, schemaName: nil) - } - - if components.count == 1 { - return .connect(connectionId: connectionId) - } - - logger.warning("Unrecognized connect deep link path: \(url.path, privacy: .public)") - return nil - } - - // MARK: - Integrations parsing - - private static func parseIntegrations(_ url: URL) -> DeeplinkAction? { - let components = url.pathComponents.filter { $0 != "/" } - guard let action = components.first else { - logger.warning("Integrations deep link missing action") - return nil - } - - switch action { - case "pair": - return parsePair(url) - case "exchange": - return parseExchange(url) - case "start-mcp": - return .startMCP - default: - logger.warning("Unknown integrations action: \(action, privacy: .public)") - return nil - } - } - - private static func parsePair(_ url: URL) -> DeeplinkAction? { - guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems - else { - logger.warning("Pair deep link missing query items") - return nil - } - - func value(_ key: String) -> String? { - queryItems.first(where: { $0.name == key })?.value - } - - guard let clientName = value("client"), !clientName.isEmpty, - let challenge = value("challenge"), !challenge.isEmpty, - let redirectRaw = value("redirect"), !redirectRaw.isEmpty, - let redirectURL = URL(string: redirectRaw) else { - logger.warning("Pair deep link missing required params") - return nil - } - - let scopes = value("scopes")?.nilIfEmpty - let connectionIds: Set? - if let csv = value("connection-ids")?.nilIfEmpty { - let parsed = csv.split(separator: ",").compactMap { UUID(uuidString: String($0)) } - connectionIds = parsed.isEmpty ? nil : Set(parsed) - } else { - connectionIds = nil - } - - return .pairIntegration( - PairingRequest( - clientName: clientName, - challenge: challenge, - redirectURL: redirectURL, - requestedScopes: scopes, - requestedConnectionIds: connectionIds - ) - ) - } - - private static func parseExchange(_ url: URL) -> DeeplinkAction? { - guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems - else { - logger.warning("Exchange deep link missing query items") - return nil - } - - func value(_ key: String) -> String? { - queryItems.first(where: { $0.name == key })?.value - } - - guard let code = value("code"), !code.isEmpty, - let verifier = value("verifier"), !verifier.isEmpty else { - logger.warning("Exchange deep link missing code or verifier") - return nil - } - - return .exchangePairing(PairingExchange(code: code, verifier: verifier)) - } - - // MARK: - Import parsing - - private static func parseImport(_ url: URL) -> DeeplinkAction? { - guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems - else { return nil } - - func value(_ key: String) -> String? { - queryItems.first(where: { $0.name == key })?.value - } - - guard let name = value("name"), !name.isEmpty, - let host = value("host"), !host.isEmpty, - let typeStr = value("type"), - let dbType = DatabaseType(validating: typeStr) - ?? PluginMetadataRegistry.shared.allRegisteredTypeIds() - .first(where: { $0.lowercased() == typeStr.lowercased() }) - .map({ DatabaseType(rawValue: $0) }) - else { - logger.warning("Import deep link missing required params") - return nil - } - - let port = value("port").flatMap(Int.init) ?? dbType.defaultPort - let username = value("username") ?? "" - let database = value("database") ?? "" - - let sshConfig: ExportableSSHConfig? - if value("ssh") == "1" { - let jumpHosts: [ExportableJumpHost]? - if let jumpJson = value("sshJumpHosts"), - let data = jumpJson.data(using: .utf8) { - jumpHosts = try? JSONDecoder().decode([ExportableJumpHost].self, from: data) - } else { - jumpHosts = nil - } - sshConfig = ExportableSSHConfig( - enabled: true, - host: value("sshHost") ?? "", - port: value("sshPort").flatMap(Int.init) ?? 22, - username: value("sshUsername") ?? "", - authMethod: value("sshAuthMethod") ?? "password", - privateKeyPath: value("sshPrivateKeyPath") ?? "", - useSSHConfig: value("sshUseSSHConfig") == "1", - agentSocketPath: value("sshAgentSocketPath") ?? "", - jumpHosts: jumpHosts, - totpMode: value("sshTotpMode"), - totpAlgorithm: value("sshTotpAlgorithm"), - totpDigits: value("sshTotpDigits").flatMap(Int.init), - totpPeriod: value("sshTotpPeriod").flatMap(Int.init) - ) - } else { - sshConfig = nil - } - - let sslConfig: ExportableSSLConfig? - if let sslMode = value("sslMode") { - sslConfig = ExportableSSLConfig( - mode: sslMode, - caCertificatePath: value("sslCaCertPath"), - clientCertificatePath: value("sslClientCertPath"), - clientKeyPath: value("sslClientKeyPath") - ) - } else { - sslConfig = nil - } - - var additionalFields: [String: String]? - let afItems = queryItems.filter { $0.name.hasPrefix("af_") } - if !afItems.isEmpty { - var fields: [String: String] = [:] - for item in afItems { - let fieldKey = String(item.name.dropFirst(3)) - if !fieldKey.isEmpty, let fieldValue = item.value { - fields[fieldKey] = fieldValue - } - } - if !fields.isEmpty { - additionalFields = fields - } - } - - let exportable = ExportableConnection( - name: name, - host: host, - port: port, - database: database, - username: username, - type: dbType.rawValue, - sshConfig: sshConfig, - sslConfig: sslConfig, - color: value("color"), - tagName: value("tagName"), - groupName: value("groupName"), - sshProfileId: nil, - safeModeLevel: value("safeModeLevel"), - aiPolicy: value("aiPolicy"), - additionalFields: additionalFields, - redisDatabase: value("redisDatabase").flatMap(Int.init), - startupCommands: value("startupCommands"), - localOnly: value("localOnly") == "1" ? true : nil - ) - - return .importConnection(exportable) - } - - // MARK: - Resolution - - static func resolveConnection(byId id: UUID) -> DatabaseConnection? { - ConnectionStorage.shared.loadConnections().first { $0.id == id } - } -} - -private extension String { - var nilIfEmpty: String? { - isEmpty ? nil : self - } -} diff --git a/TablePro/Core/Services/Infrastructure/DeeplinkParser.swift b/TablePro/Core/Services/Infrastructure/DeeplinkParser.swift new file mode 100644 index 000000000..d6bc15890 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/DeeplinkParser.swift @@ -0,0 +1,378 @@ +// +// DeeplinkParser.swift +// TablePro +// + +import Foundation +import TableProPluginKit + +internal enum DeeplinkError: Error, LocalizedError, Equatable { + case unknownScheme(String) + case unknownHost(String) + case malformedPath(String) + case missingRequiredParam(String) + case invalidUUID(String) + case sqlTooLong(Int, limit: Int) + case unsupportedDatabaseType(String) + + internal var errorDescription: String? { + switch self { + case .unknownScheme(let scheme): + return String(format: String(localized: "Unknown URL scheme: %@"), scheme) + case .unknownHost(let host): + return String(format: String(localized: "Unknown deep link host: %@"), host) + case .malformedPath(let path): + return String(format: String(localized: "Malformed deep link path: %@"), path) + case .missingRequiredParam(let name): + return String(format: String(localized: "Missing required parameter: %@"), name) + case .invalidUUID(let raw): + return String(format: String(localized: "Invalid UUID: %@"), raw) + case .sqlTooLong(let length, let limit): + return String( + format: String(localized: "SQL is too long: %d characters (limit %d)"), + length, limit + ) + case .unsupportedDatabaseType(let raw): + return String(format: String(localized: "Unsupported database type: %@"), raw) + } + } +} + +internal enum DeeplinkParser { + internal static let sqlLengthLimit = 51_200 + + internal static func parse(_ url: URL) -> Result { + guard url.scheme == "tablepro" else { + return .failure(.unknownScheme(url.scheme ?? "")) + } + let host = url.host(percentEncoded: false) ?? "" + switch host { + case "connect": + return parseConnect(url) + case "import": + return parseImport(url) + case "integrations": + return parseIntegrations(url) + default: + return .failure(.unknownHost(host)) + } + } + + private static func parseConnect(_ url: URL) -> Result { + let segments = pathSegments(url) + var cursor = PathCursor(segments: segments) + + guard let firstRaw = cursor.next() else { + return .failure(.malformedPath(url.path)) + } + guard let connectionId = UUID(uuidString: firstRaw) else { + return .failure(.invalidUUID(firstRaw)) + } + + guard let head = cursor.peek() else { + return .success(.openConnection(connectionId)) + } + + switch head { + case "table": + cursor.advance() + guard let table = cursor.next(), !table.isEmpty else { + return .failure(.malformedPath(url.path)) + } + guard cursor.atEnd else { return .failure(.malformedPath(url.path)) } + return .success(.openTable( + connectionId: connectionId, + database: nil, + schema: nil, + table: table, + isView: false + )) + + case "database": + cursor.advance() + guard let database = cursor.next(), !database.isEmpty else { + return .failure(.malformedPath(url.path)) + } + return parseDatabaseTail( + connectionId: connectionId, database: database, cursor: &cursor, fullPath: url.path + ) + + case "query": + cursor.advance() + guard cursor.atEnd else { return .failure(.malformedPath(url.path)) } + return parseQuery(url: url, connectionId: connectionId) + + default: + return .failure(.malformedPath(url.path)) + } + } + + private static func parseDatabaseTail( + connectionId: UUID, + database: String, + cursor: inout PathCursor, + fullPath: String + ) -> Result { + guard let next = cursor.next() else { + return .failure(.malformedPath(fullPath)) + } + switch next { + case "schema": + guard let schema = cursor.next(), !schema.isEmpty else { + return .failure(.malformedPath(fullPath)) + } + guard cursor.next() == "table", + let table = cursor.next(), !table.isEmpty else { + return .failure(.malformedPath(fullPath)) + } + guard cursor.atEnd else { return .failure(.malformedPath(fullPath)) } + return .success(.openTable( + connectionId: connectionId, + database: database, + schema: schema, + table: table, + isView: false + )) + + case "table": + guard let table = cursor.next(), !table.isEmpty else { + return .failure(.malformedPath(fullPath)) + } + guard cursor.atEnd else { return .failure(.malformedPath(fullPath)) } + return .success(.openTable( + connectionId: connectionId, + database: database, + schema: nil, + table: table, + isView: false + )) + + default: + return .failure(.malformedPath(fullPath)) + } + } + + private static func parseQuery(url: URL, connectionId: UUID) -> Result { + guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems, + let rawSQL = queryItems.first(where: { $0.name == "sql" })?.value, + !rawSQL.isEmpty else { + return .failure(.missingRequiredParam("sql")) + } + let length = (rawSQL as NSString).length + guard length <= sqlLengthLimit else { + return .failure(.sqlTooLong(length, limit: sqlLengthLimit)) + } + return .success(.openQuery(connectionId: connectionId, sql: rawSQL)) + } + + private static func parseIntegrations(_ url: URL) -> Result { + let segments = pathSegments(url) + var cursor = PathCursor(segments: segments) + guard let action = cursor.next() else { + return .failure(.malformedPath(url.path)) + } + switch action { + case "pair": + return parsePair(url) + case "start-mcp": + return .success(.startMCPServer) + default: + return .failure(.malformedPath(url.path)) + } + } + + private static func parsePair(_ url: URL) -> Result { + guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems + else { + return .failure(.missingRequiredParam("client")) + } + func value(_ key: String) -> String? { + queryItems.first(where: { $0.name == key })?.value + } + + guard let clientName = value("client"), !clientName.isEmpty else { + return .failure(.missingRequiredParam("client")) + } + guard let challenge = value("challenge"), !challenge.isEmpty else { + return .failure(.missingRequiredParam("challenge")) + } + guard let redirectRaw = value("redirect"), !redirectRaw.isEmpty, + let redirectURL = URL(string: redirectRaw) else { + return .failure(.missingRequiredParam("redirect")) + } + + let scopes = value("scopes")?.nilIfEmpty + let connectionIds: Set? + if let csv = value("connection-ids")?.nilIfEmpty { + let parsed = csv.split(separator: ",").compactMap { UUID(uuidString: String($0)) } + connectionIds = parsed.isEmpty ? nil : Set(parsed) + } else { + connectionIds = nil + } + + return .success(.pairIntegration( + PairingRequest( + clientName: clientName, + challenge: challenge, + redirectURL: redirectURL, + requestedScopes: scopes, + requestedConnectionIds: connectionIds + ) + )) + } + + private static func parseImport(_ url: URL) -> Result { + guard let queryItems = URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems + else { + return .failure(.missingRequiredParam("name")) + } + func value(_ key: String) -> String? { + queryItems.first(where: { $0.name == key })?.value + } + + guard let name = value("name"), !name.isEmpty else { + return .failure(.missingRequiredParam("name")) + } + guard let host = value("host"), !host.isEmpty else { + return .failure(.missingRequiredParam("host")) + } + guard let typeStr = value("type") else { + return .failure(.missingRequiredParam("type")) + } + + let resolvedType: DatabaseType? + if let direct = DatabaseType(validating: typeStr) { + resolvedType = direct + } else if let pluginMatch = PluginMetadataRegistry.shared.allRegisteredTypeIds() + .first(where: { $0.lowercased() == typeStr.lowercased() }) { + resolvedType = DatabaseType(rawValue: pluginMatch) + } else { + resolvedType = nil + } + guard let dbType = resolvedType else { + return .failure(.unsupportedDatabaseType(typeStr)) + } + + let port = value("port").flatMap(Int.init) ?? dbType.defaultPort + let username = value("username") ?? "" + let database = value("database") ?? "" + + let sshConfig: ExportableSSHConfig? + if value("ssh") == "1" { + let jumpHosts: [ExportableJumpHost]? + if let jumpJson = value("sshJumpHosts"), + let data = jumpJson.data(using: .utf8) { + jumpHosts = try? JSONDecoder().decode([ExportableJumpHost].self, from: data) + } else { + jumpHosts = nil + } + sshConfig = ExportableSSHConfig( + enabled: true, + host: value("sshHost") ?? "", + port: value("sshPort").flatMap(Int.init) ?? 22, + username: value("sshUsername") ?? "", + authMethod: value("sshAuthMethod") ?? "password", + privateKeyPath: value("sshPrivateKeyPath") ?? "", + useSSHConfig: value("sshUseSSHConfig") == "1", + agentSocketPath: value("sshAgentSocketPath") ?? "", + jumpHosts: jumpHosts, + totpMode: value("sshTotpMode"), + totpAlgorithm: value("sshTotpAlgorithm"), + totpDigits: value("sshTotpDigits").flatMap(Int.init), + totpPeriod: value("sshTotpPeriod").flatMap(Int.init) + ) + } else { + sshConfig = nil + } + + let sslConfig: ExportableSSLConfig? + if let sslMode = value("sslMode") { + sslConfig = ExportableSSLConfig( + mode: sslMode, + caCertificatePath: value("sslCaCertPath"), + clientCertificatePath: value("sslClientCertPath"), + clientKeyPath: value("sslClientKeyPath") + ) + } else { + sslConfig = nil + } + + var additionalFields: [String: String]? + let afItems = queryItems.filter { $0.name.hasPrefix("af_") } + if !afItems.isEmpty { + var fields: [String: String] = [:] + for item in afItems { + let fieldKey = String(item.name.dropFirst(3)) + if !fieldKey.isEmpty, let fieldValue = item.value, !fieldValue.isEmpty { + fields[fieldKey] = fieldValue + } + } + if !fields.isEmpty { + additionalFields = fields + } + } + + let exportable = ExportableConnection( + name: name, + host: host, + port: port, + database: database, + username: username, + type: dbType.rawValue, + sshConfig: sshConfig, + sslConfig: sslConfig, + color: value("color"), + tagName: value("tagName"), + groupName: value("groupName"), + sshProfileId: nil, + safeModeLevel: value("safeModeLevel"), + aiPolicy: value("aiPolicy"), + additionalFields: additionalFields, + redisDatabase: value("redisDatabase").flatMap(Int.init), + startupCommands: value("startupCommands"), + localOnly: value("localOnly") == "1" ? true : nil + ) + + return .success(.importConnection(exportable)) + } + + private static func pathSegments(_ url: URL) -> [String] { + url.pathComponents + .filter { $0 != "/" } + .compactMap { $0.removingPercentEncoding } + } +} + +private struct PathCursor { + private let segments: [String] + private var index: Int = 0 + + init(segments: [String]) { + self.segments = segments + } + + var atEnd: Bool { + index >= segments.count + } + + func peek() -> String? { + guard index < segments.count else { return nil } + return segments[index] + } + + mutating func advance() { + index += 1 + } + + mutating func next() -> String? { + guard index < segments.count else { return nil } + defer { index += 1 } + return segments[index] + } +} + +private extension String { + var nilIfEmpty: String? { + isEmpty ? nil : self + } +} diff --git a/TablePro/Core/Services/Infrastructure/LaunchIntent.swift b/TablePro/Core/Services/Infrastructure/LaunchIntent.swift new file mode 100644 index 000000000..aef5d2f68 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/LaunchIntent.swift @@ -0,0 +1,38 @@ +// +// LaunchIntent.swift +// TablePro +// + +import Foundation + +internal enum LaunchIntent: @unchecked Sendable { + case openConnection(UUID) + case openTable(connectionId: UUID, database: String?, schema: String?, table: String, isView: Bool) + case openQuery(connectionId: UUID, sql: String) + case importConnection(ExportableConnection) + case openSQLFile(URL) + case openDatabaseFile(URL, DatabaseType) + case openConnectionShare(URL) + case pairIntegration(PairingRequest) + case startMCPServer + case openDatabaseURL(URL) + case installPlugin(URL) + + internal var targetConnectionId: UUID? { + switch self { + case .openConnection(let id), + .openTable(let id, _, _, _, _), + .openQuery(let id, _): + return id + case .openDatabaseURL, + .openDatabaseFile, + .openSQLFile, + .importConnection, + .openConnectionShare, + .pairIntegration, + .startMCPServer, + .installPlugin: + return nil + } + } +} diff --git a/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift b/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift new file mode 100644 index 000000000..7a60f1959 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift @@ -0,0 +1,95 @@ +// +// LaunchIntentRouter.swift +// TablePro +// + +import AppKit +import Foundation +import os + +@MainActor +internal final class LaunchIntentRouter { + internal static let shared = LaunchIntentRouter() + + private static let logger = Logger(subsystem: "com.TablePro", category: "LaunchIntentRouter") + + private init() {} + + internal func route(_ intent: LaunchIntent) async { + do { + switch intent { + case .openConnection, + .openTable, + .openQuery, + .openDatabaseURL, + .openDatabaseFile, + .openSQLFile: + try await TabRouter.shared.route(intent) + + case .importConnection(let exportable): + WelcomeRouter.shared.routeImport(exportable) + + case .openConnectionShare(let url): + WelcomeRouter.shared.routeShare(url) + + case .pairIntegration(let request): + try await MCPPairingService.shared.startPairing(request) + + case .startMCPServer: + await MCPServerManager.shared.lazyStart() + + case .installPlugin(let url): + try await installPlugin(url) + } + } catch let error as TabRouterError where error == .userCancelled { + Self.logger.info("Intent cancelled by user") + } catch let error as MCPError where error.isUserCancelled { + Self.logger.info("Pairing cancelled by user") + } catch is CancellationError { + Self.logger.info("Intent cancelled") + } catch { + Self.logger.error("Intent failed: \(error.localizedDescription, privacy: .public)") + await presentError(error, for: intent) + } + } + + private func installPlugin(_ url: URL) async throws { + let entry = try await PluginManager.shared.installPlugin(from: url) + Self.logger.info("Installed plugin '\(entry.name, privacy: .public)' from Finder") + UserDefaults.standard.set(SettingsTab.plugins.rawValue, forKey: "selectedSettingsTab") + NotificationCenter.default.post(name: .openSettingsWindow, object: nil) + } + + private func presentError(_ error: Error, for intent: LaunchIntent) async { + let title: String + switch intent { + case .pairIntegration: + title = String(localized: "Pairing Failed") + case .installPlugin: + title = String(localized: "Plugin Installation Failed") + case .openConnection, .openTable, .openQuery, .openDatabaseURL, .openDatabaseFile: + title = String(localized: "Connection Failed") + case .openSQLFile: + title = String(localized: "Could Not Open File") + case .importConnection, .openConnectionShare, .startMCPServer: + title = String(localized: "Action Failed") + } + AlertHelper.showErrorSheet( + title: title, + message: error.localizedDescription, + window: NSApp.keyWindow + ) + } +} + +extension TabRouterError: Equatable { + internal static func == (lhs: TabRouterError, rhs: TabRouterError) -> Bool { + switch (lhs, rhs) { + case (.userCancelled, .userCancelled): return true + case (.connectionNotFound(let l), .connectionNotFound(let r)): return l == r + case (.malformedDatabaseURL(let l), .malformedDatabaseURL(let r)): return l == r + case (.unsupportedIntent(let l), .unsupportedIntent(let r)): return l == r + default: return false + } + } +} diff --git a/TablePro/Core/Services/Infrastructure/LaunchPhase.swift b/TablePro/Core/Services/Infrastructure/LaunchPhase.swift new file mode 100644 index 000000000..2bbf8e0c5 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/LaunchPhase.swift @@ -0,0 +1,27 @@ +// +// LaunchPhase.swift +// TablePro +// + +import Foundation + +internal enum LaunchPhase: Equatable, Sendable { + case launching + case collectingIntents(deadline: Date) + case routing + case ready + + internal var isAcceptingIntents: Bool { + switch self { + case .launching, .collectingIntents: + return true + case .routing, .ready: + return false + } + } + + internal var isReady: Bool { + if case .ready = self { return true } + return false + } +} diff --git a/TablePro/Core/Services/Infrastructure/MainSplitViewController.swift b/TablePro/Core/Services/Infrastructure/MainSplitViewController.swift index b0bdc60c7..c200decf4 100644 --- a/TablePro/Core/Services/Infrastructure/MainSplitViewController.swift +++ b/TablePro/Core/Services/Infrastructure/MainSplitViewController.swift @@ -45,7 +45,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi // MARK: - Observers private var connectionStatusObserver: NSObjectProtocol? - private var newConnectionObserver: NSObjectProtocol? // MARK: - Init @@ -199,15 +198,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi self?.handleConnectionStatusChange() } } - newConnectionObserver = NotificationCenter.default.addObserver( - forName: .newConnection, - object: nil, - queue: .main - ) { _ in - MainActor.assumeIsolated { - NotificationCenter.default.post(name: .openWelcomeWindow, object: nil) - } - } handleConnectionStatusChange() } @@ -216,10 +206,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi NotificationCenter.default.removeObserver(observer) connectionStatusObserver = nil } - if let observer = newConnectionObserver { - NotificationCenter.default.removeObserver(observer) - newConnectionObserver = nil - } } // MARK: - Toolbar @@ -320,7 +306,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi private func buildSidebarView() -> some View { if let currentSession, let sessionState { SidebarView( - tables: sessionTablesBinding, sidebarState: SharedSidebarState.forConnection(currentSession.connection.id), onDoubleClick: { [weak self] table in guard let coordinator = self?.sessionState?.coordinator else { return } @@ -358,7 +343,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi connection: currentSession.connection, payload: payload, windowTitle: windowTitleBinding, - tables: sessionTablesBinding, sidebarState: SharedSidebarState.forConnection(currentSession.connection.id), pendingTruncates: sessionPendingTruncatesBinding, pendingDeletes: sessionPendingDeletesBinding, @@ -419,10 +403,6 @@ internal final class MainSplitViewController: NSSplitViewController, InspectorVi ) } - private var sessionTablesBinding: Binding<[TableInfo]> { - createSessionBinding(get: { $0.tables }, set: { $0.tables = $1 }, defaultValue: []) - } - private var sessionPendingTruncatesBinding: Binding> { createSessionBinding(get: { $0.pendingTruncates }, set: { $0.pendingTruncates = $1 }, defaultValue: []) } diff --git a/TablePro/Core/Services/Infrastructure/MainWindowToolbar.swift b/TablePro/Core/Services/Infrastructure/MainWindowToolbar.swift index 38efcee31..5b8b08948 100644 --- a/TablePro/Core/Services/Infrastructure/MainWindowToolbar.swift +++ b/TablePro/Core/Services/Infrastructure/MainWindowToolbar.swift @@ -258,23 +258,20 @@ internal final class MainWindowToolbar: NSObject, NSToolbarDelegate { private struct ConnectionToolbarButton: View { let coordinator: MainContentCoordinator - @State private var showSwitcher = false var body: some View { + @Bindable var state = coordinator.toolbarState Button { - showSwitcher.toggle() + state.showConnectionSwitcher.toggle() } label: { Label("Connection", systemImage: "network") } .help(String(localized: "Switch Connection (⌘⌥C)")) - .popover(isPresented: $showSwitcher) { + .popover(isPresented: $state.showConnectionSwitcher) { ConnectionSwitcherPopover { - showSwitcher = false + state.showConnectionSwitcher = false } } - .onReceive(NotificationCenter.default.publisher(for: .openConnectionSwitcher)) { _ in - showSwitcher = true - } } } diff --git a/TablePro/Core/Services/Infrastructure/PendingActionStore.swift b/TablePro/Core/Services/Infrastructure/PendingActionStore.swift deleted file mode 100644 index 707884e9c..000000000 --- a/TablePro/Core/Services/Infrastructure/PendingActionStore.swift +++ /dev/null @@ -1,28 +0,0 @@ -// -// PendingActionStore.swift -// TablePro -// - -import Foundation - -@MainActor @Observable -final class PendingActionStore { - static let shared = PendingActionStore() - - var connectionShareURL: URL? - var deeplinkImport: ExportableConnection? - - private init() {} - - func consumeConnectionShareURL() -> URL? { - let url = connectionShareURL - connectionShareURL = nil - return url - } - - func consumeDeeplinkImport() -> ExportableConnection? { - let value = deeplinkImport - deeplinkImport = nil - return value - } -} diff --git a/TablePro/Core/Services/Infrastructure/SessionStateFactory.swift b/TablePro/Core/Services/Infrastructure/SessionStateFactory.swift index 099912e81..ffabe045a 100644 --- a/TablePro/Core/Services/Infrastructure/SessionStateFactory.swift +++ b/TablePro/Core/Services/Infrastructure/SessionStateFactory.swift @@ -2,9 +2,6 @@ // SessionStateFactory.swift // TablePro // -// Factory for creating session state objects used by MainContentView. -// Extracted from MainContentView.init to enable testability. -// import Foundation import os @@ -22,23 +19,36 @@ enum SessionStateFactory { let coordinator: MainContentCoordinator } - /// Hand-off registry for SessionState created eagerly by `WindowManager.openTab`. - /// `WindowManager` creates the coordinator BEFORE `TabWindowController.init` so the - /// NSToolbar can be installed synchronously in init (eliminating the toolbar flash - /// caused by lazy install via `WindowAccessor → configureWindow` after the window - /// is already on-screen). `ContentView.init` consumes the same SessionState here so - /// only one coordinator exists per window — no duplicate-tab side effects. private static var pendingSessionStates: [UUID: SessionState] = [:] + private static var pendingExpirationTasks: [UUID: Task] = [:] + + private static let pendingEntryTTL: Duration = .seconds(5) static func registerPending(_ state: SessionState, for payloadId: UUID) { pendingSessionStates[payloadId] = state + pendingExpirationTasks[payloadId]?.cancel() + pendingExpirationTasks[payloadId] = Task { [payloadId] in + try? await Task.sleep(for: pendingEntryTTL) + guard !Task.isCancelled else { return } + await MainActor.run { + pendingExpirationTasks.removeValue(forKey: payloadId) + guard let abandoned = pendingSessionStates.removeValue(forKey: payloadId) else { + return + } + MainContentCoordinator.activeCoordinators.removeValue( + forKey: abandoned.coordinator.instanceId + ) + } + } } static func consumePending(for payloadId: UUID) -> SessionState? { - pendingSessionStates.removeValue(forKey: payloadId) + pendingExpirationTasks.removeValue(forKey: payloadId)?.cancel() + return pendingSessionStates.removeValue(forKey: payloadId) } static func removePending(for payloadId: UUID) { + pendingExpirationTasks.removeValue(forKey: payloadId)?.cancel() pendingSessionStates.removeValue(forKey: payloadId) } @@ -46,14 +56,16 @@ enum SessionStateFactory { connection: DatabaseConnection, payload: EditorTabPayload? ) -> SessionState { - let tabMgr = QueryTabManager() + let connectionId = connection.id + let tabMgr = QueryTabManager(globalTabsProvider: { + MainActor.assumeIsolated { MainContentCoordinator.allTabs(for: connectionId) } + }) let changeMgr = DataChangeManager() changeMgr.databaseType = connection.type let filterMgr = FilterStateManager() let colVisMgr = ColumnVisibilityManager() let toolbarSt = ConnectionToolbarState(connection: connection) - // Eagerly populate version + state from existing session to avoid flash if let session = DatabaseManager.shared.session(for: connection.id) { toolbarSt.updateConnectionState(from: session.status) if let driver = session.driver { @@ -65,7 +77,6 @@ enum SessionStateFactory { } toolbarSt.hasCompletedSetup = true - // Redis: set initial database name eagerly to avoid toolbar flash if connection.type.pluginTypeId == "Redis" { let dbIndex = connection.redisDatabase ?? Int(connection.database) ?? 0 toolbarSt.databaseName = String(dbIndex) @@ -136,7 +147,11 @@ enum SessionStateFactory { case .newEmptyTab: let allTabs = MainContentCoordinator.allTabs(for: connection.id) let title = QueryTabManager.nextQueryTitle(existingTabs: allTabs) - tabMgr.addTab(title: title, databaseName: payload.databaseName ?? connection.database) + tabMgr.addTab( + initialQuery: payload.initialQuery, + title: title, + databaseName: payload.databaseName ?? connection.database + ) case .restoreOrDefault: break } @@ -151,6 +166,13 @@ enum SessionStateFactory { toolbarState: toolbarSt ) + // Eagerly publish to the active-coordinator registry so concurrent + // window opens for the same connection both observe each other when + // computing globals like nextQueryTitle. Without this, two windows + // opened back-to-back can both compute "Query 1" before either has + // run onAppear. + coord.registerEagerly() + return SessionState( tabManager: tabMgr, changeManager: changeMgr, diff --git a/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator+AggregatedSave.swift b/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator+AggregatedSave.swift new file mode 100644 index 000000000..c4424943e --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator+AggregatedSave.swift @@ -0,0 +1,33 @@ +// +// TabPersistenceCoordinator+AggregatedSave.swift +// TablePro +// + +import Foundation + +extension TabPersistenceCoordinator { + /// Save or clear persisted state based on tabs aggregated across all windows + /// for the connection. Prevents the per-window close path from clobbering + /// state when sibling windows still have open tabs. + func saveOrClearAggregated() { + let aggregatedTabs = MainContentCoordinator.aggregatedTabs(for: connectionId) + if aggregatedTabs.isEmpty { + clearSavedState() + } else { + let selectedId = MainContentCoordinator.aggregatedSelectedTabId(for: connectionId) + saveNow(tabs: aggregatedTabs, selectedTabId: selectedId) + } + } + + /// Synchronous variant for the window-close path, where the run loop may + /// not be available to service Tasks before the window tears down. + func saveOrClearAggregatedSync() { + let aggregatedTabs = MainContentCoordinator.aggregatedTabs(for: connectionId) + if aggregatedTabs.isEmpty { + saveNowSync(tabs: [], selectedTabId: nil) + } else { + let selectedId = MainContentCoordinator.aggregatedSelectedTabId(for: connectionId) + saveNowSync(tabs: aggregatedTabs, selectedTabId: selectedId) + } + } +} diff --git a/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator.swift b/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator.swift index 3af4eb7ee..2ce909d20 100644 --- a/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator.swift +++ b/TablePro/Core/Services/Infrastructure/TabPersistenceCoordinator.swift @@ -2,15 +2,11 @@ // TabPersistenceCoordinator.swift // TablePro // -// Explicit-save coordinator for tab state persistence. -// Replaces debounced/flag-based TabPersistenceService with direct save calls. -// import Foundation import Observation import os -/// Result of tab restoration from disk internal struct RestoreResult { let tabs: [QueryTab] let selectedTabId: UUID? @@ -22,22 +18,19 @@ internal struct RestoreResult { } } -/// Coordinator for persisting and restoring tab state. -/// All saves are explicit: no debounce timers, no onChange-driven saves, -/// no isDismissing/isRestoringTabs flag state machine. @MainActor @Observable internal final class TabPersistenceCoordinator { private static let logger = Logger(subsystem: "com.TablePro", category: "NativeTabLifecycle") let connectionId: UUID + @ObservationIgnored private var saveTask: Task? + init(connectionId: UUID) { self.connectionId = connectionId } // MARK: - Save - /// Save tab state to disk. Called explicitly at named business events - /// (tab switch, window close, quit, etc.). internal func saveNow(tabs: [QueryTab], selectedTabId: UUID?) { let nonPreviewTabs = tabs.filter { !$0.isPreview } guard !nonPreviewTabs.isEmpty else { @@ -45,43 +38,17 @@ internal final class TabPersistenceCoordinator { return } let persisted = nonPreviewTabs.map { convertToPersistedTab($0) } - let connId = connectionId let normalizedSelectedId = nonPreviewTabs.contains(where: { $0.id == selectedTabId }) ? selectedTabId : nonPreviewTabs.first?.id - Self.logger.debug("[persist] saveNow queued tabCount=\(nonPreviewTabs.count) connId=\(connId, privacy: .public)") - - Task { - let t0 = Date() - do { - try await TabDiskActor.shared.save(connectionId: connId, tabs: persisted, selectedTabId: normalizedSelectedId) - Self.logger.debug("[persist] saveNow written tabCount=\(persisted.count) connId=\(connId, privacy: .public) ms=\(Int(Date().timeIntervalSince(t0) * 1_000))") - } catch { - TabDiskActor.logSaveError(connectionId: connId, error: error) - } - } + scheduleSave(tabs: persisted, selectedTabId: normalizedSelectedId) } - /// Save pre-aggregated tabs for the quit path, where the caller has already - /// collected and converted tabs from all windows for this connection. - internal func saveNow(persistedTabs: [PersistedTab], selectedTabId: UUID?) { - let connId = connectionId - let selectedId = selectedTabId - - Task { - do { - try await TabDiskActor.shared.save(connectionId: connId, tabs: persistedTabs, selectedTabId: selectedId) - } catch { - TabDiskActor.logSaveError(connectionId: connId, error: error) - } - } - } - - /// Synchronous save for `applicationWillTerminate` where no run loop - /// remains to service async Tasks. Bypasses the actor and writes directly. internal func saveNowSync(tabs: [QueryTab], selectedTabId: UUID?) { let nonPreviewTabs = tabs.filter { !$0.isPreview } guard !nonPreviewTabs.isEmpty else { - TabDiskActor.saveSync(connectionId: connectionId, tabs: [], selectedTabId: nil) + saveTask?.cancel() + saveTask = nil + TabDiskActor.clearSync(connectionId: connectionId) return } let persisted = nonPreviewTabs.map { convertToPersistedTab($0) } @@ -92,17 +59,40 @@ internal final class TabPersistenceCoordinator { // MARK: - Clear - /// Clear all saved state for this connection (user closed all tabs). internal func clearSavedState() { + saveTask?.cancel() + saveTask = nil let connId = connectionId Task { await TabDiskActor.shared.clear(connectionId: connId) } } + // MARK: - Private save scheduling + + private func scheduleSave(tabs: [PersistedTab], selectedTabId: UUID?) { + saveTask?.cancel() + let connId = connectionId + let tabsCopy = tabs + let selectedId = selectedTabId + Self.logger.debug("[persist] saveNow queued tabCount=\(tabsCopy.count) connId=\(connId, privacy: .public)") + + saveTask = Task { + guard !Task.isCancelled else { return } + let t0 = Date() + do { + try await TabDiskActor.shared.save(connectionId: connId, tabs: tabsCopy, selectedTabId: selectedId) + Self.logger.debug("[persist] saveNow written tabCount=\(tabsCopy.count) connId=\(connId, privacy: .public) ms=\(Int(Date().timeIntervalSince(t0) * 1_000))") + } catch is CancellationError { + return + } catch { + Self.logger.fault("Failed to save tab state for connection \(connId, privacy: .public): \(error.localizedDescription, privacy: .public)") + } + } + } + // MARK: - Restore - /// Restore tabs from disk. Called once at window creation. internal func restoreFromDisk() async -> RestoreResult { guard let state = await TabDiskActor.shared.load(connectionId: connectionId) else { return RestoreResult(tabs: [], selectedTabId: nil, source: .none) diff --git a/TablePro/Core/Services/Infrastructure/TabRouter.swift b/TablePro/Core/Services/Infrastructure/TabRouter.swift new file mode 100644 index 000000000..198b47f8e --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/TabRouter.swift @@ -0,0 +1,426 @@ +// +// TabRouter.swift +// TablePro +// + +import AppKit +import Foundation +import os + +internal enum TabRouterError: Error, LocalizedError { + case connectionNotFound(UUID) + case malformedDatabaseURL(URL) + case userCancelled + case unsupportedIntent(String) + + internal var errorDescription: String? { + switch self { + case .connectionNotFound(let id): + return String( + format: String(localized: "No saved connection with ID \"%@\"."), id.uuidString + ) + case .malformedDatabaseURL(let url): + return String( + format: String(localized: "Could not parse database URL: %@"), url.sanitizedForLogging + ) + case .userCancelled: + return String(localized: "Cancelled by user.") + case .unsupportedIntent(let detail): + return String(format: String(localized: "Unsupported intent: %@"), detail) + } + } +} + +@MainActor +internal final class TabRouter { + internal static let shared = TabRouter() + + private static let logger = Logger(subsystem: "com.TablePro", category: "TabRouter") + + private init() {} + + internal func route(_ intent: LaunchIntent) async throws { + switch intent { + case .openConnection(let id): + try await openConnection(id: id) + + case .openTable(let id, let database, let schema, let table, let isView): + try await openTable( + connectionId: id, transientConnection: nil, + database: database, schema: schema, table: table, isView: isView + ) + + case .openQuery(let id, let sql): + try await openQuery(connectionId: id, sql: sql) + + case .openDatabaseURL(let url): + try await openDatabaseURL(url) + + case .openDatabaseFile(let url, let type): + try await openDatabaseFile(url, type: type) + + case .openSQLFile(let url): + try await openSQLFile(url) + + default: + throw TabRouterError.unsupportedIntent(String(describing: intent)) + } + } + + // MARK: - Connection + + private func openConnection(id: UUID) async throws { + guard let connection = ConnectionStorage.shared.loadConnections().first(where: { $0.id == id }) else { + throw TabRouterError.connectionNotFound(id) + } + if let existing = WindowLifecycleMonitor.shared.findWindow(for: id) { + existing.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + try await DatabaseManager.shared.ensureConnected(connection) + closeWelcomeWindows() + return + } + try await runPreConnectScriptIfNeeded(connection) + let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + try await DatabaseManager.shared.ensureConnected(connection) + closeWelcomeWindows() + } + + // MARK: - Table + + private func openTable( + connectionId: UUID, transientConnection: DatabaseConnection? = nil, + database: String?, schema: String?, table: String, isView: Bool + ) async throws { + let connection: DatabaseConnection + if let transientConnection { + connection = transientConnection + } else if let stored = ConnectionStorage.shared.loadConnections().first(where: { $0.id == connectionId }) { + connection = stored + } else { + throw TabRouterError.connectionNotFound(connectionId) + } + try await runPreConnectScriptIfNeeded(connection) + try await DatabaseManager.shared.ensureConnected(connection) + + if let schema { + await switchSchemaOrDatabase(connectionId: connectionId, target: schema) + } else if let database { + await switchSchemaOrDatabase(connectionId: connectionId, target: database) + } + + if focusExistingTableTab(connectionId: connectionId, database: database, schema: schema, table: table) { + NSApp.activate(ignoringOtherApps: true) + closeWelcomeWindows() + return + } + + let payload = EditorTabPayload( + connectionId: connectionId, + tabType: .table, + tableName: table, + databaseName: database, + schemaName: schema, + isView: isView + ) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + closeWelcomeWindows() + } + + private func focusExistingTableTab( + connectionId: UUID, database: String?, schema: String?, table: String + ) -> Bool { + for coordinator in MainContentCoordinator.allActiveCoordinators() + where coordinator.connectionId == connectionId { + guard let match = coordinator.tabManager.tabs.first(where: { tab in + guard tab.tabType == .table, + tab.tableContext.tableName == table else { return false } + let databaseMatches = database.map { db in + tab.tableContext.databaseName == db + } ?? true + let schemaMatches = schema.map { sch in + tab.tableContext.schemaName.map { $0 == sch } ?? false + } ?? true + return databaseMatches && schemaMatches + }) else { continue } + coordinator.tabManager.selectedTabId = match.id + if let windowId = coordinator.windowId, + let window = WindowLifecycleMonitor.shared.window(for: windowId) { + window.makeKeyAndOrderFront(nil) + } + return true + } + return false + } + + // MARK: - Query + + private func openQuery(connectionId: UUID, sql: String) async throws { + guard let connection = ConnectionStorage.shared.loadConnections().first(where: { $0.id == connectionId }) else { + throw TabRouterError.connectionNotFound(connectionId) + } + + let preview = previewForSQL(sql) + let confirmed = await AlertHelper.runApprovalModal( + title: String(localized: "Open Query from Link"), + message: String( + format: String(localized: "An external link wants to open a query on \"%@\":\n\n%@"), + connection.name, preview + ), + confirm: String(localized: "Open Query"), + cancel: String(localized: "Cancel") + ) + guard confirmed else { throw TabRouterError.userCancelled } + + try await runPreConnectScriptIfNeeded(connection) + try await DatabaseManager.shared.ensureConnected(connection) + + if focusExistingQueryTab(connectionId: connectionId, sql: sql) { + NSApp.activate(ignoringOtherApps: true) + closeWelcomeWindows() + return + } + + let payload = EditorTabPayload( + connectionId: connectionId, + tabType: .query, + initialQuery: sql + ) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + closeWelcomeWindows() + } + + private func focusExistingQueryTab(connectionId: UUID, sql: String) -> Bool { + for coordinator in MainContentCoordinator.allActiveCoordinators() + where coordinator.connectionId == connectionId { + let match = coordinator.tabManager.tabs.first { tab in + tab.tabType == .query && tab.content.query == sql + } + guard let match else { continue } + coordinator.tabManager.selectedTabId = match.id + if let windowId = coordinator.windowId, + let window = WindowLifecycleMonitor.shared.window(for: windowId) { + window.makeKeyAndOrderFront(nil) + } + return true + } + return false + } + + private func previewForSQL(_ sql: String) -> String { + let nsSQL = sql as NSString + guard nsSQL.length > 300 else { return sql } + let head = nsSQL.substring(to: 300) + let hidden = nsSQL.length - 300 + return head + String(format: String(localized: "\n\n… (%d more characters not shown)"), hidden) + } + + // MARK: - Database URL + + private func openDatabaseURL(_ url: URL) async throws { + guard case .success(let parsed) = ConnectionURLParser.parse(url.absoluteString) else { + throw TabRouterError.malformedDatabaseURL(url) + } + + let connections = ConnectionStorage.shared.loadConnections() + let matched = connections.first { conn in + conn.type == parsed.type + && conn.host == parsed.host + && (parsed.port == nil || conn.port == parsed.port) + && conn.database == parsed.database + && (parsed.username.isEmpty || conn.username == parsed.username) + } + + let connection: DatabaseConnection + let isTransient: Bool + if let matched { + connection = matched + isTransient = false + } else { + connection = TransientConnectionFactory.build(from: parsed) + isTransient = true + } + + if !parsed.password.isEmpty { + ConnectionStorage.shared.savePassword(parsed.password, for: connection.id) + } + if let sshPass = parsed.sshPassword, !sshPass.isEmpty { + ConnectionStorage.shared.saveSSHPassword(sshPass, for: connection.id) + } + + do { + if let table = parsed.tableName { + try await openTable( + connectionId: connection.id, + transientConnection: isTransient ? connection : nil, + database: parsed.database.isEmpty ? nil : parsed.database, + schema: parsed.schema, + table: table, + isView: parsed.isView + ) + if parsed.filterColumn != nil || parsed.filterCondition != nil { + try await applyFilterFromParsedURL(parsed: parsed, connectionId: connection.id) + } + return + } + + try await runPreConnectScriptIfNeeded(connection) + let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + try await DatabaseManager.shared.ensureConnected(connection) + closeWelcomeWindows() + + if let schema = parsed.schema { + await switchSchemaOrDatabase(connectionId: connection.id, target: schema) + } + } catch { + if isTransient { + ConnectionStorage.shared.deletePassword(for: connection.id) + ConnectionStorage.shared.deleteSSHPassword(for: connection.id) + } + throw error + } + } + + // MARK: - Database File + + private func openDatabaseFile(_ url: URL, type: DatabaseType) async throws { + let filePath = url.path(percentEncoded: false) + let connectionName = url.deletingPathExtension().lastPathComponent + + for (sessionId, session) in DatabaseManager.shared.activeSessions + where session.connection.type == type + && session.connection.database == filePath + && session.driver != nil { + bringConnectionWindowToFront(sessionId) + return + } + + let connection = DatabaseConnection( + name: connectionName, + host: "", + port: 0, + database: filePath, + username: "", + type: type + ) + + let payload = EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + try await DatabaseManager.shared.ensureConnected(connection) + closeWelcomeWindows() + } + + // MARK: - SQL File + + private func openSQLFile(_ url: URL) async throws { + if let existing = WindowLifecycleMonitor.shared.window(forSourceFile: url) { + existing.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + return + } + + if let session = DatabaseManager.shared.currentSession { + let content = await Task.detached(priority: .userInitiated) { () -> String? in + try? String(contentsOf: url, encoding: .utf8) + }.value + guard let content else { + Self.logger.error("Failed to read SQL file: \(url.lastPathComponent, privacy: .public)") + return + } + let payload = EditorTabPayload( + connectionId: session.connection.id, + tabType: .query, + initialQuery: content, + sourceFileURL: url + ) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + } else { + WelcomeRouter.shared.enqueueSQLFile(url) + } + } + + // MARK: - Helpers + + internal func bringConnectionWindowToFront(_ connectionId: UUID) { + let windows = WindowLifecycleMonitor.shared.windows(for: connectionId) + if let window = windows.first { + window.makeKeyAndOrderFront(nil) + } else { + NSApp.windows.first { AppLaunchCoordinator.isMainWindow($0) && $0.isVisible }?.makeKeyAndOrderFront(nil) + } + NSApp.activate(ignoringOtherApps: true) + } + + private func switchSchemaOrDatabase(connectionId: UUID, target: String) async { + guard let coordinator = MainContentCoordinator.allActiveCoordinators() + .first(where: { $0.connectionId == connectionId }) else { return } + if PluginManager.shared.supportsSchemaSwitching(for: coordinator.connection.type) { + await coordinator.switchSchema(to: target) + } else { + await coordinator.switchDatabase(to: target) + } + } + + private func runPreConnectScriptIfNeeded(_ connection: DatabaseConnection) async throws { + guard let script = connection.preConnectScript, + !script.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return } + let confirmed = await AlertHelper.confirmDestructive( + title: String(localized: "Pre-Connect Script"), + message: String( + format: String(localized: "Connection \"%@\" has a script that will run before connecting:\n\n%@"), + connection.name, script + ), + confirmButton: String(localized: "Run Script"), + cancelButton: String(localized: "Cancel"), + window: NSApp.keyWindow + ) + guard confirmed else { throw TabRouterError.userCancelled } + } + + private func applyFilterFromParsedURL(parsed: ParsedConnectionURL, connectionId: UUID) async throws { + let description: String + if let condition = parsed.filterCondition, !condition.isEmpty { + description = (condition as NSString).length > 300 + ? String(condition.prefix(300)) + "…" : condition + } else { + description = [parsed.filterColumn, parsed.filterOperation, parsed.filterValue] + .compactMap { $0 }.joined(separator: " ") + } + if !description.isEmpty { + let confirmed = await AlertHelper.confirmDestructive( + title: String(localized: "Apply Filter from Link"), + message: String( + format: String(localized: "An external link wants to apply a filter:\n\n%@"), + description + ), + confirmButton: String(localized: "Apply Filter"), + cancelButton: String(localized: "Cancel"), + window: NSApp.keyWindow + ) + guard confirmed else { throw TabRouterError.userCancelled } + } + + guard let coordinator = MainContentCoordinator.allActiveCoordinators() + .first(where: { $0.connectionId == connectionId }) else { return } + coordinator.applyURLFilter( + condition: parsed.filterCondition, + column: parsed.filterColumn, + operation: parsed.filterOperation, + value: parsed.filterValue + ) + } + + private func closeWelcomeWindows() { + for window in NSApp.windows where AppLaunchCoordinator.isWelcomeWindow(window) { + window.close() + } + } +} diff --git a/TablePro/Core/Services/Infrastructure/URLClassifier.swift b/TablePro/Core/Services/Infrastructure/URLClassifier.swift new file mode 100644 index 000000000..114cf13c8 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/URLClassifier.swift @@ -0,0 +1,48 @@ +// +// URLClassifier.swift +// TablePro +// + +import Foundation + +@MainActor +internal enum URLClassifier { + internal static func classify(_ url: URL) -> Result? { + if url.scheme == "tablepro" { + return DeeplinkParser.parse(url) + } + if url.isFileURL { + return classifyFile(url) + } + if isDatabaseURL(url) { + return .success(.openDatabaseURL(url)) + } + return nil + } + + private static func classifyFile(_ url: URL) -> Result? { + let ext = url.pathExtension.lowercased() + if ext == "tableplugin" { + return .success(.installPlugin(url)) + } + if ext == "tablepro" { + return .success(.openConnectionShare(url)) + } + if ext == "sql" { + return .success(.openSQLFile(url)) + } + if let dbType = PluginManager.shared.allRegisteredFileExtensions[ext] { + return .success(.openDatabaseFile(url, dbType)) + } + return nil + } + + private static func isDatabaseURL(_ url: URL) -> Bool { + guard let scheme = url.scheme?.lowercased() else { return false } + let base = scheme + .replacingOccurrences(of: "+ssh", with: "") + .replacingOccurrences(of: "+srv", with: "") + let registered = PluginManager.shared.allRegisteredURLSchemes + return registered.contains(base) || registered.contains(scheme) + } +} diff --git a/TablePro/Core/Services/Infrastructure/WelcomeRouter.swift b/TablePro/Core/Services/Infrastructure/WelcomeRouter.swift new file mode 100644 index 000000000..f0997e8c9 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/WelcomeRouter.swift @@ -0,0 +1,70 @@ +// +// WelcomeRouter.swift +// TablePro +// + +import AppKit +import Foundation +import Observation + +@MainActor +@Observable +internal final class WelcomeRouter { + internal static let shared = WelcomeRouter() + + private(set) var pendingImport: ExportableConnection? + private(set) var pendingConnectionShare: URL? + private(set) var pendingSQLFiles: [URL] = [] + + private init() { + NotificationCenter.default.addObserver( + forName: .databaseDidConnect, object: nil, queue: .main + ) { _ in + MainActor.assumeIsolated { + WelcomeRouter.shared.drainPendingSQLFiles() + } + } + } + + private func drainPendingSQLFiles() { + let urls = consumePendingSQLFiles() + guard !urls.isEmpty else { return } + NotificationCenter.default.post(name: .openSQLFiles, object: urls) + } + + internal func routeImport(_ exportable: ExportableConnection) { + pendingImport = exportable + showWelcomeWindow() + } + + internal func routeShare(_ url: URL) { + pendingConnectionShare = url + showWelcomeWindow() + } + + internal func enqueueSQLFile(_ url: URL) { + pendingSQLFiles.append(url) + } + + internal func consumePendingImport() -> ExportableConnection? { + let value = pendingImport + pendingImport = nil + return value + } + + internal func consumePendingShare() -> URL? { + let value = pendingConnectionShare + pendingConnectionShare = nil + return value + } + + internal func consumePendingSQLFiles() -> [URL] { + let value = pendingSQLFiles + pendingSQLFiles.removeAll() + return value + } + + private func showWelcomeWindow() { + WelcomeWindowFactory.openOrFront() + } +} diff --git a/TablePro/Core/Services/Infrastructure/WelcomeWindowFactory.swift b/TablePro/Core/Services/Infrastructure/WelcomeWindowFactory.swift new file mode 100644 index 000000000..e64d00315 --- /dev/null +++ b/TablePro/Core/Services/Infrastructure/WelcomeWindowFactory.swift @@ -0,0 +1,55 @@ +// +// WelcomeWindowFactory.swift +// TablePro +// + +import AppKit +import SwiftUI + +@MainActor +internal enum WelcomeWindowFactory { + private static let identifier = NSUserInterfaceItemIdentifier("welcome") + private static let contentSize = NSSize(width: 700, height: 450) + + internal static func openOrFront() { + if let existing = existingWindow() { + existing.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + return + } + let window = makeWindow() + window.makeKeyAndOrderFront(nil) + NSApp.activate(ignoringOtherApps: true) + } + + internal static func close() { + existingWindow()?.close() + } + + internal static func orderOut() { + existingWindow()?.orderOut(nil) + } + + private static func existingWindow() -> NSWindow? { + NSApp.windows.first { AppLaunchCoordinator.isWelcomeWindow($0) } + } + + private static func makeWindow() -> NSWindow { + let hostingController = NSHostingController(rootView: WelcomeWindowView()) + let window = NSWindow(contentViewController: hostingController) + window.identifier = identifier + window.title = String(localized: "Welcome to TablePro") + window.styleMask = [.titled, .closable, .fullSizeContentView] + window.titleVisibility = .hidden + window.titlebarAppearsTransparent = true + window.isOpaque = false + window.backgroundColor = .clear + window.standardWindowButton(.miniaturizeButton)?.isHidden = true + window.standardWindowButton(.zoomButton)?.isHidden = true + window.collectionBehavior.insert(.fullScreenNone) + window.setContentSize(contentSize) + window.center() + window.isReleasedWhenClosed = false + return window + } +} diff --git a/TablePro/Core/Services/Infrastructure/WindowManager.swift b/TablePro/Core/Services/Infrastructure/WindowManager.swift index d4a18eedd..27af3c145 100644 --- a/TablePro/Core/Services/Infrastructure/WindowManager.swift +++ b/TablePro/Core/Services/Infrastructure/WindowManager.swift @@ -2,17 +2,6 @@ // WindowManager.swift // TablePro // -// Imperative AppKit window management for main editor tabs. -// Phase 1 scope: create TabWindowController, install into tab group with -// correct ordering (orderFront before addTabbedWindow — avoids the synchronous -// full-tree layout that slowed the earlier prototype 4–5×), retain strong -// reference, release on willClose. -// -// In later phases WindowManager will also absorb the lookup API currently -// on WindowLifecycleMonitor (windows(for:), previewWindow(for:), etc.). -// In Phase 1, WindowLifecycleMonitor keeps that responsibility — this -// manager only owns window creation + controller lifetime. -// import AppKit import os @@ -24,9 +13,6 @@ internal final class WindowManager { internal static let shared = WindowManager() - /// Strong refs keyed by NSWindow identity. Because - /// `NSWindow.isReleasedWhenClosed = false` on our windows, this is the - /// only owner — dropping the entry deallocates controller + window. private var controllers: [ObjectIdentifier: TabWindowController] = [:] private var closeObservers: [ObjectIdentifier: NSObjectProtocol] = [:] @@ -34,25 +20,12 @@ internal final class WindowManager { // MARK: - Open - /// Creates and shows a new main-editor window hosting ContentView(payload:). - /// If a sibling window with the same tabbingIdentifier is already visible, - /// the new window joins its tab group. internal func openTab(payload: EditorTabPayload) { let t0 = Date() Self.lifecycleLogger.info( "[open] WindowManager.openTab start payloadId=\(payload.id, privacy: .public) connId=\(payload.connectionId, privacy: .public) intent=\(String(describing: payload.intent), privacy: .public) skipAutoExecute=\(payload.skipAutoExecute)" ) - // Eagerly create SessionState (coordinator + tab manager + toolbar state) - // BEFORE constructing the controller. This lets `TabWindowController.init` - // install the NSToolbar synchronously — so the window's first paint - // already has it, eliminating the toolbar-flash that occurs when the - // toolbar is installed later via `configureWindow` (which runs only - // after the window is on-screen). - // - // The same SessionState is handed off to ContentView via - // `SessionStateFactory.consumePending` so only ONE coordinator exists - // per window — no duplicate tabs. let resolvedConnection = DatabaseManager.shared.activeSessions[payload.connectionId]?.connection let preCreatedSessionState: SessionStateFactory.SessionState? if let resolvedConnection { @@ -60,9 +33,6 @@ internal final class WindowManager { SessionStateFactory.registerPending(state, for: payload.id) preCreatedSessionState = state } else { - // Connection not ready yet (welcome → connect race). Fall back to - // lazy SessionState creation inside ContentView.init + lazy toolbar - // install via configureWindow. preCreatedSessionState = nil } @@ -71,31 +41,14 @@ internal final class WindowManager { Self.lifecycleLogger.error( "[open] WindowManager.openTab failed: controller has no window payloadId=\(payload.id, privacy: .public)" ) - // Clean up the pending state we registered above so it doesn't leak. SessionStateFactory.removePending(for: payload.id) return } retain(controller: controller, window: window) - // Pre-mark so AppDelegate.windowDidBecomeKey skips its tabbing-merge - // block (we do the merge here, at creation, with the correct ordering). - if let appDelegate = NSApp.delegate as? AppDelegate { - appDelegate.configuredWindows.insert(ObjectIdentifier(window)) - } - - // --- Tab-group merge, correctly ordered --- - // - // The earlier prototype called `addTabbedWindow(window, …)` before - // the window was visible. AppKit responded by synchronously flushing - // the NSHostingView's SwiftUI layout (NavigationSplitView + editor + - // TreeSitterClient warmup) on the main thread — observed cost - // 800–960 ms per open. - // - // Ordering `orderFront(nil)` first makes the window visible and lets - // SwiftUI render asynchronously via its normal display cycle. Then - // `addTabbedWindow` re-parents an already-visible window into the - // tab group, which is a cheap AppKit-level operation. + // orderFront before addTabbedWindow avoids a synchronous full-tree + // SwiftUI layout pass that adds 700-900ms per open. let tabbingId = window.tabbingIdentifier ?? "" let groupAll = AppSettingsManager.shared.tabs.groupAllConnectionTabs let sibling = findSibling( @@ -103,13 +56,7 @@ internal final class WindowManager { ) if let sibling { - // Tab-merge: `addTabbedWindow(_:ordered:)` both adds the window to - // the group AND orders it — calling orderFront separately beforehand - // triggers a redundant layout pass on NSHostingView (observed cost - // 700-900ms vs. 75ms standalone). Let addTabbedWindow do both at once. if groupAll { - // groupAll mode: retag every visible main window with the unified - // identifier so addTabbedWindow is willing to merge. let otherMains = NSApp.windows.filter { $0 !== window && Self.isMainWindow($0) && $0.isVisible } @@ -119,29 +66,19 @@ internal final class WindowManager { } let target = sibling.tabbedWindows?.last ?? sibling target.addTabbedWindow(window, ordered: .above) - // `addTabbedWindow(_:ordered:)` only inserts — it doesn't select - // the new tab in the group. `makeKeyAndOrderFront` brings this - // window to the front of the group AND makes it key, which is - // what the user expects on Cmd+T. window.makeKeyAndOrderFront(nil) Self.lifecycleLogger.info( "[open] WindowManager joined existing tab group payloadId=\(payload.id, privacy: .public) tabbingId=\(tabbingId, privacy: .public)" ) } else { - // Standalone case: center the frame BEFORE showing so the window - // doesn't flash at the default (0,0) position before jumping. - // `makeKeyAndOrderFront` is the standard AppKit idiom for this. window.center() window.makeKeyAndOrderFront(nil) - // Ensure the app is active when opening from a background context - // (e.g. Welcome window's Connect button races with welcome close). NSApp.activate(ignoringOtherApps: true) Self.lifecycleLogger.info( "[open] WindowManager standalone window payloadId=\(payload.id, privacy: .public) tabbingId=\(tabbingId, privacy: .public)" ) } - Self.lifecycleLogger.info( "[open] WindowManager.openTab done payloadId=\(payload.id, privacy: .public) elapsedMs=\(Int(Date().timeIntervalSince(t0) * 1_000))" ) @@ -177,10 +114,6 @@ internal final class WindowManager { return raw == "main" || raw.hasPrefix("main-") } - /// Tabbing identifier for a connection. Per-connection by default; - /// shared "com.TablePro.main" when the user enables Group All Connection - /// Tabs in Settings → Tabs. Used by `TabWindowController.init` and by - /// AppDelegate's pre-Phase-1 fallback in `windowDidBecomeKey`. internal static func tabbingIdentifier(for connectionId: UUID) -> String { if AppSettingsManager.shared.tabs.groupAllConnectionTabs { return "com.TablePro.main" diff --git a/TablePro/Core/Services/Infrastructure/WindowOpener.swift b/TablePro/Core/Services/Infrastructure/WindowOpener.swift deleted file mode 100644 index e8a746911..000000000 --- a/TablePro/Core/Services/Infrastructure/WindowOpener.swift +++ /dev/null @@ -1,49 +0,0 @@ -// -// WindowOpener.swift -// TablePro -// -// Bridges SwiftUI's `OpenWindowAction` to imperative call sites for the -// remaining SwiftUI scenes (Welcome, Connection Form, Settings). The main -// editor windows no longer use this — they go through `WindowManager.openTab` -// directly. -// - -import os -import SwiftUI - -@MainActor -internal final class WindowOpener { - private static let logger = Logger(subsystem: "com.TablePro", category: "WindowOpener") - - internal static let shared = WindowOpener() - - private var readyContinuations: [CheckedContinuation] = [] - - /// Set on appear by `OpenWindowHandler` (TableProApp). Used to open the - /// welcome window, connection form, and settings from imperative code. - /// Safe to store — `OpenWindowAction` is app-scoped, not view-scoped. - internal var openWindow: OpenWindowAction? { - didSet { - if openWindow != nil { - for continuation in readyContinuations { - continuation.resume() - } - readyContinuations.removeAll() - } - } - } - - /// Suspends until `openWindow` is set. Returns immediately if available. - /// Used by Dock-menu / URL-scheme cold-launch paths that fire before any - /// SwiftUI view has appeared. - internal func waitUntilReady() async { - if openWindow != nil { return } - await withCheckedContinuation { continuation in - if openWindow != nil { - continuation.resume() - } else { - readyContinuations.append(continuation) - } - } - } -} diff --git a/TablePro/Core/Services/Query/SchemaService.swift b/TablePro/Core/Services/Query/SchemaService.swift new file mode 100644 index 000000000..8df9bf0fc --- /dev/null +++ b/TablePro/Core/Services/Query/SchemaService.swift @@ -0,0 +1,87 @@ +// +// SchemaService.swift +// TablePro +// + +import Foundation +import os + +@MainActor +@Observable +final class SchemaService { + static let shared = SchemaService() + + private(set) var states: [UUID: SchemaState] = [:] + + @ObservationIgnored private var lastLoadDates: [UUID: Date] = [:] + @ObservationIgnored private let loadDedup = OnceTask() + @ObservationIgnored private static let logger = Logger(subsystem: "com.TablePro", category: "SchemaService") + + init() {} + + func state(for connectionId: UUID) -> SchemaState { + states[connectionId] ?? .idle + } + + func tables(for connectionId: UUID) -> [TableInfo] { + if case .loaded(let tables) = state(for: connectionId) { + return tables + } + return [] + } + + func load(connectionId: UUID, driver: DatabaseDriver, connection: DatabaseConnection) async { + switch state(for: connectionId) { + case .loaded: + return + case .idle, .loading, .failed: + await runLoad(connectionId: connectionId, driver: driver, connection: connection) + } + } + + func reload(connectionId: UUID, driver: DatabaseDriver, connection: DatabaseConnection) async { + await runLoad(connectionId: connectionId, driver: driver, connection: connection) + } + + func reloadIfStale( + connectionId: UUID, + driver: DatabaseDriver, + connection: DatabaseConnection, + staleness: TimeInterval + ) async { + guard let lastLoad = lastLoadDates[connectionId] else { + await reload(connectionId: connectionId, driver: driver, connection: connection) + return + } + guard Date().timeIntervalSince(lastLoad) > staleness else { return } + await reload(connectionId: connectionId, driver: driver, connection: connection) + } + + func invalidate(connectionId: UUID) async { + await loadDedup.cancel(key: connectionId) + states.removeValue(forKey: connectionId) + lastLoadDates.removeValue(forKey: connectionId) + } + + private func runLoad( + connectionId: UUID, + driver: DatabaseDriver, + connection: DatabaseConnection + ) async { + states[connectionId] = .loading + do { + let tables = try await loadDedup.execute(key: connectionId) { + try await driver.fetchTables() + } + states[connectionId] = .loaded(tables) + lastLoadDates[connectionId] = Date() + } catch is CancellationError { + return + } catch { + Self.logger.warning( + "[schema] load failed connId=\(connectionId, privacy: .public) error=\(error.localizedDescription, privacy: .public)" + ) + states[connectionId] = .failed(error.localizedDescription) + } + } +} diff --git a/TablePro/Core/Services/Query/SchemaState.swift b/TablePro/Core/Services/Query/SchemaState.swift new file mode 100644 index 000000000..c61cbe46d --- /dev/null +++ b/TablePro/Core/Services/Query/SchemaState.swift @@ -0,0 +1,13 @@ +// +// SchemaState.swift +// TablePro +// + +import Foundation + +enum SchemaState: Equatable, Sendable { + case idle + case loading + case loaded([TableInfo]) + case failed(String) +} diff --git a/TablePro/Core/Storage/TabDiskActor.swift b/TablePro/Core/Storage/TabDiskActor.swift index 8f3382948..25e81f7a1 100644 --- a/TablePro/Core/Storage/TabDiskActor.swift +++ b/TablePro/Core/Storage/TabDiskActor.swift @@ -10,16 +10,11 @@ import Foundation import os -/// Persisted tab state for a connection internal struct TabDiskState: Codable { let tabs: [PersistedTab] let selectedTabId: UUID? } -/// Actor that serializes all tab-state disk I/O. -/// -/// Data is stored as individual JSON files per connection in: -/// `~/Library/Application Support/TablePro/TabState/` internal actor TabDiskActor { internal static let shared = TabDiskActor() @@ -52,7 +47,6 @@ internal actor TabDiskActor { // MARK: - Public API - /// Save tab state for a connection. Throws on encoding or disk write failure. internal func save(connectionId: UUID, tabs: [PersistedTab], selectedTabId: UUID?) throws { let state = TabDiskState(tabs: tabs, selectedTabId: selectedTabId) let data = try encoder.encode(state) @@ -60,12 +54,6 @@ internal actor TabDiskActor { try data.write(to: fileURL, options: .atomic) } - /// Log a save error from callers that handle errors externally. - nonisolated static func logSaveError(connectionId: UUID, error: Error) { - logger.error("Failed to save tab state for \(connectionId): \(error.localizedDescription)") - } - - /// Load tab state for a connection. Returns nil if the file is missing or corrupt. internal func load(connectionId: UUID) -> TabDiskState? { let fileURL = tabStateFileURL(for: connectionId) @@ -82,7 +70,6 @@ internal actor TabDiskActor { } } - /// Delete the tab state file for a connection. internal func clear(connectionId: UUID) { let fileURL = tabStateFileURL(for: connectionId) @@ -95,7 +82,6 @@ internal actor TabDiskActor { } } - /// List all connection IDs that have saved tab state on disk. internal func connectionIdsWithSavedState() -> [UUID] { let fm = FileManager.default guard let files = try? fm.contentsOfDirectory( @@ -104,10 +90,18 @@ internal actor TabDiskActor { ) else { return [] } - return files.compactMap { url -> UUID? in - guard url.pathExtension == "json" else { return nil } - return UUID(uuidString: url.deletingPathExtension().lastPathComponent) + var validIds: [UUID] = [] + for url in files where url.pathExtension == "json" { + guard let id = UUID(uuidString: url.deletingPathExtension().lastPathComponent) else { continue } + if let data = try? Data(contentsOf: url), + let state = try? decoder.decode(TabDiskState.self, from: data), + !state.tabs.isEmpty { + validIds.append(id) + } else { + try? fm.removeItem(at: url) + } } + return validIds } // MARK: - Static Path Helpers @@ -127,9 +121,6 @@ internal actor TabDiskActor { // MARK: - Synchronous Save (quit-time only) - /// Synchronous file write for `applicationWillTerminate`, where no run loop - /// remains to execute an async Task. Safe because the process is single-threaded - /// at termination — no concurrent actor access is possible. nonisolated internal static func saveSync( connectionId: UUID, tabs: [PersistedTab], @@ -145,7 +136,17 @@ internal actor TabDiskActor { let fileURL = tabStateFileURL(for: connectionId) try data.write(to: fileURL, options: .atomic) } catch { - logger.error("saveSync failed for \(connectionId): \(error.localizedDescription)") + logger.fault("saveSync failed for \(connectionId): \(error.localizedDescription)") + } + } + + nonisolated internal static func clearSync(connectionId: UUID) { + let fileURL = tabStateFileURL(for: connectionId) + guard FileManager.default.fileExists(atPath: fileURL.path) else { return } + do { + try FileManager.default.removeItem(at: fileURL) + } catch { + logger.fault("clearSync failed for \(connectionId): \(error.localizedDescription)") } } @@ -157,9 +158,6 @@ internal actor TabDiskActor { // MARK: - Migration from UserDefaults - /// One-time migration: reads existing tab state from UserDefaults, - /// writes it to file storage, then clears the old UserDefaults keys. - /// This is a static method to avoid actor-isolation issues during init. private static func performMigrationIfNeeded(tabStateDirectory: URL) { let defaults = UserDefaults.standard diff --git a/TablePro/Core/Utilities/Connection/TransientConnectionFactory.swift b/TablePro/Core/Utilities/Connection/TransientConnectionFactory.swift new file mode 100644 index 000000000..ea2fa15a2 --- /dev/null +++ b/TablePro/Core/Utilities/Connection/TransientConnectionFactory.swift @@ -0,0 +1,71 @@ +// +// TransientConnectionFactory.swift +// TablePro +// + +import Foundation + +@MainActor +internal enum TransientConnectionFactory { + internal static func build(from parsed: ParsedConnectionURL) -> DatabaseConnection { + var sshConfig = SSHConfiguration() + if let sshHost = parsed.sshHost { + sshConfig.enabled = true + sshConfig.host = sshHost + sshConfig.port = parsed.sshPort ?? 22 + sshConfig.username = parsed.sshUsername ?? "" + if parsed.usePrivateKey == true { + sshConfig.authMethod = .privateKey + } + if parsed.useSSHAgent == true { + sshConfig.authMethod = .sshAgent + sshConfig.agentSocketPath = parsed.agentSocket ?? "" + } + } + + var sslConfig = SSLConfiguration() + if let sslMode = parsed.sslMode { + sslConfig.mode = sslMode + } + + var color: ConnectionColor = .none + if let hex = parsed.statusColor { + color = ConnectionURLParser.connectionColor(fromHex: hex) + } + + var tagId: UUID? + if let envName = parsed.envTag { + tagId = ConnectionURLParser.tagId(fromEnvName: envName) + } + + let resolvedSafeMode = parsed.safeModeLevel.flatMap(SafeModeLevel.from(urlInteger:)) ?? .silent + + var connection = DatabaseConnection( + name: parsed.connectionName ?? parsed.suggestedName, + host: parsed.host, + port: parsed.port ?? parsed.type.defaultPort, + database: parsed.database, + username: parsed.username, + type: parsed.type, + sshConfig: sshConfig, + sslConfig: sslConfig, + color: color, + tagId: tagId, + safeModeLevel: resolvedSafeMode, + mongoAuthSource: parsed.authSource, + mongoUseSrv: parsed.useSrv, + mongoAuthMechanism: parsed.mongoQueryParams["authMechanism"], + mongoReplicaSet: parsed.mongoQueryParams["replicaSet"], + redisDatabase: parsed.redisDatabase, + oracleServiceName: parsed.oracleServiceName + ) + + for (key, value) in parsed.mongoQueryParams where !value.isEmpty { + if key != "authMechanism" && key != "replicaSet" { + connection.additionalFields["mongoParam_\(key)"] = value + } + } + + return connection + } +} diff --git a/TablePro/Core/Utilities/UI/AlertHelper.swift b/TablePro/Core/Utilities/UI/AlertHelper.swift index 07341e5e8..c1a43c226 100644 --- a/TablePro/Core/Utilities/UI/AlertHelper.swift +++ b/TablePro/Core/Utilities/UI/AlertHelper.swift @@ -2,31 +2,18 @@ // AlertHelper.swift // TablePro // -// Created by TablePro on 1/19/26. -// import AppKit +import SwiftUI -/// Centralized helper for creating and displaying NSAlert dialogs -/// Provides consistent styling and behavior across the application @MainActor final class AlertHelper { - /// Tries multiple sources to find a presentable window, minimizing runModal() fallback usage. static func resolveWindow(_ window: NSWindow?) -> NSWindow? { window ?? NSApp.keyWindow ?? NSApp.mainWindow ?? NSApp.windows.first { $0.isVisible } } // MARK: - Destructive Confirmations - /// Shows a destructive confirmation dialog (warning style) - /// Uses async sheet presentation when window is available, falls back to modal - /// - Parameters: - /// - title: Alert title - /// - message: Detailed message - /// - confirmButton: Label for destructive action button (default: "OK") - /// - cancelButton: Label for cancel button (default: "Cancel") - /// - window: Parent window to attach sheet to (optional) - /// - Returns: true if user confirmed, false if cancelled static func confirmDestructive( title: String, message: String, @@ -41,32 +28,18 @@ final class AlertHelper { alert.addButton(withTitle: confirmButton) alert.addButton(withTitle: cancelButton) - // Use sheet presentation when window is available (non-blocking, Swift 6 friendly) if let window = resolveWindow(window) { return await withCheckedContinuation { continuation in alert.beginSheetModal(for: window) { response in continuation.resume(returning: response == .alertFirstButtonReturn) } } - } else { - // Fallback to modal when no window available - let response = alert.runModal() - return response == .alertFirstButtonReturn } + return alert.runModal() == .alertFirstButtonReturn } // MARK: - Critical Confirmations - /// Shows a critical confirmation dialog (critical style) - /// Uses async sheet presentation when window is available, falls back to modal - /// Used for dangerous operations like DROP, TRUNCATE, DELETE without WHERE - /// - Parameters: - /// - title: Alert title - /// - message: Detailed message - /// - confirmButton: Label for dangerous action button (default: "Execute") - /// - cancelButton: Label for cancel button (default: "Cancel") - /// - window: Parent window to attach sheet to (optional) - /// - Returns: true if user confirmed, false if cancelled static func confirmCritical( title: String, message: String, @@ -81,33 +54,79 @@ final class AlertHelper { alert.addButton(withTitle: confirmButton) alert.addButton(withTitle: cancelButton) - // Use sheet presentation when window is available (non-blocking, Swift 6 friendly) if let window = resolveWindow(window) { return await withCheckedContinuation { continuation in alert.beginSheetModal(for: window) { response in continuation.resume(returning: response == .alertFirstButtonReturn) } } - } else { - // Fallback to modal when no window available - let response = alert.runModal() - return response == .alertFirstButtonReturn + } + return alert.runModal() == .alertFirstButtonReturn + } + + // MARK: - Cross-Process Approval + + static func runApprovalModal( + title: String, + message: String, + confirm: String, + cancel: String + ) async -> Bool { + NSApp.activate(ignoringOtherApps: true) + let alert = NSAlert() + alert.messageText = title + alert.informativeText = message + alert.alertStyle = .warning + alert.addButton(withTitle: confirm) + alert.addButton(withTitle: cancel) + return alert.runModal() == .alertFirstButtonReturn + } + + static func runPairingApproval(request: PairingRequest) async throws -> PairingApproval { + try await withCheckedThrowingContinuation { continuation in + var deliver: ((Result) -> Void)? + let host = NSHostingController( + rootView: PairingApprovalSheet( + request: request, + onComplete: { result in deliver?(result) } + ) + ) + host.view.frame = NSRect(x: 0, y: 0, width: 520, height: 560) + + let parent = resolveWindow(nil) + let sheetWindow = NSWindow(contentViewController: host) + sheetWindow.styleMask = [.titled] + sheetWindow.title = String(localized: "Approve Integration") + sheetWindow.isReleasedWhenClosed = false + + var resolved = false + deliver = { result in + guard !resolved else { return } + resolved = true + if let parent { + parent.endSheet(sheetWindow) + } else { + sheetWindow.close() + } + continuation.resume(with: result) + } + + if let parent { + parent.beginSheet(sheetWindow, completionHandler: nil) + } else { + NSApp.activate(ignoringOtherApps: true) + sheetWindow.center() + sheetWindow.makeKeyAndOrderFront(nil) + } } } // MARK: - Save Changes Confirmation - /// Result of a standard macOS save-changes confirmation dialog enum SaveConfirmationResult { case save, dontSave, cancel } - /// Shows the standard macOS "save changes before closing?" dialog. - /// Button layout matches NSDocument convention: Save (default) | Cancel | Don't Save (Cmd+D). - /// - Parameters: - /// - message: Detailed message explaining what has unsaved changes - /// - window: Parent window to attach sheet to (optional) - /// - Returns: The user's choice static func confirmSaveChanges( message: String, window: NSWindow? = nil @@ -117,17 +136,15 @@ final class AlertHelper { alert.informativeText = message alert.alertStyle = .warning - // Button order follows macOS convention (rightmost → leftmost): - // [Don't Save] [Cancel] [Save] - alert.addButton(withTitle: String(localized: "Save")) // alertFirstButtonReturn (default) - alert.addButton(withTitle: String(localized: "Cancel")) // alertSecondButtonReturn - let dontSaveButton = alert.addButton(withTitle: String(localized: "Don't Save")) // alertThirdButtonReturn + // Button order follows NSDocument convention: Save | Cancel | Don't Save (Cmd+D) + alert.addButton(withTitle: String(localized: "Save")) + alert.addButton(withTitle: String(localized: "Cancel")) + let dontSaveButton = alert.addButton(withTitle: String(localized: "Don't Save")) dontSaveButton.hasDestructiveAction = true dontSaveButton.keyEquivalent = "d" dontSaveButton.keyEquivalentModifierMask = .command let response: NSApplication.ModalResponse - if let window = resolveWindow(window) { response = await withCheckedContinuation { continuation in alert.beginSheetModal(for: window) { resp in @@ -139,27 +156,14 @@ final class AlertHelper { } switch response { - case .alertFirstButtonReturn: - return .save - case .alertThirdButtonReturn: - return .dontSave - default: - return .cancel + case .alertFirstButtonReturn: return .save + case .alertThirdButtonReturn: return .dontSave + default: return .cancel } } // MARK: - Three-Way Confirmations - /// Shows a three-option confirmation dialog - /// Uses async sheet presentation when window is available, falls back to modal - /// - Parameters: - /// - title: Alert title - /// - message: Detailed message - /// - first: Label for first button - /// - second: Label for second button - /// - third: Label for third button - /// - window: Parent window to attach sheet to (optional) - /// - Returns: 0 for first button, 1 for second, 2 for third static func confirmThreeWay( title: String, message: String, @@ -177,8 +181,6 @@ final class AlertHelper { alert.addButton(withTitle: third) let response: NSApplication.ModalResponse - - // Use sheet presentation when window is available (non-blocking, Swift 6 friendly) if let window = resolveWindow(window) { response = await withCheckedContinuation { continuation in alert.beginSheetModal(for: window) { resp in @@ -186,29 +188,19 @@ final class AlertHelper { } } } else { - // Fallback to modal when no window available response = alert.runModal() } switch response { - case .alertFirstButtonReturn: - return 0 - case .alertSecondButtonReturn: - return 1 - case .alertThirdButtonReturn: - return 2 - default: - return 2 // Default to third option (usually cancel) + case .alertFirstButtonReturn: return 0 + case .alertSecondButtonReturn: return 1 + case .alertThirdButtonReturn: return 2 + default: return 2 } } - // MARK: - Error Sheets + // MARK: - Error / Info Sheets - /// Shows an error message as a non-blocking sheet - /// - Parameters: - /// - title: Error title - /// - message: Error details - /// - window: Parent window to attach sheet to (optional, falls back to modal) static func showErrorSheet( title: String, message: String, @@ -221,22 +213,12 @@ final class AlertHelper { alert.addButton(withTitle: String(localized: "OK")) if let window = resolveWindow(window) { - alert.beginSheetModal(for: window) { _ in - // Sheet dismissed, no action needed - } + alert.beginSheetModal(for: window) { _ in } } else { - // Fallback to modal if no window available alert.runModal() } } - // MARK: - Info Sheets - - /// Shows an informational message as a non-blocking sheet - /// - Parameters: - /// - title: Info title - /// - message: Info details - /// - window: Parent window to attach sheet to (optional, falls back to modal) static func showInfoSheet( title: String, message: String, @@ -249,23 +231,14 @@ final class AlertHelper { alert.addButton(withTitle: String(localized: "OK")) if let window = resolveWindow(window) { - alert.beginSheetModal(for: window) { _ in - // Sheet dismissed, no action needed - } + alert.beginSheetModal(for: window) { _ in } } else { - // Fallback to modal if no window available alert.runModal() } } // MARK: - Query Error with AI Option - /// Shows a query error dialog with an option to ask AI to fix it - /// - Parameters: - /// - title: Error title - /// - message: Error details - /// - window: Parent window to attach sheet to (optional) - /// - Returns: true if "Ask AI to Fix" was clicked static func showQueryErrorWithAIOption( title: String, message: String, @@ -284,9 +257,7 @@ final class AlertHelper { continuation.resume(returning: response == .alertSecondButtonReturn) } } - } else { - let response = alert.runModal() - return response == .alertSecondButtonReturn } + return alert.runModal() == .alertSecondButtonReturn } } diff --git a/TablePro/Extensions/URL+SanitizedLogging.swift b/TablePro/Extensions/URL+SanitizedLogging.swift new file mode 100644 index 000000000..18f1fc93c --- /dev/null +++ b/TablePro/Extensions/URL+SanitizedLogging.swift @@ -0,0 +1,17 @@ +// +// URL+SanitizedLogging.swift +// TablePro +// + +import Foundation + +internal extension URL { + var sanitizedForLogging: String { + guard var components = URLComponents(url: self, resolvingAgainstBaseURL: false), + components.password != nil else { + return absoluteString + } + components.password = "***" + return components.string ?? absoluteString + } +} diff --git a/TablePro/Models/Connection/ConnectionSession.swift b/TablePro/Models/Connection/ConnectionSession.swift index 806907039..734b2bc95 100644 --- a/TablePro/Models/Connection/ConnectionSession.swift +++ b/TablePro/Models/Connection/ConnectionSession.swift @@ -18,7 +18,6 @@ struct ConnectionSession: Identifiable { var lastError: String? // Per-connection state - var tables: [TableInfo] = [] var selectedTables: Set = [] var pendingTruncates: Set = [] var pendingDeletes: Set = [] @@ -26,6 +25,11 @@ struct ConnectionSession: Identifiable { var currentSchema: String? var currentDatabase: String? + @MainActor + var tables: [TableInfo] { + SchemaService.shared.tables(for: id) + } + /// In-memory password for prompt-for-password connections. Never persisted to disk. var cachedPassword: String? @@ -63,7 +67,6 @@ struct ConnectionSession: Identifiable { /// to release memory held by stale table metadata. /// Note: `cachedPassword` is intentionally NOT cleared — auto-reconnect needs it after disconnect. mutating func clearCachedData() { - tables = [] selectedTables = [] pendingTruncates = [] pendingDeletes = [] @@ -80,12 +83,12 @@ struct ConnectionSession: Identifiable { /// Compares fields used by ContentView's body to avoid unnecessary SwiftUI re-renders. /// Excludes: driver (protocol, non-comparable), - /// lastActiveAt (volatile), lastError, effectiveConnection. + /// lastActiveAt (volatile), lastError, effectiveConnection, + /// tables (owned by SchemaService and observed independently). func isContentViewEquivalent(to other: ConnectionSession) -> Bool { id == other.id && status == other.status && connection == other.connection - && tables == other.tables && pendingTruncates == other.pendingTruncates && pendingDeletes == other.pendingDeletes && tableOperationOptions == other.tableOperationOptions diff --git a/TablePro/Models/Connection/ConnectionToolbarState.swift b/TablePro/Models/Connection/ConnectionToolbarState.swift index 9faaf3b75..484b38b82 100644 --- a/TablePro/Models/Connection/ConnectionToolbarState.swift +++ b/TablePro/Models/Connection/ConnectionToolbarState.swift @@ -195,6 +195,9 @@ final class ConnectionToolbarState { /// Whether the SQL review popover is showing var showSQLReviewPopover: Bool = false + /// Whether the connection switcher popover is showing + var showConnectionSwitcher: Bool = false + /// SQL statements to display in the review popover var previewStatements: [String] = [] diff --git a/TablePro/Models/Query/QueryResult.swift b/TablePro/Models/Query/QueryResult.swift index e7fa54063..92434a0fc 100644 --- a/TablePro/Models/Query/QueryResult.swift +++ b/TablePro/Models/Query/QueryResult.swift @@ -72,7 +72,7 @@ enum DatabaseError: Error, LocalizedError { } /// Information about a database table -struct TableInfo: Identifiable, Hashable { +struct TableInfo: Identifiable, Hashable, Sendable { var id: String { "\(name)_\(type.rawValue)" } let name: String let type: TableType diff --git a/TablePro/Models/Query/QueryTabManager.swift b/TablePro/Models/Query/QueryTabManager.swift index 98c1be0f2..0f6d387c5 100644 --- a/TablePro/Models/Query/QueryTabManager.swift +++ b/TablePro/Models/Query/QueryTabManager.swift @@ -26,6 +26,12 @@ final class QueryTabManager { @ObservationIgnored private var _tabIndexMap: [UUID: Int] = [:] @ObservationIgnored private var _tabIndexMapDirty = true + @ObservationIgnored private let globalTabsProvider: () -> [QueryTab] + + init(globalTabsProvider: @escaping () -> [QueryTab] = { [] }) { + self.globalTabsProvider = globalTabsProvider + } + private func rebuildTabIndexMapIfNeeded() { guard _tabIndexMapDirty else { return } _tabIndexMap = Dictionary(uniqueKeysWithValues: tabs.enumerated().map { ($1.id, $0) }) @@ -50,11 +56,6 @@ final class QueryTabManager { return (tabs[index], index) } - init() { - tabs = [] - selectedTabId = nil - } - // MARK: - Tab Naming /// Next "Query N" title based on existing tabs across all windows. @@ -69,6 +70,10 @@ final class QueryTabManager { return "Query \(maxNumber + 1)" } + private func nextTitle() -> String { + Self.nextQueryTitle(existingTabs: globalTabsProvider() + tabs) + } + // MARK: - Tab Management func addTab(initialQuery: String? = nil, title: String? = nil, databaseName: String = "", sourceFileURL: URL? = nil) { @@ -81,7 +86,7 @@ final class QueryTabManager { return } - let tabTitle = title ?? Self.nextQueryTitle(existingTabs: tabs) + let tabTitle = title ?? nextTitle() var newTab = QueryTab(title: tabTitle, tabType: .query) if let query = initialQuery { @@ -181,6 +186,13 @@ final class QueryTabManager { databaseName: String = "", quoteIdentifier: ((String) -> String)? = nil ) throws { + if let existing = tabs.first(where: { + $0.tabType == .table && $0.tableContext.tableName == tableName && $0.tableContext.databaseName == databaseName + }) { + selectedTabId = existing.id + return + } + let pageSize = AppSettingsManager.shared.dataGrid.defaultPageSize let query = try QueryTab.buildBaseTableQuery( tableName: tableName, databaseType: databaseType, quoteIdentifier: quoteIdentifier diff --git a/TablePro/Resources/Localizable.xcstrings b/TablePro/Resources/Localizable.xcstrings index 1dfd89061..bdb760dc2 100644 --- a/TablePro/Resources/Localizable.xcstrings +++ b/TablePro/Resources/Localizable.xcstrings @@ -2668,6 +2668,9 @@ } } } + }, + "1 day" : { + }, "1 of %lld conflicts" : { "localizations" : { @@ -3456,6 +3459,9 @@ } } } + }, + "Access Level" : { + }, "Account" : { "localizations" : { @@ -3801,6 +3807,9 @@ } } } + }, + "Activity Log" : { + }, "Actual" : { "localizations" : { @@ -4301,6 +4310,9 @@ } } } + }, + "Administration" : { + }, "Advanced" : { @@ -4681,6 +4693,9 @@ } } } + }, + "All categories" : { + }, "All columns" : { "localizations" : { @@ -4841,6 +4856,9 @@ } } } + }, + "All time" : { + }, "All Time" : { "localizations" : { @@ -4863,6 +4881,9 @@ } } } + }, + "All tokens" : { + }, "Allow" : { "localizations" : { @@ -4885,6 +4906,9 @@ } } } + }, + "Allow %@ to access TablePro?" : { + }, "Allow AI Access" : { "localizations" : { @@ -4910,6 +4934,9 @@ }, "Allow remote connections" : { + }, + "Allowed Connections" : { + }, "Also handles" : { "localizations" : { @@ -5050,6 +5077,9 @@ }, "Always Show" : { + }, + "An external app is asking for an API token. Review the permissions before approving." : { + }, "An external link wants to add a database connection:\n\nName: %@\n%@" : { "extractionState" : "stale", @@ -5102,7 +5132,18 @@ } } }, + "An external link wants to open a query on \"%@\":\n\n%@" : { + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "new", + "value" : "An external link wants to open a query on \"%1$@\":\n\n%2$@" + } + } + } + }, "An external link wants to open a query on connection \"%@\":\n\n%@" : { + "extractionState" : "stale", "localizations" : { "en" : { "stringUnit" : { @@ -5630,6 +5671,12 @@ } } } + }, + "Approve" : { + + }, + "Approve Integration" : { + }, "Are you sure you want to cancel the running query for this session?" : { "localizations" : { @@ -6676,6 +6723,9 @@ } } } + }, + "Blocked" : { + }, "Blue" : { "localizations" : { @@ -7190,6 +7240,9 @@ } } } + }, + "Cancelled by user." : { + }, "Cannot connect to Ollama at %@. Is Ollama running?" : { "localizations" : { @@ -9675,6 +9728,9 @@ } } } + }, + "Connection is read-only for external clients" : { + }, "Connection lost" : { "localizations" : { @@ -9719,8 +9775,12 @@ } } } + }, + "Connection not found" : { + }, "Connection Not Found" : { + "extractionState" : "stale", "localizations" : { "tr" : { "stringUnit" : { @@ -9922,6 +9982,9 @@ } } } + }, + "Connection: %@" : { + }, "Connections" : { "localizations" : { @@ -10077,6 +10140,9 @@ } } } + }, + "Controls how external clients (Raycast, Cursor, Claude Desktop) access this connection. Tokens cannot exceed this level even with full-access scope." : { + }, "Conversation History" : { "extractionState" : "stale", @@ -10836,6 +10902,12 @@ }, "Could not generate SQL for changes." : { + }, + "Could Not Open File" : { + + }, + "Could not parse database URL: %@" : { + }, "Could not reach the license server. Check your internet connection and try again." : { "localizations" : { @@ -13512,6 +13584,9 @@ } } } + }, + "Denied" : { + }, "Deny" : { "localizations" : { @@ -17213,6 +17288,12 @@ } } } + }, + "External Access" : { + + }, + "External access is disabled for this connection" : { + }, "Extra Large" : { "extractionState" : "stale", @@ -19409,6 +19490,9 @@ }, "Full Access" : { + }, + "Full access including destructive DDL after explicit confirmation." : { + }, "Function" : { "localizations" : { @@ -21825,6 +21909,9 @@ } } } + }, + "Integrations" : { + }, "Interactive Data Grid" : { "localizations" : { @@ -22162,6 +22249,9 @@ } } } + }, + "Invalid UUID: %@" : { + }, "Invisibles" : { "localizations" : { @@ -22847,6 +22937,15 @@ } } } + }, + "Last 7 days" : { + + }, + "Last 24 hours" : { + + }, + "Last 30 days" : { + }, "Last query execution summary" : { "localizations" : { @@ -24080,6 +24179,9 @@ } } } + }, + "Malformed deep link path: %@" : { + }, "Manage Connections" : { "localizations" : { @@ -24481,9 +24583,6 @@ } } } - }, - "MCP" : { - }, "MCP Access Request" : { "localizations" : { @@ -24860,6 +24959,9 @@ } } } + }, + "Missing required parameter: %@" : { + }, "Missing value for parameter: %@" : { @@ -26114,6 +26216,9 @@ } } } + }, + "No activity yet" : { + }, "No AI provider configured. Go to Settings > AI to add one." : { "localizations" : { @@ -26494,6 +26599,16 @@ } } }, + "No free port in range %d-%d" : { + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "new", + "value" : "No free port in range %1$d-%2$d" + } + } + } + }, "No iCloud" : { "localizations" : { "tr" : { @@ -27119,6 +27234,7 @@ } }, "No saved connection named \"%@\"." : { + "extractionState" : "stale", "localizations" : { "tr" : { "stringUnit" : { @@ -27139,6 +27255,9 @@ } } } + }, + "No saved connection with ID \"%@\"." : { + }, "No saved connections" : { @@ -29209,6 +29328,9 @@ } } } + }, + "Pairing Failed" : { + }, "Panel State" : { "localizations" : { @@ -32119,6 +32241,12 @@ } } } + }, + "Range" : { + + }, + "Rate limited" : { + }, "Rate limited. Please try again later." : { "localizations" : { @@ -32280,6 +32408,12 @@ } } } + }, + "Read schema and run any non-destructive query, including INSERT, UPDATE, and DELETE." : { + + }, + "Read schema and run SELECT queries." : { + }, "Read-only" : { "extractionState" : "stale", @@ -32370,6 +32504,9 @@ } } } + }, + "Read-Write" : { + }, "Reading connections..." : { "localizations" : { @@ -33674,6 +33811,9 @@ } } } + }, + "Resource" : { + }, "Restart TablePro for the language change to take full effect." : { "localizations" : { @@ -37710,6 +37850,9 @@ } } } + }, + "SQL dialect for %@ is not available. The plugin may not be installed or loaded." : { + }, "SQL Editor" : { "localizations" : { @@ -37799,6 +37942,16 @@ } } }, + "SQL is too long: %d characters (limit %d)" : { + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "new", + "value" : "SQL is too long: %1$d characters (limit %2$d)" + } + } + } + }, "SQL Preview" : { "localizations" : { "tr" : { @@ -42195,12 +42348,28 @@ }, "Token" : { + }, + "Token '%@' with permission '%@' cannot access '%@'" : { + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "new", + "value" : "Token '%1$@' with permission '%2$@' cannot access '%3$@'" + } + } + } + }, + "Token does not have access to this connection" : { + }, "Token Name" : { }, "Too many submissions. Please try again later." : { + }, + "Tool" : { + }, "Toolbar" : { "localizations" : { @@ -43066,6 +43235,9 @@ } } } + }, + "Unknown deep link host: %@" : { + }, "Unknown error" : { "localizations" : { @@ -43110,6 +43282,9 @@ } } } + }, + "Unknown URL scheme: %@" : { + }, "Unlicensed" : { "localizations" : { @@ -43243,6 +43418,9 @@ } } } + }, + "Unsupported database type: %@" : { + }, "Unsupported encryption version %d" : { "localizations" : { @@ -43310,6 +43488,9 @@ } } } + }, + "Unsupported intent: %@" : { + }, "Unsupported MongoDB method: %@" : { "extractionState" : "stale", diff --git a/TablePro/TableProApp.swift b/TablePro/TableProApp.swift index b5d18fdf0..753d4d0d8 100644 --- a/TablePro/TableProApp.swift +++ b/TablePro/TableProApp.swift @@ -196,7 +196,7 @@ struct AppMenuCommands: Commands { // File menu CommandGroup(replacing: .newItem) { Button("Manage Connections") { - NotificationCenter.default.post(name: .newConnection, object: nil) + WelcomeWindowFactory.openOrFront() } .optionalKeyboardShortcut(shortcut(for: .manageConnections)) } @@ -384,7 +384,7 @@ struct AppMenuCommands: Commands { .disabled(!(actions?.isConnected ?? false)) Button("Switch Connection...") { - NotificationCenter.default.post(name: .openConnectionSwitcher, object: nil) + actions?.openConnectionSwitcher() } .optionalKeyboardShortcut(shortcut(for: .switchConnection)) .disabled(!(actions?.isConnected ?? false)) @@ -626,33 +626,15 @@ struct TableProApp: App { } var body: some Scene { - // Welcome Window - opens on launch (must be first Window scene so SwiftUI - // restores it by default when clicking the dock icon) - Window("Welcome to TablePro", id: "welcome") { - WelcomeWindowView() - .background(OpenWindowHandler()) // Handle window notifications from startup - } - .windowStyle(.hiddenTitleBar) - .windowResizability(.contentSize) - .defaultSize(width: 700, height: 450) - - // Connection Form Window - opens when creating/editing a connection - WindowGroup(id: "connection-form", for: UUID?.self) { $connectionId in - ConnectionFormView(connectionId: connectionId ?? nil) - } - .windowResizability(.contentSize) - - // NOTE (prototype): main windows are now created imperatively via - // MainWindowFactory → NSWindow + NSHostingController. The retired - // `WindowGroup(id:"main", for: EditorTabPayload.self)` caused SwiftUI to - // re-instantiate ContentView for every historical payload on every scene - // phase diff (5-7 phantom inits per open). AppKit-native windows avoid - // that and eliminate the 68-437ms openWindow() latency. - - // Settings Window - opens with Cmd+, + // All app windows are created imperatively via NSWindow + NSHostingController + // factories (MainWindow via WindowManager, Welcome via WelcomeWindowFactory, + // ConnectionForm via ConnectionFormWindowFactory). Declaring them as SwiftUI + // Scenes auto-opens the first Scene on launch and races with cold-launch + // intent routing. Settings { SettingsView() .environment(updaterBridge) + .background(SettingsNotificationBridge()) } .commands { @@ -668,9 +650,6 @@ struct TableProApp: App { // MARK: - Notification Names extension Notification.Name { - // Connection lifecycle - static let newConnection = Notification.Name("newConnection") - static let openConnectionSwitcher = Notification.Name("openConnectionSwitcher") // Multi-listener broadcasts (Sidebar + Coordinator + StructureView) static let refreshData = Notification.Name("refreshData") @@ -687,12 +666,6 @@ extension Notification.Name { // Window lifecycle notifications static let mainWindowWillClose = Notification.Name("mainWindowWillClose") - static let openMainWindow = Notification.Name("openMainWindow") - static let openWelcomeWindow = Notification.Name("openWelcomeWindow") - - // Database URL handling notifications - static let switchSchemaFromURL = Notification.Name("switchSchemaFromURL") - static let applyURLFilter = Notification.Name("applyURLFilter") } // MARK: - Check for Updates @@ -738,32 +711,18 @@ private struct MCPServerMenuItem: View { } } -// MARK: - Open Window Handler +// MARK: - Settings Notification Bridge -/// Helper view that listens for window open notifications -private struct OpenWindowHandler: View { - @Environment(\.openWindow) - private var openWindow +/// Forwards `.openSettingsWindow` notifications to SwiftUI's `openSettings` +/// action. Lives inside the Settings scene because `\.openSettings` is only +/// available there. +private struct SettingsNotificationBridge: View { @Environment(\.openSettings) private var openSettings var body: some View { Color.clear .frame(width: 0, height: 0) - .onAppear { - // Store openWindow action for imperative access (e.g., from MainContentCommandActions) - WindowOpener.shared.openWindow = openWindow - } - .onReceive(NotificationCenter.default.publisher(for: .openWelcomeWindow)) { _ in - openWindow(id: "welcome") - } - .onReceive(NotificationCenter.default.publisher(for: .openMainWindow)) { notification in - if let payload = notification.object as? EditorTabPayload { - WindowManager.shared.openTab(payload: payload) - } else if let connectionId = notification.object as? UUID { - WindowManager.shared.openTab(payload: EditorTabPayload(connectionId: connectionId)) - } - } .onReceive(NotificationCenter.default.publisher(for: .openSettingsWindow)) { _ in openSettings() } diff --git a/TablePro/ViewModels/SidebarViewModel.swift b/TablePro/ViewModels/SidebarViewModel.swift index 4d3e1aee8..268c57d78 100644 --- a/TablePro/ViewModels/SidebarViewModel.swift +++ b/TablePro/ViewModels/SidebarViewModel.swift @@ -41,7 +41,6 @@ final class SidebarViewModel { // MARK: - Binding Storage - private var tablesBinding: Binding<[TableInfo]> private var selectedTablesBinding: Binding> private var pendingTruncatesBinding: Binding> private var pendingDeletesBinding: Binding> @@ -54,11 +53,6 @@ final class SidebarViewModel { // MARK: - Convenience Accessors - var tables: [TableInfo] { - get { tablesBinding.wrappedValue } - set { tablesBinding.wrappedValue = newValue } - } - var selectedTables: Set { get { selectedTablesBinding.wrappedValue } set { selectedTablesBinding.wrappedValue = newValue } @@ -82,7 +76,6 @@ final class SidebarViewModel { // MARK: - Initialization init( - tables: Binding<[TableInfo]>, selectedTables: Binding>, pendingTruncates: Binding>, pendingDeletes: Binding>, @@ -90,7 +83,6 @@ final class SidebarViewModel { databaseType: DatabaseType, connectionId: UUID ) { - self.tablesBinding = tables self.selectedTablesBinding = selectedTables self.pendingTruncatesBinding = pendingTruncates self.pendingDeletesBinding = pendingDeletes diff --git a/TablePro/ViewModels/WelcomeViewModel.swift b/TablePro/ViewModels/WelcomeViewModel.swift index 01d391815..91a25c345 100644 --- a/TablePro/ViewModels/WelcomeViewModel.swift +++ b/TablePro/ViewModels/WelcomeViewModel.swift @@ -33,7 +33,6 @@ final class WelcomeViewModel { private let storage = ConnectionStorage.shared private let groupStorage = GroupStorage.shared - private let dbManager = DatabaseManager.shared // MARK: - State @@ -78,14 +77,12 @@ final class WelcomeViewModel { // MARK: - Notification Observers - @ObservationIgnored private var openWindow: OpenWindowAction? @ObservationIgnored private var connectionUpdatedObserver: NSObjectProtocol? - @ObservationIgnored private var shareFileObserver: NSObjectProtocol? @ObservationIgnored private var exportObserver: NSObjectProtocol? @ObservationIgnored private var importObserver: NSObjectProtocol? @ObservationIgnored private var linkedFoldersObserver: NSObjectProtocol? @ObservationIgnored private var importFromAppObserver: NSObjectProtocol? - @ObservationIgnored private var deeplinkImportObserver: NSObjectProtocol? + @ObservationIgnored private var welcomeRouterTask: Task? // MARK: - Computed Properties @@ -146,8 +143,7 @@ final class WelcomeViewModel { // MARK: - Setup & Teardown - func setUp(openWindow: OpenWindowAction) { - self.openWindow = openWindow + func setUp() { guard connectionUpdatedObserver == nil else { return } if expandedGroupIds.isEmpty { @@ -168,16 +164,6 @@ final class WelcomeViewModel { } } - shareFileObserver = NotificationCenter.default.addObserver( - forName: .connectionShareFileOpened, object: nil, queue: .main - ) { [weak self] notification in - Task { @MainActor [weak self] in - guard let url = notification.object as? URL else { return } - _ = PendingActionStore.shared.consumeConnectionShareURL() - self?.activeSheet = .importFile(url) - } - } - exportObserver = NotificationCenter.default.addObserver( forName: .exportConnections, object: nil, queue: .main ) { [weak self] _ in @@ -211,35 +197,74 @@ final class WelcomeViewModel { } } - deeplinkImportObserver = NotificationCenter.default.addObserver( - forName: .deeplinkImportRequested, object: nil, queue: .main - ) { [weak self] notification in - Task { @MainActor [weak self] in - guard let self else { return } - let exportable = (notification.object as? ExportableConnection) - ?? PendingActionStore.shared.consumeDeeplinkImport() - guard let exportable else { return } - PendingActionStore.shared.deeplinkImport = nil - self.activeSheet = .deeplinkImport(exportable) - } - } - loadConnections() linkedConnections = LinkedFolderWatcher.shared.linkedConnections - if let pendingURL = PendingActionStore.shared.consumeConnectionShareURL() { + consumePendingRouterActions() + startWelcomeRouterObservation() + } + + private func consumePendingRouterActions() { + if let pendingURL = WelcomeRouter.shared.consumePendingShare() { activeSheet = .importFile(pendingURL) + return } - - if let pendingImport = PendingActionStore.shared.consumeDeeplinkImport() { + if let pendingImport = WelcomeRouter.shared.consumePendingImport() { activeSheet = .deeplinkImport(pendingImport) } } + private func startWelcomeRouterObservation() { + welcomeRouterTask?.cancel() + welcomeRouterTask = Task { @MainActor [weak self] in + while !Task.isCancelled { + let didChange = await Self.awaitWelcomeRouterChange() + guard didChange else { return } + self?.consumePendingRouterActions() + } + } + } + + private static func awaitWelcomeRouterChange() async -> Bool { + let box = ContinuationBox() + return await withTaskCancellationHandler { + await withCheckedContinuation { continuation in + box.set(continuation) + withObservationTracking({ + _ = WelcomeRouter.shared.pendingImport + _ = WelcomeRouter.shared.pendingConnectionShare + }, onChange: { + box.resume(with: true) + }) + } + } onCancel: { + box.resume(with: false) + } + } + + private final class ContinuationBox: @unchecked Sendable { + private var continuation: CheckedContinuation? + private let lock = NSLock() + + func set(_ continuation: CheckedContinuation) { + lock.lock() + defer { lock.unlock() } + self.continuation = continuation + } + + func resume(with value: Bool) { + lock.lock() + let pending = continuation + continuation = nil + lock.unlock() + pending?.resume(returning: value) + } + } + deinit { - [connectionUpdatedObserver, shareFileObserver, exportObserver, - importObserver, importFromAppObserver, linkedFoldersObserver, - deeplinkImportObserver].forEach { + welcomeRouterTask?.cancel() + [connectionUpdatedObserver, exportObserver, importObserver, + importFromAppObserver, linkedFoldersObserver].forEach { if let observer = $0 { NotificationCenter.default.removeObserver(observer) } @@ -261,26 +286,13 @@ final class WelcomeViewModel { // MARK: - Connection Actions func connectToDatabase(_ connection: DatabaseConnection) { - guard let openWindow else { return } - if WindowOpener.shared.openWindow == nil { - WindowOpener.shared.openWindow = openWindow - } - // Close welcome BEFORE opening the new editor window. Otherwise the - // welcome window (still key + visible) reasserts itself during the - // new window's `makeKeyAndOrderFront` — the new window briefly - // becomes key, immediately resigns, welcome retakes key, and the - // app is left with no key window after welcome closes → menu - // @FocusedValue nil → Cmd+T/1...9 disabled. - NSApplication.shared.closeWindows(withId: "welcome") - WindowManager.shared.openTab(payload: EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault)) - + WelcomeWindowFactory.close() Task { do { - try await dbManager.connectToSession(connection) + try await TabRouter.shared.route(.openConnection(connection.id)) } catch is CancellationError { - // User cancelled password prompt — return to welcome closeConnectionWindows(for: connection.id) - self.openWindow?(id: "welcome") + WelcomeWindowFactory.openOrFront() } catch { if case PluginError.pluginNotInstalled = error { Self.logger.info("Plugin not installed for \(connection.type.rawValue), prompting install") @@ -295,21 +307,13 @@ final class WelcomeViewModel { } func connectAfterInstall(_ connection: DatabaseConnection) { - guard let openWindow else { return } - if WindowOpener.shared.openWindow == nil { - WindowOpener.shared.openWindow = openWindow - } - // Close welcome before opening editor — see connectToDatabase above - // for the welcome-reasserts-key race that disabled menu shortcuts. - NSApplication.shared.closeWindows(withId: "welcome") - WindowManager.shared.openTab(payload: EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault)) - + WelcomeWindowFactory.close() Task { do { - try await dbManager.connectToSession(connection) + try await TabRouter.shared.route(.openConnection(connection.id)) } catch is CancellationError { closeConnectionWindows(for: connection.id) - self.openWindow?(id: "welcome") + WelcomeWindowFactory.openOrFront() } catch { Self.logger.error( "Failed to connect after plugin install: \(error.localizedDescription, privacy: .public)") @@ -340,8 +344,7 @@ final class WelcomeViewModel { func duplicateConnection(_ connection: DatabaseConnection) { let duplicate = storage.duplicateConnection(connection) loadConnections() - openWindow?(id: "connection-form", value: duplicate.id as UUID?) - focusConnectionFormWindow() + ConnectionFormWindowFactory.openOrFront(connectionId: duplicate.id) } // MARK: - Delete @@ -589,17 +592,15 @@ final class WelcomeViewModel { // MARK: - Private Helpers private func handleConnectionFailure(error: Error, connectionId: UUID) { - guard let openWindow else { return } closeConnectionWindows(for: connectionId) connectionError = error.localizedDescription showConnectionError = true - openWindow(id: "welcome") + WelcomeWindowFactory.openOrFront() } private func handleMissingPlugin(connection: DatabaseConnection) { - guard let openWindow else { return } closeConnectionWindows(for: connection.id) - openWindow(id: "welcome") + WelcomeWindowFactory.openOrFront() pluginInstallConnection = connection } diff --git a/TablePro/Views/Connection/ConnectionFormView+Footer.swift b/TablePro/Views/Connection/ConnectionFormView+Footer.swift index f21e90e5e..bc8ce94af 100644 --- a/TablePro/Views/Connection/ConnectionFormView+Footer.swift +++ b/TablePro/Views/Connection/ConnectionFormView+Footer.swift @@ -52,7 +52,7 @@ extension ConnectionFormView { // Cancel Button("Cancel") { - NSApplication.shared.closeWindows(withId: "connection-form") + ConnectionFormWindowFactory.closeAll() } if isNew { @@ -74,7 +74,7 @@ extension ConnectionFormView { } .background(Color(nsColor: .windowBackgroundColor)) .onExitCommand { - NSApplication.shared.closeWindows(withId: "connection-form") + ConnectionFormWindowFactory.closeAll() } .onChange(of: host) { _, _ in testSucceeded = false } .onChange(of: port) { _, _ in testSucceeded = false } diff --git a/TablePro/Views/Connection/ConnectionFormView+Helpers.swift b/TablePro/Views/Connection/ConnectionFormView+Helpers.swift index 96ff405e3..40a3054e0 100644 --- a/TablePro/Views/Connection/ConnectionFormView+Helpers.swift +++ b/TablePro/Views/Connection/ConnectionFormView+Helpers.swift @@ -279,7 +279,7 @@ extension ConnectionFormView { if !connectionToSave.localOnly { SyncChangeTracker.shared.markDirty(.connection, id: connectionToSave.id.uuidString) } - NSApplication.shared.closeWindows(withId: "connection-form") + ConnectionFormWindowFactory.closeAll() NotificationCenter.default.post(name: .connectionUpdated, object: nil) if connect { connectToDatabase(connectionToSave) @@ -292,7 +292,7 @@ extension ConnectionFormView { SyncChangeTracker.shared.markDirty(.connection, id: connectionToSave.id.uuidString) } } - NSApplication.shared.closeWindows(withId: "connection-form") + ConnectionFormWindowFactory.closeAll() NotificationCenter.default.post(name: .connectionUpdated, object: nil) } } @@ -301,23 +301,15 @@ extension ConnectionFormView { guard let id = connectionId, let connection = storage.loadConnections().first(where: { $0.id == id }) else { return } storage.deleteConnection(connection) - NSApplication.shared.closeWindows(withId: "connection-form") + ConnectionFormWindowFactory.closeAll() NotificationCenter.default.post(name: .connectionUpdated, object: nil) } func connectToDatabase(_ connection: DatabaseConnection) { - if WindowOpener.shared.openWindow == nil { - WindowOpener.shared.openWindow = openWindow - } - // Close welcome BEFORE opening the editor window so it can't reassert - // key status during the new window's `makeKeyAndOrderFront`. See - // WelcomeViewModel.connectToDatabase for the diagnosed race. - NSApplication.shared.closeWindows(withId: "welcome") - WindowManager.shared.openTab(payload: EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault)) - + WelcomeWindowFactory.close() Task { do { - try await dbManager.connectToSession(connection) + try await TabRouter.shared.route(.openConnection(connection.id)) } catch { handleConnectError(error, connection: connection) } @@ -330,7 +322,7 @@ extension ConnectionFormView { return } closeConnectionWindows(for: connection.id) - openWindow(id: "welcome") + WelcomeWindowFactory.openOrFront() guard !(error is CancellationError) else { return } Self.logger.error("Failed to connect: \(error.localizedDescription, privacy: .public)") AlertHelper.showErrorSheet( @@ -341,7 +333,7 @@ extension ConnectionFormView { func handleMissingPlugin(connection: DatabaseConnection) { closeConnectionWindows(for: connection.id) - openWindow(id: "welcome") + WelcomeWindowFactory.openOrFront() pluginInstallConnection = connection } @@ -352,17 +344,10 @@ extension ConnectionFormView { } func connectAfterInstall(_ connection: DatabaseConnection) { - if WindowOpener.shared.openWindow == nil { - WindowOpener.shared.openWindow = openWindow - } - // Close welcome before opening editor — see connectToDatabase above - // for the welcome-reasserts-key race that disabled menu shortcuts. - NSApplication.shared.closeWindows(withId: "welcome") - WindowManager.shared.openTab(payload: EditorTabPayload(connectionId: connection.id, intent: .restoreOrDefault)) - + WelcomeWindowFactory.close() Task { do { - try await dbManager.connectToSession(connection) + try await TabRouter.shared.route(.openConnection(connection.id)) } catch { handleConnectError(error, connection: connection) } diff --git a/TablePro/Views/Connection/ConnectionFormView.swift b/TablePro/Views/Connection/ConnectionFormView.swift index a3d6140e3..75af777c4 100644 --- a/TablePro/Views/Connection/ConnectionFormView.swift +++ b/TablePro/Views/Connection/ConnectionFormView.swift @@ -12,13 +12,11 @@ import UniformTypeIdentifiers struct ConnectionFormView: View { static let logger = Logger(subsystem: "com.TablePro", category: "ConnectionFormView") - @Environment(\.openWindow) var openWindow // Connection ID: nil = new connection, UUID = edit existing let connectionId: UUID? let storage = ConnectionStorage.shared - let dbManager = DatabaseManager.shared var isNew: Bool { connectionId == nil } diff --git a/TablePro/Views/Connection/WelcomeContextMenus.swift b/TablePro/Views/Connection/WelcomeContextMenus.swift index 7ee94bced..d826d48d2 100644 --- a/TablePro/Views/Connection/WelcomeContextMenus.swift +++ b/TablePro/Views/Connection/WelcomeContextMenus.swift @@ -91,7 +91,7 @@ extension WelcomeWindowView { Divider() Button { - openWindow(id: "connection-form", value: connection.id as UUID?) + ConnectionFormWindowFactory.openOrFront(connectionId: connection.id) vm.focusConnectionFormWindow() } label: { Label(String(localized: "Edit"), systemImage: "pencil") @@ -228,7 +228,7 @@ extension WelcomeWindowView { @ViewBuilder var newConnectionContextMenu: some View { - Button(action: { openWindow(id: "connection-form") }) { + Button(action: { ConnectionFormWindowFactory.openOrFront() }) { Label("New Connection...", systemImage: "plus") } diff --git a/TablePro/Views/Connection/WelcomeWindowView.swift b/TablePro/Views/Connection/WelcomeWindowView.swift index 9c1ee3560..e52349874 100644 --- a/TablePro/Views/Connection/WelcomeWindowView.swift +++ b/TablePro/Views/Connection/WelcomeWindowView.swift @@ -16,7 +16,6 @@ struct WelcomeWindowView: View { @State var vm = WelcomeViewModel() @FocusState private var focus: FocusField? - @Environment(\.openWindow) var openWindow var body: some View { ZStack { @@ -36,7 +35,7 @@ struct WelcomeWindowView: View { .ignoresSafeArea() .frame(minWidth: 600, idealWidth: 700, minHeight: 400, idealHeight: 450) .onAppear { - vm.setUp(openWindow: openWindow) + vm.setUp() focus = .search } .alert( @@ -171,7 +170,7 @@ struct WelcomeWindowView: View { HStack(spacing: 0) { WelcomeLeftPanel( onActivateLicense: { vm.activeSheet = .activation }, - onCreateConnection: { openWindow(id: "connection-form") } + onCreateConnection: { ConnectionFormWindowFactory.openOrFront() } ) Divider() rightPanel @@ -184,7 +183,7 @@ struct WelcomeWindowView: View { private var rightPanel: some View { VStack(spacing: 0) { HStack(spacing: 8) { - Button(action: { openWindow(id: "connection-form") }) { + Button(action: { ConnectionFormWindowFactory.openOrFront() }) { Image(systemName: "plus") .font(.callout.weight(.medium)) .foregroundStyle(.secondary) @@ -418,7 +417,7 @@ struct WelcomeWindowView: View { .font(.callout) .foregroundStyle(.tertiary) - Button(action: { openWindow(id: "connection-form") }) { + Button(action: { ConnectionFormWindowFactory.openOrFront() }) { Label("New Connection", systemImage: "plus") } .controlSize(.large) diff --git a/TablePro/Views/Main/Child/MainEditorContentView.swift b/TablePro/Views/Main/Child/MainEditorContentView.swift index a1c8a2214..68c63347e 100644 --- a/TablePro/Views/Main/Child/MainEditorContentView.swift +++ b/TablePro/Views/Main/Child/MainEditorContentView.swift @@ -263,7 +263,7 @@ struct MainEditorContentView: View { parameters: parameterBinding(for: tab), isParameterPanelVisible: parameterVisibilityBinding(for: tab), onExecute: { coordinator.runQuery() }, - schemaProvider: coordinator.schemaProvider, + schemaProvider: SchemaProviderRegistry.shared.getOrCreate(for: coordinator.connection.id), databaseType: coordinator.connection.type, connectionId: coordinator.connection.id, connectionAIPolicy: coordinator.connection.aiPolicy ?? AppSettingsManager.shared.ai.defaultConnectionPolicy, diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift index 97aa94f0e..d085311ae 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift @@ -49,7 +49,7 @@ extension MainContentCoordinator { // During database switch, update the existing tab in-place instead of // opening a new native window tab. - if sidebarLoadingState == .loading { + if case .loading = SchemaService.shared.state(for: connectionId) { if tabManager.tabs.isEmpty { do { try tabManager.addTableTab( @@ -65,14 +65,18 @@ extension MainContentCoordinator { } // Check if another native window tab already has this table open — switch to it - if let keyWindow = NSApp.keyWindow { - let ownWindows = Set(WindowLifecycleMonitor.shared.windows(for: connectionId).map { ObjectIdentifier($0) }) - let tabbedWindows = keyWindow.tabbedWindows ?? [keyWindow] - for window in tabbedWindows - where window.title == tableName && ownWindows.contains(ObjectIdentifier(window)) { - window.makeKeyAndOrderFront(nil) - return + for sibling in MainContentCoordinator.allActiveCoordinators() + where sibling !== self && sibling.connectionId == connectionId { + let hasMatch = sibling.tabManager.tabs.contains { tab in + tab.tabType == .table + && tab.tableContext.tableName == tableName + && tab.tableContext.databaseName == currentDatabase } + guard hasMatch, + let windowId = sibling.windowId, + let window = WindowLifecycleMonitor.shared.window(for: windowId) else { continue } + window.makeKeyAndOrderFront(nil) + return } // If no tabs exist (empty state), add a table tab directly. @@ -401,7 +405,6 @@ extension MainContentCoordinator { /// Switch to a different database (called from database switcher) func switchDatabase(to database: String) async { - sidebarLoadingState = .loading filterStateManager.clearAll() let previousDatabase = toolbarState.databaseName toolbarState.databaseName = database @@ -414,14 +417,11 @@ extension MainContentCoordinator { tableRowsStore.tearDown() tabManager.tabs = [] tabManager.selectedTabId = nil - DatabaseManager.shared.updateSession(connectionId) { session in - session.tables = [] - } + await SchemaService.shared.invalidate(connectionId: connectionId) await refreshTables() } catch { toolbarState.databaseName = previousDatabase - sidebarLoadingState = .error(error.localizedDescription) navigationLogger.error("Failed to switch database: \(error.localizedDescription, privacy: .public)") AlertHelper.showErrorSheet( @@ -436,7 +436,6 @@ extension MainContentCoordinator { func switchSchema(to schema: String) async { guard PluginManager.shared.supportsSchemaSwitching(for: connection.type) else { return } - sidebarLoadingState = .loading filterStateManager.clearAll() let previousSchema = toolbarState.databaseName toolbarState.databaseName = schema @@ -449,9 +448,7 @@ extension MainContentCoordinator { tableRowsStore.tearDown() tabManager.tabs = [] tabManager.selectedTabId = nil - DatabaseManager.shared.updateSession(connectionId) { session in - session.tables = [] - } + await SchemaService.shared.invalidate(connectionId: connectionId) await refreshTables() } catch { diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+URLFilter.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+URLFilter.swift index dcbb54f4a..bc2f35b55 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+URLFilter.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+URLFilter.swift @@ -6,57 +6,7 @@ import Foundation extension MainContentCoordinator { - func setupURLNotificationObservers() -> [NSObjectProtocol] { - let connId = connectionId - let observer1 = NotificationCenter.default.addObserver( - forName: .applyURLFilter, - object: nil, - queue: .main - ) { [weak self] notification in - guard let userInfo = notification.userInfo, - let targetId = userInfo["connectionId"] as? UUID, - targetId == connId else { return } - - let condition = userInfo["condition"] as? String - let column = userInfo["column"] as? String - let operation = userInfo["operation"] as? String - let value = userInfo["value"] as? String - Task { [weak self] in - self?.applyURLFilterValues( - condition: condition, column: column, - operation: operation, value: value - ) - } - } - - let observer2 = NotificationCenter.default.addObserver( - forName: .switchSchemaFromURL, - object: nil, - queue: .main - ) { [weak self] notification in - guard let userInfo = notification.userInfo, - let targetId = userInfo["connectionId"] as? UUID, - targetId == connId, - let schema = userInfo["schema"] as? String else { return } - - Task { [weak self] in - guard let self else { return } - - if PluginManager.shared.supportsSchemaSwitching(for: self.connection.type) { - await self.switchSchema(to: schema) - } else { - await self.switchDatabase(to: schema) - } - } - } - - return [observer1, observer2] - } - - private func applyURLFilterValues( - condition: String?, column: String?, - operation: String?, value: String? - ) { + func applyURLFilter(condition: String?, column: String?, operation: String?, value: String?) { if let condition, !condition.isEmpty { let filter = TableFilter( id: UUID(), @@ -74,7 +24,6 @@ extension MainContentCoordinator { guard let column, !column.isEmpty else { return } let filterOp = mapTablePlusOperation(operation ?? "Equal") - let filter = TableFilter( id: UUID(), columnName: column, diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+WindowLifecycle.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+WindowLifecycle.swift index 07f0244ab..cb11e10ab 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+WindowLifecycle.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+WindowLifecycle.swift @@ -102,19 +102,11 @@ extension MainContentCoordinator { "[close] coordinator.handleWindowWillClose connId=\(self.connectionId, privacy: .public) tabs=\(self.tabManager.tabs.count)" ) - // Persist remaining non-preview tabs synchronously. saveNowSync writes - // directly without spawning a Task — required here because the window - // is closing and we cannot rely on async tasks being serviced. - let persistableTabs = tabManager.tabs.filter { !$0.isPreview } - if persistableTabs.isEmpty { - // Empty → clear saved state so next open shows a default empty window. - persistence.saveNowSync(tabs: [], selectedTabId: nil) - } else { - let normalizedSelectedId = - persistableTabs.contains(where: { $0.id == tabManager.selectedTabId }) - ? tabManager.selectedTabId : persistableTabs.first?.id - persistence.saveNowSync(tabs: persistableTabs, selectedTabId: normalizedSelectedId) - } + // Persist tabs aggregated across all windows for this connection. + // Writing this window's tabs in isolation can clobber sibling windows' + // state on disk — for example, closing an empty window would erase the + // saved tabs of an open sibling window. + persistence.saveOrClearAggregatedSync() // Cancel the pending eviction task before teardown drops it. evictionTask?.cancel() diff --git a/TablePro/Views/Main/Extensions/MainContentView+EventHandlers.swift b/TablePro/Views/Main/Extensions/MainContentView+EventHandlers.swift index 0ce5ace93..986262d1b 100644 --- a/TablePro/Views/Main/Extensions/MainContentView+EventHandlers.swift +++ b/TablePro/Views/Main/Extensions/MainContentView+EventHandlers.swift @@ -58,20 +58,9 @@ extension MainContentView { coordinator.promotePreviewTab() } - let persistableTabs = tabManager.tabs.filter { !$0.isPreview } - if persistableTabs.isEmpty { - coordinator.persistence.clearSavedState() - } else { - let normalizedSelectedId = - persistableTabs.contains(where: { $0.id == tabManager.selectedTabId }) - ? tabManager.selectedTabId : persistableTabs.first?.id - coordinator.persistence.saveNow( - tabs: persistableTabs, - selectedTabId: normalizedSelectedId - ) - } + coordinator.persistence.saveOrClearAggregated() MainContentView.lifecycleLogger.debug( - "[switch] handleStructureChange tabCount=\(tabManager.tabs.count) persistableCount=\(persistableTabs.count) ms=\(Int(Date().timeIntervalSince(t0) * 1_000))" + "[switch] handleStructureChange tabCount=\(tabManager.tabs.count) ms=\(Int(Date().timeIntervalSince(t0) * 1_000))" ) } diff --git a/TablePro/Views/Main/Extensions/MainContentView+Modifiers.swift b/TablePro/Views/Main/Extensions/MainContentView+Modifiers.swift index 6dc477e02..1da3d0631 100644 --- a/TablePro/Views/Main/Extensions/MainContentView+Modifiers.swift +++ b/TablePro/Views/Main/Extensions/MainContentView+Modifiers.swift @@ -54,7 +54,6 @@ struct FocusedCommandActionsModifier: ViewModifier { connection: DatabaseConnection.preview, payload: nil, windowTitle: .constant("SQL Query"), - tables: .constant([]), sidebarState: SharedSidebarState(), pendingTruncates: .constant([]), pendingDeletes: .constant([]), diff --git a/TablePro/Views/Main/Extensions/MainContentView+Setup.swift b/TablePro/Views/Main/Extensions/MainContentView+Setup.swift index a884c2c02..5eebb9fdc 100644 --- a/TablePro/Views/Main/Extensions/MainContentView+Setup.swift +++ b/TablePro/Views/Main/Extensions/MainContentView+Setup.swift @@ -95,11 +95,6 @@ extension MainContentView { private func handleRestoreOrDefault() async { if WindowLifecycleMonitor.shared.hasOtherWindows(for: connection.id, excluding: windowId) { - if tabManager.tabs.isEmpty { - let allTabs = MainContentCoordinator.allTabs(for: connection.id) - let title = QueryTabManager.nextQueryTitle(existingTabs: allTabs) - tabManager.addTab(title: title, databaseName: connection.database) - } MainContentView.lifecycleLogger.info( "[open] handleRestoreOrDefault short-circuit (other windows exist) windowId=\(windowId, privacy: .public)" ) @@ -111,7 +106,8 @@ extension MainContentView { MainContentView.lifecycleLogger.info( "[open] restoreFromDisk done windowId=\(windowId, privacy: .public) tabsRestored=\(result.tabs.count) source=\(String(describing: result.source), privacy: .public) elapsedMs=\(Int(Date().timeIntervalSince(restoreStart) * 1_000))" ) - if !result.tabs.isEmpty { + guard !result.tabs.isEmpty else { return } + do { var restoredTabs = result.tabs for i in restoredTabs.indices where restoredTabs[i].tabType == .table { if let tableName = restoredTabs[i].tableContext.tableName { @@ -141,17 +137,13 @@ extension MainContentView { if !remainingTabs.isEmpty { let selectedWasFirst = firstTab.id == selectedId - Task { @MainActor in - for tab in remainingTabs { - let restorePayload = EditorTabPayload( - from: tab, connectionId: connection.id, skipAutoExecute: true) - WindowManager.shared.openTab(payload: restorePayload) - } - // Bring the first window to front only if it had the selected tab. - // Otherwise let the last restored window stay focused. - if selectedWasFirst { - viewWindow?.makeKeyAndOrderFront(nil) - } + for tab in remainingTabs { + let restorePayload = EditorTabPayload( + from: tab, connectionId: connection.id, skipAutoExecute: true) + WindowManager.shared.openTab(payload: restorePayload) + } + if selectedWasFirst { + viewWindow?.makeKeyAndOrderFront(nil) } } diff --git a/TablePro/Views/Main/MainContentCommandActions.swift b/TablePro/Views/Main/MainContentCommandActions.swift index 694bbf9d6..f22f2bcff 100644 --- a/TablePro/Views/Main/MainContentCommandActions.swift +++ b/TablePro/Views/Main/MainContentCommandActions.swift @@ -316,15 +316,8 @@ final class MainContentCommandActions { // MARK: - Tab Operations (Group A — Called Directly) func newTab(initialQuery: String? = nil) { - // If no tabs exist (empty state), add directly to this window - if coordinator?.tabManager.tabs.isEmpty == true { - coordinator?.tabManager.addTab(initialQuery: initialQuery, databaseName: connection.database) - return - } - // Open a new native macOS window tab with a query editor let payload = EditorTabPayload( connectionId: connection.id, - tabType: .query, initialQuery: initialQuery, intent: .newEmptyTab ) @@ -487,11 +480,11 @@ final class MainContentCommandActions { // MARK: - Tab Navigation (Group A — Called Directly) func selectTab(number: Int) { - // Switch to the nth native window tab guard let keyWindow = NSApp.keyWindow, let tabbedWindows = keyWindow.tabbedWindows, - number > 0, number <= tabbedWindows.count else { return } - tabbedWindows[number - 1].makeKeyAndOrderFront(nil) + tabbedWindows.indices.contains(number - 1) else { return } + let target = tabbedWindows[number - 1] + target.makeKeyAndOrderFront(nil) } // MARK: - Filter Operations (Group A — Called Directly) @@ -704,6 +697,10 @@ final class MainContentCommandActions { coordinator?.activeSheet = .quickSwitcher } + func openConnectionSwitcher() { + coordinator?.toolbarState.showConnectionSwitcher = true + } + // MARK: - Undo/Redo (Group A — Called Directly) func undoChange() { @@ -762,9 +759,11 @@ final class MainContentCommandActions { if let driver = DatabaseManager.shared.driver(for: self.connection.id) { coordinator?.toolbarState.databaseVersion = driver.serverVersion } - if coordinator?.sidebarLoadingState != .loading { - await coordinator?.refreshTables() + if case .loading = SchemaService.shared.state(for: self.connection.id) { + coordinator?.initRedisKeyTreeIfNeeded() + return } + await coordinator?.refreshTables() coordinator?.initRedisKeyTreeIfNeeded() } } @@ -791,32 +790,9 @@ final class MainContentCommandActions { private func handleOpenSQLFiles(_ notification: Notification) { guard let urls = notification.object as? [URL] else { return } - Task { for url in urls { - if let existingWindow = WindowLifecycleMonitor.shared.window(forSourceFile: url) { - existingWindow.makeKeyAndOrderFront(nil) - continue - } - - let content = await Task.detached(priority: .userInitiated) { () -> String? in - do { - return try String(contentsOf: url, encoding: .utf8) - } catch { - Self.logger.error("Failed to read \(url.lastPathComponent, privacy: .public): \(error.localizedDescription, privacy: .public)") - return nil - } - }.value - - if let content { - let payload = EditorTabPayload( - connectionId: connection.id, - tabType: .query, - initialQuery: content, - sourceFileURL: url - ) - WindowManager.shared.openTab(payload: payload) - } + try? await TabRouter.shared.route(.openSQLFile(url)) } } } diff --git a/TablePro/Views/Main/MainContentCoordinator.swift b/TablePro/Views/Main/MainContentCoordinator.swift index 82c946750..cf8b99a1f 100644 --- a/TablePro/Views/Main/MainContentCoordinator.swift +++ b/TablePro/Views/Main/MainContentCoordinator.swift @@ -38,14 +38,6 @@ struct DisplayFormatsCacheEntry { let formats: [ValueDisplayFormat?] } -/// Sidebar table loading state — single source of truth for sidebar UI -enum SidebarLoadingState: Equatable { - case idle - case loading - case loaded - case error(String) -} - /// Represents which sheet is currently active in MainContentView. /// Uses a single `.sheet(item:)` modifier instead of multiple `.sheet(isPresented:)`. enum ActiveSheet: Identifiable { @@ -141,15 +133,12 @@ final class MainContentCoordinator { // MARK: - Published State - var schemaProvider: SQLSchemaProvider var cursorPositions: [CursorPosition] = [] var tableMetadata: TableMetadata? - // Removed: showErrorAlert and errorAlertMessage - errors now display inline var activeSheet: ActiveSheet? var importFileURL: URL? var exportPreselectedTableNames: Set? var needsLazyLoad = false - var sidebarLoadingState: SidebarLoadingState = .idle /// Cache for async-sorted query tab rows (large datasets sorted on background thread) @ObservationIgnored var querySortCache: [UUID: QuerySortCacheEntry] = [:] @@ -171,10 +160,8 @@ final class MainContentCoordinator { @ObservationIgnored private var changeManagerUpdateTask: Task? @ObservationIgnored private var activeSortTasks: [UUID: Task] = [:] @ObservationIgnored private var terminationObserver: NSObjectProtocol? - @ObservationIgnored private var urlFilterObservers: [NSObjectProtocol] = [] @ObservationIgnored private var pluginDriverObserver: NSObjectProtocol? @ObservationIgnored private var fileWatcher: DatabaseFileWatcher? - @ObservationIgnored private var lastSchemaRefreshDate = Date.distantPast /// Set during handleTabChange to suppress redundant column-change reconfiguration @ObservationIgnored internal var isHandlingTabSwitch = false @@ -241,22 +228,32 @@ final class MainContentCoordinator { set { _isAppTerminating.withLock { $0 = newValue } } } + /// Stable instance identity. Used to key the registry so a recycled + /// `ObjectIdentifier` from a freshly-allocated coordinator can never + /// remove a different instance's entry from a delayed cleanup Task. + let instanceId = UUID() + /// Registry of active coordinators for aggregated quit-time persistence. - /// Keyed by ObjectIdentifier of each coordinator instance. - static var activeCoordinators: [ObjectIdentifier: MainContentCoordinator] = [:] + /// Keyed by `instanceId` (UUID) — never by `ObjectIdentifier`, which can + /// be recycled across allocations. + static var activeCoordinators: [UUID: MainContentCoordinator] = [:] /// Register this coordinator so quit-time persistence can aggregate tabs. + /// Idempotent — repeated registration is a no-op. + func registerEagerly() { + Self.activeCoordinators[instanceId] = self + } + private func registerForPersistence() { - Self.activeCoordinators[ObjectIdentifier(self)] = self + Self.activeCoordinators[instanceId] = self } - /// Unregister this coordinator from quit-time aggregation. private func unregisterFromPersistence() { - Self.activeCoordinators.removeValue(forKey: ObjectIdentifier(self)) + Self.activeCoordinators.removeValue(forKey: instanceId) } /// Collect non-preview tabs for persistence. - private static func aggregatedTabs(for connectionId: UUID) -> [QueryTab] { + static func aggregatedTabs(for connectionId: UUID) -> [QueryTab] { let coordinators = activeCoordinators.values .filter { $0.connectionId == connectionId } @@ -282,7 +279,7 @@ final class MainContentCoordinator { } /// Get selected tab ID from any coordinator for a given connectionId. - private static func aggregatedSelectedTabId(for connectionId: UUID) -> UUID? { + static func aggregatedSelectedTabId(for connectionId: UUID) -> UUID? { activeCoordinators.values .first { $0.connectionId == connectionId && $0.tabManager.selectedTabId != nil }? .tabManager.selectedTabId @@ -355,9 +352,8 @@ final class MainContentCoordinator { ) self.persistence = TabPersistenceCoordinator(connectionId: connection.id) - self.schemaProvider = SchemaProviderRegistry.shared.getOrCreate(for: connection.id) + _ = SchemaProviderRegistry.shared.getOrCreate(for: connection.id) SchemaProviderRegistry.shared.retain(for: connection.id) - urlFilterObservers = setupURLNotificationObservers() changeManager.undoManagerProvider = { [weak self] in self?.contentWindow?.undoManager } changeManager.onUndoApplied = { [weak self] result in self?.handleUndoResult(result) } @@ -414,16 +410,6 @@ final class MainContentCoordinator { ) } - /// Transition sidebar from `.idle` to `.loaded` when tables already exist - /// (e.g. populated by another window's `refreshTables()`). - func healSidebarLoadingStateIfNeeded() { - guard sidebarLoadingState == .idle else { return } - let tables = DatabaseManager.shared.session(for: connectionId)?.tables ?? [] - if !tables.isEmpty { - sidebarLoadingState = .loaded - } - } - /// Start watching the database file for external changes (SQLite, DuckDB). private func startFileWatcherIfNeeded() { guard PluginManager.shared.connectionMode(for: connection.type) == .fileBased else { return } @@ -432,8 +418,9 @@ final class MainContentCoordinator { let watcher = DatabaseFileWatcher() watcher.watch(filePath: filePath, connectionId: connectionId) { [weak self] in - guard let self, self.sidebarLoadingState != .loading else { return } - Task { await self.refreshTablesIfStale() } + guard let self else { return } + if case .loading = SchemaService.shared.state(for: self.connectionId) { return } + Task { await self.refreshTables() } } fileWatcher = watcher } @@ -441,9 +428,14 @@ final class MainContentCoordinator { /// Refresh schema only if not recently refreshed (avoids redundant work /// when both the file watcher and window focus trigger close together). func refreshTablesIfStale() async { - guard Date().timeIntervalSince(lastSchemaRefreshDate) > 2 else { return } - lastSchemaRefreshDate = Date() - await refreshTables() + guard let driver = DatabaseManager.shared.driver(for: connectionId) else { return } + await SchemaService.shared.reloadIfStale( + connectionId: connectionId, + driver: driver, + connection: connection, + staleness: 2 + ) + await reconcilePostSchemaLoad() } func showAIChatPanel() { @@ -473,45 +465,44 @@ final class MainContentCoordinator { } func refreshTables() async { - lastSchemaRefreshDate = Date() - sidebarLoadingState = .loading - guard let driver = DatabaseManager.shared.driver(for: connectionId) else { - sidebarLoadingState = .error(String(localized: "Not connected")) - return - } - do { - let tables = try await driver.fetchTables() - .sorted { $0.name.localizedCaseInsensitiveCompare($1.name) == .orderedAscending } - DatabaseManager.shared.updateSession(connectionId) { $0.tables = tables } + guard let driver = DatabaseManager.shared.driver(for: connectionId) else { return } + await SchemaService.shared.reload( + connectionId: connectionId, + driver: driver, + connection: connection + ) + await reconcilePostSchemaLoad() + } + + /// Push the SchemaService table list into the autocomplete provider and prune sidebar + /// state for tables that no longer exist. + private func reconcilePostSchemaLoad() async { + guard case .loaded(let tables) = SchemaService.shared.state(for: connectionId) else { return } + if let driver = DatabaseManager.shared.driver(for: connectionId), + let provider = SchemaProviderRegistry.shared.provider(for: connectionId) { let currentDb = DatabaseManager.shared.session(for: connectionId)?.activeDatabase - await schemaProvider.resetForDatabase(currentDb, tables: tables, driver: driver) - - // Clean up stale selections and pending operations for tables that no longer exist - if let vm = sidebarViewModel { - let validNames = Set(tables.map(\.name)) - let staleSelections = vm.selectedTables.filter { !validNames.contains($0.name) } - if !staleSelections.isEmpty { - vm.selectedTables.subtract(staleSelections) - } - let stalePendingDeletes = vm.pendingDeletes.subtracting(validNames) - if !stalePendingDeletes.isEmpty { - vm.pendingDeletes.subtract(stalePendingDeletes) - for name in stalePendingDeletes { - vm.tableOperationOptions.removeValue(forKey: name) - } - } - let stalePendingTruncates = vm.pendingTruncates.subtracting(validNames) - if !stalePendingTruncates.isEmpty { - vm.pendingTruncates.subtract(stalePendingTruncates) - for name in stalePendingTruncates { - vm.tableOperationOptions.removeValue(forKey: name) - } - } - } + await provider.resetForDatabase(currentDb, tables: tables, driver: driver) + } - sidebarLoadingState = .loaded - } catch { - sidebarLoadingState = .error(error.localizedDescription) + guard let vm = sidebarViewModel else { return } + let validNames = Set(tables.map(\.name)) + let staleSelections = vm.selectedTables.filter { !validNames.contains($0.name) } + if !staleSelections.isEmpty { + vm.selectedTables.subtract(staleSelections) + } + let stalePendingDeletes = vm.pendingDeletes.subtracting(validNames) + if !stalePendingDeletes.isEmpty { + vm.pendingDeletes.subtract(stalePendingDeletes) + for name in stalePendingDeletes { + vm.tableOperationOptions.removeValue(forKey: name) + } + } + let stalePendingTruncates = vm.pendingTruncates.subtracting(validNames) + if !stalePendingTruncates.isEmpty { + vm.pendingTruncates.subtract(stalePendingTruncates) + for name in stalePendingTruncates { + vm.tableOperationOptions.removeValue(forKey: name) + } } } @@ -525,10 +516,6 @@ final class MainContentCoordinator { _didTeardown.withLock { $0 = true } unregisterFromPersistence() - for observer in urlFilterObservers { - NotificationCenter.default.removeObserver(observer) - } - urlFilterObservers.removeAll() if let observer = terminationObserver { NotificationCenter.default.removeObserver(observer) terminationObserver = nil @@ -592,7 +579,7 @@ final class MainContentCoordinator { // Never-activated coordinators are throwaway instances created by SwiftUI // during body re-evaluation — @State only keeps the first, rest are discarded guard _didActivate.withLock({ $0 }) else { - let id = ObjectIdentifier(self) + let id = instanceId Task { @MainActor in Self.activeCoordinators.removeValue(forKey: id) } @@ -635,20 +622,8 @@ final class MainContentCoordinator { } } - /// Load schema only if the shared provider hasn't loaded yet + /// Load schema if not already loaded by another window for this connection. func loadSchemaIfNeeded() async { - let alreadyLoaded = await schemaProvider.isSchemaLoaded() - if alreadyLoaded { - let cachedTables = await schemaProvider.getTables() - let sessionTables = DatabaseManager.shared.session(for: connectionId)?.tables ?? [] - if sessionTables.isEmpty && !cachedTables.isEmpty { - DatabaseManager.shared.updateSession(connectionId) { $0.tables = cachedTables } - } - if sidebarLoadingState == .idle { - sidebarLoadingState = .loaded - } - return - } await loadSchema() } @@ -672,18 +647,12 @@ final class MainContentCoordinator { func loadSchema() async { guard let driver = DatabaseManager.shared.driver(for: connectionId) else { return } - sidebarLoadingState = .loading - await schemaProvider.loadSchema(using: driver, connection: connection) - let fetchedTables = await schemaProvider.getTables() - if !fetchedTables.isEmpty { - let sessionTables = DatabaseManager.shared.session(for: connectionId)?.tables ?? [] - if sessionTables != fetchedTables { - DatabaseManager.shared.updateSession(connectionId) { $0.tables = fetchedTables } - } - sidebarLoadingState = .loaded - } else { - sidebarLoadingState = .idle - } + await SchemaService.shared.load( + connectionId: connectionId, + driver: driver, + connection: connection + ) + await reconcilePostSchemaLoad() } func loadTableMetadata(tableName: String) async { diff --git a/TablePro/Views/Main/MainContentView.swift b/TablePro/Views/Main/MainContentView.swift index 76ce35776..8467e0aaf 100644 --- a/TablePro/Views/Main/MainContentView.swift +++ b/TablePro/Views/Main/MainContentView.swift @@ -30,13 +30,17 @@ struct MainContentView: View { // Shared state from parent @Binding var windowTitle: String - @Binding var tables: [TableInfo] + @Bindable var schemaService = SchemaService.shared var sidebarState: SharedSidebarState @Binding var pendingTruncates: Set @Binding var pendingDeletes: Set @Binding var tableOperationOptions: [String: TableOperationOptions] var rightPanelState: RightPanelState + private var tables: [TableInfo] { + schemaService.tables(for: connection.id) + } + // MARK: - State Objects let tabManager: QueryTabManager @@ -66,7 +70,6 @@ struct MainContentView: View { connection: DatabaseConnection, payload: EditorTabPayload?, windowTitle: Binding, - tables: Binding<[TableInfo]>, sidebarState: SharedSidebarState, pendingTruncates: Binding>, pendingDeletes: Binding>, @@ -81,7 +84,6 @@ struct MainContentView: View { self.connection = connection self.payload = payload self._windowTitle = windowTitle - self._tables = tables self.sidebarState = sidebarState self._pendingTruncates = pendingTruncates self._pendingDeletes = pendingDeletes @@ -202,7 +204,7 @@ struct MainContentView: View { case .quickSwitcher: QuickSwitcherSheet( isPresented: dismissBinding, - schemaProvider: coordinator.schemaProvider, + schemaProvider: SchemaProviderRegistry.shared.getOrCreate(for: connection.id), connectionId: connection.id, databaseType: connection.type, onSelect: { item in @@ -262,7 +264,7 @@ struct MainContentView: View { setupCommandActions() updateToolbarPendingState() updateInspectorContext() - rightPanelState.aiViewModel.schemaProvider = coordinator.schemaProvider + rightPanelState.aiViewModel.schemaProvider = SchemaProviderRegistry.shared.getOrCreate(for: connection.id) coordinator.aiViewModel = rightPanelState.aiViewModel coordinator.rightPanelState = rightPanelState diff --git a/TablePro/Views/Settings/Sections/MCPSection.swift b/TablePro/Views/Settings/Sections/MCPSection.swift index fbcd8f3d0..0c8902385 100644 --- a/TablePro/Views/Settings/Sections/MCPSection.swift +++ b/TablePro/Views/Settings/Sections/MCPSection.swift @@ -188,10 +188,11 @@ struct MCPSection: View { private func handleGenerate(name: String, permissions: TokenPermissions, connectionIds: Set?, expiresAt: Date?) { Task { guard let store = manager.tokenStore else { return } + let access: ConnectionAccess = connectionIds.map { .limited($0) } ?? .all let result = await store.generate( name: name, permissions: permissions, - allowedConnectionIds: connectionIds, + connectionAccess: access, expiresAt: expiresAt ) revealedToken = result.token diff --git a/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift b/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift index 8750ed913..e82d00e7f 100644 --- a/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift +++ b/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift @@ -3,7 +3,6 @@ // TablePro // -import AppKit import SwiftUI struct PairingApproval: Sendable { @@ -12,48 +11,6 @@ struct PairingApproval: Sendable { let expiresAt: Date? } -@MainActor -enum PairingApprovalPresenter { - static func present(request: PairingRequest) async throws -> PairingApproval { - try await withCheckedThrowingContinuation { continuation in - var deliver: ((Result) -> Void)? - let host = NSHostingController( - rootView: PairingApprovalSheet( - request: request, - onComplete: { result in deliver?(result) } - ) - ) - host.view.frame = NSRect(x: 0, y: 0, width: 520, height: 560) - - let parent = AlertHelper.resolveWindow(nil) - let sheetWindow = NSWindow(contentViewController: host) - sheetWindow.styleMask = [.titled] - sheetWindow.title = String(localized: "Approve Integration") - sheetWindow.isReleasedWhenClosed = false - - var resolved = false - deliver = { result in - guard !resolved else { return } - resolved = true - if let parent { - parent.endSheet(sheetWindow) - } else { - sheetWindow.close() - } - continuation.resume(with: result) - } - - if let parent { - parent.beginSheet(sheetWindow, completionHandler: nil) - } else { - NSApp.activate(ignoringOtherApps: true) - sheetWindow.center() - sheetWindow.makeKeyAndOrderFront(nil) - } - } - } -} - struct PairingApprovalSheet: View { let request: PairingRequest let onComplete: (Result) -> Void diff --git a/TablePro/Views/Sidebar/SidebarView.swift b/TablePro/Views/Sidebar/SidebarView.swift index c554b71e1..1fe9406e1 100644 --- a/TablePro/Views/Sidebar/SidebarView.swift +++ b/TablePro/Views/Sidebar/SidebarView.swift @@ -13,8 +13,8 @@ import SwiftUI struct SidebarView: View { @State private var viewModel: SidebarViewModel + @Bindable private var schemaService = SchemaService.shared - @Binding var tables: [TableInfo] var sidebarState: SharedSidebarState @Binding var pendingTruncates: Set @Binding var pendingDeletes: Set @@ -23,6 +23,10 @@ struct SidebarView: View { var connectionId: UUID private weak var coordinator: MainContentCoordinator? + private var tables: [TableInfo] { + schemaService.tables(for: connectionId) + } + private var filteredTables: [TableInfo] { guard !viewModel.searchText.isEmpty else { return tables } return tables.filter { $0.name.localizedCaseInsensitiveContains(viewModel.searchText) } @@ -36,7 +40,6 @@ struct SidebarView: View { } init( - tables: Binding<[TableInfo]>, sidebarState: SharedSidebarState, onDoubleClick: ((TableInfo) -> Void)? = nil, pendingTruncates: Binding>, @@ -46,7 +49,6 @@ struct SidebarView: View { connectionId: UUID, coordinator: MainContentCoordinator? = nil ) { - _tables = tables self.sidebarState = sidebarState self.onDoubleClick = onDoubleClick _pendingTruncates = pendingTruncates @@ -56,7 +58,6 @@ struct SidebarView: View { set: { sidebarState.selectedTables = $0 } ) let vm = SidebarViewModel( - tables: tables, selectedTables: selectedBinding, pendingTruncates: pendingTruncates, pendingDeletes: pendingDeletes, @@ -93,7 +94,6 @@ struct SidebarView: View { } .onAppear { coordinator?.sidebarViewModel = viewModel - coordinator?.healSidebarLoadingStateIfNeeded() // Update toolbar version if driver connected before this window's observer was set up if let driver = DatabaseManager.shared.driver(for: connectionId), coordinator?.toolbarState.databaseVersion == nil { @@ -122,23 +122,16 @@ struct SidebarView: View { @ViewBuilder private var tablesContent: some View { - let rawState = coordinator?.sidebarLoadingState ?? .idle - let effectiveState: SidebarLoadingState = { - if case .error = rawState { return rawState } - if !tables.isEmpty { return .loaded } - if case .loading = rawState { return .loading } - return rawState - }() - switch effectiveState { - case .loading: + switch schemaService.state(for: connectionId) { + case .loading where tables.isEmpty: loadingState - case .error(let message): + case .failed(let message): errorState(message: message) case .loaded where !viewModel.searchText.isEmpty && filteredTables.isEmpty: noMatchState - case .loaded where tables.isEmpty: + case .loaded(let allTables) where allTables.isEmpty: emptyState - case .loaded: + case .loaded, .loading: tableList case .idle: emptyState @@ -266,7 +259,6 @@ struct SidebarView: View { #Preview { SidebarView( - tables: .constant([]), sidebarState: SharedSidebarState(), pendingTruncates: .constant([]), pendingDeletes: .constant([]), diff --git a/TablePro/Views/Toolbar/ConnectionSwitcherPopover.swift b/TablePro/Views/Toolbar/ConnectionSwitcherPopover.swift index ea57967af..a9676ce64 100644 --- a/TablePro/Views/Toolbar/ConnectionSwitcherPopover.swift +++ b/TablePro/Views/Toolbar/ConnectionSwitcherPopover.swift @@ -13,11 +13,8 @@ import TableProPluginKit /// Popover content for quick connection switching struct ConnectionSwitcherPopover: View { @State private var savedConnections: [DatabaseConnection] = [] - @State private var isConnecting: UUID? @State private var selectedIndex: Int = 0 - @Environment(\.openWindow) private var openWindow - /// Callback when the popover should dismiss var onDismiss: (() -> Void)? @@ -94,7 +91,6 @@ struct ConnectionSwitcherPopover: View { connection: connection, isActive: false, isConnected: false, - isConnecting: isConnecting == connection.id, isHighlighted: itemIndex == selectedIndex ) } @@ -126,7 +122,7 @@ struct ConnectionSwitcherPopover: View { // Manage connections button Button { onDismiss?() - NotificationCenter.default.post(name: .openWelcomeWindow, object: nil) + WelcomeWindowFactory.openOrFront() } label: { HStack { Image(systemName: "gear") @@ -203,7 +199,6 @@ struct ConnectionSwitcherPopover: View { connection: DatabaseConnection, isActive: Bool, isConnected: Bool, - isConnecting: Bool = false, isHighlighted: Bool = false ) -> some View { HStack(spacing: 8) { @@ -228,10 +223,7 @@ struct ConnectionSwitcherPopover: View { Spacer() // Status indicator - if isConnecting { - ProgressView() - .controlSize(.small) - } else if isActive { + if isActive { Image(systemName: "checkmark.circle.fill") .foregroundStyle(isHighlighted ? Color(nsColor: .alternateSelectedControlTextColor) : Color(nsColor: .systemGreen)) .font(.system(size: 14)) @@ -287,23 +279,18 @@ struct ConnectionSwitcherPopover: View { } private func switchToSession(_ sessionId: UUID) { - onDismiss?() - // Try to bring existing window for this connection to front - if let existingWindow = findWindow(for: sessionId) { - existingWindow.makeKeyAndOrderFront(nil) - } else { - openWindowForDifferentConnection(EditorTabPayload(connectionId: sessionId)) - } + openConnection(sessionId) } private func connectToSaved(_ connection: DatabaseConnection) { - isConnecting = connection.id + openConnection(connection.id) + } + + private func openConnection(_ id: UUID) { onDismiss?() - // Open a new window, then connect — window shows "Connecting..." until ready - openWindowForDifferentConnection(EditorTabPayload(connectionId: connection.id)) Task { do { - try await DatabaseManager.shared.connectToSession(connection) + try await TabRouter.shared.route(.openConnection(id)) } catch { await MainActor.run { AlertHelper.showErrorSheet( @@ -313,32 +300,6 @@ struct ConnectionSwitcherPopover: View { ) } } - await MainActor.run { - isConnecting = nil - } - } - } - - /// Find an existing visible window for the given connection ID - private func findWindow(for connectionId: UUID) -> NSWindow? { - WindowLifecycleMonitor.shared.findWindow(for: connectionId) - } - - /// Open a new window for a different connection, ensuring it doesn't - /// merge as a tab with the current connection's window group - /// (unless the user opted to group all connections in one window). - private func openWindowForDifferentConnection(_ payload: EditorTabPayload) { - if AppSettingsManager.shared.tabs.groupAllConnectionTabs { - WindowManager.shared.openTab(payload: payload) - } else { - // Temporarily disable tab merging so the new window opens independently - let currentWindow = NSApp.keyWindow - let previousMode = currentWindow?.tabbingMode ?? .preferred - currentWindow?.tabbingMode = .disallowed - WindowManager.shared.openTab(payload: payload) - DispatchQueue.main.async { - currentWindow?.tabbingMode = previousMode - } } } } diff --git a/TablePro/Views/Toolbar/TableProToolbarView.swift b/TablePro/Views/Toolbar/TableProToolbarView.swift index df95683d6..1d13af062 100644 --- a/TablePro/Views/Toolbar/TableProToolbarView.swift +++ b/TablePro/Views/Toolbar/TableProToolbarView.swift @@ -59,7 +59,6 @@ struct ToolbarPrincipalContent: View { struct TableProToolbar: ViewModifier { @Bindable var state: ConnectionToolbarState @FocusedValue(\.commandActions) private var actions: MainContentCommandActions? - @State private var showConnectionSwitcher = false func body(content: Content) -> some View { content @@ -68,14 +67,14 @@ struct TableProToolbar: ViewModifier { ToolbarItem(placement: .navigation) { Button { - showConnectionSwitcher.toggle() + state.showConnectionSwitcher.toggle() } label: { Label("Connection", systemImage: "network") } .help(String(localized: "Switch Connection (⌘⌥C)")) - .popover(isPresented: $showConnectionSwitcher) { + .popover(isPresented: $state.showConnectionSwitcher) { ConnectionSwitcherPopover { - showConnectionSwitcher = false + state.showConnectionSwitcher = false } } } @@ -229,9 +228,6 @@ struct TableProToolbar: ViewModifier { } } } - .onReceive(NotificationCenter.default.publisher(for: .openConnectionSwitcher)) { _ in - showConnectionSwitcher = true - } } } diff --git a/TableProTests/Core/Concurrency/OnceTaskTests.swift b/TableProTests/Core/Concurrency/OnceTaskTests.swift new file mode 100644 index 000000000..33576a7ef --- /dev/null +++ b/TableProTests/Core/Concurrency/OnceTaskTests.swift @@ -0,0 +1,186 @@ +// +// OnceTaskTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import XCTest + +final class OnceTaskTests: XCTestCase { + actor Counter { + private(set) var value: Int = 0 + + func increment() { + value += 1 + } + } + + private struct TestError: Error, Equatable { + let tag: String + } + + func testConcurrentSameKeyRunsWorkOnce() async throws { + let dedup = OnceTask() + let counter = Counter() + + async let first = dedup.execute(key: "k") { + await counter.increment() + try await Task.sleep(for: .milliseconds(50)) + return 42 + } + async let second = dedup.execute(key: "k") { + await counter.increment() + try await Task.sleep(for: .milliseconds(50)) + return 99 + } + + let results = try await [first, second] + let invocations = await counter.value + + XCTAssertEqual(invocations, 1, "Work block must run exactly once for concurrent same-key callers") + XCTAssertEqual(results[0], results[1], "Concurrent callers must observe the same value") + XCTAssertEqual(results[0], 42, "Both callers must receive the value produced by the first work block") + } + + func testConcurrentDifferentKeysRunWorkSeparately() async throws { + let dedup = OnceTask() + let counter = Counter() + + async let alpha = dedup.execute(key: "alpha") { + await counter.increment() + try await Task.sleep(for: .milliseconds(20)) + return "alpha-value" + } + async let beta = dedup.execute(key: "beta") { + await counter.increment() + try await Task.sleep(for: .milliseconds(20)) + return "beta-value" + } + + let alphaValue = try await alpha + let betaValue = try await beta + let invocations = await counter.value + + XCTAssertEqual(invocations, 2, "Distinct keys must each run their own work block") + XCTAssertEqual(alphaValue, "alpha-value") + XCTAssertEqual(betaValue, "beta-value") + } + + func testThrowingWorkPropagatesAndClearsInFlight() async throws { + let dedup = OnceTask() + let counter = Counter() + + do { + _ = try await dedup.execute(key: "k") { + await counter.increment() + throw TestError(tag: "first") + } + XCTFail("Expected throw from first execute") + } catch let error as TestError { + XCTAssertEqual(error.tag, "first") + } + + let secondValue = try await dedup.execute(key: "k") { + await counter.increment() + return 7 + } + + XCTAssertEqual(secondValue, 7, "After a throw, the next execute must rerun the work") + let invocations = await counter.value + XCTAssertEqual(invocations, 2, "Both work blocks must have run (throw cleared the in-flight slot)") + } + + func testCancelKeyClearsInFlightAndAllowsRerun() async throws { + let dedup = OnceTask() + let counter = Counter() + let started = expectation(description: "work started") + started.assertForOverFulfill = false + + let inFlight = Task { + try await dedup.execute(key: "k") { + await counter.increment() + started.fulfill() + try await Task.sleep(for: .seconds(5)) + return 1 + } + } + + await fulfillment(of: [started], timeout: 2.0) + await dedup.cancel(key: "k") + + do { + _ = try await inFlight.value + XCTFail("Expected CancellationError from cancelled in-flight call") + } catch is CancellationError { + // expected + } catch { + XCTFail("Expected CancellationError, got \(error)") + } + + let rerunValue = try await dedup.execute(key: "k") { + await counter.increment() + return 11 + } + + XCTAssertEqual(rerunValue, 11, "After cancel, a fresh execute must run the work again") + let invocations = await counter.value + XCTAssertEqual(invocations, 2) + } + + func testSequentialSameKeyRunsWorkAgain() async throws { + let dedup = OnceTask() + let counter = Counter() + + let first = try await dedup.execute(key: "k") { + await counter.increment() + return 1 + } + let second = try await dedup.execute(key: "k") { + await counter.increment() + return 2 + } + + XCTAssertEqual(first, 1) + XCTAssertEqual(second, 2) + let invocations = await counter.value + XCTAssertEqual(invocations, 2, "Sequential calls (after first completes) must each run the work") + } + + func testCancelAllCancelsEveryInFlight() async throws { + let dedup = OnceTask() + let firstStarted = expectation(description: "first started") + let secondStarted = expectation(description: "second started") + firstStarted.assertForOverFulfill = false + secondStarted.assertForOverFulfill = false + + let firstTask = Task { + try await dedup.execute(key: "a") { + firstStarted.fulfill() + try await Task.sleep(for: .seconds(5)) + return 1 + } + } + let secondTask = Task { + try await dedup.execute(key: "b") { + secondStarted.fulfill() + try await Task.sleep(for: .seconds(5)) + return 2 + } + } + + await fulfillment(of: [firstStarted, secondStarted], timeout: 2.0) + await dedup.cancelAll() + + for task in [firstTask, secondTask] { + do { + _ = try await task.value + XCTFail("Expected CancellationError from cancelAll") + } catch is CancellationError { + // expected + } catch { + XCTFail("Expected CancellationError, got \(error)") + } + } + } +} diff --git a/TableProTests/Core/Database/MultiConnectionTests.swift b/TableProTests/Core/Database/MultiConnectionTests.swift index 60a2b39bc..6ad99ad1e 100644 --- a/TableProTests/Core/Database/MultiConnectionTests.swift +++ b/TableProTests/Core/Database/MultiConnectionTests.swift @@ -95,13 +95,12 @@ struct DatabaseManagerMultiSessionTests { DatabaseManager.shared.removeSession(for: id2) } - let table = TestFixtures.makeTableInfo(name: "users") DatabaseManager.shared.updateSession(id1) { session in - session.tables = [table] + session.pendingTruncates = ["users"] } - #expect(DatabaseManager.shared.session(for: id1)?.tables.count == 1) - #expect(DatabaseManager.shared.session(for: id2)?.tables.isEmpty == true) + #expect(DatabaseManager.shared.session(for: id1)?.pendingTruncates == ["users"]) + #expect(DatabaseManager.shared.session(for: id2)?.pendingTruncates.isEmpty == true) } @Test("Removing one session does not affect the other") @@ -133,7 +132,7 @@ struct DatabaseManagerMultiSessionTests { let countBefore = DatabaseManager.shared.activeSessions.count DatabaseManager.shared.updateSession(unknownId) { session in - session.tables = [TestFixtures.makeTableInfo(name: "ghost")] + session.pendingTruncates = ["ghost"] } #expect(DatabaseManager.shared.activeSessions.count == countBefore) @@ -228,35 +227,22 @@ struct CoordinatorConnectionIsolationTests { #expect(coordinator2.connectionId == id2) } - @Test("sidebarLoadingState is per-coordinator and does not bleed across instances") - func sidebarLoadingStateIsPerCoordinator() { - let conn1 = TestFixtures.makeConnection(id: UUID(), name: "Conn1", database: "db_a", type: .mysql) - let conn2 = TestFixtures.makeConnection(id: UUID(), name: "Conn2", database: "db_b", type: .mysql) - - let coordinator1 = MainContentCoordinator( - connection: conn1, - tabManager: QueryTabManager(), - changeManager: DataChangeManager(), - filterStateManager: FilterStateManager(), - columnVisibilityManager: ColumnVisibilityManager(), - toolbarState: ConnectionToolbarState() - ) - defer { coordinator1.teardown() } - - let coordinator2 = MainContentCoordinator( - connection: conn2, - tabManager: QueryTabManager(), - changeManager: DataChangeManager(), - filterStateManager: FilterStateManager(), - columnVisibilityManager: ColumnVisibilityManager(), - toolbarState: ConnectionToolbarState() - ) - defer { coordinator2.teardown() } + @Test("Schema state is per-connection in SchemaService") + func schemaStateIsPerConnection() async { + let id1 = UUID() + let id2 = UUID() - coordinator1.sidebarLoadingState = .loading + await SchemaService.shared.invalidate(connectionId: id1) + await SchemaService.shared.invalidate(connectionId: id2) + defer { + Task { + await SchemaService.shared.invalidate(connectionId: id1) + await SchemaService.shared.invalidate(connectionId: id2) + } + } - #expect(coordinator1.sidebarLoadingState == .loading) - #expect(coordinator2.sidebarLoadingState == .idle) + #expect(SchemaService.shared.state(for: id1) == .idle) + #expect(SchemaService.shared.state(for: id2) == .idle) } @Test("openTableTab uses coordinator's connection database for the added tab") diff --git a/TableProTests/Core/MCP/MCPRouterTests.swift b/TableProTests/Core/MCP/MCPRouterTests.swift new file mode 100644 index 000000000..f2c392eae --- /dev/null +++ b/TableProTests/Core/MCP/MCPRouterTests.swift @@ -0,0 +1,167 @@ +// +// MCPRouterTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP Router") +struct MCPRouterTests { + private final class StubHandler: MCPRouteHandler, @unchecked Sendable { + let methods: [HTTPRequest.Method] + let path: String + private let result: MCPRouter.RouteResult + private(set) var invocationCount: Int = 0 + private(set) var lastRequest: HTTPRequest? + + init(methods: [HTTPRequest.Method], path: String, result: MCPRouter.RouteResult = .accepted) { + self.methods = methods + self.path = path + self.result = result + } + + func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { + invocationCount += 1 + lastRequest = request + return result + } + } + + private func makeRequest( + method: HTTPRequest.Method, + path: String, + body: Data? = nil + ) -> HTTPRequest { + HTTPRequest(method: method, path: path, headers: [:], body: body, remoteIP: nil) + } + + @Test("OPTIONS preflight returns noContent regardless of path") + func optionsPreflightAlwaysNoContent() async { + let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) + let router = MCPRouter(routes: [mcpHandler]) + + let optionsAtMcp = makeRequest(method: .options, path: "/mcp") + let result1 = await router.handle(optionsAtMcp) + guard case .noContent = result1 else { + Issue.record("Expected .noContent for OPTIONS /mcp, got \(result1)") + return + } + + let optionsAtUnknown = makeRequest(method: .options, path: "/unknown/path") + let result2 = await router.handle(optionsAtUnknown) + guard case .noContent = result2 else { + Issue.record("Expected .noContent for OPTIONS /unknown, got \(result2)") + return + } + + #expect(mcpHandler.invocationCount == 0) + } + + @Test("POST /mcp dispatches to MCP protocol handler") + func postMcpDispatchesToProtocolHandler() async { + let mcpHandler = StubHandler(methods: [.get, .post, .delete], path: "/mcp", result: .accepted) + let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) + let router = MCPRouter(routes: [mcpHandler, exchangeHandler]) + + let request = makeRequest(method: .post, path: "/mcp") + _ = await router.handle(request) + + #expect(mcpHandler.invocationCount == 1) + #expect(exchangeHandler.invocationCount == 0) + } + + @Test("POST /v1/integrations/exchange dispatches to exchange handler") + func postExchangeDispatchesToExchangeHandler() async { + let mcpHandler = StubHandler(methods: [.get, .post, .delete], path: "/mcp", result: .accepted) + let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) + let router = MCPRouter(routes: [mcpHandler, exchangeHandler]) + + let request = makeRequest(method: .post, path: "/v1/integrations/exchange") + _ = await router.handle(request) + + #expect(exchangeHandler.invocationCount == 1) + #expect(mcpHandler.invocationCount == 0) + } + + @Test("Path with query string still matches canonical route") + func queryStringMatchesCanonicalPath() async { + let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) + let router = MCPRouter(routes: [mcpHandler]) + + let request = makeRequest(method: .post, path: "/mcp?session=abc") + _ = await router.handle(request) + + #expect(mcpHandler.invocationCount == 1) + } + + @Test("Unknown path returns 404 httpError") + func unknownPathReturnsNotFound() async { + let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) + let router = MCPRouter(routes: [mcpHandler]) + + let request = makeRequest(method: .post, path: "/totally/unknown") + let result = await router.handle(request) + + guard case .httpError(let status, _) = result else { + Issue.record("Expected .httpError, got \(result)") + return + } + #expect(status == 404) + #expect(mcpHandler.invocationCount == 0) + } + + @Test("Method mismatch on registered path returns 404") + func methodMismatchReturnsNotFound() async { + let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) + let router = MCPRouter(routes: [exchangeHandler]) + + let request = makeRequest(method: .get, path: "/v1/integrations/exchange") + let result = await router.handle(request) + + guard case .httpError(let status, _) = result else { + Issue.record("Expected .httpError, got \(result)") + return + } + #expect(status == 404) + #expect(exchangeHandler.invocationCount == 0) + } + + @Test(".well-known requests return 404 immediately") + func wellKnownReturnsNotFound() async { + let mcpHandler = StubHandler(methods: [.get], path: "/.well-known/oauth", result: .accepted) + let router = MCPRouter(routes: [mcpHandler]) + + let request = makeRequest(method: .get, path: "/.well-known/oauth") + let result = await router.handle(request) + + guard case .httpError(let status, _) = result else { + Issue.record("Expected .httpError, got \(result)") + return + } + #expect(status == 404) + #expect(mcpHandler.invocationCount == 0) + } + + @Test("Handler receives the original request") + func handlerReceivesOriginalRequest() async { + let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) + let router = MCPRouter(routes: [mcpHandler]) + + let body = Data("{\"hello\":\"world\"}".utf8) + let request = HTTPRequest( + method: .post, + path: "/mcp", + headers: ["content-type": "application/json"], + body: body, + remoteIP: "10.0.0.1" + ) + _ = await router.handle(request) + + #expect(mcpHandler.lastRequest?.path == "/mcp") + #expect(mcpHandler.lastRequest?.method == .post) + #expect(mcpHandler.lastRequest?.body == body) + #expect(mcpHandler.lastRequest?.remoteIP == "10.0.0.1") + } +} diff --git a/TableProTests/Models/ConnectionSessionTests.swift b/TableProTests/Models/ConnectionSessionTests.swift index 9e3e50e67..0c843fef4 100644 --- a/TableProTests/Models/ConnectionSessionTests.swift +++ b/TableProTests/Models/ConnectionSessionTests.swift @@ -19,7 +19,6 @@ struct ConnectionSessionEquivalenceTests { id: UUID = UUID(), database: String = "testdb", type: DatabaseType = .mysql, - tables: [TableInfo] = [], status: ConnectionStatus = .connected ) -> ConnectionSession { let connection = DatabaseConnection( @@ -30,7 +29,6 @@ struct ConnectionSessionEquivalenceTests { ) var session = ConnectionSession(connection: connection) session.status = status - session.tables = tables return session } @@ -71,16 +69,15 @@ struct ConnectionSessionEquivalenceTests { #expect(!a.isContentViewEquivalent(to: b)) } - @Test("Returns false when tables change") - func falseWhenTablesChange() { + @Test("Tables are excluded from equivalence (owned by SchemaService)") + @MainActor + func tablesAreExcludedFromEquivalence() async { let id = UUID() - var a = makeSession(id: id) - var b = makeSession(id: id) - - a.tables = [TestFixtures.makeTableInfo(name: "users")] - b.tables = [TestFixtures.makeTableInfo(name: "orders")] + let a = makeSession(id: id) + let b = makeSession(id: id) - #expect(!a.isContentViewEquivalent(to: b)) + await SchemaService.shared.invalidate(connectionId: id) + #expect(a.isContentViewEquivalent(to: b)) } @Test("Returns false when status changes") @@ -162,14 +159,6 @@ struct ConnectionSessionStateTests { #expect(!session.isConnected) } - @Test("clearCachedData clears tables") - func clearCachedDataClearsTables() { - var session = makeSession() - session.tables = [TestFixtures.makeTableInfo(name: "users")] - session.clearCachedData() - #expect(session.tables.isEmpty) - } - @Test("clearCachedData clears selectedTables") func clearCachedDataClearsSelectedTables() { var session = makeSession() @@ -207,7 +196,7 @@ struct ConnectionSessionStateTests { let connection = TestFixtures.makeConnection(name: "Production") var session = ConnectionSession(connection: connection) session.status = .connected - session.tables = [TestFixtures.makeTableInfo(name: "users")] + session.selectedTables = [TestFixtures.makeTableInfo(name: "users")] session.clearCachedData() #expect(session.status == .connected) #expect(session.connection.id == connection.id)