diff --git a/go.mod b/go.mod index c7e1ef4..bd0c601 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/golang-migrate/migrate/v4 v4.19.1 github.com/gotd/log v0.1.0 github.com/gotd/log/logzap v0.1.1 - github.com/gotd/td v0.156.2 + github.com/gotd/td v0.156.3 github.com/jackc/pgx/v5 v5.10.0 github.com/quasilyte/go-ruleguard/dsl v0.3.23 github.com/riverqueue/river v0.39.0 diff --git a/go.sum b/go.sum index de961d8..e917ac1 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,8 @@ github.com/gotd/log/logzap v0.1.1 h1:O6l7d8HUbODe+UMcrM47eXYDwdJ6RNmpQejLjrlcEIQ github.com/gotd/log/logzap v0.1.1/go.mod h1:5ObZkITbfhbsBOLzBkzmMk9QxXc0eNQpimau7zRL+Y8= github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ= github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ= -github.com/gotd/td v0.156.2 h1:07bZ5YnuKjStRHJ1d8IyU19KZMLs1vZslT9dQeUAcaw= -github.com/gotd/td v0.156.2/go.mod h1:pVVlJYiMUMinSR/5uDfCSUoB3DyqxftanMWYLP16riY= +github.com/gotd/td v0.156.3 h1:Y4JxcYPUS386HSuVw646EIpJBVJ9kY33lBfZZJDfF0o= +github.com/gotd/td v0.156.3/go.mod h1:pVVlJYiMUMinSR/5uDfCSUoB3DyqxftanMWYLP16riY= github.com/grafana/otel-profiling-go v0.5.3 h1:BEwmU7KI2/J57RBe/kA0fgdeN1E0Ps1KSj33vIF5KXg= github.com/grafana/otel-profiling-go v0.5.3/go.mod h1:cqLIDgNXlnzknJ0WLiEe+JPjZk2MZ4ftMdqRJRWj1ZM= github.com/grafana/pyroscope-go v1.3.1 h1:Eb9h55+vtLezn/DQ4iXz+SJrOz8CNghDk9xx8XQ4tc0= diff --git a/internal/mtproto/conn.go b/internal/mtproto/conn.go index 70fd6a5..953b4ae 100644 --- a/internal/mtproto/conn.go +++ b/internal/mtproto/conn.go @@ -2,6 +2,7 @@ package mtproto import ( "context" + "encoding/hex" "github.com/go-faster/errors" @@ -77,6 +78,9 @@ func (s *Server) serveConn(ctx context.Context, conn transport.Conn) error { } // Unknown, non-zero key: ask the client to re-run key exchange. + log.For(s.log).Warn(ctx, "Auth key not found; sending -404", + log.String("key_id", hex.EncodeToString(authKeyID[:]))) + if err := s.sendProtoError(ctx, conn, codec.CodeAuthKeyNotFound); err != nil { return errors.Wrap(err, "send AuthKeyNotFound") } @@ -90,10 +94,48 @@ func (s *Server) serveConn(ctx context.Context, conn transport.Conn) error { c := newBufferedConn(conn) c.Push(b) - key, err := s.exchange(ctx, exchangeConn{Conn: c}) + key, err := s.exchange(ctx, c) if err != nil { + // The client sent a frame encrypted with an existing auth key + // instead of an unencrypted exchange message: it is reusing an + // already-established key, not performing key exchange. Resolve the + // key and handle the frame as a normal RPC. Replying with -404 here + // would tell clients like Telegram Desktop to discard a still-valid + // temporary key, triggering a reconnect/key-exchange storm. + var encErr *exchange.UnexpectedEncryptedError + if errors.As(err, &encErr) { + _, ok, lookupErr := s.registry.getSession(ctx, encErr.AuthKeyID) + if lookupErr != nil { + return errors.Wrap(lookupErr, "lookup session") + } + + if ok { + var fb bin.Buffer + fb.ResetTo(encErr.Frame) + + if err := s.rpcHandle(ctx, conn, &fb); err != nil { + return errors.Wrap(err, "handle") + } + + continue + } + + // Genuinely unknown key: ask the client to re-run key exchange. + log.For(s.log).Warn(ctx, "Auth key not found during exchange; sending -404", + log.String("key_id", hex.EncodeToString(encErr.AuthKeyID[:]))) + + if err := s.sendProtoError(ctx, conn, codec.CodeAuthKeyNotFound); err != nil { + return errors.Wrap(err, "send AuthKeyNotFound") + } + + continue + } + var exchangeErr *exchange.ServerExchangeError if errors.As(err, &exchangeErr) { + log.For(s.log).Warn(ctx, "Key exchange failed; sending proto error", + log.Int("code", int(exchangeErr.Code)), log.Error(err)) + if sendErr := s.sendProtoError(ctx, c, exchangeErr.Code); sendErr != nil { return errors.Wrapf(sendErr, "send proto error %v", exchangeErr.Code) } @@ -101,6 +143,8 @@ func (s *Server) serveConn(ctx context.Context, conn transport.Conn) error { return nil } + log.For(s.log).Warn(ctx, "Key exchange failed", log.Error(err)) + return errors.Wrap(err, "key exchange failed") } diff --git a/internal/mtproto/exchange.go b/internal/mtproto/exchange.go index 744e477..e04268b 100644 --- a/internal/mtproto/exchange.go +++ b/internal/mtproto/exchange.go @@ -3,49 +3,12 @@ package mtproto import ( "context" - "github.com/go-faster/errors" - "github.com/gotd/log" - "github.com/gotd/td/bin" "github.com/gotd/td/crypto" "github.com/gotd/td/exchange" - "github.com/gotd/td/proto/codec" "github.com/gotd/td/transport" ) -// exchangeConn rejects frames carrying a non-zero auth key id during the key -// exchange flow, which expects only unencrypted messages. -type exchangeConn struct { - transport.Conn -} - -func (e exchangeConn) Recv(ctx context.Context, b *bin.Buffer) error { - for { - if err := e.Conn.Recv(ctx, b); err != nil { - return err - } - - var authKeyID [8]byte - if err := b.PeekN(authKeyID[:], len(authKeyID)); err != nil { - return errors.Wrap(err, "peek auth key id") - } - - if authKeyID != ([8]byte{}) { - var buf bin.Buffer - - buf.PutInt32(-codec.CodeAuthKeyNotFound) - - if err := e.Send(ctx, &buf); err != nil { - return errors.Wrap(err, "send") - } - - continue - } - - return nil - } -} - // exchange runs the server side of the MTProto key exchange. func (s *Server) exchange(ctx context.Context, conn transport.Conn) (crypto.AuthKey, error) { r, err := exchange.NewExchanger(conn, s.dcID).