diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py index 3706807..24a2ad6 100644 --- a/pypush/apns/albert.py +++ b/pypush/apns/albert.py @@ -50,6 +50,13 @@ async def activate( build: str = "10.6.4", model: str = "windows1,1", ) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]: + """ + Activate with Apple's Albert service, obtaining an activation certificate and private key. + + By default, this will activate a Windows device with a random UDID, serial, version, build, and model. + + Windows activations will not function for iMessage or FaceTime. + """ if http_client is None: # Do this here to ensure the client is not accidentally reused during tests http_client = httpx.AsyncClient() diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 8b97a9c..77917e2 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -25,6 +25,16 @@ async def create_apns_connection( sandbox: bool = False, courier: typing.Optional[str] = None, ): + """ + This context manager will create a connection to the APNs server and yield a Connection object. + + Args: + certificate (x509.Certificate): A valid activation certificate obtained from Albert. + private_key (rsa.RSAPrivateKey): The private key corresponding to the activation certificate. + token (bytes, optional): An optional base token to use for the connection. If not provided, the connection will be established and the base token will be set to the token provided in the ConnectAck command. + sandbox (bool, optional): A boolean indicating whether to connect to the APNs sandbox or production server. + courier (str, optional): An optional string indicating the courier server to connect to. If not provided, a random courier server will be selected based on the `sandbox` parameter. + """ async with anyio.create_task_group() as tg: conn = Connection( tg, certificate, private_key, token, sandbox, courier @@ -71,11 +81,16 @@ class Connection: logging.debug(f"Using courier: {courier}") self.courier = courier - self._tg.start_soon(self.reconnect) + self._tg.start_soon(self._reconnect) self._tg.start_soon(self._ping_task) @property async def base_token(self) -> bytes: + """ + `base_token` must be awaited to ensure a token is available + + This may not complete until a connection has been established + """ if self._base_token is None: await self._connected.wait() assert self._base_token is not None @@ -98,7 +113,7 @@ class Connection: ) # Explicitly disable the backlog since we don't want to receive old acks @_util.exponential_backoff - async def reconnect(self): + async def _reconnect(self): async with ( self._reconnect_lock ): # Prevent weird situations where multiple reconnects are happening at once @@ -161,9 +176,15 @@ class Connection: await self._update_filter() async def aclose(self): + """ + Closes the connection to the APNS server. + + If the connection is open, it will be closed. This method is typically unnecessary if the connection is managed by `create_apns_connection`. + + Note: The connection will be reopened if the task group is still open (the ping task is still running). + """ if self._conn is not None: await self._conn.aclose() - # Note: Will be reopened if task group is still running and ping task is still running T = typing.TypeVar("T") @@ -191,7 +212,7 @@ class Connection: await self._conn.send(command) except Exception: logging.warning("Error sending command, reconnecting") - await self.reconnect() + await self._reconnect() await self._send(command) async def _update_filter(self): @@ -207,6 +228,8 @@ class Connection: @asynccontextmanager async def _filter(self, topics: list[str]): for topic in topics: + if topic not in protocol.KNOWN_TOPICS: + protocol.note_topic(topic) self._filters[topic] = self._filters.get(topic, 0) + 1 await self._update_filter() yield @@ -217,6 +240,14 @@ class Connection: await self._update_filter() async def mint_scoped_token(self, topic: str) -> bytes: + """ + Mint a "scoped token" for the given topic/bundle ID. + + This token is equivalent to the token provided to `application:didRegisterForRemoteNotificationsWithDeviceToken:` in iOS, + for an app with the given bundle ID. + + This token can be used with `expect_notification` or `notification_stream`, but it will only function on connections with the same base token as the connection that originally minted the token. + """ topic_hash = sha1(topic.encode()).digest() await self._send( protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash) @@ -234,6 +265,14 @@ class Connection: protocol.SendMessageCommand, protocol.SendMessageCommand ] = filters.ALL, ): + """ + Create a stream of notifications for the given topic and token. + If the token is not provided, the base token will be used. + + A custom `Filter` can be provided to filter out unwanted notifications. + + Notifications will NOT be ack'd automatically, you must call `ack` on each notification you process. + """ if token is None: token = await self.base_token async with ( @@ -254,6 +293,9 @@ class Connection: yield stream async def ack(self, command: protocol.SendMessageCommand, status: int = 0): + """ + Acknowledge a notification. + """ await self._send( protocol.SendMessageAck(status=status, token=command.token, id=command.id) ) @@ -266,6 +308,13 @@ class Connection: protocol.SendMessageCommand, protocol.SendMessageCommand ] = filters.ALL, ) -> protocol.SendMessageCommand: + """ + Wait for a notification that matches the given topic and token. + If the token is not provided, the base token will be used. + + A custom `Filter` can be provided to filter out unwanted notifications. + + This method WILL ack the notification automatically.""" async with self.notification_stream(topic, token, filter) as stream: command = await stream.receive() await self.ack(command) diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index 147119c..75ce939 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -13,6 +13,15 @@ KNOWN_TOPICS_LOOKUP = {sha1(topic.encode()).digest():topic for topic in KNOWN_TO # fmt: on +def note_topic(topic: str): + """ + Add a topic to the KNOWN_TOPICS set, such that it can be recognized later. + This is mostly just a convenience, so that you do not have to work with SHA1 hashes directly. + """ + KNOWN_TOPICS.add(topic) + KNOWN_TOPICS_LOOKUP[sha1(topic.encode()).digest()] = topic + + @dataclass class Command: @classmethod @@ -148,6 +157,14 @@ class SetStateCommand(Command): @command @dataclass class SendMessageCommand(Command): + """ + The most common form of command, used to send a message, also represents incoming messages. + + May also be called a "Notification" in the context of APNs. + + Important note: The `topic` field may be a string or bytes, depending on if the topic is in the `KNOWN_TOPICS` set. This should not happen if you are using the proper `Connection` API. + """ + PacketType = Packet.Type.SendMessage payload: bytes = fid(3) diff --git a/tests/test_apns.py b/tests/test_apns.py index 501e862..46313c2 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -28,6 +28,7 @@ async def test_lifecycle_2(): ASSETS_DIR = Path(__file__).parent / "assets" +# Not a part of pypush, this is public API async def send_test_notification(device_token, payload=b"hello, world"): async with httpx.AsyncClient( cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True