From ce88a8ea5c2da7374a162c62dbe5f93e5f8c8f71 Mon Sep 17 00:00:00 2001 From: JJTech Date: Sun, 19 May 2024 11:54:28 -0700 Subject: [PATCH] Rewrite: APNs: Scoped App Tokens (#101) + Adds a lot of new API surface around filters + Adds CI type checking and linting --- .github/workflows/pyright.yml | 18 ++ .github/workflows/ruff.yml | 8 + pyproject.toml | 6 +- pypush/apns/__init__.py | 6 +- pypush/apns/_protocol.py | 14 +- pypush/apns/_util.py | 50 +++++- pypush/apns/albert.py | 6 +- pypush/apns/filters.py | 44 +++++ pypush/apns/lifecycle.py | 202 ++++++++++++++++++----- pypush/apns/protocol.py | 51 ++++-- pypush/apns/transport.py | 16 +- pypush/cli/__init__.py | 38 ++++- pypush/cli/_frida.py | 3 +- pypush/cli/proxy.py | 8 +- pypush/cli/pushclient.py | 0 tests/assets/dev.jjtech.pypush.tests.pem | 75 +++++++++ tests/test_apns.py | 63 ++++--- 17 files changed, 496 insertions(+), 112 deletions(-) create mode 100644 .github/workflows/pyright.yml create mode 100644 .github/workflows/ruff.yml create mode 100644 pypush/apns/filters.py delete mode 100644 pypush/cli/pushclient.py create mode 100644 tests/assets/dev.jjtech.pypush.tests.pem diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..bfad777 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,18 @@ +name: Pyright +on: [push, pull_request] +jobs: + pyright: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + cache: 'pip' + + - run: | + python -m venv .venv + source .venv/bin/activate + pip install -e '.[test,cli]' + + - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH + - uses: jakebailey/pyright-action@v2 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..b268138 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,8 @@ +name: Ruff +on: [push, pull_request] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 diff --git a/pyproject.toml b/pyproject.toml index 05602b2..287497e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,8 @@ version_file = "pypush/_version.py" [tool.pytest.ini_options] minversion = "6.0" addopts = ["-ra", "-q"] -testpaths = ["tests"] \ No newline at end of file +testpaths = ["tests"] + +[tool.ruff.lint] +select = ["E", "F", "B", "SIM", "I"] +ignore = ["E501", "B010"] \ No newline at end of file diff --git a/pypush/apns/__init__.py b/pypush/apns/__init__.py index ff6398a..3c954b8 100644 --- a/pypush/apns/__init__.py +++ b/pypush/apns/__init__.py @@ -1,5 +1,5 @@ -__all__ = ["protocol", "create_apns_connection", "activate"] +__all__ = ["protocol", "create_apns_connection", "activate", "filters"] -from . import protocol -from .lifecycle import create_apns_connection +from . import filters, protocol from .albert import activate +from .lifecycle import create_apns_connection diff --git a/pypush/apns/_protocol.py b/pypush/apns/_protocol.py index 140d9e3..bd6c4b5 100644 --- a/pypush/apns/_protocol.py +++ b/pypush/apns/_protocol.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from dataclasses import MISSING, field from dataclasses import fields as dataclass_fields -from typing import Any, TypeVar, get_origin, get_args, Union +from typing import Any, TypeVar, Union, get_args, get_origin from pypush.apns.transport import Packet @@ -67,14 +67,14 @@ def command(cls: T) -> T: ) # Check for extra fields - for field in packet.fields: - if field.id not in [ + for current_field in packet.fields: + if current_field.id not in [ f.metadata["packet_id"] for f in dataclass_fields(cls) if f.metadata is not None and "packet_id" in f.metadata ]: logging.warning( - f"Unexpected field with packet ID {field.id} in packet {packet}" + f"Unexpected field with packet ID {current_field.id} in packet {packet}" ) return cls(**field_values) @@ -122,15 +122,15 @@ def fid( :param byte_len: The length of the field in bytes (for int fields) :param default: The default value of the field """ - if not default == MISSING and not default_factory == MISSING: + if default != MISSING and default_factory != MISSING: raise ValueError("Cannot specify both default and default_factory") - if not default == MISSING: + if default != MISSING: return field( metadata={"packet_id": packet_id, "packet_bytes": byte_len}, default=default, repr=repr, ) - if not default_factory == MISSING: + if default_factory != MISSING: return field( metadata={"packet_id": packet_id, "packet_bytes": byte_len}, default_factory=default_factory, diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py index 09e9574..3564892 100644 --- a/pypush/apns/_util.py +++ b/pypush/apns/_util.py @@ -3,25 +3,40 @@ from contextlib import asynccontextmanager from typing import Generic, TypeVar import anyio -from anyio.abc import ObjectSendStream +from anyio.abc import ObjectReceiveStream, ObjectSendStream + +from . import filters T = TypeVar("T") class BroadcastStream(Generic[T]): - def __init__(self): + def __init__(self, backlog: int = 50): self.streams: list[ObjectSendStream[T]] = [] + self.backlog: list[T] = [] + self._backlog_size = backlog async def broadcast(self, packet): + logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams") for stream in self.streams: try: await stream.send(packet) except anyio.BrokenResourceError: - self.streams.remove(stream) + logging.error("Broken resource error") + # self.streams.remove(stream) + # If we have a backlog, add the packet to it + if len(self.backlog) >= self._backlog_size: + self.backlog.pop(0) + self.backlog.append(packet) @asynccontextmanager - async def open_stream(self): - send, recv = anyio.create_memory_object_stream[T]() + async def open_stream(self, backlog: bool = True): + # 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will + # start stalling the APNs connection itself + send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000) + if backlog: + for packet in self.backlog: + await send.send(packet) self.streams.append(send) async with recv: yield recv @@ -29,6 +44,31 @@ class BroadcastStream(Generic[T]): await send.aclose() +W = TypeVar("W") +F = TypeVar("F") + + +class FilteredStream(ObjectReceiveStream[F]): + """ + A stream that filters out unwanted items + + filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it + """ + + def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]): + self.source = source + self.filter = filter + + async def receive(self) -> F: + async for item in self.source: + if (filtered := self.filter(item)) is not None: + return filtered + raise anyio.EndOfStream + + async def aclose(self): + await self.source.aclose() + + def exponential_backoff(f): async def wrapper(*args, **kwargs): backoff = 1 diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py index 024e449..3706807 100644 --- a/pypush/apns/albert.py +++ b/pypush/apns/albert.py @@ -4,7 +4,7 @@ import plistlib import re import uuid from base64 import b64decode -from typing import Tuple, Optional +from typing import Optional, Tuple import httpx from cryptography import x509 @@ -96,10 +96,10 @@ async def activate( try: protocol = re.search("(.*)", resp.text).group(1) # type: ignore - except AttributeError: + except AttributeError as e: # Search for error text between and error = re.search("(.*)", resp.text).group(1) # type: ignore - raise Exception(f"Failed to get certificate from Albert: {error}") + raise Exception(f"Failed to get certificate from Albert: {error}") from e protocol = plistlib.loads(protocol.encode("utf-8")) diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py new file mode 100644 index 0000000..63bb784 --- /dev/null +++ b/pypush/apns/filters.py @@ -0,0 +1,44 @@ +import logging +from typing import Callable, Optional, Type, TypeVar + +from pypush.apns import protocol + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +Filter = Callable[[T1], Optional[T2]] + +# Chain with proper types so that subsequent filters only need to take output type of previous filter +T_IN = TypeVar("T_IN", bound=protocol.Command) +T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command) +T_OUT = TypeVar("T_OUT", bound=protocol.Command) + + +def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]): + def filter(command: T_IN) -> Optional[T_OUT]: + logging.debug(f"Filtering {command} with {first} and {second}") + filtered = first(command) + if filtered is None: + return None + return second(filtered) + + return filter + + +T = TypeVar("T", bound=protocol.Command) + + +def cmd(type: Type[T]) -> Filter[protocol.Command, T]: + def filter(command: protocol.Command) -> Optional[T]: + if isinstance(command, type): + return command + return None + + return filter + + +def ALL(c): + return c + + +def NONE(_): + return None diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 49b4fcf..23d3f94 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -6,6 +6,7 @@ import random import time import typing from contextlib import asynccontextmanager +from hashlib import sha1 import anyio from anyio.abc import TaskGroup @@ -13,7 +14,7 @@ from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa -from . import protocol, transport, _util +from . import _util, filters, protocol, transport @asynccontextmanager @@ -21,13 +22,18 @@ async def create_apns_connection( certificate: x509.Certificate, private_key: rsa.RSAPrivateKey, token: typing.Optional[bytes] = None, + sandbox: bool = False, courier: typing.Optional[str] = None, ): async with anyio.create_task_group() as tg: - conn = Connection(tg, certificate, private_key, token, courier) + conn = Connection( + tg, certificate, private_key, token, sandbox, courier + ) # Await connected for first time here, so that base token is set yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits - await conn.aclose() # Make sure to close the connection after the task group is cancelled + await ( + conn.aclose() + ) # Make sure to close the connection after the task group is cancelled class Connection: @@ -37,26 +43,44 @@ class Connection: certificate: x509.Certificate, private_key: rsa.RSAPrivateKey, token: typing.Optional[bytes] = None, + sandbox: bool = False, courier: typing.Optional[str] = None, ): - self.certificate = certificate self.private_key = private_key - self.base_token = token + self._base_token = token + + self._filters: dict[str, int] = {} # topic -> use count + + self._connected = anyio.Event() # Only use for base_token property self._conn = None self._tg = task_group self._broadcast = _util.BroadcastStream[protocol.Command]() self._reconnect_lock = anyio.Lock() + self._send_lock = anyio.Lock() + self.sandbox = sandbox if courier is None: # Pick a random courier server from 1 to 50 - courier = f"{random.randint(1, 50)}-courier.push.apple.com" + courier = ( + f"{random.randint(1, 50)}-courier.push.apple.com" + if not sandbox + else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com" + ) + logging.debug(f"Using courier: {courier}") self.courier = courier self._tg.start_soon(self.reconnect) self._tg.start_soon(self._ping_task) + @property + async def base_token(self) -> bytes: + if self._base_token is None: + await self._connected.wait() + assert self._base_token is not None + return self._base_token + async def _receive_task(self): assert self._conn is not None async for command in self._conn: @@ -68,8 +92,10 @@ class Connection: while True: await anyio.sleep(30) logging.debug("Sending keepalive") - await self.send(protocol.KeepAliveCommand()) - await self.receive(protocol.KeepAliveAck) + await self._send(protocol.KeepAliveCommand()) + await self._receive( + filters.cmd(protocol.KeepAliveAck), backlog=False + ) # Explicitly disable the backlog since we don't want to receive old acks @_util.exponential_backoff async def reconnect(self): @@ -77,8 +103,11 @@ class Connection: if self._conn is not None: logging.warning("Closing existing connection") await self._conn.aclose() - self._conn = protocol.CommandStream( - await transport.create_courier_connection(courier=self.courier) + + self._broadcast.backlog = [] # Clear the backlog + + conn = protocol.CommandStream( + await transport.create_courier_connection(self.sandbox, self.courier) ) cert = self.certificate.public_bytes(serialization.Encoding.DER) nonce = ( @@ -89,53 +118,150 @@ class Connection: signature = b"\x01\x01" + self.private_key.sign( nonce, padding.PKCS1v15(), hashes.SHA1() ) - await self._conn.send( + await conn.send( protocol.ConnectCommand( - push_token=self.base_token, + push_token=self._base_token, state=1, - flags=69, + flags=65, # 69 certificate=cert, nonce=nonce, signature=signature, ) ) + + # Don't set self._conn until we've sent the connect command + self._conn = conn + self._tg.start_soon(self._receive_task) - ack = await self.receive(protocol.ConnectAck) + ack = await self._receive( + filters.chain( + filters.cmd(protocol.ConnectAck), + lambda c: ( + c + if ( + c.token == self._base_token + if self._base_token is not None + else True + ) + else None + ), + ) + ) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 - if self.base_token is None: - self.base_token = ack.token + if self._base_token is None: + self._base_token = ack.token else: - assert ack.token == self.base_token + assert ack.token == self._base_token + if not self._connected.is_set(): + self._connected.set() + + await self._update_filter() async def aclose(self): if self._conn is not None: await self._conn.aclose() # Note: Will be reopened if task group is still running and ping task is still running - T = typing.TypeVar("T", bound=protocol.Command) + T = typing.TypeVar("T") - async def receive_stream( - self, filter: typing.Type[T], max: int = -1 - ) -> typing.AsyncIterator[T]: - async with self._broadcast.open_stream() as stream: + @asynccontextmanager + async def _receive_stream( + self, + filter: filters.Filter[protocol.Command, T] = lambda c: c, + backlog: bool = True, + ): + async with self._broadcast.open_stream(backlog) as stream: + yield _util.FilteredStream(stream, filter) + + async def _receive( + self, filter: filters.Filter[protocol.Command, T], backlog: bool = True + ): + async with self._receive_stream(filter, backlog) as stream: async for command in stream: - if isinstance(command, filter): - yield command - max -= 1 - if max == 0: - break + return command + raise ValueError("Did not receive expected command") - async def receive(self, filter: typing.Type[T]) -> T: - async for command in self.receive_stream(filter, 1): - return command - raise ValueError("No matching command received") - - async def send(self, command: protocol.Command): + async def _send(self, command: protocol.Command): try: - assert self._conn is not None - await self._conn.send(command) - except Exception as e: - logging.warning(f"Error sending command, reconnecting") + async with self._send_lock: + assert self._conn is not None + await self._conn.send(command) + except Exception: + logging.warning("Error sending command, reconnecting") await self.reconnect() - await self.send(command) + await self._send(command) + + async def _update_filter(self): + await self._send( + protocol.FilterCommand( + token=await self.base_token, + enabled_topic_hashes=[ + sha1(topic.encode()).digest() for topic in self._filters + ], + ) + ) + + @asynccontextmanager + async def _filter(self, topics: list[str]): + for topic in topics: + self._filters[topic] = self._filters.get(topic, 0) + 1 + await self._update_filter() + yield + for topic in topics: + self._filters[topic] -= 1 + if self._filters[topic] == 0: + del self._filters[topic] + await self._update_filter() + + async def mint_scoped_token(self, topic: str) -> bytes: + topic_hash = sha1(topic.encode()).digest() + await self._send( + protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash) + ) + ack = await self._receive(filters.cmd(protocol.ScopedTokenAck)) + assert ack.status == 0 + return ack.scoped_token + + @asynccontextmanager + async def notification_stream( + self, + topic: str, + token: typing.Optional[bytes] = None, + filter: filters.Filter[ + protocol.SendMessageCommand, protocol.SendMessageCommand + ] = filters.ALL, + ): + if token is None: + token = await self.base_token + async with self._filter([topic]), self._receive_stream( + filters.chain( + filters.chain( + filters.chain( + filters.cmd(protocol.SendMessageCommand), + lambda c: c if c.token == token else None, + ), + lambda c: (c if c.topic == topic else None), + ), + filter, + ) + ) as stream: + yield stream + + async def ack(self, command: protocol.SendMessageCommand, status: int = 0): + await self._send( + protocol.SendMessageAck(status=status, token=command.token, id=command.id) + ) + + async def expect_notification( + self, + topic: str, + token: typing.Optional[bytes] = None, + filter: filters.Filter[ + protocol.SendMessageCommand, protocol.SendMessageCommand + ] = filters.ALL, + ) -> protocol.SendMessageCommand: + async with self.notification_stream(topic, token, filter) as stream: + command = await stream.receive() + await self.ack(command) + return command diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py index ea0f7d3..147119c 100644 --- a/pypush/apns/protocol.py +++ b/pypush/apns/protocol.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from hashlib import sha1 from typing import Optional, Union -from anyio.abc import ByteStream, ObjectStream +from anyio.abc import ObjectStream from pypush.apns._protocol import command, fid from pypush.apns.transport import Packet @@ -87,12 +87,7 @@ class FilterCommand(Command): def _lookup_hashes(self, hashes: Optional[list[bytes]]): return ( - [ - KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash - for hash in hashes - ] - if hashes - else [] + [KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else [] ) @property @@ -140,6 +135,7 @@ class KeepAliveAck(Command): PacketType = Packet.Type.KeepAliveAck unknown: Optional[int] = fid(1) + @command @dataclass class SetStateCommand(Command): @@ -182,7 +178,7 @@ class SendMessageCommand(Command): ) and not (self._token_topic_1 is not None and self._token_topic_2 is not None): raise ValueError("topic, token, and outgoing must be set.") - if self.outgoing == True: + if self.outgoing is True: assert self.topic and self.token self._token_topic_1 = ( sha1(self.topic.encode()).digest() @@ -190,7 +186,7 @@ class SendMessageCommand(Command): else self.topic ) self._token_topic_2 = self.token - elif self.outgoing == False: + elif self.outgoing is False: assert self.topic and self.token self._token_topic_1 = self.token self._token_topic_2 = ( @@ -201,18 +197,14 @@ class SendMessageCommand(Command): else: assert self._token_topic_1 and self._token_topic_2 if len(self._token_topic_1) == 20: # SHA1 hash, topic - self.topic = ( - KNOWN_TOPICS_LOOKUP[self._token_topic_1] - if self._token_topic_1 in KNOWN_TOPICS_LOOKUP - else self._token_topic_1 + self.topic = KNOWN_TOPICS_LOOKUP.get( + self._token_topic_1, self._token_topic_1 ) self.token = self._token_topic_2 self.outgoing = True else: - self.topic = ( - KNOWN_TOPICS_LOOKUP[self._token_topic_2] - if self._token_topic_2 in KNOWN_TOPICS_LOOKUP - else self._token_topic_2 + self.topic = KNOWN_TOPICS_LOOKUP.get( + self._token_topic_2, self._token_topic_2 ) self.token = self._token_topic_1 self.outgoing = False @@ -229,6 +221,27 @@ class SendMessageAck(Command): unknown6: Optional[bytes] = fid(6, default=None) +@command +@dataclass +class ScopedTokenCommand(Command): + PacketType = Packet.Type.ScopedToken + + token: bytes = fid(1) + topic: bytes = fid(2) + app_id: Optional[bytes] = fid(3, default=None) + + +@command +@dataclass +class ScopedTokenAck(Command): + PacketType = Packet.Type.ScopedTokenAck + + status: int = fid(1) + scoped_token: bytes = fid(2) + topic: bytes = fid(3) + app_id: Optional[bytes] = fid(4, default=None) + + @dataclass class UnknownCommand(Command): id: Packet.Type @@ -240,7 +253,7 @@ class UnknownCommand(Command): def to_packet(self) -> Packet: return Packet(id=self.id, fields=self.fields) - + def __repr__(self): if self.id.value in [29, 30, 32]: return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])" @@ -259,6 +272,8 @@ def command_from_packet(packet: Packet) -> Command: Packet.Type.SetState: SetStateCommand, Packet.Type.SendMessage: SendMessageCommand, Packet.Type.SendMessageAck: SendMessageAck, + Packet.Type.ScopedToken: ScopedTokenCommand, + Packet.Type.ScopedTokenAck: ScopedTokenAck, # Add other mappings here... } command_class = command_classes.get(packet.id, None) diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py index 864f3eb..d0de86b 100644 --- a/pypush/apns/transport.py +++ b/pypush/apns/transport.py @@ -30,6 +30,8 @@ class Packet: KeepAlive = 12 KeepAliveAck = 13 NoStorage = 14 + ScopedToken = 17 + ScopedTokenAck = 18 SetState = 20 UNKNOWN = "Unknown" @@ -38,20 +40,19 @@ class Packet: obj = object.__new__(cls) obj._value_ = value return obj - + @classmethod def _missing_(cls, value): # Handle unknown values instance = cls.UNKNOWN instance._value_ = value # Assign the unknown value return instance - + def __str__(self): if self is Packet.Type.UNKNOWN: return f"Unknown({self._value_})" return self.name - id: Type fields: list[Field] @@ -60,18 +61,25 @@ class Packet: async def create_courier_connection( + sandbox: bool = False, courier: str = "1-courier.push.apple.com", ) -> PacketStream: context = ssl.create_default_context() context.set_alpn_protocols(ALPN) + sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com" + # TODO: Verify courier certificate context.check_hostname = False context.verify_mode = ssl.CERT_NONE return PacketStream( await anyio.connect_tcp( - courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False + courier, + COURIER_PORT, + ssl_context=context, + tls_standard_compatible=False, + tls_hostname=sni, ) ) diff --git a/pypush/cli/__init__.py b/pypush/cli/__init__.py index 83e70a0..1495dd0 100644 --- a/pypush/cli/__init__.py +++ b/pypush/cli/__init__.py @@ -1,12 +1,17 @@ +import contextlib import logging +from asyncio import CancelledError +import anyio import typer from rich.logging import RichHandler from typing_extensions import Annotated +from pypush import apns + from . import proxy as _proxy -logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s") +logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s") app = typer.Typer() @@ -22,12 +27,12 @@ def proxy( Attach requires SIP to be disabled and to be running as root """ - - _proxy.main(attach) + with contextlib.suppress(CancelledError): + _proxy.main(attach) @app.command() -def client( +def notifications( topic: Annotated[str, typer.Argument(help="app topic to listen on")], sandbox: Annotated[ bool, typer.Option("--sandbox/--production", help="APNs courier to use") @@ -36,8 +41,29 @@ def client( """ Connect to the APNs courier and listen for app notifications on the given topic """ - typer.echo("Running APNs client") - raise NotImplementedError("Not implemented yet") + logging.getLogger("httpx").setLevel(logging.WARNING) + with contextlib.suppress(CancelledError): + anyio.run(notifications_async, topic, sandbox) + + +async def notifications_async(topic: str, sandbox: bool): + async with apns.create_apns_connection( + *await apns.activate(), + courier="1-courier.sandbox.push.apple.com" + if sandbox + else "1-courier.push.apple.com", + ) as connection: + token = await connection.mint_scoped_token(topic) + + async with connection.notification_stream(topic, token) as stream: + logging.info( + f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})" + ) + logging.info(f"Token: {token.hex()}") + + async for notification in stream: + await connection.ack(notification) + logging.info(notification.payload.decode()) def main(): diff --git a/pypush/cli/_frida.py b/pypush/cli/_frida.py index dc30ce5..3a71ae4 100644 --- a/pypush/cli/_frida.py +++ b/pypush/cli/_frida.py @@ -1,6 +1,7 @@ -import frida import logging +import frida + def attach_to_apsd() -> frida.core.Session: frida.kill("apsd") diff --git a/pypush/cli/proxy.py b/pypush/cli/proxy.py index b801d43..8c43bc4 100644 --- a/pypush/cli/proxy.py +++ b/pypush/cli/proxy.py @@ -2,7 +2,6 @@ import datetime import logging import ssl import tempfile -from typing import Optional import anyio import anyio.abc @@ -12,11 +11,10 @@ from cryptography import x509 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.hashes import SHA256 -from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from cryptography.hazmat.primitives.serialization import Encoding # from pypush import apns -from pypush.apns import transport -from pypush.apns import protocol +from pypush.apns import protocol, transport from . import _frida @@ -71,7 +69,7 @@ async def handle(client: TLSStream): else "1-courier.sandbox.push.apple.com" ) name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}" - async with await transport.create_courier_connection(forward) as conn: + async with await transport.create_courier_connection(sandbox, forward) as conn: logging.debug("Connected to courier") async with anyio.create_task_group() as tg: tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}") diff --git a/pypush/cli/pushclient.py b/pypush/cli/pushclient.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/assets/dev.jjtech.pypush.tests.pem b/tests/assets/dev.jjtech.pypush.tests.pem new file mode 100644 index 0000000..0188045 --- /dev/null +++ b/tests/assets/dev.jjtech.pypush.tests.pem @@ -0,0 +1,75 @@ +Bag Attributes + friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests + localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3 +subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US +issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US +-----BEGIN CERTIFICATE----- +MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1 +MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD +ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw +cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw +MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0 +czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0 +ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC +VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi +HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa +XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y +RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI +0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW +rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC +TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n +mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0 +cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw +Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC +AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug +b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh +bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv +bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj +YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v +d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI +KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v +d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV +HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw +dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw +MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO +DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0 +MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC +AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y +8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl +/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty +wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0 +74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX +5s0ybZYoG5meGKUplwu7A2bfFw== +-----END CERTIFICATE----- +Bag Attributes + friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key + localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3 +Key Attributes: +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL +VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+ +pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt +8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0 +GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0 +taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6 +p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT +bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt +I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX +gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7 +bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW +ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB +myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW +BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE +14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw +ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr +tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea +i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi +a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b +zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa +8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2 +YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V +SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT +c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK +J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD +v77FY1M3FxGR6rNqPJQ9rRLFbA== +-----END PRIVATE KEY----- diff --git a/tests/test_apns.py b/tests/test_apns.py index 3b24508..501e862 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -1,18 +1,13 @@ -import pytest -from pypush import apns -import asyncio - -# from aioapns import * -import uuid -import anyio - -# from pypush.apns import _util -# from pypush.apns import albert, lifecycle, protocol -from pypush import apns - import logging +import uuid +from pathlib import Path + +import httpx +import pytest from rich.logging import RichHandler +from pypush import apns + logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s") @@ -26,17 +21,43 @@ async def test_activate(): @pytest.mark.asyncio async def test_lifecycle_2(): - async with apns.create_apns_connection( - certificate, key, courier="localhost" - ) as connection: - await connection.receive( - apns.protocol.ConnectAck - ) # Just wait until the initial connection is established. Don't do this in real code plz. + async with apns.create_apns_connection(certificate, key) as _: + pass + + +ASSETS_DIR = Path(__file__).parent / "assets" + + +async def send_test_notification(device_token, payload=b"hello, world"): + async with httpx.AsyncClient( + cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True + ) as client: + # Use the certificate and key from above + response = await client.post( + f"https://api.sandbox.push.apple.com/3/device/{device_token}", + content=payload, + headers={ + "apns-topic": "dev.jjtech.pypush.tests", + "apns-push-type": "alert", + "apns-priority": "10", + }, + ) + assert response.status_code == 200 @pytest.mark.asyncio -async def test_shorthand(): +async def test_scoped_token(): async with apns.create_apns_connection( - *await apns.activate(), courier="localhost" + *await apns.activate(), sandbox=True ) as connection: - await connection.receive(apns.protocol.ConnectAck) + token = await connection.mint_scoped_token("dev.jjtech.pypush.tests") + + test_message = f"test-message-{uuid.uuid4().hex}" + + await send_test_notification(token.hex(), test_message.encode()) + + await connection.expect_notification( + "dev.jjtech.pypush.tests", + token, + lambda c: c if c.payload == test_message.encode() else None, + )