aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/bwmarrin/discordgo/state.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/bwmarrin/discordgo/state.go')
-rw-r--r--vendor/github.com/bwmarrin/discordgo/state.go289
1 files changed, 231 insertions, 58 deletions
diff --git a/vendor/github.com/bwmarrin/discordgo/state.go b/vendor/github.com/bwmarrin/discordgo/state.go
index 2eeabd8..e75be89 100644
--- a/vendor/github.com/bwmarrin/discordgo/state.go
+++ b/vendor/github.com/bwmarrin/discordgo/state.go
@@ -38,13 +38,15 @@ type State struct {
Ready
// MaxMessageCount represents how many messages per channel the state will store.
- MaxMessageCount int
- TrackChannels bool
- TrackEmojis bool
- TrackMembers bool
- TrackRoles bool
- TrackVoice bool
- TrackPresences bool
+ MaxMessageCount int
+ TrackChannels bool
+ TrackThreads bool
+ TrackEmojis bool
+ TrackMembers bool
+ TrackThreadMembers bool
+ TrackRoles bool
+ TrackVoice bool
+ TrackPresences bool
guildMap map[string]*Guild
channelMap map[string]*Channel
@@ -58,15 +60,17 @@ func NewState() *State {
PrivateChannels: []*Channel{},
Guilds: []*Guild{},
},
- TrackChannels: true,
- TrackEmojis: true,
- TrackMembers: true,
- TrackRoles: true,
- TrackVoice: true,
- TrackPresences: true,
- guildMap: make(map[string]*Guild),
- channelMap: make(map[string]*Channel),
- memberMap: make(map[string]map[string]*Member),
+ TrackChannels: true,
+ TrackThreads: true,
+ TrackEmojis: true,
+ TrackMembers: true,
+ TrackThreadMembers: true,
+ TrackRoles: true,
+ TrackVoice: true,
+ TrackPresences: true,
+ guildMap: make(map[string]*Guild),
+ channelMap: make(map[string]*Channel),
+ memberMap: make(map[string]map[string]*Member),
}
}
@@ -93,6 +97,11 @@ func (s *State) GuildAdd(guild *Guild) error {
s.channelMap[c.ID] = c
}
+ // Add all the threads to the state in case of thread sync list.
+ for _, t := range guild.Threads {
+ s.channelMap[t.ID] = t
+ }
+
// If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid
if guild.Members != nil {
s.createMemberMap(guild)
@@ -122,6 +131,9 @@ func (s *State) GuildAdd(guild *Guild) error {
if guild.Channels == nil {
guild.Channels = g.Channels
}
+ if guild.Threads == nil {
+ guild.Threads = g.Threads
+ }
if guild.VoiceStates == nil {
guild.VoiceStates = g.VoiceStates
}
@@ -180,21 +192,12 @@ func (s *State) Guild(guildID string) (*Guild, error) {
return nil, ErrStateNotFound
}
-// PresenceAdd adds a presence to the current world state, or
-// updates it if it already exists.
-func (s *State) PresenceAdd(guildID string, presence *Presence) error {
- if s == nil {
- return ErrNilState
- }
-
- guild, err := s.Guild(guildID)
- if err != nil {
- return err
+func (s *State) presenceAdd(guildID string, presence *Presence) error {
+ guild, ok := s.guildMap[guildID]
+ if !ok {
+ return ErrStateNotFound
}
- s.Lock()
- defer s.Unlock()
-
for i, p := range guild.Presences {
if p.User.ID == presence.User.ID {
//guild.Presences[i] = presence
@@ -233,6 +236,19 @@ func (s *State) PresenceAdd(guildID string, presence *Presence) error {
return nil
}
+// PresenceAdd adds a presence to the current world state, or
+// updates it if it already exists.
+func (s *State) PresenceAdd(guildID string, presence *Presence) error {
+ if s == nil {
+ return ErrNilState
+ }
+
+ s.Lock()
+ defer s.Unlock()
+
+ return s.presenceAdd(guildID, presence)
+}
+
// PresenceRemove removes a presence from the current world state.
func (s *State) PresenceRemove(guildID string, presence *Presence) error {
if s == nil {
@@ -279,21 +295,12 @@ func (s *State) Presence(guildID, userID string) (*Presence, error) {
// TODO: Consider moving Guild state update methods onto *Guild.
-// MemberAdd adds a member to the current world state, or
-// updates it if it already exists.
-func (s *State) MemberAdd(member *Member) error {
- if s == nil {
- return ErrNilState
- }
-
- guild, err := s.Guild(member.GuildID)
- if err != nil {
- return err
+func (s *State) memberAdd(member *Member) error {
+ guild, ok := s.guildMap[member.GuildID]
+ if !ok {
+ return ErrStateNotFound
}
- s.Lock()
- defer s.Unlock()
-
members, ok := s.memberMap[member.GuildID]
if !ok {
return ErrStateNotFound
@@ -306,15 +313,27 @@ func (s *State) MemberAdd(member *Member) error {
} else {
// We are about to replace `m` in the state with `member`, but first we need to
// make sure we preserve any fields that the `member` doesn't contain from `m`.
- if member.JoinedAt == "" {
+ if member.JoinedAt.IsZero() {
member.JoinedAt = m.JoinedAt
}
*m = *member
}
-
return nil
}
+// MemberAdd adds a member to the current world state, or
+// updates it if it already exists.
+func (s *State) MemberAdd(member *Member) error {
+ if s == nil {
+ return ErrNilState
+ }
+
+ s.Lock()
+ defer s.Unlock()
+
+ return s.memberAdd(member)
+}
+
// MemberRemove removes a member from current world state.
func (s *State) MemberRemove(member *Member) error {
if s == nil {
@@ -465,6 +484,9 @@ func (s *State) ChannelAdd(channel *Channel) error {
if channel.PermissionOverwrites == nil {
channel.PermissionOverwrites = c.PermissionOverwrites
}
+ if channel.ThreadMetadata == nil {
+ channel.ThreadMetadata = c.ThreadMetadata
+ }
*c = *channel
return nil
@@ -472,12 +494,18 @@ func (s *State) ChannelAdd(channel *Channel) error {
if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM {
s.PrivateChannels = append(s.PrivateChannels, channel)
- } else {
- guild, ok := s.guildMap[channel.GuildID]
- if !ok {
- return ErrStateNotFound
- }
+ s.channelMap[channel.ID] = channel
+ return nil
+ }
+ guild, ok := s.guildMap[channel.GuildID]
+ if !ok {
+ return ErrStateNotFound
+ }
+
+ if channel.IsThread() {
+ guild.Threads = append(guild.Threads, channel)
+ } else {
guild.Channels = append(guild.Channels, channel)
}
@@ -507,15 +535,26 @@ func (s *State) ChannelRemove(channel *Channel) error {
break
}
}
- } else {
- guild, err := s.Guild(channel.GuildID)
- if err != nil {
- return err
- }
+ delete(s.channelMap, channel.ID)
+ return nil
+ }
- s.Lock()
- defer s.Unlock()
+ guild, err := s.Guild(channel.GuildID)
+ if err != nil {
+ return err
+ }
+
+ s.Lock()
+ defer s.Unlock()
+ if channel.IsThread() {
+ for i, t := range guild.Threads {
+ if t.ID == channel.ID {
+ guild.Threads = append(guild.Threads[:i], guild.Threads[i+1:]...)
+ break
+ }
+ }
+ } else {
for i, c := range guild.Channels {
if c.ID == channel.ID {
guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
@@ -529,6 +568,99 @@ func (s *State) ChannelRemove(channel *Channel) error {
return nil
}
+// ThreadListSync syncs guild threads with provided ones.
+func (s *State) ThreadListSync(tls *ThreadListSync) error {
+ guild, err := s.Guild(tls.GuildID)
+ if err != nil {
+ return err
+ }
+
+ s.Lock()
+ defer s.Unlock()
+
+ // This algorithm filters out archived or
+ // threads which are children of channels in channelIDs
+ // and then it adds all synced threads to guild threads and cache
+ index := 0
+outer:
+ for _, t := range guild.Threads {
+ if !t.ThreadMetadata.Archived && tls.ChannelIDs != nil {
+ for _, v := range tls.ChannelIDs {
+ if t.ParentID == v {
+ delete(s.channelMap, t.ID)
+ continue outer
+ }
+ }
+ guild.Threads[index] = t
+ index++
+ } else {
+ delete(s.channelMap, t.ID)
+ }
+ }
+ guild.Threads = guild.Threads[:index]
+ for _, t := range tls.Threads {
+ s.channelMap[t.ID] = t
+ guild.Threads = append(guild.Threads, t)
+ }
+
+ for _, m := range tls.Members {
+ if c, ok := s.channelMap[m.ID]; ok {
+ c.Member = m
+ }
+ }
+
+ return nil
+}
+
+// ThreadMembersUpdate updates thread members list
+func (s *State) ThreadMembersUpdate(tmu *ThreadMembersUpdate) error {
+ thread, err := s.Channel(tmu.ID)
+ if err != nil {
+ return err
+ }
+ s.Lock()
+ defer s.Unlock()
+
+ for idx, member := range thread.Members {
+ for _, removedMember := range tmu.RemovedMembers {
+ if member.ID == removedMember {
+ thread.Members = append(thread.Members[:idx], thread.Members[idx+1:]...)
+ break
+ }
+ }
+ }
+
+ for _, addedMember := range tmu.AddedMembers {
+ thread.Members = append(thread.Members, addedMember.ThreadMember)
+ if addedMember.Member != nil {
+ err = s.memberAdd(addedMember.Member)
+ if err != nil {
+ return err
+ }
+ }
+ if addedMember.Presence != nil {
+ err = s.presenceAdd(tmu.GuildID, addedMember.Presence)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ thread.MemberCount = tmu.MemberCount
+
+ return nil
+}
+
+// ThreadMemberUpdate sets or updates member data for the current user.
+func (s *State) ThreadMemberUpdate(mu *ThreadMemberUpdate) error {
+ thread, err := s.Channel(mu.ID)
+ if err != nil {
+ return err
+ }
+
+ thread.Member = mu.ThreadMember
+ return nil
+}
+
// GuildChannel gets a channel by ID from a guild.
// This method is Deprecated, use Channel(channelID)
func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
@@ -637,7 +769,7 @@ func (s *State) MessageAdd(message *Message) error {
if message.Content != "" {
m.Content = message.Content
}
- if message.EditedTimestamp != "" {
+ if message.EditedTimestamp != nil {
m.EditedTimestamp = message.EditedTimestamp
}
if message.Mentions != nil {
@@ -649,12 +781,15 @@ func (s *State) MessageAdd(message *Message) error {
if message.Attachments != nil {
m.Attachments = message.Attachments
}
- if message.Timestamp != "" {
+ if !message.Timestamp.IsZero() {
m.Timestamp = message.Timestamp
}
if message.Author != nil {
m.Author = message.Author
}
+ if message.Components != nil {
+ m.Components = message.Components
+ }
return nil
}
@@ -665,6 +800,7 @@ func (s *State) MessageAdd(message *Message) error {
if len(c.Messages) > s.MaxMessageCount {
c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
}
+
return nil
}
@@ -690,6 +826,7 @@ func (s *State) messageRemoveByID(channelID, messageID string) error {
for i, m := range c.Messages {
if m.ID == messageID {
c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
+
return nil
}
}
@@ -833,6 +970,13 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) {
case *GuildUpdate:
err = s.GuildAdd(t.Guild)
case *GuildDelete:
+ var old *Guild
+ old, err = s.Guild(t.ID)
+ if err == nil {
+ oldCopy := *old
+ t.BeforeDelete = &oldCopy
+ }
+
err = s.GuildRemove(t.Guild)
case *GuildMemberAdd:
// Updates the MemberCount of the guild.
@@ -903,6 +1047,35 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) {
if s.TrackChannels {
err = s.ChannelRemove(t.Channel)
}
+ case *ThreadCreate:
+ if s.TrackThreads {
+ err = s.ChannelAdd(t.Channel)
+ }
+ case *ThreadUpdate:
+ if s.TrackThreads {
+ old, err := s.Channel(t.ID)
+ if err == nil {
+ oldCopy := *old
+ t.BeforeUpdate = &oldCopy
+ }
+ err = s.ChannelAdd(t.Channel)
+ }
+ case *ThreadDelete:
+ if s.TrackThreads {
+ err = s.ChannelRemove(t.Channel)
+ }
+ case *ThreadMemberUpdate:
+ if s.TrackThreads {
+ err = s.ThreadMemberUpdate(t)
+ }
+ case *ThreadMembersUpdate:
+ if s.TrackThreadMembers {
+ err = s.ThreadMembersUpdate(t)
+ }
+ case *ThreadListSync:
+ if s.TrackThreads {
+ err = s.ThreadListSync(t)
+ }
case *MessageCreate:
if s.MaxMessageCount != 0 {
err = s.MessageAdd(t.Message)