apns: document public API functions

This commit is contained in:
JJTech0130 2024-05-23 09:26:27 -04:00
parent 51b04f816d
commit 29c66bde66
No known key found for this signature in database
GPG Key ID: 23C92EBCCF8F93D6
4 changed files with 78 additions and 4 deletions

View File

@ -50,6 +50,13 @@ async def activate(
build: str = "10.6.4",
model: str = "windows1,1",
) -> Tuple[x509.Certificate, rsa.RSAPrivateKey]:
"""
Activate with Apple's Albert service, obtaining an activation certificate and private key.
By default, this will activate a Windows device with a random UDID, serial, version, build, and model.
Windows activations will not function for iMessage or FaceTime.
"""
if http_client is None:
# Do this here to ensure the client is not accidentally reused during tests
http_client = httpx.AsyncClient()

View File

@ -25,6 +25,16 @@ async def create_apns_connection(
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
"""
This context manager will create a connection to the APNs server and yield a Connection object.
Args:
certificate (x509.Certificate): A valid activation certificate obtained from Albert.
private_key (rsa.RSAPrivateKey): The private key corresponding to the activation certificate.
token (bytes, optional): An optional base token to use for the connection. If not provided, the connection will be established and the base token will be set to the token provided in the ConnectAck command.
sandbox (bool, optional): A boolean indicating whether to connect to the APNs sandbox or production server.
courier (str, optional): An optional string indicating the courier server to connect to. If not provided, a random courier server will be selected based on the `sandbox` parameter.
"""
async with anyio.create_task_group() as tg:
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
@ -71,11 +81,16 @@ class Connection:
logging.debug(f"Using courier: {courier}")
self.courier = courier
self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._reconnect)
self._tg.start_soon(self._ping_task)
@property
async def base_token(self) -> bytes:
"""
`base_token` must be awaited to ensure a token is available
This may not complete until a connection has been established
"""
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
@ -98,7 +113,7 @@ class Connection:
) # Explicitly disable the backlog since we don't want to receive old acks
@_util.exponential_backoff
async def reconnect(self):
async def _reconnect(self):
async with (
self._reconnect_lock
): # Prevent weird situations where multiple reconnects are happening at once
@ -161,9 +176,15 @@ class Connection:
await self._update_filter()
async def aclose(self):
"""
Closes the connection to the APNS server.
If the connection is open, it will be closed. This method is typically unnecessary if the connection is managed by `create_apns_connection`.
Note: The connection will be reopened if the task group is still open (the ping task is still running).
"""
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running
T = typing.TypeVar("T")
@ -191,7 +212,7 @@ class Connection:
await self._conn.send(command)
except Exception:
logging.warning("Error sending command, reconnecting")
await self.reconnect()
await self._reconnect()
await self._send(command)
async def _update_filter(self):
@ -207,6 +228,8 @@ class Connection:
@asynccontextmanager
async def _filter(self, topics: list[str]):
for topic in topics:
if topic not in protocol.KNOWN_TOPICS:
protocol.note_topic(topic)
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
yield
@ -217,6 +240,14 @@ class Connection:
await self._update_filter()
async def mint_scoped_token(self, topic: str) -> bytes:
"""
Mint a "scoped token" for the given topic/bundle ID.
This token is equivalent to the token provided to `application:didRegisterForRemoteNotificationsWithDeviceToken:` in iOS,
for an app with the given bundle ID.
This token can be used with `expect_notification` or `notification_stream`, but it will only function on connections with the same base token as the connection that originally minted the token.
"""
topic_hash = sha1(topic.encode()).digest()
await self._send(
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
@ -234,6 +265,14 @@ class Connection:
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
):
"""
Create a stream of notifications for the given topic and token.
If the token is not provided, the base token will be used.
A custom `Filter` can be provided to filter out unwanted notifications.
Notifications will NOT be ack'd automatically, you must call `ack` on each notification you process.
"""
if token is None:
token = await self.base_token
async with (
@ -254,6 +293,9 @@ class Connection:
yield stream
async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
"""
Acknowledge a notification.
"""
await self._send(
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
)
@ -266,6 +308,13 @@ class Connection:
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
) -> protocol.SendMessageCommand:
"""
Wait for a notification that matches the given topic and token.
If the token is not provided, the base token will be used.
A custom `Filter` can be provided to filter out unwanted notifications.
This method WILL ack the notification automatically."""
async with self.notification_stream(topic, token, filter) as stream:
command = await stream.receive()
await self.ack(command)

View File

@ -13,6 +13,15 @@ KNOWN_TOPICS_LOOKUP = {sha1(topic.encode()).digest():topic for topic in KNOWN_TO
# fmt: on
def note_topic(topic: str):
"""
Add a topic to the KNOWN_TOPICS set, such that it can be recognized later.
This is mostly just a convenience, so that you do not have to work with SHA1 hashes directly.
"""
KNOWN_TOPICS.add(topic)
KNOWN_TOPICS_LOOKUP[sha1(topic.encode()).digest()] = topic
@dataclass
class Command:
@classmethod
@ -148,6 +157,14 @@ class SetStateCommand(Command):
@command
@dataclass
class SendMessageCommand(Command):
"""
The most common form of command, used to send a message, also represents incoming messages.
May also be called a "Notification" in the context of APNs.
Important note: The `topic` field may be a string or bytes, depending on if the topic is in the `KNOWN_TOPICS` set. This should not happen if you are using the proper `Connection` API.
"""
PacketType = Packet.Type.SendMessage
payload: bytes = fid(3)

View File

@ -28,6 +28,7 @@ async def test_lifecycle_2():
ASSETS_DIR = Path(__file__).parent / "assets"
# Not a part of pypush, this is public API
async def send_test_notification(device_token, payload=b"hello, world"):
async with httpx.AsyncClient(
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True