diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index c0a569c..16a5fde 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -71,6 +71,7 @@ jobs: build-docker: runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 53d430f..ca09a91 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,15 @@ cd data ## Login +The bridge offers two login flows: + +- **QR Code** (recommended): scan the QR code with LINE on your mobile device. + On a first login, LINE may show a PIN prompt on mobile; enter the PIN shown + by the bridge. If a saved LINE certificate is still valid, the PIN step is + skipped. +- **Email and Password**: enter the email/password configured on your LINE + account. This remains available as a secondary login path. + ### Via Beeper Desktop Settings 1. Open Beeper Desktop Settings @@ -304,9 +313,10 @@ There are two common reasons login can fail: ### 1. No email is set on your LINE account -This bridge uses the email from your account information. If your -account is older, you signed in using a phone number, or you signed in -with Google, you may not have an email set for your LINE account. +The QR code flow does not require a LINE email address. The email/password +flow uses the email from your account information. If your account is older, +or you signed in with Google or Apple, you may not have an email set for your +LINE account. **How to set an email for your LINE account:** diff --git a/pkg/connector/client.go b/pkg/connector/client.go index a72d873..a9ffcd3 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -28,6 +28,7 @@ type LineClient struct { reqSeqMu sync.Mutex sentReqSeqs map[int]time.Time + tokenMu sync.Mutex lastReqSeq int // cacheMu protects peerKeys, blockedUsers, contactCache, mediaFlowCache, @@ -98,7 +99,7 @@ func (lc *LineClient) shouldUseE2EEMediaFlow(chatMid string, contentType int) bo } lc.cacheMu.Unlock() - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() resp, err := client.DetermineMediaMessageFlow(chatMid) if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Str("chat_mid", chatMid). @@ -140,7 +141,31 @@ var _ bridgev2.ReadReceiptHandlingNetworkAPI = (*LineClient)(nil) var _ bridgev2.BackfillingNetworkAPI = (*LineClient)(nil) var _ bridgev2.ReactionHandlingNetworkAPI = (*LineClient)(nil) +func (lc *LineClient) accessToken() string { + lc.tokenMu.Lock() + defer lc.tokenMu.Unlock() + return lc.AccessToken +} + +func (lc *LineClient) setTokens(accessToken, refreshToken string, updateRefreshToken bool) (string, string) { + lc.tokenMu.Lock() + defer lc.tokenMu.Unlock() + + lc.AccessToken = accessToken + if updateRefreshToken { + lc.RefreshToken = refreshToken + } + return lc.AccessToken, lc.RefreshToken +} + +func (lc *LineClient) newLineClient() *line.Client { + return line.NewClient(lc.accessToken()) +} + func (lc *LineClient) refreshAndSave(ctx context.Context) error { + lc.tokenMu.Lock() + defer lc.tokenMu.Unlock() + if lc.RefreshToken == "" { return fmt.Errorf("no refresh token available") } @@ -155,7 +180,6 @@ func (lc *LineClient) refreshAndSave(ctx context.Context) error { if res.RefreshToken != "" { lc.RefreshToken = res.RefreshToken } - // Rotating the main access token invalidates any OBS token derived from it, // so drop the cached one — the next OBS call will mint a fresh one. line.InvalidateOBSTokenCache() @@ -174,7 +198,10 @@ func (lc *LineClient) refreshAndSave(ctx context.Context) error { } func (lc *LineClient) isRefreshRequired(err error) bool { - return strings.Contains(err.Error(), "\"code\":119") || strings.Contains(err.Error(), "Access token refresh required") + msg := err.Error() + return strings.Contains(msg, "\"code\":119") || + strings.Contains(msg, "Access token refresh required") || + (strings.Contains(msg, "\"code\":10051") && strings.Contains(msg, "Authentication Failed")) } func (lc *LineClient) isLoggedOut(err error) bool { @@ -188,9 +215,13 @@ func (lc *LineClient) recoverToken(ctx context.Context) error { if err := lc.refreshAndSave(ctx); err == nil { lc.UserLogin.Bridge.Log.Info().Msg("Token recovered via refresh") return nil + } else { + lc.UserLogin.Bridge.Log.Info().Err(err).Msg("Refresh failed, attempting re-login with stored credentials...") + if errLogin := lc.tryLogin(ctx); errLogin != nil { + return fmt.Errorf("refresh failed: %w; re-login failed: %v", err, errLogin) + } + return nil } - lc.UserLogin.Bridge.Log.Info().Msg("Refresh failed, attempting re-login with stored credentials...") - return lc.tryLogin(ctx) } func (lc *LineClient) Connect(ctx context.Context) { @@ -219,8 +250,8 @@ func (lc *LineClient) Connect(ctx context.Context) { lc.Mid = meta.Mid } } - if lc.AccessToken == "" { - if err := lc.tryLogin(ctx); err != nil { + if lc.accessToken() == "" { + if err := lc.recoverToken(ctx); err != nil { lc.UserLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateBadCredentials, Error: "line-login-failed", @@ -242,7 +273,7 @@ func (lc *LineClient) Connect(ctx context.Context) { return } - lc.UserLogin.Bridge.Log.Info().Int("token_len", len(lc.AccessToken)).Msg("LINE client connected; notifying bridge") + lc.UserLogin.Bridge.Log.Info().Int("token_len", len(lc.accessToken())).Msg("LINE client connected; notifying bridge") lc.UserLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateConnected, }) @@ -262,7 +293,7 @@ func (lc *LineClient) Connect(ctx context.Context) { } // Storage key is optional for runtime decrypt/encrypt; try it for file support - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() ei3, err := client.GetEncryptedIdentityV3() if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to fetch EncryptedIdentityV3") @@ -279,7 +310,7 @@ func (lc *LineClient) Connect(ctx context.Context) { // Fetch initial blocked contacts list before starting sync loops. blockedMIDs, err := func() ([]string, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() return client.GetBlockedContactIds() }() if err != nil { @@ -346,7 +377,7 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { lc.UserLogin.Bridge.Log.Info().Msg("Waiting for PIN verification on mobile device...") waitClient := line.NewClient("") - waitRes, err := waitClient.WaitForLogin(res.Verifier, res.NoE2EE) + waitRes, err := waitClient.WaitForLogin(res.Verifier, res.NoE2EE, res.LoginKeyID) if err != nil { return fmt.Errorf("PIN verification failed: %w", err) } @@ -357,15 +388,22 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { res = waitRes client = waitClient } - lc.AccessToken = client.AccessToken + accessToken := client.AccessToken + if accessToken == "" { + accessToken = res.AuthToken + } + refreshToken := "" + updateRefreshToken := false if res.TokenV3IssueResult != nil { if res.TokenV3IssueResult.AccessToken != "" { - lc.AccessToken = res.TokenV3IssueResult.AccessToken + accessToken = res.TokenV3IssueResult.AccessToken } if res.TokenV3IssueResult.RefreshToken != "" { - lc.RefreshToken = res.TokenV3IssueResult.RefreshToken + refreshToken = res.TokenV3IssueResult.RefreshToken + updateRefreshToken = true } } + accessToken, refreshToken = lc.setTokens(accessToken, refreshToken, updateRefreshToken) // Re-login replaces the main access token, which invalidates any cached // OBS token derived from the previous one. @@ -380,8 +418,8 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { // Save the new tokens and updated certificate to metadata if meta, ok := lc.UserLogin.Metadata.(*UserLoginMetadata); ok { - meta.AccessToken = lc.AccessToken - meta.RefreshToken = lc.RefreshToken + meta.AccessToken = accessToken + meta.RefreshToken = refreshToken if res.Certificate != "" { meta.Certificate = res.Certificate } @@ -395,7 +433,7 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { } func (lc *LineClient) ensureValidToken(ctx context.Context) error { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() _, err := client.GetProfile() if err == nil { return nil @@ -427,7 +465,7 @@ func (lc *LineClient) Disconnect() { lc.wg.Wait() } -func (lc *LineClient) IsLoggedIn() bool { return lc.AccessToken != "" } +func (lc *LineClient) IsLoggedIn() bool { return lc.accessToken() != "" } func (lc *LineClient) GetUserID() networkid.UserID { return makeUserID(lc.Mid) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 4a9f213..d616a98 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -18,6 +18,13 @@ import ( "github.com/highesttt/matrix-line-messenger/pkg/e2ee" "github.com/highesttt/matrix-line-messenger/pkg/line" + "github.com/highesttt/matrix-line-messenger/pkg/line/secret" +) + +const ( + lineQRLoginFlowID = "qr" + lineEmailLoginFlowID = "dev.highest.matrix.line.email_login" + lineEmailLoginFlowAlias = "email" ) type LineConnector struct { @@ -125,15 +132,230 @@ func (lc *LineConnector) LoadUserLogin(ctx context.Context, login *bridgev2.User } func (lc *LineConnector) GetLoginFlows() []bridgev2.LoginFlow { - return []bridgev2.LoginFlow{{ - Name: "Login", - Description: "Login with your LINE Email and Password", - ID: "dev.highest.matrix.line.email_login", - }} + return []bridgev2.LoginFlow{ + { + Name: "QR Code", + Description: "Login by scanning a QR code with the LINE mobile app", + ID: lineQRLoginFlowID, + }, + { + Name: "Email and Password", + Description: "Login with your LINE email and password", + ID: lineEmailLoginFlowID, + }, + } } func (lc *LineConnector) CreateLogin(ctx context.Context, user *bridgev2.User, flowID string) (bridgev2.LoginProcess, error) { - return &LineEmailLogin{User: user}, nil + switch flowID { + case "", lineQRLoginFlowID: + return &LineQRLogin{User: user, Certificate: storedCertificateForUser(user)}, nil + case lineEmailLoginFlowID, lineEmailLoginFlowAlias: + return &LineEmailLogin{User: user}, nil + default: + return nil, bridgev2.ErrInvalidLoginFlowID + } +} + +func storedCertificateForUser(user *bridgev2.User) string { + if user == nil { + return "" + } + for _, login := range user.GetUserLogins() { + meta, ok := login.Metadata.(*UserLoginMetadata) + if ok && meta.Certificate != "" { + return meta.Certificate + } + } + return "" +} + +type LineQRLogin struct { + User *bridgev2.User + client *line.Client + AuthSessionID string + Certificate string + LoginKeyID int + LongPollingMaxCount int + LongPollingIntervalSeconds int + QRScanned bool + PINRequired bool + + pollErr chan error + pollResult chan struct{} + pollCancel context.CancelFunc + pollMu sync.Mutex +} + +var _ bridgev2.LoginProcessDisplayAndWait = (*LineQRLogin)(nil) + +func (lq *LineQRLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) { + lq.client = line.NewClient("") + sessionID, err := lq.client.CreateQRSession() + if err != nil { + return nil, fmt.Errorf("failed to create QR session: %w", err) + } + qrCode, err := lq.client.CreateQRCode(sessionID) + if err != nil { + return nil, fmt.Errorf("failed to create QR code: %w", err) + } + secretRes, err := secret.GenerateSecret() + if err != nil { + return nil, fmt.Errorf("failed to generate QR login e2ee secret: %w", err) + } + lq.LoginKeyID = secretRes.LoginKeyID + callbackURL, err := line.QRCodeCallbackURLWithE2EESecret(qrCode.CallbackURL, secretRes.PublicKeyBase64) + if err != nil { + return nil, err + } + + lq.AuthSessionID = sessionID + lq.LongPollingMaxCount = qrCode.LongPollingMaxCount + lq.LongPollingIntervalSeconds = qrCode.LongPollingIntervalSeconds + lq.startPoll(func(ctx context.Context) error { + return lq.client.CheckQRCodeVerifiedContext(ctx, sessionID) + }) + + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeDisplayAndWait, + StepID: "dev.highest.matrix.line.qr", + Instructions: "Scan this QR code with the LINE mobile app.", + DisplayAndWaitParams: &bridgev2.LoginDisplayAndWaitParams{ + Type: bridgev2.LoginDisplayTypeQR, + Data: callbackURL, + }, + }, nil +} + +func (lq *LineQRLogin) Wait(ctx context.Context) (*bridgev2.LoginStep, error) { + if lq.client == nil || lq.AuthSessionID == "" { + return nil, fmt.Errorf("QR login has not been started") + } + + if lq.PINRequired { + if err := lq.waitForPoll(ctx); err != nil { + return nil, fmt.Errorf("PIN verification failed: %w", err) + } + res, err := lq.client.QRCodeLoginV2(lq.AuthSessionID) + if err != nil { + return nil, fmt.Errorf("failed to complete QR login: %w", err) + } + return finishLineLoginWithLoginKey(ctx, lq.User, "", "", res, lq.LoginKeyID) + } + + if !lq.QRScanned { + if err := lq.waitForPoll(ctx); err != nil { + return nil, fmt.Errorf("QR verification failed: %w", err) + } + lq.QRScanned = true + + if lq.Certificate != "" { + if err := lq.client.VerifyCertificate(lq.AuthSessionID, lq.Certificate); err == nil { + res, err := lq.client.QRCodeLoginV2(lq.AuthSessionID) + if err != nil { + return nil, fmt.Errorf("failed to complete QR login: %w", err) + } + return finishLineLoginWithLoginKey(ctx, lq.User, "", "", res, lq.LoginKeyID) + } else if !line.IsQRLoginCertificateRejected(err) { + return nil, fmt.Errorf("failed to verify QR login certificate: %w", err) + } + } + + pin, err := lq.client.CreatePinCode(lq.AuthSessionID) + if err != nil { + return nil, fmt.Errorf("failed to create QR login PIN: %w", err) + } + lq.PINRequired = true + lq.startPoll(func(ctx context.Context) error { + return lq.client.CheckPinCodeVerifiedContext(ctx, lq.AuthSessionID) + }) + + return linePINStep("dev.highest.matrix.line.qr_pin", pin), nil + } + + return nil, fmt.Errorf("no pending QR login continuation") +} + +func (lq *LineQRLogin) startPoll(fn func(context.Context) error) { + lq.Cancel() + ctx, cancel := context.WithCancel(context.Background()) + pollErr := make(chan error, 1) + pollResult := make(chan struct{}, 1) + + lq.pollMu.Lock() + lq.pollCancel = cancel + lq.pollErr = pollErr + lq.pollResult = pollResult + lq.pollMu.Unlock() + + go func() { + if err := lq.pollWithRetry(ctx, fn); err != nil { + pollErr <- err + } else { + pollResult <- struct{}{} + } + }() +} + +func (lq *LineQRLogin) pollWithRetry(ctx context.Context, fn func(context.Context) error) error { + maxCount := lq.LongPollingMaxCount + if maxCount <= 0 { + maxCount = 1 + } + interval := time.Duration(lq.LongPollingIntervalSeconds) * time.Second + + var err error + for attempt := 0; attempt < maxCount; attempt++ { + err = fn(ctx) + if err == nil { + return nil + } + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if attempt == maxCount-1 { + break + } + if interval <= 0 { + continue + } + timer := time.NewTimer(interval) + select { + case <-timer.C: + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + } + } + return err +} + +func (lq *LineQRLogin) waitForPoll(ctx context.Context) error { + lq.pollMu.Lock() + pollResult := lq.pollResult + pollErr := lq.pollErr + lq.pollMu.Unlock() + + select { + case <-pollResult: + return nil + case err := <-pollErr: + return err + case <-ctx.Done(): + lq.Cancel() + return ctx.Err() + } +} + +func (lq *LineQRLogin) Cancel() { + lq.pollMu.Lock() + defer lq.pollMu.Unlock() + if lq.pollCancel != nil { + lq.pollCancel() + lq.pollCancel = nil + } } type LineEmailLogin struct { @@ -141,6 +363,7 @@ type LineEmailLogin struct { Email string Password string Verifier string + LoginKeyID int AwaitingPIN bool NoE2EE bool // True when login fell back to non-E2EE (LSOFF account) @@ -280,6 +503,7 @@ func (ll *LineEmailLogin) handleLoginResponse(ctx context.Context, res *line.Log if (res.Type == 3 || res.Type == 0) && res.Verifier != "" { ll.Verifier = res.Verifier + ll.LoginKeyID = res.LoginKeyID ll.NoE2EE = res.NoE2EE ll.AwaitingPIN = false instructions := "Please open the LINE app on your mobile device to complete the login." @@ -298,7 +522,7 @@ func (ll *LineEmailLogin) handleLoginResponse(ctx context.Context, res *line.Log ll.pollErr = make(chan error, 1) go func() { client := line.NewClient("") - res, err := client.WaitForLogin(ll.Verifier, ll.NoE2EE) + res, err := client.WaitForLogin(ll.Verifier, ll.NoE2EE, ll.LoginKeyID) if err != nil { ll.pollErr <- err } else { @@ -333,6 +557,18 @@ func (ll *LineEmailLogin) handleLoginResponse(ctx context.Context, res *line.Log } func (ll *LineEmailLogin) finishLogin(ctx context.Context, res *line.LoginResult) (*bridgev2.LoginStep, error) { + return finishLineLogin(ctx, ll.User, ll.Email, ll.Password, res) +} + +func finishLineLogin(ctx context.Context, user *bridgev2.User, email, password string, res *line.LoginResult) (*bridgev2.LoginStep, error) { + loginKeyID := 0 + if res != nil { + loginKeyID = res.LoginKeyID + } + return finishLineLoginWithLoginKey(ctx, user, email, password, res, loginKeyID) +} + +func finishLineLoginWithLoginKey(ctx context.Context, user *bridgev2.User, email, password string, res *line.LoginResult, loginKeyID int) (*bridgev2.LoginStep, error) { if res == nil { return nil, fmt.Errorf("login result missing") } @@ -360,13 +596,13 @@ func (ll *LineEmailLogin) finishLogin(ctx context.Context, res *line.LoginResult displayName = "LINE User" } - meta := &UserLoginMetadata{AccessToken: token, RefreshToken: refreshToken, Email: ll.Email, Password: ll.Password, Certificate: res.Certificate, Mid: res.Mid} + meta := &UserLoginMetadata{AccessToken: token, RefreshToken: refreshToken, Email: email, Password: password, Certificate: res.Certificate, Mid: profile.Mid} - ll.fetchLoginKeys(res, meta, client) + fetchLoginKeys(user, res, meta, client, loginKeyID) detectedLineID := networkid.UserLoginID(profile.Mid) - ul, err := ll.User.NewLogin(ctx, &database.UserLogin{ + ul, err := user.NewLogin(ctx, &database.UserLogin{ ID: detectedLineID, RemoteName: displayName, Metadata: meta, @@ -376,6 +612,7 @@ func (ll *LineEmailLogin) finishLogin(ctx context.Context, res *line.LoginResult UserLogin: login, AccessToken: token, RefreshToken: refreshToken, + Mid: profile.Mid, HTTPClient: &http.Client{Timeout: 10 * time.Second}, } return nil @@ -396,7 +633,18 @@ func (ll *LineEmailLogin) finishLogin(ctx context.Context, res *line.LoginResult }, nil } -func (ll *LineEmailLogin) fetchLoginKeys(res *line.LoginResult, meta *UserLoginMetadata, client *line.Client) { +func linePINStep(stepID, pin string) *bridgev2.LoginStep { + return &bridgev2.LoginStep{ + Type: bridgev2.LoginStepTypeDisplayAndWait, + StepID: stepID, + Instructions: fmt.Sprintf("Please open the LINE app on your mobile device and enter this PIN code: %s", pin), + DisplayAndWaitParams: &bridgev2.LoginDisplayAndWaitParams{ + Type: bridgev2.LoginDisplayTypeNothing, + }, + } +} + +func fetchLoginKeys(user *bridgev2.User, res *line.LoginResult, meta *UserLoginMetadata, client *line.Client, loginKeyID int) { if res.EncryptedKeyChain == "" || res.E2EEPublicKey == "" { return } @@ -406,26 +654,26 @@ func (ll *LineEmailLogin) fetchLoginKeys(res *line.LoginResult, meta *UserLoginM meta.E2EEKeyID = res.E2EEKeyID mgr, err := e2ee.NewManager() if err != nil { - ll.User.Bridge.Log.Warn().Err(err).Msg("Login: failed to create E2EE manager") + user.Bridge.Log.Warn().Err(err).Msg("Login: failed to create E2EE manager") return } ei3, err := client.GetEncryptedIdentityV3() if err != nil { - ll.User.Bridge.Log.Warn().Err(err).Msg("Login: failed to get EncryptedIdentityV3") + user.Bridge.Log.Warn().Err(err).Msg("Login: failed to get EncryptedIdentityV3") return } if err := mgr.InitStorage(ei3.WrappedNonce, ei3.KDFParameter1, ei3.KDFParameter2); err != nil { - ll.User.Bridge.Log.Warn().Err(err).Msg("Login: InitStorage failed") + user.Bridge.Log.Warn().Err(err).Msg("Login: InitStorage failed") return } - exported, err := mgr.InitFromLoginKeyChain(res.E2EEPublicKey, res.EncryptedKeyChain) + exported, err := mgr.InitFromLoginKeyChainWithKey(loginKeyID, res.E2EEPublicKey, res.EncryptedKeyChain) if err != nil { - ll.User.Bridge.Log.Warn().Err(err).Msg("Login: InitFromLoginKeyChain failed") + user.Bridge.Log.Warn().Err(err).Msg("Login: InitFromLoginKeyChain failed") return } meta.ExportedKeyMap = exported - _ = mgr.SaveSecureDataToFile(string(ll.User.MXID), map[string]any{"exportedKeyMap": exported}) - ll.User.Bridge.Log.Info().Int("keys", len(exported)).Msg("Login: E2EE keys exported successfully") + _ = mgr.SaveSecureDataToFile(string(user.MXID), map[string]any{"exportedKeyMap": exported}) + user.Bridge.Log.Info().Int("keys", len(exported)).Msg("Login: E2EE keys exported successfully") } func (ll *LineEmailLogin) Cancel() {} diff --git a/pkg/connector/creategroup.go b/pkg/connector/creategroup.go index eb6ddbc..1e09360 100644 --- a/pkg/connector/creategroup.go +++ b/pkg/connector/creategroup.go @@ -26,7 +26,7 @@ func (lc *LineClient) CreateGroup(ctx context.Context, params *bridgev2.GroupCre name = params.Name.Name } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() var chat *line.Chat var err error chatType := 1 // ROOM: members join automatically. @@ -34,7 +34,7 @@ func (lc *LineClient) CreateGroup(ctx context.Context, params *bridgev2.GroupCre chat, err = client.CreateChat(participantMids, lineName, chatType) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chat, err = client.CreateChat(participantMids, lineName, chatType) } } @@ -150,7 +150,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb } members = otherMembers - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() // Fetch current E2EE public keys for all other members as a batch. If the batch // call fails (e.g. server 500 for a specific member), fall back to fetching @@ -163,7 +163,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() pubKeys, err = client.GetLastE2EEPublicKeys(pubKeysReq) } } @@ -184,7 +184,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb } if lc.isRefreshRequired(nErr) || lc.isLoggedOut(nErr) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() res, nErr = client.NegotiateE2EEPublicKey(mid) } } @@ -259,7 +259,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb if err := client.RegisterE2EEGroupKey(1, chatMid, apiMembers, keyIds, encryptedKeys); err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() err = client.RegisterE2EEGroupKey(1, chatMid, apiMembers, keyIds, encryptedKeys) } } diff --git a/pkg/connector/e2ee_keys.go b/pkg/connector/e2ee_keys.go index 5f453b5..10e9362 100644 --- a/pkg/connector/e2ee_keys.go +++ b/pkg/connector/e2ee_keys.go @@ -21,7 +21,7 @@ func (lc *LineClient) fetchAndUnwrapGroupKey(ctx context.Context, chatMid string return fmt.Errorf("E2EE manager not initialized") } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() fetch := func() (*line.E2EEGroupSharedKey, error) { if groupKeyID > 0 { return client.GetE2EEGroupSharedKey(chatMid, groupKeyID) @@ -44,7 +44,7 @@ func (lc *LineClient) fetchAndUnwrapGroupKey(ctx context.Context, chatMid string // Token recovery for other error types if err != nil && !line.IsNoUsableE2EEGroupKey(err) && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() sharedKey, err = fetch() } else { return fmt.Errorf("failed to recover token before fetching group key: %w", errRecover) @@ -86,7 +86,7 @@ func (lc *LineClient) fetchAndUnwrapGroupKey(ctx context.Context, chatMid string return nil } -func (lc *LineClient) ensurePeerKey(_ context.Context, mid string) (int, string, error) { +func (lc *LineClient) ensurePeerKey(ctx context.Context, mid string) (int, string, error) { lc.cacheMu.Lock() if lc.peerKeys == nil { lc.peerKeys = make(map[string]peerKeyInfo) @@ -107,8 +107,14 @@ func (lc *LineClient) ensurePeerKey(_ context.Context, mid string) (int, string, return cached.raw, cached.pub, nil } } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() res, err := client.NegotiateE2EEPublicKey(mid) + if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { + if errRecover := lc.recoverToken(ctx); errRecover == nil { + client = lc.newLineClient() + res, err = client.NegotiateE2EEPublicKey(mid) + } + } if err != nil { // Cache negative result so we don't keep hitting the API if line.IsNoUsableE2EEPublicKey(err) { @@ -165,12 +171,12 @@ func (lc *LineClient) clearGroupNoE2EE(chatMid string) { // Invitees are included because group key registration must happen before they accept, // otherwise the key won't be available when they start sending messages. func (lc *LineClient) getChatMemberMIDs(ctx context.Context, chatMid string) ([]string, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() chats, err := client.GetChats([]string{chatMid}, true, true) if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chats, err = client.GetChats([]string{chatMid}, true, true) } } @@ -305,7 +311,7 @@ func (lc *LineClient) getGroupMemberMIDsViaMatrix(ctx context.Context, chatMid s return mids, nil } -func (lc *LineClient) ensurePeerKeyByID(_ context.Context, mid string, keyID int) (int, string, error) { +func (lc *LineClient) ensurePeerKeyByID(ctx context.Context, mid string, keyID int) (int, string, error) { lc.cacheMu.Lock() if lc.peerKeys == nil { lc.peerKeys = make(map[string]peerKeyInfo) @@ -319,9 +325,15 @@ func (lc *LineClient) ensurePeerKeyByID(_ context.Context, mid string, keyID int return cached.raw, cached.pub, nil } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() // keyVersion 1 res, err := client.GetE2EEPublicKey(mid, 1, keyID) + if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { + if errRecover := lc.recoverToken(ctx); errRecover == nil { + client = lc.newLineClient() + res, err = client.GetE2EEPublicKey(mid, 1, keyID) + } + } if err != nil { return 0, "", err } diff --git a/pkg/connector/handle_message.go b/pkg/connector/handle_message.go index 2c7fb29..75a9aec 100644 --- a/pkg/connector/handle_message.go +++ b/pkg/connector/handle_message.go @@ -30,7 +30,7 @@ func (lc *LineClient) newMessageHandler() *handlers.Handler { RecoverToken: lc.recoverToken, IsRefreshRequired: lc.isRefreshRequired, IsLoggedOut: lc.isLoggedOut, - NewClient: func() *line.Client { return line.NewClient(lc.AccessToken) }, + NewClient: func() *line.Client { return lc.newLineClient() }, DecryptMedia: lc.decryptImageData, } } diff --git a/pkg/connector/handlers/handler.go b/pkg/connector/handlers/handler.go index 12b1a66..46e65ee 100644 --- a/pkg/connector/handlers/handler.go +++ b/pkg/connector/handlers/handler.go @@ -10,6 +10,8 @@ import ( "github.com/highesttt/matrix-line-messenger/pkg/line" ) +const obsDownloadUnauthorizedMsg = "OBS download failed (401)" + // Handler provides dependencies needed by content type conversion functions. type Handler struct { Log zerolog.Logger @@ -33,6 +35,19 @@ func (h *Handler) tryRecoverClient(ctx context.Context, err error) (*line.Client if err == nil { return nil, false } + isOBSAuthFailure := strings.Contains(err.Error(), obsDownloadUnauthorizedMsg) + if isOBSAuthFailure { + h.Log.Warn(). + Err(err). + Str("action", "clearing_encrypted_access_token_cache"). + Msg("Recovering stale OBS access token") + // isOBSAuthFailure means the OBS token derived from the main LINE access + // token is stale. Use line.ClearEncryptedAccessTokenCache() and retry + // with h.NewClient() instead of invoking the broader 401 + // refresh/logout recovery. + line.ClearEncryptedAccessTokenCache() + return h.NewClient(), true + } if !strings.Contains(err.Error(), "401") && !h.IsRefreshRequired(err) && !h.IsLoggedOut(err) { return nil, false } diff --git a/pkg/connector/reaction.go b/pkg/connector/reaction.go index b571861..88cb2df 100644 --- a/pkg/connector/reaction.go +++ b/pkg/connector/reaction.go @@ -365,7 +365,7 @@ func (lc *LineClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.Ma return nil, err } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() reqSeq := lc.nextReqSeq() if err = client.React(int64(reqSeq), targetID, ref.reactionType()); err != nil { if line.IsInvalidPaidReactionType(err) { @@ -391,7 +391,7 @@ func (lc *LineClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridg if err != nil { return err } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() reqSeq := lc.nextReqSeq() err = client.CancelReaction(int64(reqSeq), targetID) if line.IsNotAMemberError(err) { diff --git a/pkg/connector/send_message.go b/pkg/connector/send_message.go index ec42b8e..aec8b9d 100644 --- a/pkg/connector/send_message.go +++ b/pkg/connector/send_message.go @@ -31,7 +31,7 @@ type mentionEntry struct { var mentionLinkRegex = regexp.MustCompile(`]*href="https://matrix\.to/#/([^"]+)"[^>]*>([^<]+)`) func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() portalMid := string(msg.Portal.ID) fromMid := lc.midOrFallback() @@ -793,7 +793,7 @@ func contentTypeForMsgType(msgType event.MessageType) int { } func (lc *LineClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() reqSeq := int(time.Now().UnixMilli() % 1_000_000_000) lc.trackReqSeq(reqSeq) @@ -810,7 +810,7 @@ func (lc *LineClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridge } func (lc *LineClient) HandleMatrixLeaveRoom(ctx context.Context, portal *bridgev2.Portal) error { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() reqSeq := int(time.Now().UnixMilli() % 1_000_000_000) @@ -833,7 +833,7 @@ var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*LineClient)(nil) // HandleMatrixAcceptMessageRequest is called when the user accepts a Request in Beeper (a // pending LINE group invitation). It accepts the invitation on the LINE side, joining the chat. func (lc *LineClient) HandleMatrixAcceptMessageRequest(ctx context.Context, msg *bridgev2.MatrixAcceptMessageRequest) error { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() reqSeq := int64(time.Now().UnixMilli() % 1_000_000_000) return client.AcceptChatInvitation(reqSeq, string(msg.Portal.ID)) } diff --git a/pkg/connector/sync.go b/pkg/connector/sync.go index 6c782d6..28e8249 100644 --- a/pkg/connector/sync.go +++ b/pkg/connector/sync.go @@ -26,7 +26,7 @@ import ( func (lc *LineClient) syncDMChats(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() opts := line.MessageBoxesOptions{ ActiveOnly: true, MessageBoxCountLimit: 100, @@ -37,7 +37,7 @@ func (lc *LineClient) syncDMChats(ctx context.Context) { res, err := client.GetMessageBoxes(opts) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() res, err = client.GetMessageBoxes(opts) } } @@ -146,11 +146,11 @@ func (lc *LineClient) FetchMessages(ctx context.Context, params bridgev2.FetchMe limit = 50 } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() msgs, err := client.GetRecentMessagesV2(chatMID, limit) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() msgs, err = client.GetRecentMessagesV2(chatMID, limit) } } @@ -212,7 +212,7 @@ func (lc *LineClient) FetchMessages(ctx context.Context, params bridgev2.FetchMe func (lc *LineClient) prefetchMessages(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() opts := line.MessageBoxesOptions{ ActiveOnly: true, MessageBoxCountLimit: 100, @@ -223,7 +223,7 @@ func (lc *LineClient) prefetchMessages(ctx context.Context) { res, err := client.GetMessageBoxes(opts) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() res, err = client.GetMessageBoxes(opts) } } @@ -243,7 +243,7 @@ func (lc *LineClient) prefetchMessages(ctx context.Context) { // notifies for any not-yet-bridged messages; the silent backfill path used on // unblock goes through FetchMessages instead. func (lc *LineClient) backfillRecentMessages(ctx context.Context, chatMID string, limit int) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() msgs, err := client.GetRecentMessagesV2(chatMID, limit) if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Str("chat_mid", chatMID).Msg("Failed to fetch recent messages") @@ -272,11 +272,11 @@ func (lc *LineClient) backfillRecentMessages(ctx context.Context, chatMID string func (lc *LineClient) syncChats(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() midsResp, err := client.GetAllChatMids(true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() midsResp, err = client.GetAllChatMids(true, true) } } @@ -300,7 +300,7 @@ func (lc *LineClient) syncChats(ctx context.Context) { chatsResp, err := client.GetChats(batch, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chatsResp, err = client.GetChats(batch, true, true) } } @@ -564,11 +564,11 @@ func (lc *LineClient) cacheGroupMembersFromRecentMessages(ctx context.Context, c if len(lc.getCachedGroupMembers(chatMid)) > 1 { return } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() msgs, err := client.GetRecentMessagesV2(chatMid, 50) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() msgs, err = client.GetRecentMessagesV2(chatMid, 50) } } @@ -663,13 +663,13 @@ func (lc *LineClient) pollLoop(ctx context.Context) { defer lc.wg.Done() var localRev int64 = 0 - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() lc.UserLogin.Bridge.Log.Info().Msg("Starting LINE SSE loop...") rev, err := client.GetLastOpRevision() if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() rev, err = client.GetLastOpRevision() } else { lc.UserLogin.Bridge.Log.Warn().Err(errRecover).Msg("Failed to recover token for getLastOpRevision") @@ -751,7 +751,7 @@ func (lc *LineClient) pollLoop(ctx context.Context) { }) return } - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() } } time.Sleep(3 * time.Second) @@ -1233,11 +1233,11 @@ func (lc *LineClient) clearReactionDedupEntries(msgID string, removeOnly bool) { func (lc *LineClient) syncSingleChat(ctx context.Context, op line.Operation) { chatMid := op.Param1 - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } @@ -1294,11 +1294,11 @@ func (lc *LineClient) syncSingleChat(ctx context.Context, op line.Operation) { // checkChatMembership calls GetAllChatMids to verify whether the bridge user // is a member or invitee of the given chat. func (lc *LineClient) checkChatMembership(ctx context.Context, chatMid string) (isMember, isInvitee bool) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() midsResp, err := client.GetAllChatMids(true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() midsResp, err = client.GetAllChatMids(true, true) } } @@ -1369,11 +1369,11 @@ func (lc *LineClient) handleMemberJoin(chatMid, joinerMid string) { } func (lc *LineClient) handleInvite(ctx context.Context, chatMid string, opType OperationType) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } @@ -1417,11 +1417,11 @@ func (lc *LineClient) handleInvite(ctx context.Context, chatMid string, opType O } func (lc *LineClient) handleInviteForSelf(ctx context.Context, chatMid string) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } diff --git a/pkg/connector/userinfo.go b/pkg/connector/userinfo.go index a6d4388..4255eb8 100644 --- a/pkg/connector/userinfo.go +++ b/pkg/connector/userinfo.go @@ -37,7 +37,7 @@ func (lc *LineClient) HandleMatrixReadReceipt(ctx context.Context, read *bridgev return nil } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() return client.SendChatChecked(string(read.Portal.ID), targetID) } @@ -140,11 +140,11 @@ func (lc *LineClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) mid := string(portal.ID) lowerMid := strings.ToLower(mid) if strings.HasPrefix(lowerMid, "c") || strings.HasPrefix(lowerMid, "r") { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() res, err := client.GetChats([]string{mid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() res, err = client.GetChats([]string{mid}, true, true) } } @@ -207,11 +207,11 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { // Use GetProfile for our own user data if mid == lc.Mid || mid == string(lc.UserLogin.ID) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() profile, err := client.GetProfile() if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() profile, err = client.GetProfile() } } @@ -223,11 +223,11 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { return line.Contact{Mid: mid, DisplayName: mid} } - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() res, err := client.GetContactsV2([]string{mid}) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() res, err = client.GetContactsV2([]string{mid}) } } @@ -243,7 +243,7 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { buddy, err := client.GetBuddyProfile(mid) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() buddy, err = client.GetBuddyProfile(mid) } } @@ -314,7 +314,7 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev // Try by LINE user ID first lowerQuery := strings.ToLower(strings.TrimSpace(query)) if lowerQuery != "" { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() contact, err := client.FindContactByUserid(lowerQuery) if err == nil && contact != nil && contact.Mid != "" { if r := lc.midToResolveIdentifier(ctx, contact.Mid); r != nil { @@ -327,12 +327,12 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev } // Search contacts by display name - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() allMids, err := client.GetAllContactIds() if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() allMids, err = client.GetAllContactIds() } } @@ -373,12 +373,12 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev var _ bridgev2.UserSearchingNetworkAPI = (*LineClient)(nil) func (lc *LineClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newLineClient() allMids, err := client.GetAllContactIds() if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newLineClient() allMids, err = client.GetAllContactIds() } } diff --git a/pkg/e2ee/manager.go b/pkg/e2ee/manager.go index de520b3..b3c0b6d 100644 --- a/pkg/e2ee/manager.go +++ b/pkg/e2ee/manager.go @@ -196,7 +196,11 @@ func (m *Manager) SaveSecureDataToFile(id string, data map[string]any) error { } func (m *Manager) InitFromLoginKeyChain(serverPubB64, encryptedKeyChainB64 string) (map[string]string, error) { - keys, err := m.runner.LoginUnwrapKeyChain(serverPubB64, encryptedKeyChainB64) + return m.InitFromLoginKeyChainWithKey(0, serverPubB64, encryptedKeyChainB64) +} + +func (m *Manager) InitFromLoginKeyChainWithKey(loginKeyID int, serverPubB64, encryptedKeyChainB64 string) (map[string]string, error) { + keys, err := m.runner.LoginUnwrapKeyChainWithKey(loginKeyID, serverPubB64, encryptedKeyChainB64) if err != nil { return nil, err } diff --git a/pkg/line/client.go b/pkg/line/client.go index d65f06e..3af8039 100644 --- a/pkg/line/client.go +++ b/pkg/line/client.go @@ -12,6 +12,7 @@ import ( "log" "net/http" "net/url" + "strconv" "strings" "time" @@ -22,6 +23,7 @@ import ( const ( BaseURL = "https://line-chrome-gw.line-apps.com/api/talk/thrift/Talk" + QRLoginBaseURL = "https://line-chrome-gw.line-apps.com/api/talk/thrift/LoginQrCode" ShopBaseURL = "https://line-chrome-gw.line-apps.com/api/shop/thrift/ShopService" ExtensionVersion = "3.7.2" UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/143.0.0.0 Safari/537.36" @@ -119,6 +121,7 @@ func (c *Client) Login(email, pass, certificate string) (*LoginResult, error) { } res.NoE2EE = noE2EE + res.LoginKeyID = secretRes.LoginKeyID // Prefer the V3 token if present, otherwise fall back to legacy authToken. // LINE returns only the V3 token when re-authenticating with a stored @@ -141,11 +144,11 @@ func isLoginNotSupported(err error) bool { return strings.Contains(msg, "\"code\":89") || strings.Contains(msg, "not supported") } -func (c *Client) WaitForLogin(verifier string, noE2EE bool) (*LoginResult, error) { +func (c *Client) WaitForLogin(verifier string, noE2EE bool, loginKeyID int) (*LoginResult, error) { if noE2EE { return c.waitForLoginJQ(verifier) } - return c.waitForLoginLF1(verifier) + return c.waitForLoginLF1(verifier, loginKeyID) } // waitForLoginJQ polls the JQ endpoint for LSOFF accounts (no E2EE). @@ -213,7 +216,7 @@ func (c *Client) waitForLoginJQ(verifier string) (*LoginResult, error) { } // waitForLoginLF1 polls the LF1 endpoint for LSON accounts (E2EE). -func (c *Client) waitForLoginLF1(verifier string) (*LoginResult, error) { +func (c *Client) waitForLoginLF1(verifier string, loginKeyID int) (*LoginResult, error) { url := "https://line-chrome-gw.line-apps.com/api/talk/long-polling/LF1" req, err := http.NewRequest("GET", url, nil) @@ -261,7 +264,7 @@ func (c *Client) waitForLoginLF1(verifier string) (*LoginResult, error) { // LSON path: confirm E2EE handshake first, then finalize with verifier if meta.EncryptedKeyChain != "" && meta.PublicKey != "" { - if err := c.ConfirmE2EELogin(verifier, meta.PublicKey, meta.EncryptedKeyChain); err != nil { + if err := c.ConfirmE2EELoginWithKey(loginKeyID, verifier, meta.PublicKey, meta.EncryptedKeyChain); err != nil { log.Printf("[LINE] ConfirmE2EELogin failed: %v", err) } else { if res, err := c.LoginV2WithVerifier(verifier); err != nil { @@ -271,6 +274,7 @@ func (c *Client) waitForLoginLF1(verifier string) (*LoginResult, error) { res.E2EEPublicKey = meta.PublicKey res.E2EEVersion = meta.E2EEVersion res.E2EEKeyID = meta.KeyID + res.LoginKeyID = loginKeyID return res, nil } } @@ -281,6 +285,7 @@ func (c *Client) waitForLoginLF1(verifier string) (*LoginResult, error) { return &LoginResult{ AuthToken: meta.AuthToken, Certificate: meta.Certificate, + LoginKeyID: loginKeyID, }, nil } @@ -289,6 +294,7 @@ func (c *Client) waitForLoginLF1(verifier string) (*LoginResult, error) { if res, err := c.LoginV2WithVerifier(verifier); err != nil { return nil, fmt.Errorf("login finalization failed: %w", err) } else { + res.LoginKeyID = loginKeyID return res, nil } } @@ -376,12 +382,18 @@ func (c *Client) callRPCWithBaseURL(baseURL, service, method string, args ...int // ConfirmE2EELogin completes the E2EE handshake after LF1 by hashing the encrypted key // chain and posting it alongside the verifier. func (c *Client) ConfirmE2EELogin(verifier, serverPublicKeyB64, encryptedKeyChainB64 string) error { + return c.ConfirmE2EELoginWithKey(0, verifier, serverPublicKeyB64, encryptedKeyChainB64) +} + +// ConfirmE2EELoginWithKey completes the E2EE handshake using a specific login +// key generated for this login attempt. +func (c *Client) ConfirmE2EELoginWithKey(loginKeyID int, verifier, serverPublicKeyB64, encryptedKeyChainB64 string) error { runner, err := gen.GetRunner() if err != nil { return fmt.Errorf("failed to init runner: %w", err) } - hash, err := runner.GenerateConfirmHash(serverPublicKeyB64, encryptedKeyChainB64) + hash, err := runner.GenerateConfirmHashWithKey(loginKeyID, serverPublicKeyB64, encryptedKeyChainB64) if err != nil { return fmt.Errorf("failed to derive confirm hash: %w", err) } @@ -412,10 +424,25 @@ func (c *Client) ConfirmE2EELogin(verifier, serverPublicKeyB64, encryptedKeyChai return nil } +type hmacPostOptions struct { + includeLineApplication bool + sessionID string + longPollingTimeout string + ctx context.Context +} + // postWithHMAC is a small helper for non-standard RPC endpoints that still expect // the same headers and HMAC signature as the Talk endpoints. func (c *Client) postWithHMAC(fullURL string, body []byte) ([]byte, error) { - req, err := http.NewRequest("POST", fullURL, bytes.NewBuffer(body)) + return c.postWithHMACOptions(fullURL, body, hmacPostOptions{includeLineApplication: true}) +} + +func (c *Client) postWithHMACOptions(fullURL string, body []byte, opts hmacPostOptions) ([]byte, error) { + ctx := opts.ctx + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -423,8 +450,16 @@ func (c *Client) postWithHMAC(fullURL string, body []byte) ([]byte, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", UserAgent) req.Header.Set("x-line-chrome-version", ExtensionVersion) - req.Header.Set("x-line-application", "CHROMEOS\t3.7.2\tChrome_OS") + if opts.includeLineApplication { + req.Header.Set("x-line-application", "CHROMEOS\t3.7.2\tChrome_OS") + } req.Header.Set("x-lal", "en_US") + if opts.sessionID != "" { + req.Header.Set("X-Line-Session-ID", opts.sessionID) + } + if opts.longPollingTimeout != "" { + req.Header.Set("X-LST", opts.longPollingTimeout) + } if c.AccessToken != "" { req.Header.Set("x-line-access", c.AccessToken) req.Header.Set("Cookie", fmt.Sprintf("lct=%s", c.AccessToken)) @@ -442,7 +477,21 @@ func (c *Client) postWithHMAC(fullURL string, body []byte) ([]byte, error) { } req.Header.Set("x-hmac", signature) - resp, err := c.HTTPClient.Do(req) + httpClient := c.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + if opts.longPollingTimeout != "" { + if timeoutMillis, err := strconv.Atoi(opts.longPollingTimeout); err == nil && timeoutMillis > 0 { + timeout := time.Duration(timeoutMillis)*time.Millisecond + 10*time.Second + if httpClient.Timeout == 0 || httpClient.Timeout < timeout { + copied := *httpClient + copied.Timeout = timeout + httpClient = &copied + } + } + } + resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } @@ -474,6 +523,22 @@ func (c *Client) RefreshAccessToken(refreshToken string) (*TokenV3IssueResult, e return nil, err } + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data *TokenV3IssueResult `json:"data"` + } + if err := json.Unmarshal(respBytes, &wrapper); err == nil && wrapper.Data != nil { + if wrapper.Code != 0 { + return nil, fmt.Errorf("tokenRefresh failed: %s", wrapper.Message) + } + if wrapper.Data.AccessToken == "" { + return nil, fmt.Errorf("tokenRefresh returned empty access token") + } + c.AccessToken = wrapper.Data.AccessToken + return wrapper.Data, nil + } + var res TokenV3IssueResult if err := json.Unmarshal(respBytes, &res); err != nil { return nil, fmt.Errorf("failed to parse refresh response: %w", err) diff --git a/pkg/line/methods.go b/pkg/line/methods.go index b976573..35a61d5 100644 --- a/pkg/line/methods.go +++ b/pkg/line/methods.go @@ -1,8 +1,10 @@ package line import ( + "context" "encoding/base64" "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -18,6 +20,28 @@ var ( const obsTokenBuffer = 30 * time.Second +var ErrQRLoginCertificateRejected = errors.New("qr login certificate rejected") + +type QRLoginCertificateRejectedError struct { + Code int + Message string +} + +func (e *QRLoginCertificateRejectedError) Error() string { + if e.Message == "" { + return fmt.Sprintf("%v: code %d", ErrQRLoginCertificateRejected, e.Code) + } + return fmt.Sprintf("%v: %s (code %d)", ErrQRLoginCertificateRejected, e.Message, e.Code) +} + +func (e *QRLoginCertificateRejectedError) Unwrap() error { + return ErrQRLoginCertificateRejected +} + +func IsQRLoginCertificateRejected(err error) bool { + return errors.Is(err, ErrQRLoginCertificateRejected) +} + // InvalidateOBSTokenCache clears the cached OBS access token. The OBS token is // derived from the main LINE access token; when the latter is rotated (refresh // or re-login) any previously-issued OBS token is invalidated server-side, but @@ -30,6 +54,10 @@ func InvalidateOBSTokenCache() { obsTokenMu.Unlock() } +func ClearEncryptedAccessTokenCache() { + InvalidateOBSTokenCache() +} + // LoginV2 performs the loginV2 RPC call to authenticate a user func (c *Client) LoginV2(email, password, certificate, secret string) ([]byte, error) { return c.LoginV2WithType(2, email, password, certificate, secret) @@ -98,6 +126,187 @@ func (c *Client) LoginV2WithVerifier(verifier string) (*LoginResult, error) { return &wrapper.Data, nil } +func (c *Client) CreateQRSession() (string, error) { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + AuthSessionID string `json:"authSessionId"` + } `json:"data"` + } + if err := c.callQRRPC("SecondaryQrCodeLoginService", "createSession", "", "", &wrapper, struct{}{}); err != nil { + return "", err + } + if wrapper.Code != 0 { + return "", fmt.Errorf("createSession failed: %s", wrapper.Message) + } + return wrapper.Data.AuthSessionID, nil +} + +func (c *Client) CreateQRCode(authSessionID string) (*QRCodeResponse, error) { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data QRCodeResponse `json:"data"` + } + req := struct { + AuthSessionID string `json:"authSessionId"` + }{AuthSessionID: authSessionID} + if err := c.callQRRPC("SecondaryQrCodeLoginService", "createQrCode", "", "", &wrapper, req); err != nil { + return nil, err + } + if wrapper.Code != 0 { + return nil, fmt.Errorf("createQrCode failed: %s", wrapper.Message) + } + return &wrapper.Data, nil +} + +func (c *Client) CheckQRCodeVerified(authSessionID string) error { + return c.CheckQRCodeVerifiedContext(context.Background(), authSessionID) +} + +func (c *Client) CheckQRCodeVerifiedContext(ctx context.Context, authSessionID string) error { + return c.checkQRPermitNotice(ctx, "checkQrCodeVerified", authSessionID, "150000") +} + +func (c *Client) VerifyCertificate(authSessionID, certificate string) error { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + } + req := struct { + AuthSessionID string `json:"authSessionId"` + Certificate string `json:"certificate"` + }{AuthSessionID: authSessionID, Certificate: certificate} + if err := c.callQRRPC("SecondaryQrCodeLoginService", "verifyCertificate", "", "", &wrapper, req); err != nil { + return err + } + if wrapper.Code != 0 { + return &QRLoginCertificateRejectedError{Code: wrapper.Code, Message: wrapper.Message} + } + return nil +} + +func (c *Client) CreatePinCode(authSessionID string) (string, error) { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + PinCode string `json:"pinCode"` + } `json:"data"` + } + req := struct { + AuthSessionID string `json:"authSessionId"` + }{AuthSessionID: authSessionID} + if err := c.callQRRPC("SecondaryQrCodeLoginService", "createPinCode", "", "", &wrapper, req); err != nil { + return "", err + } + if wrapper.Code != 0 { + return "", fmt.Errorf("createPinCode failed: %s", wrapper.Message) + } + return wrapper.Data.PinCode, nil +} + +func (c *Client) CheckPinCodeVerified(authSessionID string) error { + return c.CheckPinCodeVerifiedContext(context.Background(), authSessionID) +} + +func (c *Client) CheckPinCodeVerifiedContext(ctx context.Context, authSessionID string) error { + return c.checkQRPermitNotice(ctx, "checkPinCodeVerified", authSessionID, "110000") +} + +func (c *Client) QRCodeLoginV2(authSessionID string) (*LoginResult, error) { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + LoginResult + LastBindTimestamp string `json:"lastBindTimestamp"` + MetaData struct { + EncryptedKeyChain string `json:"encryptedKeyChain"` + E2EEVersion string `json:"e2eeVersion"` + KeyID string `json:"keyId"` + PublicKey string `json:"publicKey"` + } `json:"metaData"` + } `json:"data"` + } + req := struct { + SystemName string `json:"systemName"` + ModelName string `json:"modelName"` + AutoLoginIsRequired bool `json:"autoLoginIsRequired"` + AuthSessionID string `json:"authSessionId"` + }{ + SystemName: "CHROMEOS", + ModelName: "CHROME", + AutoLoginIsRequired: false, + AuthSessionID: authSessionID, + } + if err := c.callQRRPC("SecondaryQrCodeLoginService", "qrCodeLoginV2", "", "", &wrapper, req); err != nil { + return nil, err + } + if wrapper.Code != 0 { + return nil, fmt.Errorf("qrCodeLoginV2 failed: %s", wrapper.Message) + } + + res := wrapper.Data.LoginResult + res.LastPrimaryBindTime = wrapper.Data.LastBindTimestamp + res.EncryptedKeyChain = wrapper.Data.MetaData.EncryptedKeyChain + res.E2EEPublicKey = wrapper.Data.MetaData.PublicKey + res.E2EEVersion = wrapper.Data.MetaData.E2EEVersion + res.E2EEKeyID = wrapper.Data.MetaData.KeyID + if res.TokenV3IssueResult != nil && res.TokenV3IssueResult.AccessToken != "" { + res.AuthToken = res.TokenV3IssueResult.AccessToken + c.AccessToken = res.TokenV3IssueResult.AccessToken + } else if res.AuthToken != "" { + c.AccessToken = res.AuthToken + } + return &res, nil +} + +func (c *Client) checkQRPermitNotice(ctx context.Context, method, authSessionID, longPollingTimeout string) error { + var wrapper struct { + Code int `json:"code"` + Message string `json:"message"` + } + req := struct { + AuthSessionID string `json:"authSessionId"` + }{AuthSessionID: authSessionID} + if err := c.callQRRPCContext(ctx, "SecondaryQrCodeLoginPermitNoticeService", method, authSessionID, longPollingTimeout, &wrapper, req); err != nil { + return err + } + if wrapper.Code != 0 { + return fmt.Errorf("%s failed: %s", method, wrapper.Message) + } + return nil +} + +func (c *Client) callQRRPC(service, method, authSessionID, longPollingTimeout string, out interface{}, args ...interface{}) error { + return c.callQRRPCContext(context.Background(), service, method, authSessionID, longPollingTimeout, out, args...) +} + +func (c *Client) callQRRPCContext(ctx context.Context, service, method, authSessionID, longPollingTimeout string, out interface{}, args ...interface{}) error { + url := fmt.Sprintf("%s/%s/%s", QRLoginBaseURL, service, method) + + bodyBytes, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal QR args: %w", err) + } + + respBytes, err := c.postWithHMACOptions(url, bodyBytes, hmacPostOptions{ + includeLineApplication: false, + sessionID: authSessionID, + longPollingTimeout: longPollingTimeout, + ctx: ctx, + }) + if err != nil { + return err + } + if err := json.Unmarshal(respBytes, out); err != nil { + return fmt.Errorf("failed to parse %s response: %w", method, err) + } + return nil +} + // GetProfile fetches the user's profile information func (c *Client) GetProfile() (*Profile, error) { resp, err := c.callRPC("TalkService", "getProfile", 2) diff --git a/pkg/line/qr.go b/pkg/line/qr.go new file mode 100644 index 0000000..9db00ab --- /dev/null +++ b/pkg/line/qr.go @@ -0,0 +1,21 @@ +package line + +import ( + "fmt" + "net/url" +) + +// QRCodeCallbackURLWithE2EESecret returns the Chrome-style QR callback URL. +// The QR service returns a bare callbackUrl, but the extension appends the +// login Curve25519 public key so mobile can encrypt the login keychain. +func QRCodeCallbackURLWithE2EESecret(callbackURL, publicKeyBase64 string) (string, error) { + parsed, err := url.Parse(callbackURL) + if err != nil { + return "", fmt.Errorf("failed to parse QR callback URL: %w", err) + } + query := parsed.Query() + query.Set("secret", publicKeyBase64) + query.Set("e2eeVersion", "1") + parsed.RawQuery = query.Encode() + return parsed.String(), nil +} diff --git a/pkg/line/secret/secret.go b/pkg/line/secret/secret.go index e39d7d2..481e3ee 100644 --- a/pkg/line/secret/secret.go +++ b/pkg/line/secret/secret.go @@ -7,28 +7,31 @@ import ( ) type SecretResult struct { - Secret string `json:"secret"` - Pin string `json:"pin"` - PublicKeyHex string `json:"publicKeyHex"` + Secret string `json:"secret"` + Pin string `json:"pin"` + PublicKeyHex string `json:"publicKeyHex"` + PublicKeyBase64 string `json:"publicKeyBase64"` + LoginKeyID int `json:"loginKeyId"` } -// GenerateSecret performs the E2EE handshake logic using the WASM module via Node.js. +// GenerateSecret creates the LINE login E2EE secret and returns the matching +// local login key ID for the follow-up keychain unwrap. func GenerateSecret() (*SecretResult, error) { runner, err := gen.GetRunner() - if err != nil { return nil, fmt.Errorf("failed to get runner: %w", err) } - res, err := runner.GenerateE2EESecret() + res, err := runner.GenerateE2EESecret() if err != nil { return nil, fmt.Errorf("failed to generate secret: %w", err) } return &SecretResult{ - Secret: res.Secret, - Pin: res.Pin, - PublicKeyHex: res.PublicKeyHex, + Secret: res.Secret, + Pin: res.Pin, + PublicKeyHex: res.PublicKeyHex, + PublicKeyBase64: res.PublicKeyBase64, + LoginKeyID: res.LoginKeyID, }, nil - } diff --git a/pkg/line/structs.go b/pkg/line/structs.go index d81e1d7..9195b6d 100644 --- a/pkg/line/structs.go +++ b/pkg/line/structs.go @@ -126,6 +126,7 @@ type LoginResult struct { E2EEPublicKey string `json:"publicKey,omitempty"` E2EEVersion string `json:"e2eeVersion,omitempty"` E2EEKeyID string `json:"keyId,omitempty"` + LoginKeyID int `json:"-"` NoE2EE bool `json:"-"` // True when login fell back to non-E2EE (LSOFF) } @@ -138,6 +139,12 @@ type TokenV3IssueResult struct { TokenIssueTimeEpochSec string `json:"tokenIssueTimeEpochSec"` } +type QRCodeResponse struct { + CallbackURL string `json:"callbackUrl"` + LongPollingMaxCount int `json:"longPollingMaxCount"` + LongPollingIntervalSeconds int `json:"longPollingIntervalSec"` +} + type RefreshApiRetryPolicy struct { InitialDelayInMillis string `json:"initialDelayInMillis"` MaxDelayInMillis string `json:"maxDelayInMillis"` diff --git a/pkg/runner.go b/pkg/runner.go index 77da7b0..cfedc4c 100644 --- a/pkg/runner.go +++ b/pkg/runner.go @@ -22,6 +22,7 @@ type Runner struct { skPtr uint32 // SecureKey from loadToken storageKey uint32 // AesKey ptr (after StorageInit) loginCurveKey uint32 // Curve25519Key ptr (after GenerateE2EESecret) + loginKeyStore map[int]uint32 // internal ID -> login Curve25519Key ptr keyStore map[int]uint32 // internal ID -> E2EEKey ptr channelStore map[int]uint32 // internal ID -> E2EEChannel ptr nextID int @@ -36,9 +37,11 @@ type Runner struct { } type SecretResult struct { - Secret string `json:"secret"` - Pin string `json:"pin"` - PublicKeyHex string `json:"publicKeyHex"` + Secret string `json:"secret"` + Pin string `json:"pin"` + PublicKeyHex string `json:"publicKeyHex"` + PublicKeyBase64 string `json:"publicKeyBase64"` + LoginKeyID int `json:"loginKeyId"` } type UnwrappedKey struct { @@ -88,6 +91,7 @@ func GetRunner() (*Runner, error) { token: token, clientVersion: clientVersion, skPtr: skPtr, + loginKeyStore: make(map[int]uint32), keyStore: make(map[int]uint32), channelStore: make(map[int]uint32), nextID: 1, @@ -130,6 +134,28 @@ func (r *Runner) putChannel(ptr uint32) int { return id } +func (r *Runner) putLoginKey(ptr uint32) int { + id := r.nextID + r.nextID++ + r.loginKeyStore[id] = ptr + return id +} + +func (r *Runner) getLoginKeyLocked(loginKeyID int) (uint32, error) { + loginCurveKey := r.loginCurveKey + if loginKeyID != 0 { + var ok bool + loginCurveKey, ok = r.loginKeyStore[loginKeyID] + if !ok { + return 0, fmt.Errorf("unknown login key: %d", loginKeyID) + } + } + if loginCurveKey == 0 { + return 0, fmt.Errorf("login key not initialized") + } + return loginCurveKey, nil +} + func (r *Runner) getChannel(id int) (uint32, error) { ptr, ok := r.channelStore[id] if !ok { @@ -324,8 +350,14 @@ func (r *Runner) StorageEncrypt(plaintext string) (string, error) { return base64.StdEncoding.EncodeToString(ctBytes), nil } -// LoginUnwrapKeyChain unwraps the encrypted key chain from LF1 using the login curve key. +// LoginUnwrapKeyChain unwraps the encrypted key chain from LF1 using the latest login curve key. func (r *Runner) LoginUnwrapKeyChain(serverPubB64, encryptedKeyChainB64 string) ([]UnwrappedKey, error) { + return r.LoginUnwrapKeyChainWithKey(0, serverPubB64, encryptedKeyChainB64) +} + +// LoginUnwrapKeyChainWithKey unwraps the encrypted key chain from LF1 using a +// specific login curve key. Passing 0 preserves the legacy latest-key behavior. +func (r *Runner) LoginUnwrapKeyChainWithKey(loginKeyID int, serverPubB64, encryptedKeyChainB64 string) ([]UnwrappedKey, error) { normalizedServerPub, err := normalizeServerPublicKeyB64(serverPubB64) if err != nil { return nil, err @@ -334,8 +366,9 @@ func (r *Runner) LoginUnwrapKeyChain(serverPubB64, encryptedKeyChainB64 string) r.mu.Lock() defer r.mu.Unlock() - if r.loginCurveKey == 0 { - return nil, fmt.Errorf("login key not initialized") + loginCurveKey, err := r.getLoginKeyLocked(loginKeyID) + if err != nil { + return nil, err } serverPubBytes, err := base64.StdEncoding.DecodeString(normalizedServerPub) @@ -343,7 +376,7 @@ func (r *Runner) LoginUnwrapKeyChain(serverPubB64, encryptedKeyChainB64 string) return nil, fmt.Errorf("invalid server public key: %w", err) } - chanPtr, err := r.rt.Curve25519KeyCreateChannel(r.loginCurveKey, serverPubBytes) + chanPtr, err := r.rt.Curve25519KeyCreateChannel(loginCurveKey, serverPubBytes) if err != nil { return nil, err } @@ -664,6 +697,7 @@ func (r *Runner) GenerateE2EESecret() (*SecretResult, error) { return nil, err } r.loginCurveKey = ckPtr + loginKeyID := r.putLoginKey(ckPtr) pubBytes, err := r.rt.Curve25519KeyGetPublicKey(ckPtr) if err != nil { @@ -681,15 +715,23 @@ func (r *Runner) GenerateE2EESecret() (*SecretResult, error) { } return &SecretResult{ - Secret: secret, - Pin: pin, - PublicKeyHex: hex.EncodeToString(pubBytes), + Secret: secret, + Pin: pin, + PublicKeyHex: hex.EncodeToString(pubBytes), + PublicKeyBase64: base64.StdEncoding.EncodeToString(pubBytes), + LoginKeyID: loginKeyID, }, nil } -// GenerateConfirmHash derives the hash key chain for confirmE2EELogin. -// Must be called after GenerateE2EESecret. +// GenerateConfirmHash derives the hash key chain for confirmE2EELogin using +// the latest login key. func (r *Runner) GenerateConfirmHash(serverPublicKeyB64, encryptedKeyChainB64 string) (string, error) { + return r.GenerateConfirmHashWithKey(0, serverPublicKeyB64, encryptedKeyChainB64) +} + +// GenerateConfirmHashWithKey derives the hash key chain for confirmE2EELogin +// using the same login key that generated the login secret. +func (r *Runner) GenerateConfirmHashWithKey(loginKeyID int, serverPublicKeyB64, encryptedKeyChainB64 string) (string, error) { normalizedServerPub, err := normalizeServerPublicKeyB64(serverPublicKeyB64) if err != nil { return "", err @@ -698,8 +740,9 @@ func (r *Runner) GenerateConfirmHash(serverPublicKeyB64, encryptedKeyChainB64 st r.mu.Lock() defer r.mu.Unlock() - if r.loginCurveKey == 0 { - return "", fmt.Errorf("login key not initialized") + loginCurveKey, err := r.getLoginKeyLocked(loginKeyID) + if err != nil { + return "", err } serverPubBytes, err := base64.StdEncoding.DecodeString(normalizedServerPub) @@ -707,7 +750,7 @@ func (r *Runner) GenerateConfirmHash(serverPublicKeyB64, encryptedKeyChainB64 st return "", fmt.Errorf("invalid server public key: %w", err) } - chanPtr, err := r.rt.Curve25519KeyCreateChannel(r.loginCurveKey, serverPubBytes) + chanPtr, err := r.rt.Curve25519KeyCreateChannel(loginCurveKey, serverPubBytes) if err != nil { return "", err }