From 66bfe530b4fd5cb9755337ce718df26ee81f1b77 Mon Sep 17 00:00:00 2001 From: Jordan Date: Thu, 17 Mar 2022 09:43:29 -0700 Subject: misc: go get -u ./... ; go mod vendor --- vendor/github.com/bwmarrin/discordgo/state.go | 289 ++++++++++++++++++++------ 1 file changed, 231 insertions(+), 58 deletions(-) (limited to 'vendor/github.com/bwmarrin/discordgo/state.go') diff --git a/vendor/github.com/bwmarrin/discordgo/state.go b/vendor/github.com/bwmarrin/discordgo/state.go index 2eeabd8..e75be89 100644 --- a/vendor/github.com/bwmarrin/discordgo/state.go +++ b/vendor/github.com/bwmarrin/discordgo/state.go @@ -38,13 +38,15 @@ type State struct { Ready // MaxMessageCount represents how many messages per channel the state will store. - MaxMessageCount int - TrackChannels bool - TrackEmojis bool - TrackMembers bool - TrackRoles bool - TrackVoice bool - TrackPresences bool + MaxMessageCount int + TrackChannels bool + TrackThreads bool + TrackEmojis bool + TrackMembers bool + TrackThreadMembers bool + TrackRoles bool + TrackVoice bool + TrackPresences bool guildMap map[string]*Guild channelMap map[string]*Channel @@ -58,15 +60,17 @@ func NewState() *State { PrivateChannels: []*Channel{}, Guilds: []*Guild{}, }, - TrackChannels: true, - TrackEmojis: true, - TrackMembers: true, - TrackRoles: true, - TrackVoice: true, - TrackPresences: true, - guildMap: make(map[string]*Guild), - channelMap: make(map[string]*Channel), - memberMap: make(map[string]map[string]*Member), + TrackChannels: true, + TrackThreads: true, + TrackEmojis: true, + TrackMembers: true, + TrackThreadMembers: true, + TrackRoles: true, + TrackVoice: true, + TrackPresences: true, + guildMap: make(map[string]*Guild), + channelMap: make(map[string]*Channel), + memberMap: make(map[string]map[string]*Member), } } @@ -93,6 +97,11 @@ func (s *State) GuildAdd(guild *Guild) error { s.channelMap[c.ID] = c } + // Add all the threads to the state in case of thread sync list. + for _, t := range guild.Threads { + s.channelMap[t.ID] = t + } + // If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid if guild.Members != nil { s.createMemberMap(guild) @@ -122,6 +131,9 @@ func (s *State) GuildAdd(guild *Guild) error { if guild.Channels == nil { guild.Channels = g.Channels } + if guild.Threads == nil { + guild.Threads = g.Threads + } if guild.VoiceStates == nil { guild.VoiceStates = g.VoiceStates } @@ -180,21 +192,12 @@ func (s *State) Guild(guildID string) (*Guild, error) { return nil, ErrStateNotFound } -// PresenceAdd adds a presence to the current world state, or -// updates it if it already exists. -func (s *State) PresenceAdd(guildID string, presence *Presence) error { - if s == nil { - return ErrNilState - } - - guild, err := s.Guild(guildID) - if err != nil { - return err +func (s *State) presenceAdd(guildID string, presence *Presence) error { + guild, ok := s.guildMap[guildID] + if !ok { + return ErrStateNotFound } - s.Lock() - defer s.Unlock() - for i, p := range guild.Presences { if p.User.ID == presence.User.ID { //guild.Presences[i] = presence @@ -233,6 +236,19 @@ func (s *State) PresenceAdd(guildID string, presence *Presence) error { return nil } +// PresenceAdd adds a presence to the current world state, or +// updates it if it already exists. +func (s *State) PresenceAdd(guildID string, presence *Presence) error { + if s == nil { + return ErrNilState + } + + s.Lock() + defer s.Unlock() + + return s.presenceAdd(guildID, presence) +} + // PresenceRemove removes a presence from the current world state. func (s *State) PresenceRemove(guildID string, presence *Presence) error { if s == nil { @@ -279,21 +295,12 @@ func (s *State) Presence(guildID, userID string) (*Presence, error) { // TODO: Consider moving Guild state update methods onto *Guild. -// MemberAdd adds a member to the current world state, or -// updates it if it already exists. -func (s *State) MemberAdd(member *Member) error { - if s == nil { - return ErrNilState - } - - guild, err := s.Guild(member.GuildID) - if err != nil { - return err +func (s *State) memberAdd(member *Member) error { + guild, ok := s.guildMap[member.GuildID] + if !ok { + return ErrStateNotFound } - s.Lock() - defer s.Unlock() - members, ok := s.memberMap[member.GuildID] if !ok { return ErrStateNotFound @@ -306,15 +313,27 @@ func (s *State) MemberAdd(member *Member) error { } else { // We are about to replace `m` in the state with `member`, but first we need to // make sure we preserve any fields that the `member` doesn't contain from `m`. - if member.JoinedAt == "" { + if member.JoinedAt.IsZero() { member.JoinedAt = m.JoinedAt } *m = *member } - return nil } +// MemberAdd adds a member to the current world state, or +// updates it if it already exists. +func (s *State) MemberAdd(member *Member) error { + if s == nil { + return ErrNilState + } + + s.Lock() + defer s.Unlock() + + return s.memberAdd(member) +} + // MemberRemove removes a member from current world state. func (s *State) MemberRemove(member *Member) error { if s == nil { @@ -465,6 +484,9 @@ func (s *State) ChannelAdd(channel *Channel) error { if channel.PermissionOverwrites == nil { channel.PermissionOverwrites = c.PermissionOverwrites } + if channel.ThreadMetadata == nil { + channel.ThreadMetadata = c.ThreadMetadata + } *c = *channel return nil @@ -472,12 +494,18 @@ func (s *State) ChannelAdd(channel *Channel) error { if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM { s.PrivateChannels = append(s.PrivateChannels, channel) - } else { - guild, ok := s.guildMap[channel.GuildID] - if !ok { - return ErrStateNotFound - } + s.channelMap[channel.ID] = channel + return nil + } + guild, ok := s.guildMap[channel.GuildID] + if !ok { + return ErrStateNotFound + } + + if channel.IsThread() { + guild.Threads = append(guild.Threads, channel) + } else { guild.Channels = append(guild.Channels, channel) } @@ -507,15 +535,26 @@ func (s *State) ChannelRemove(channel *Channel) error { break } } - } else { - guild, err := s.Guild(channel.GuildID) - if err != nil { - return err - } + delete(s.channelMap, channel.ID) + return nil + } - s.Lock() - defer s.Unlock() + guild, err := s.Guild(channel.GuildID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + if channel.IsThread() { + for i, t := range guild.Threads { + if t.ID == channel.ID { + guild.Threads = append(guild.Threads[:i], guild.Threads[i+1:]...) + break + } + } + } else { for i, c := range guild.Channels { if c.ID == channel.ID { guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...) @@ -529,6 +568,99 @@ func (s *State) ChannelRemove(channel *Channel) error { return nil } +// ThreadListSync syncs guild threads with provided ones. +func (s *State) ThreadListSync(tls *ThreadListSync) error { + guild, err := s.Guild(tls.GuildID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + // This algorithm filters out archived or + // threads which are children of channels in channelIDs + // and then it adds all synced threads to guild threads and cache + index := 0 +outer: + for _, t := range guild.Threads { + if !t.ThreadMetadata.Archived && tls.ChannelIDs != nil { + for _, v := range tls.ChannelIDs { + if t.ParentID == v { + delete(s.channelMap, t.ID) + continue outer + } + } + guild.Threads[index] = t + index++ + } else { + delete(s.channelMap, t.ID) + } + } + guild.Threads = guild.Threads[:index] + for _, t := range tls.Threads { + s.channelMap[t.ID] = t + guild.Threads = append(guild.Threads, t) + } + + for _, m := range tls.Members { + if c, ok := s.channelMap[m.ID]; ok { + c.Member = m + } + } + + return nil +} + +// ThreadMembersUpdate updates thread members list +func (s *State) ThreadMembersUpdate(tmu *ThreadMembersUpdate) error { + thread, err := s.Channel(tmu.ID) + if err != nil { + return err + } + s.Lock() + defer s.Unlock() + + for idx, member := range thread.Members { + for _, removedMember := range tmu.RemovedMembers { + if member.ID == removedMember { + thread.Members = append(thread.Members[:idx], thread.Members[idx+1:]...) + break + } + } + } + + for _, addedMember := range tmu.AddedMembers { + thread.Members = append(thread.Members, addedMember.ThreadMember) + if addedMember.Member != nil { + err = s.memberAdd(addedMember.Member) + if err != nil { + return err + } + } + if addedMember.Presence != nil { + err = s.presenceAdd(tmu.GuildID, addedMember.Presence) + if err != nil { + return err + } + } + } + thread.MemberCount = tmu.MemberCount + + return nil +} + +// ThreadMemberUpdate sets or updates member data for the current user. +func (s *State) ThreadMemberUpdate(mu *ThreadMemberUpdate) error { + thread, err := s.Channel(mu.ID) + if err != nil { + return err + } + + thread.Member = mu.ThreadMember + return nil +} + // GuildChannel gets a channel by ID from a guild. // This method is Deprecated, use Channel(channelID) func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { @@ -637,7 +769,7 @@ func (s *State) MessageAdd(message *Message) error { if message.Content != "" { m.Content = message.Content } - if message.EditedTimestamp != "" { + if message.EditedTimestamp != nil { m.EditedTimestamp = message.EditedTimestamp } if message.Mentions != nil { @@ -649,12 +781,15 @@ func (s *State) MessageAdd(message *Message) error { if message.Attachments != nil { m.Attachments = message.Attachments } - if message.Timestamp != "" { + if !message.Timestamp.IsZero() { m.Timestamp = message.Timestamp } if message.Author != nil { m.Author = message.Author } + if message.Components != nil { + m.Components = message.Components + } return nil } @@ -665,6 +800,7 @@ func (s *State) MessageAdd(message *Message) error { if len(c.Messages) > s.MaxMessageCount { c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:] } + return nil } @@ -690,6 +826,7 @@ func (s *State) messageRemoveByID(channelID, messageID string) error { for i, m := range c.Messages { if m.ID == messageID { c.Messages = append(c.Messages[:i], c.Messages[i+1:]...) + return nil } } @@ -833,6 +970,13 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) { case *GuildUpdate: err = s.GuildAdd(t.Guild) case *GuildDelete: + var old *Guild + old, err = s.Guild(t.ID) + if err == nil { + oldCopy := *old + t.BeforeDelete = &oldCopy + } + err = s.GuildRemove(t.Guild) case *GuildMemberAdd: // Updates the MemberCount of the guild. @@ -903,6 +1047,35 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) { if s.TrackChannels { err = s.ChannelRemove(t.Channel) } + case *ThreadCreate: + if s.TrackThreads { + err = s.ChannelAdd(t.Channel) + } + case *ThreadUpdate: + if s.TrackThreads { + old, err := s.Channel(t.ID) + if err == nil { + oldCopy := *old + t.BeforeUpdate = &oldCopy + } + err = s.ChannelAdd(t.Channel) + } + case *ThreadDelete: + if s.TrackThreads { + err = s.ChannelRemove(t.Channel) + } + case *ThreadMemberUpdate: + if s.TrackThreads { + err = s.ThreadMemberUpdate(t) + } + case *ThreadMembersUpdate: + if s.TrackThreadMembers { + err = s.ThreadMembersUpdate(t) + } + case *ThreadListSync: + if s.TrackThreads { + err = s.ThreadListSync(t) + } case *MessageCreate: if s.MaxMessageCount != 0 { err = s.MessageAdd(t.Message) -- cgit v1.2.3-54-g00ecf