Skip to content

Commit

Permalink
Add parent message to threaded replies and reactions
Browse files Browse the repository at this point in the history
  • Loading branch information
hloeung committed Feb 25, 2024
1 parent 6251d20 commit 358e5b8
Showing 1 changed file with 83 additions and 15 deletions.
98 changes: 83 additions & 15 deletions bridge/matrix/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/42wim/matterbridge/bridge/helper"
"github.com/42wim/matterircd/bridge"
"github.com/davecgh/go-spew/spew"
lru "github.com/hashicorp/golang-lru"
prefixed "github.com/matterbridge/logrus-prefixed-formatter"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
Expand All @@ -32,6 +33,9 @@ type Matrix struct {
channels map[id.RoomID]*Channel
users map[id.UserID]*User
sync.RWMutex

msgParentCache *lru.Cache
msgLastSentCache *lru.Cache
}

var logger *logrus.Entry
Expand All @@ -45,6 +49,8 @@ func New(v *viper.Viper, cred bridge.Credentials, eventChan chan *bridge.Event,
dmChannels: make(map[id.RoomID][]id.UserID),
users: make(map[id.UserID]*User),
}
m.msgParentCache, _ = lru.New(100)
m.msgLastSentCache, _ = lru.New(10)

ourlog := logrus.New()
ourlog.SetFormatter(&prefixed.TextFormatter{
Expand Down Expand Up @@ -285,23 +291,30 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
}

var text string
var parentID string
var parentID id.EventID

switch {
case ev.Type.String() == "m.text" || ev.Type.String() == "m.room.message":
msgEventContent, _ := ev.Content.Parsed.(*event.MessageEventContent)
text = msgEventContent.Body
if msgEventContent.RelatesTo != nil {
parentID = msgEventContent.RelatesTo.EventID.String()
parentID = msgEventContent.RelatesTo.EventID
}
default:
logger.Warnf("handleMessageEvent unsupported event type %s", ev.Type.String())
}

if !m.v.GetBool("matrix.hidereplies") && parentID.String() != "" {
message, err := m.addParentMsg(ev.RoomID, parentID, text, m.v.GetInt("matrix.shortenrepliesto"), "@", m.v.GetBool("matrix.unicode"))
if err != nil {
logger.Errorf("Unable to get parent post for %#v", ev)
}
text = message
}

m.RLock()
_, ok := m.dmChannels[ev.RoomID]
m.RUnlock()

if ok {
event := &bridge.Event{ //nolint:gocritic
Type: "direct_message",
Expand All @@ -313,7 +326,7 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
// Files: m.getFilesFromData(data),
MessageID: string(ev.ID),
// Event: rmsg.Event,
ParentID: parentID,
ParentID: parentID.String(),
},
}

Expand All @@ -331,7 +344,7 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
// Files: m.getFilesFromData(data),
MessageID: string(ev.ID),
// Event: rmsg.Event,
ParentID: parentID,
ParentID: parentID.String(),
},
}

Expand All @@ -348,22 +361,30 @@ func (m *Matrix) handleReactionEvent(source mautrix.EventSource, ev *event.Event
return
}

var text string
var reaction string
var parentID string
var parentID id.EventID

switch {
case ev.Type.String() == "m.reaction":
reactionEventContent, _ := ev.Content.Parsed.(*event.ReactionEventContent)
reaction = reactionEventContent.RelatesTo.Key
parentID = reactionEventContent.RelatesTo.EventID.String()
parentID = reactionEventContent.RelatesTo.EventID
default:
logger.Warnf("handleEvent unsupported event type %s", ev.Type.String())
}

if !m.v.GetBool("matrix.hidereplies") {
message, err := m.addParentMsg(ev.RoomID, parentID, text, m.v.GetInt("matrix.shortenrepliesto"), "@", m.v.GetBool("matrix.unicode"))
if err != nil {
logger.Errorf("Unable to get parent post for %#v", ev)
}
text = message
}

m.RLock()
_, ok := m.dmChannels[ev.RoomID]
m.RUnlock()

channelType := ""
if ok {
channelType = "D"
Expand All @@ -377,8 +398,8 @@ func (m *Matrix) handleReactionEvent(source mautrix.EventSource, ev *event.Event
Sender: ghost,
Reaction: reaction,
ChannelType: channelType,
Message: "",
ParentID: parentID,
Message: text,
ParentID: parentID.String(),
},
}

Expand Down Expand Up @@ -498,6 +519,7 @@ func (m *Matrix) MsgChannelThread(channelID, parentID, text string) (string, err

logger.Trace("msgchannelthread: error,resp ", err, resp)

m.msgLastSentCache.Add(resp.EventID.String(), fmt.Sprintf("%s: %s", id.RoomAlias(channelID), text))
return resp.EventID.String(), nil
}

Expand Down Expand Up @@ -592,7 +614,8 @@ func (m *Matrix) GetChannelUsers(channelID string) ([]*bridge.UserInfo, error) {
return nil, err
}

logger.Tracef("GetChannelUsers %s %d", channelID, len(resp.Joined))
logger.Debugf("GetChannelUsers %s %d", channelID, len(resp.Joined))
logger.Tracef("GetChannelUsers %s", spew.Sdump(resp.Joined))

for user := range resp.Joined {
users = append(users, m.createUser(user))
Expand All @@ -606,6 +629,7 @@ func (m *Matrix) GetUsers() []*bridge.UserInfo {

logger.Trace("GetUsers ", m.users)
logger.Trace("GetUsers ", spew.Sdump(m.users))
logger.Debugf("GetUsers %d", len(m.users))

m.RLock()
for userID := range m.users {
Expand All @@ -625,6 +649,8 @@ func (m *Matrix) GetChannels() []*bridge.ChannelInfo {
m.RLock()
defer m.RUnlock()

logger.Tracef("GetChannels %s", spew.Sdump(m.channels))

for roomID, channel := range m.channels {
channel.RLock()

Expand All @@ -644,6 +670,8 @@ func (m *Matrix) GetChannels() []*bridge.ChannelInfo {
Private: false,
})

logger.Debugf("GetChannels %s (%s)", channel.Alias.String(), roomID.String())

channel.RUnlock()
}

Expand Down Expand Up @@ -761,11 +789,42 @@ func isValidNick(s string) bool {
return true
}

func (m *Matrix) addParentMsg(roomID id.RoomID, parentID id.EventID, msg string, newLen int, uncounted string, unicode bool) (string, error) {
var replyMessage string

// Search and use cached reply if it exists.
// None found, so we'll need to create one and save it for future uses.
if v, ok := m.msgParentCache.Get(parentID); !ok {
resp, err := m.mc.GetEvent(roomID, parentID)
// Retry once on failure.
if err != nil {
resp, err = m.mc.GetEvent(roomID, parentID)
}
if err != nil {
return msg, err
}

body := ""
if val, ok2 := resp.Content.Raw["body"].(string); ok2 {
body = val
}

parentUser := m.GetUser(resp.Sender.String())
parentMessage := maybeShorten(body, newLen, uncounted, unicode)
replyMessage = fmt.Sprintf(" (re @%s: %s)", parentUser.Nick, parentMessage)
logger.Debugf("Created reply for parent post %s:%s", parentID.String(), replyMessage)

m.msgParentCache.Add(parentID, replyMessage)
} else if replyMessage, ok = v.(string); ok {
logger.Debugf("Found saved reply for parent post %s, using:%s", parentID, replyMessage)
}

return strings.TrimRight(msg, "\n") + replyMessage, nil
}

// maybeShorten returns a prefix of msg that is approximately newLen
// characters long, followed by "...". Words that start with uncounted
// are included in the result but are not reckoned against newLen.
//
//nolint:unused
func maybeShorten(msg string, newLen int, uncounted string, unicode bool) string {
if newLen == 0 || len(msg) < newLen {
return msg
Expand Down Expand Up @@ -840,7 +899,7 @@ func (m *Matrix) SearchUsers(query string) ([]*bridge.UserInfo, error) {
return brusers, nil
}

func (m *Matrix) GetPostThread(channelID string) interface{} {
func (m *Matrix) GetPostThread(postID string) interface{} {
return nil
}

Expand Down Expand Up @@ -873,5 +932,14 @@ func (m *Matrix) RemoveReaction(msgID, emoji string) error {
}

func (m *Matrix) GetLastSentMsgs() []string {
return []string{}
data := make([]string, 0)

for _, k := range m.msgLastSentCache.Keys() {
if v, ok := m.msgLastSentCache.Get(k); ok {
msg, _ := v.(string)
data = append(data, fmt.Sprintf("[@@%s] %s", k, msg))
}
}

return data
}

0 comments on commit 358e5b8

Please sign in to comment.