Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions apps/code/src/main/services/auth/service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,70 @@ describe("AuthService", () => {
"rotated-refresh-token",
);
});

it("preserves the selected project across logout and re-login for the same account", async () => {
vi.mocked(oauthService.startFlow)
.mockResolvedValueOnce({
success: true,
data: {
access_token: "initial-access-token",
refresh_token: "initial-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
})
.mockResolvedValueOnce({
success: true,
data: {
access_token: "second-access-token",
refresh_token: "second-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
});
vi.mocked(oauthService.refreshToken).mockResolvedValue({
success: true,
data: {
access_token: "refreshed-access-token",
refresh_token: "refreshed-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
});

vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ has_access: true }),
}) as unknown as typeof fetch,
);

await service.login("us");
await service.selectProject(84);
await service.logout();

expect(service.getState()).toMatchObject({
status: "anonymous",
cloudRegion: "us",
projectId: 84,
});

await service.login("us");

expect(service.getState()).toMatchObject({
status: "authenticated",
cloudRegion: "us",
projectId: 84,
availableProjectIds: [42, 84],
});
});
});
4 changes: 3 additions & 1 deletion apps/code/src/main/services/auth/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,11 @@ export class AuthService extends TypedEventEmitter<AuthServiceEvents> {
}

async logout(): Promise<AuthState> {
const { cloudRegion, projectId } = this.state;

this.authSessionRepository.clearCurrent();
this.session = null;
this.setAnonymousState();
this.setAnonymousState({ cloudRegion, projectId });
return this.getState();
}

Expand Down
25 changes: 11 additions & 14 deletions apps/code/src/renderer/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import { ScopeReauthPrompt } from "@components/ScopeReauthPrompt";
import { UpdatePrompt } from "@components/UpdatePrompt";
import { AuthScreen } from "@features/auth/components/AuthScreen";
import { InviteCodeScreen } from "@features/auth/components/InviteCodeScreen";
import { useAuthStore } from "@features/auth/stores/authStore";
import { useAuthStateValue } from "@features/auth/hooks/authQueries";
import { useAuthSession } from "@features/auth/hooks/useAuthSession";
import { OnboardingFlow } from "@features/onboarding/components/OnboardingFlow";
import { useOnboardingStore } from "@features/onboarding/stores/onboardingStore";
import { Flex, Spinner, Text } from "@radix-ui/themes";
import { initializeConnectivityStore } from "@renderer/stores/connectivityStore";
import { useFocusStore } from "@renderer/stores/focusStore";
Expand All @@ -25,10 +27,14 @@ const log = logger.scope("app");

function App() {
const trpcReact = useTRPC();
const { isAuthenticated, hasCompletedOnboarding, hasCodeAccess } =
useAuthStore();
const { isBootstrapped } = useAuthSession();
const authState = useAuthStateValue((state) => state);
const hasCompletedOnboarding = useOnboardingStore(
(state) => state.hasCompletedOnboarding,
);
const isAuthenticated = authState.status === "authenticated";
const hasCodeAccess = authState.hasCodeAccess;
const isDarkMode = useThemeStore((state) => state.isDarkMode);
const [isLoading, setIsLoading] = useState(true);
const [showTransition, setShowTransition] = useState(false);
const wasInMainApp = useRef(isAuthenticated && hasCompletedOnboarding);

Expand Down Expand Up @@ -114,15 +120,6 @@ function App() {
}),
);

// Initialize auth state from main process
useEffect(() => {
const initialize = async () => {
await useAuthStore.getState().initializeOAuth();
setIsLoading(false);
};
void initialize();
}, []);

// Handle transition into main app — only show the dark overlay if dark mode is active
useEffect(() => {
const isInMainApp = isAuthenticated && hasCompletedOnboarding;
Expand All @@ -136,7 +133,7 @@ function App() {
setShowTransition(false);
};

if (isLoading) {
if (!isBootstrapped) {
return (
<Flex align="center" justify="center" minHeight="100vh">
<Flex align="center" gap="3">
Expand Down
148 changes: 55 additions & 93 deletions apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx
Original file line number Diff line number Diff line change
@@ -1,65 +1,40 @@
import { Theme } from "@radix-ui/themes";
import { render, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import type { ReactElement } from "react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { ScopeReauthPrompt } from "./ScopeReauthPrompt";

vi.mock("@renderer/trpc/client", () => ({
trpcClient: {
auth: {
getState: { query: vi.fn() },
onStateChanged: { subscribe: vi.fn(() => ({ unsubscribe: vi.fn() })) },
getValidAccessToken: {
query: vi.fn().mockResolvedValue({
accessToken: "token",
apiHost: "https://us.posthog.com",
}),
},
refreshAccessToken: {
mutate: vi.fn().mockResolvedValue({
accessToken: "token",
apiHost: "https://us.posthog.com",
}),
},
login: {
mutate: vi.fn().mockResolvedValue({
state: {
status: "authenticated",
bootstrapComplete: true,
cloudRegion: "us",
projectId: 1,
availableProjectIds: [1],
availableOrgIds: [],
hasCodeAccess: true,
needsScopeReauth: false,
},
}),
},
signup: { mutate: vi.fn() },
selectProject: { mutate: vi.fn() },
redeemInviteCode: { mutate: vi.fn() },
logout: {
mutate: vi.fn().mockResolvedValue({
status: "anonymous",
bootstrapComplete: true,
cloudRegion: null,
projectId: null,
availableProjectIds: [],
availableOrgIds: [],
hasCodeAccess: null,
needsScopeReauth: false,
}),
},
},
analytics: {
setUserId: { mutate: vi.fn().mockResolvedValue(undefined) },
resetUser: { mutate: vi.fn().mockResolvedValue(undefined) },
},
},
const authState = {
status: "anonymous" as const,
bootstrapComplete: true,
cloudRegion: null as "us" | "eu" | "dev" | null,
projectId: null,
availableProjectIds: [],
availableOrgIds: [],
hasCodeAccess: null,
needsScopeReauth: false,
};

const mockLoginMutateAsync = vi.fn();
const mockLogoutMutate = vi.fn(() => {
authState.needsScopeReauth = false;
authState.cloudRegion = null;
});

vi.mock("@features/auth/hooks/authQueries", () => ({
useAuthStateValue: (selector: (state: typeof authState) => unknown) =>
selector(authState),
}));

vi.mock("@utils/analytics", () => ({
identifyUser: vi.fn(),
resetUser: vi.fn(),
track: vi.fn(),
vi.mock("@features/auth/hooks/authMutations", () => ({
useLoginMutation: () => ({
mutateAsync: mockLoginMutateAsync,
isPending: false,
}),
useLogoutMutation: () => ({
mutate: mockLogoutMutate,
}),
}));

vi.mock("@utils/logger", () => ({
Expand All @@ -73,40 +48,18 @@ vi.mock("@utils/logger", () => ({
},
}));

vi.mock("@utils/queryClient", () => ({
queryClient: {
clear: vi.fn(),
setQueryData: vi.fn(),
removeQueries: vi.fn(),
},
}));

vi.mock("@stores/navigationStore", () => ({
useNavigationStore: {
getState: () => ({ navigateToTaskInput: vi.fn() }),
},
}));

import {
resetAuthStoreModuleStateForTest,
useAuthStore,
} from "@features/auth/stores/authStore";
import { Theme } from "@radix-ui/themes";
import type { ReactElement } from "react";
import { ScopeReauthPrompt } from "./ScopeReauthPrompt";

function renderWithTheme(ui: ReactElement) {
return render(<Theme>{ui}</Theme>);
}

describe("ScopeReauthPrompt", () => {
beforeEach(() => {
localStorage.clear();
resetAuthStoreModuleStateForTest();
useAuthStore.setState({
needsScopeReauth: false,
cloudRegion: null,
});
vi.clearAllMocks();
authState.status = "anonymous";
authState.cloudRegion = null;
authState.projectId = null;
authState.hasCodeAccess = null;
authState.needsScopeReauth = false;
});

it("does not render dialog when needsScopeReauth is false", () => {
Expand All @@ -117,25 +70,34 @@ describe("ScopeReauthPrompt", () => {
});

it("renders dialog when needsScopeReauth is true", () => {
useAuthStore.setState({ needsScopeReauth: true, cloudRegion: "us" });
authState.needsScopeReauth = true;
authState.cloudRegion = "us";

renderWithTheme(<ScopeReauthPrompt />);

expect(screen.getByText("Re-authentication required")).toBeInTheDocument();
});

it("disables Sign in button when cloudRegion is null", () => {
useAuthStore.setState({ needsScopeReauth: true, cloudRegion: null });
authState.needsScopeReauth = true;

renderWithTheme(<ScopeReauthPrompt />);

expect(screen.getByRole("button", { name: "Sign in" })).toBeDisabled();
});

it("enables Sign in button when cloudRegion is set", () => {
useAuthStore.setState({ needsScopeReauth: true, cloudRegion: "us" });
authState.needsScopeReauth = true;
authState.cloudRegion = "us";

renderWithTheme(<ScopeReauthPrompt />);

expect(screen.getByRole("button", { name: "Sign in" })).not.toBeDisabled();
});

it("shows Log out button as an escape hatch when cloudRegion is null", () => {
useAuthStore.setState({ needsScopeReauth: true, cloudRegion: null });
authState.needsScopeReauth = true;

renderWithTheme(<ScopeReauthPrompt />);

const logoutButton = screen.getByRole("button", { name: "Log out" });
Expand All @@ -145,14 +107,14 @@ describe("ScopeReauthPrompt", () => {

it("calls logout when Log out button is clicked", async () => {
const user = userEvent.setup();
useAuthStore.setState({ needsScopeReauth: true, cloudRegion: null });
authState.needsScopeReauth = true;

renderWithTheme(<ScopeReauthPrompt />);

await user.click(screen.getByRole("button", { name: "Log out" }));

const state = useAuthStore.getState();
expect(state.needsScopeReauth).toBe(false);
expect(state.isAuthenticated).toBe(false);
expect(state.cloudRegion).toBeNull();
expect(mockLogoutMutate).toHaveBeenCalledTimes(1);
expect(authState.needsScopeReauth).toBe(false);
expect(authState.cloudRegion).toBeNull();
});
});
Loading
Loading