Skip to content

Commit 5a12196

Browse files
author
Nekokatt
authored
Merge pull request #167 from nekokatt/task/161-typing-event-user
Implemented `user` member on typing events.
2 parents 0f0d2b9 + 2e64418 commit 5a12196

3 files changed

Lines changed: 163 additions & 55 deletions

File tree

hikari/events/typing_events.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
from hikari import channels
3737
from hikari import intents
38+
from hikari import users
39+
from hikari.api import special_endpoints
3840
from hikari.events import base_events
3941
from hikari.events import shard_events
4042
from hikari.utilities import attr_extensions
@@ -45,7 +47,6 @@
4547
from hikari import guilds
4648
from hikari import snowflakes
4749
from hikari import traits
48-
from hikari import users
4950
from hikari.api import shard as gateway_shard
5051

5152

@@ -97,6 +98,18 @@ def channel(self) -> typing.Optional[channels.TextChannel]:
9798
The channel, if known.
9899
"""
99100

101+
@property
102+
@abc.abstractmethod
103+
def user(self) -> typing.Optional[users.User]:
104+
"""Get the cached user that is typing, if known.
105+
106+
Returns
107+
-------
108+
typing.Optional[hikari.users.User]
109+
The user, if known.
110+
"""
111+
112+
@abc.abstractmethod
100113
async def fetch_channel(self) -> channels.TextChannel:
101114
"""Perform an API call to fetch an up-to-date image of this channel.
102115
@@ -105,10 +118,8 @@ async def fetch_channel(self) -> channels.TextChannel:
105118
hikari.channels.TextChannel
106119
The channel.
107120
"""
108-
channel = await self.app.rest.fetch_channel(self.channel_id)
109-
assert isinstance(channel, channels.TextChannel)
110-
return channel
111121

122+
@abc.abstractmethod
112123
async def fetch_user(self) -> users.User:
113124
"""Perform an API call to fetch an up-to-date image of this user.
114125
@@ -117,7 +128,17 @@ async def fetch_user(self) -> users.User:
117128
hikari.users.User
118129
The user.
119130
"""
120-
return await self.app.rest.fetch_user(self.user_id)
131+
132+
def trigger_typing(self) -> special_endpoints.TypingIndicator:
133+
"""Return a typing indicator for this channel that can be awaited.
134+
135+
Returns
136+
-------
137+
hikari.api.special_endpoints.TypingIndicator
138+
A typing indicator context manager and awaitable to trigger typing
139+
in a channel with.
140+
"""
141+
return self.app.rest.trigger_typing(self.channel_id)
121142

122143

123144
@base_events.requires_intents(intents.Intents.GUILD_MESSAGE_TYPING)
@@ -135,9 +156,6 @@ class GuildTypingEvent(TypingEvent):
135156
channel_id: snowflakes.Snowflake = attr.ib()
136157
# <<inherited docstring from TypingEvent>>.
137158

138-
user_id: snowflakes.Snowflake = attr.ib(repr=True)
139-
# <<inherited docstring from TypingEvent>>.
140-
141159
timestamp: datetime.datetime = attr.ib(repr=False)
142160
# <<inherited docstring from TypingEvent>>.
143161

@@ -150,19 +168,32 @@ class GuildTypingEvent(TypingEvent):
150168
The ID of the guild that relates to this event.
151169
"""
152170

153-
member: guilds.Member = attr.ib(repr=False)
171+
user: guilds.Member = attr.ib(repr=False)
154172
"""Member object of the user who triggered this typing event.
155173
174+
Unlike on `PrivateTypingEvent` instances, Discord will always send
175+
this field in any payload.
176+
156177
Returns
157178
-------
158179
hikari.guilds.Member
159180
Member of the user who triggered this typing event.
160181
"""
161182

162183
@property
163-
def channel(self) -> typing.Optional[channels.GuildTextChannel]:
164-
# <<inherited docstring from TypingEvent>>.
165-
return typing.cast("channels.GuildTextChannel", self.app.cache.get_guild_channel(self.channel_id))
184+
def channel(self) -> typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel]:
185+
"""Get the cached channel object this typing event occurred in.
186+
187+
Returns
188+
-------
189+
typing.Union[hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel]
190+
The channel.
191+
"""
192+
channel = self.app.cache.get_guild_channel(self.channel_id)
193+
assert isinstance(
194+
channel, (channels.GuildTextChannel, channels.GuildNewsChannel)
195+
), f"expected GuildTextChannel or GuildNewsChannel from cache, got {channel}"
196+
return channel
166197

167198
@property
168199
def guild(self) -> typing.Optional[guilds.GatewayGuild]:
@@ -177,10 +208,24 @@ def guild(self) -> typing.Optional[guilds.GatewayGuild]:
177208
"""
178209
return self.app.cache.get_available_guild(self.guild_id) or self.app.cache.get_unavailable_guild(self.guild_id)
179210

180-
if typing.TYPE_CHECKING:
211+
@property
212+
def user_id(self) -> snowflakes.Snowflake:
213+
# <<inherited docstring from TypingEvent>>.
214+
return self.user.id
181215

182-
async def fetch_channel(self) -> channels.GuildTextChannel:
183-
...
216+
async def fetch_channel(self) -> typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel]:
217+
"""Perform an API call to fetch an up-to-date image of this channel.
218+
219+
Returns
220+
-------
221+
typing.Union[hikari.channels.GuildTextChannel, hikari.channels.GuildNewsChannel]
222+
The channel.
223+
"""
224+
channel = await self.app.rest.fetch_channel(self.channel_id)
225+
assert isinstance(
226+
channel, (channels.GuildTextChannel, channels.GuildNewsChannel)
227+
), f"expected GuildTextChannel or GuildNewsChannel from API, got {channel}"
228+
return channel
184229

185230
async def fetch_guild(self) -> guilds.Guild:
186231
"""Perform an API call to fetch an up-to-date image of this guild.
@@ -202,7 +247,7 @@ async def fetch_guild_preview(self) -> guilds.GuildPreview:
202247
"""
203248
return await self.app.rest.fetch_guild_preview(self.guild_id)
204249

205-
async def fetch_member(self) -> guilds.Member:
250+
async def fetch_user(self) -> guilds.Member:
206251
"""Perform an API call to fetch an up-to-date image of this member.
207252
208253
Returns
@@ -232,7 +277,6 @@ class PrivateTypingEvent(TypingEvent):
232277
# <<inherited docstring from TypingEvent>>.
233278

234279
timestamp: datetime.datetime = attr.ib(repr=False)
235-
236280
# <<inherited docstring from TypingEvent>>.
237281

238282
@property
@@ -246,7 +290,29 @@ def channel(self) -> typing.Optional[channels.DMChannel]:
246290
"""
247291
return self.app.cache.get_dm(self.user_id)
248292

249-
if typing.TYPE_CHECKING:
293+
@property
294+
def user(self) -> typing.Optional[users.User]:
295+
# <<inherited docstring from TypingEvent>>.
296+
return self.app.cache.get_user(self.user_id)
250297

251-
async def fetch_channel(self) -> channels.DMChannel:
252-
...
298+
async def fetch_channel(self) -> channels.DMChannel:
299+
"""Perform an API call to fetch an up-to-date image of this channel.
300+
301+
Returns
302+
-------
303+
hikari.channels.DMChannel
304+
The channel.
305+
"""
306+
channel = await self.app.rest.fetch_channel(self.channel_id)
307+
assert isinstance(channel, channels.DMChannel), f"expected DMChannel from API, got {channel}"
308+
return channel
309+
310+
async def fetch_user(self) -> users.User:
311+
"""Perform an API call to fetch an up-to-date image of the user.
312+
313+
Returns
314+
-------
315+
hikari.users.User
316+
The user.
317+
"""
318+
return await self.app.rest.fetch_user(self.user_id)

hikari/impl/event_factory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,8 @@ def deserialize_typing_start_event(
141141
shard=shard,
142142
channel_id=channel_id,
143143
guild_id=guild_id,
144-
user_id=user_id,
145144
timestamp=timestamp,
146-
member=member,
145+
user=member,
147146
)
148147

149148
return typing_events.PrivateTypingEvent(

tests/hikari/events/test_typing_events.py

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,54 +25,49 @@
2525
from hikari import channels
2626
from hikari import users
2727
from hikari.events import typing_events
28+
from tests.hikari import hikari_test_helpers
2829

2930

3031
@pytest.mark.asyncio
3132
class TestTypingEvent:
3233
@pytest.fixture()
3334
def event(self):
34-
class StubEvent(typing_events.TypingEvent):
35-
channel_id = 123
36-
user_id = 456
37-
timestamp = None
38-
shard = None
39-
app = mock.Mock(rest=mock.AsyncMock())
40-
channel = object()
41-
guild = object()
42-
43-
return StubEvent()
44-
45-
async def test_fetch_channel(self, event):
46-
mock_channel = mock.Mock(spec_set=channels.TextChannel)
47-
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel)
48-
assert await event.fetch_channel() is mock_channel
49-
50-
event.app.rest.fetch_channel.assert_awaited_once_with(123)
51-
52-
async def test_fetch_user(self, event):
53-
mock_user = mock.Mock(spec_set=users.User)
54-
event.app.rest.fetch_user = mock.AsyncMock(return_value=mock_user)
35+
cls = hikari_test_helpers.mock_class_namespace(
36+
typing_events.TypingEvent,
37+
channel_id=123,
38+
user_id=456,
39+
timestamp=object(),
40+
shard=object(),
41+
channel=object(),
42+
)
5543

56-
assert await event.fetch_user() is mock_user
44+
return cls()
5745

58-
event.app.rest.fetch_user.assert_awaited_once_with(456)
46+
async def test_trigger_typing(self, event):
47+
event.app.rest.trigger_typing = mock.Mock()
48+
result = event.trigger_typing()
49+
event.app.rest.trigger_typing.assert_called_once_with(123)
50+
assert result is event.app.rest.trigger_typing.return_value
5951

6052

6153
@pytest.mark.asyncio
6254
class TestGuildTypingEvent:
6355
@pytest.fixture()
6456
def event(self):
65-
return typing_events.GuildTypingEvent(
66-
app=mock.AsyncMock(cache=mock.Mock()),
67-
shard=None,
57+
cls = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent)
58+
59+
return cls(
6860
channel_id=123,
69-
user_id=456,
61+
timestamp=object(),
62+
shard=object(),
63+
app=mock.Mock(rest=mock.AsyncMock()),
7064
guild_id=789,
71-
timestamp=None,
72-
member=None,
65+
user=mock.Mock(id=456),
7366
)
7467

75-
def test_channel(self, event):
68+
@pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel])
69+
async def test_channel(self, event, guild_channel_impl):
70+
event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl))
7671
result = event.channel
7772

7873
assert result is event.app.cache.get_guild_channel.return_value
@@ -93,10 +88,16 @@ def test_guild_when_unavailable(self, event):
9388
event.app.cache.get_unavailable_guild.assert_called_once_with(789)
9489
event.app.cache.get_available_guild.assert_called_once_with(789)
9590

96-
async def test_fetch_channel(self, event):
97-
await event.fetch_member()
91+
def test_user_id(self, event):
92+
assert event.user_id == event.user.id
93+
assert event.user_id == 456
9894

99-
event.app.rest.fetch_member.assert_awaited_once_with(789, 456)
95+
@pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel])
96+
async def test_fetch_channel(self, event, guild_channel_impl):
97+
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=guild_channel_impl))
98+
await event.fetch_channel()
99+
100+
event.app.rest.fetch_channel.assert_awaited_once_with(123)
100101

101102
async def test_fetch_guild(self, event):
102103
await event.fetch_guild()
@@ -107,3 +108,45 @@ async def test_fetch_guild_preview(self, event):
107108
await event.fetch_guild_preview()
108109

109110
event.app.rest.fetch_guild_preview.assert_awaited_once_with(789)
111+
112+
async def test_fetch_user(self, event):
113+
await event.fetch_user()
114+
115+
event.app.rest.fetch_member.assert_awaited_once_with(789, 456)
116+
117+
118+
@pytest.mark.asyncio
119+
class TestPrivateTypingEvent:
120+
@pytest.fixture()
121+
def event(self):
122+
cls = hikari_test_helpers.mock_class_namespace(typing_events.PrivateTypingEvent)
123+
124+
return cls(
125+
channel_id=123,
126+
timestamp=object(),
127+
shard=object(),
128+
app=mock.Mock(rest=mock.AsyncMock()),
129+
user_id=456,
130+
)
131+
132+
async def test_channel(self, event):
133+
event.app.cache.get_dm = mock.Mock(return_value=mock.Mock(spec_set=channels.DMChannel))
134+
result = event.channel
135+
assert result is event.app.cache.get_dm.return_value
136+
event.app.cache.get_dm.assert_called_once_with(456)
137+
138+
def test_user(self, event):
139+
event.app.cache.get_user = mock.Mock(return_value=mock.Mock(spec_set=users.User))
140+
141+
assert event.user is event.app.cache.get_user.return_value
142+
143+
async def test_fetch_channel(self, event):
144+
event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=channels.DMChannel))
145+
await event.fetch_channel()
146+
147+
event.app.rest.fetch_channel.assert_awaited_once_with(123)
148+
149+
async def test_fetch_user(self, event):
150+
await event.fetch_user()
151+
152+
event.app.rest.fetch_user.assert_awaited_once_with(456)

0 commit comments

Comments
 (0)