Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions messaging/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ func (m *Message) Clone() *Message {
}
}

// IsReplyable reports whether this message may be the target of a reply. Only
// user-facing messages are replyable; this is a whitelist so that content types
// added later (and non-conversational ones like system messages) are treated as
// non-replyable until explicitly allowed. Deleted messages remain replyable —
// the tombstone is still a real message in the thread.
func (m *Message) IsReplyable() bool {
if len(m.Content) == 0 {
return false
}
switch m.Content[0].Type.(type) {
case *messagingpb.Content_Text,
*messagingpb.Content_Cash,
*messagingpb.Content_Media,
*messagingpb.Content_Reply,
*messagingpb.Content_Deleted:
return true
default:
return false
}
}

// ToProto projects the stored message onto a messagingpb.Message.
func (m *Message) ToProto() *messagingpb.Message {
content := make([]*messagingpb.Content, len(m.Content))
Expand Down
26 changes: 25 additions & 1 deletion messaging/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,16 @@ func (s *Server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe

log := s.log.With(zap.String("user_id", model.UserIDString(userID)))

switch req.Content[0].Type.(type) {
var repliedMessageID *messagingpb.MessageId
switch content := req.Content[0].Type.(type) {
case *messagingpb.Content_Text:
case *messagingpb.Content_Reply:
switch content.Reply.Content[0].Type.(type) {
case *messagingpb.Content_Text:
default:
return &messagingpb.SendMessageResponse{Result: messagingpb.SendMessageResponse_DENIED}, nil
}
repliedMessageID = content.Reply.RepliedMessageId
default:
return &messagingpb.SendMessageResponse{Result: messagingpb.SendMessageResponse_DENIED}, nil
}
Expand All @@ -139,6 +147,22 @@ func (s *Server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe
return &messagingpb.SendMessageResponse{Result: messagingpb.SendMessageResponse_DENIED}, nil
}

// The replied-to message must exist in this chat and be repliable. Checked
// after membership so non-members can't probe which message IDs exist.
if repliedMessageID != nil {
repliedMessage, err := s.messages.GetMessage(ctx, req.ChatId, repliedMessageID)
switch {
case errors.Is(err, ErrMessageNotFound):
return &messagingpb.SendMessageResponse{Result: messagingpb.SendMessageResponse_DENIED}, nil
case err != nil:
log.With(zap.Error(err)).Warn("Failure getting replied-to message")
return nil, status.Error(codes.Internal, "")
}
if !repliedMessage.IsReplyable() {
return &messagingpb.SendMessageResponse{Result: messagingpb.SendMessageResponse_DENIED}, nil
}
}

msg, err := s.sender.Send(ctx, req.ChatId, userID, req.Content, req.ClientMessageId, true)
if err != nil {
return nil, err
Expand Down
67 changes: 67 additions & 0 deletions messaging/tests/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func RunServerTests(t *testing.T, badges badge.Store, chats chat.Store, messages
testServer_NonMember_Denied,
testServer_AdvancePointer,
testServer_SendMessage_Broadcast,
testServer_SendReply,
testServer_NotifyIsTyping,
} {
tf(t, chats, messages, profiles, badges)
Expand Down Expand Up @@ -153,6 +154,72 @@ func testServer_SendAndGet(t *testing.T, chats chat.Store, messages messaging.St
require.Len(t, listResp.Messages.Messages, 1)
}

func testServer_SendReply(t *testing.T, chats chat.Store, messages messaging.Store, profiles profile.Store, badges badge.Store) {
e := newServerEnv(t, badges, chats, messages, profiles)

// Seed a message to reply to.
original, err := e.send(e.keysA, "original", generateClientID())
require.NoError(t, err)

// A text reply to that message is accepted and round-trips its content.
replyReq := &messagingpb.SendMessageRequest{
ChatId: e.chatID,
Content: replyContent(original.Message.MessageId.Value, "replying"),
ClientMessageId: generateClientID(),
}
require.NoError(t, e.keysB.Auth(replyReq, &replyReq.Auth))
replyResp, err := e.client.SendMessage(e.ctx, replyReq)
require.NoError(t, err)
require.Equal(t, messagingpb.SendMessageResponse_OK, replyResp.Result)

reply := replyResp.Message.Content[0].GetReply()
require.NotNil(t, reply)
require.Equal(t, original.Message.MessageId.Value, reply.RepliedMessageId.Value)
require.Equal(t, "replying", reply.Content[0].GetText().Text)

// A reply wrapping unsupported content (e.g. a nested reply) is denied.
deniedReq := &messagingpb.SendMessageRequest{
ChatId: e.chatID,
Content: []*messagingpb.Content{{
Type: &messagingpb.Content_Reply{
Reply: &messagingpb.ReplyContent{
RepliedMessageId: original.Message.MessageId,
Content: replyContent(original.Message.MessageId.Value, "nested"),
},
},
}},
ClientMessageId: generateClientID(),
}
require.NoError(t, e.keysB.Auth(deniedReq, &deniedReq.Auth))
deniedResp, err := e.client.SendMessage(e.ctx, deniedReq)
require.NoError(t, err)
require.Equal(t, messagingpb.SendMessageResponse_DENIED, deniedResp.Result)

// Replying to a message that does not exist is denied.
missingReq := &messagingpb.SendMessageRequest{
ChatId: e.chatID,
Content: replyContent(original.Message.MessageId.Value+999, "ghost"),
ClientMessageId: generateClientID(),
}
require.NoError(t, e.keysB.Auth(missingReq, &missingReq.Auth))
missingResp, err := e.client.SendMessage(e.ctx, missingReq)
require.NoError(t, err)
require.Equal(t, messagingpb.SendMessageResponse_DENIED, missingResp.Result)

// Replying to a non-replyable (system) message is denied.
systemMsg, _, err := messages.PutMessage(e.ctx, e.chatID, nil, systemContent("joined"), at(100), generateClientID(), false)
require.NoError(t, err)
systemReplyReq := &messagingpb.SendMessageRequest{
ChatId: e.chatID,
Content: replyContent(systemMsg.ID.Value, "to a system message"),
ClientMessageId: generateClientID(),
}
require.NoError(t, e.keysB.Auth(systemReplyReq, &systemReplyReq.Auth))
systemReplyResp, err := e.client.SendMessage(e.ctx, systemReplyReq)
require.NoError(t, err)
require.Equal(t, messagingpb.SendMessageResponse_DENIED, systemReplyResp.Result)
}

func testServer_SendMessage_Idempotent(t *testing.T, chats chat.Store, messages messaging.Store, profiles profile.Store, badges badge.Store) {
e := newServerEnv(t, badges, chats, messages, profiles)

Expand Down
19 changes: 19 additions & 0 deletions messaging/tests/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,25 @@ func textContent(text string) []*messagingpb.Content {
}}
}

func systemContent(text string) []*messagingpb.Content {
return []*messagingpb.Content{{
Type: &messagingpb.Content_System{
System: &messagingpb.SystemContent{FallbackText: text},
},
}}
}

func replyContent(repliedMessageID uint64, text string) []*messagingpb.Content {
return []*messagingpb.Content{{
Type: &messagingpb.Content_Reply{
Reply: &messagingpb.ReplyContent{
RepliedMessageId: &messagingpb.MessageId{Value: repliedMessageID},
Content: textContent(text),
},
},
}}
}

func messageText(m *messaging.Message) string {
return m.Content[0].GetText().Text
}
Expand Down
10 changes: 10 additions & 0 deletions push/pushes.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ func SendContactDmPush(ctx context.Context, pusher Pusher, badges badge.Store, o
switch content := message.Content[0].Type.(type) {
case *messagingpb.Content_Text:
body = content.Text.Text
case *messagingpb.Content_Reply:
// Push the reply's wrapped content. Only text replies are supported today.
if len(content.Reply.Content) == 0 {
return nil
}
textContent, ok := content.Reply.Content[0].Type.(*messagingpb.Content_Text)
if !ok {
return nil
}
body = textContent.Text.Text
case *messagingpb.Content_Cash:
currencyName, err := resolveCurrencyName(ctx, ocpData, content.Cash.Amount.Mint)
if err != nil {
Expand Down
Loading