Rewrite: APNs: Scoped App Tokens (#101)

+ Adds a lot of new API surface around filters
+ Adds CI type checking and linting
This commit is contained in:
JJTech 2024-05-19 11:54:28 -07:00 committed by GitHub
parent b1c30a98ff
commit ce88a8ea5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 496 additions and 112 deletions

18
.github/workflows/pyright.yml vendored Normal file
View File

@ -0,0 +1,18 @@
name: Pyright
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
cache: 'pip'
- run: |
python -m venv .venv
source .venv/bin/activate
pip install -e '.[test,cli]'
- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
- uses: jakebailey/pyright-action@v2

8
.github/workflows/ruff.yml vendored Normal file
View File

@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1

View File

@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "-q"]
testpaths = ["tests"]
testpaths = ["tests"]
[tool.ruff.lint]
select = ["E", "F", "B", "SIM", "I"]
ignore = ["E501", "B010"]

View File

@ -1,5 +1,5 @@
__all__ = ["protocol", "create_apns_connection", "activate"]
__all__ = ["protocol", "create_apns_connection", "activate", "filters"]
from . import protocol
from .lifecycle import create_apns_connection
from . import filters, protocol
from .albert import activate
from .lifecycle import create_apns_connection

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from dataclasses import MISSING, field
from dataclasses import fields as dataclass_fields
from typing import Any, TypeVar, get_origin, get_args, Union
from typing import Any, TypeVar, Union, get_args, get_origin
from pypush.apns.transport import Packet
@ -67,14 +67,14 @@ def command(cls: T) -> T:
)
# Check for extra fields
for field in packet.fields:
if field.id not in [
for current_field in packet.fields:
if current_field.id not in [
f.metadata["packet_id"]
for f in dataclass_fields(cls)
if f.metadata is not None and "packet_id" in f.metadata
]:
logging.warning(
f"Unexpected field with packet ID {field.id} in packet {packet}"
f"Unexpected field with packet ID {current_field.id} in packet {packet}"
)
return cls(**field_values)
@ -122,15 +122,15 @@ def fid(
:param byte_len: The length of the field in bytes (for int fields)
:param default: The default value of the field
"""
if not default == MISSING and not default_factory == MISSING:
if default != MISSING and default_factory != MISSING:
raise ValueError("Cannot specify both default and default_factory")
if not default == MISSING:
if default != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default=default,
repr=repr,
)
if not default_factory == MISSING:
if default_factory != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default_factory=default_factory,

View File

@ -3,25 +3,40 @@ from contextlib import asynccontextmanager
from typing import Generic, TypeVar
import anyio
from anyio.abc import ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from . import filters
T = TypeVar("T")
class BroadcastStream(Generic[T]):
def __init__(self):
def __init__(self, backlog: int = 50):
self.streams: list[ObjectSendStream[T]] = []
self.backlog: list[T] = []
self._backlog_size = backlog
async def broadcast(self, packet):
logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
for stream in self.streams:
try:
await stream.send(packet)
except anyio.BrokenResourceError:
self.streams.remove(stream)
logging.error("Broken resource error")
# self.streams.remove(stream)
# If we have a backlog, add the packet to it
if len(self.backlog) >= self._backlog_size:
self.backlog.pop(0)
self.backlog.append(packet)
@asynccontextmanager
async def open_stream(self):
send, recv = anyio.create_memory_object_stream[T]()
async def open_stream(self, backlog: bool = True):
# 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
# start stalling the APNs connection itself
send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
if backlog:
for packet in self.backlog:
await send.send(packet)
self.streams.append(send)
async with recv:
yield recv
@ -29,6 +44,31 @@ class BroadcastStream(Generic[T]):
await send.aclose()
W = TypeVar("W")
F = TypeVar("F")
class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""
def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter
async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream
async def aclose(self):
await self.source.aclose()
def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1

View File

@ -4,7 +4,7 @@ import plistlib
import re
import uuid
from base64 import b64decode
from typing import Tuple, Optional
from typing import Optional, Tuple
import httpx
from cryptography import x509
@ -96,10 +96,10 @@ async def activate(
try:
protocol = re.search("<Protocol>(.*)</Protocol>", resp.text).group(1) # type: ignore
except AttributeError:
except AttributeError as e:
# Search for error text between <b> and </b>
error = re.search("<b>(.*)</b>", resp.text).group(1) # type: ignore
raise Exception(f"Failed to get certificate from Albert: {error}")
raise Exception(f"Failed to get certificate from Albert: {error}") from e
protocol = plistlib.loads(protocol.encode("utf-8"))

44
pypush/apns/filters.py Normal file
View File

@ -0,0 +1,44 @@
import logging
from typing import Callable, Optional, Type, TypeVar
from pypush.apns import protocol
T1 = TypeVar("T1")
T2 = TypeVar("T2")
Filter = Callable[[T1], Optional[T2]]
# Chain with proper types so that subsequent filters only need to take output type of previous filter
T_IN = TypeVar("T_IN", bound=protocol.Command)
T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
T_OUT = TypeVar("T_OUT", bound=protocol.Command)
def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
def filter(command: T_IN) -> Optional[T_OUT]:
logging.debug(f"Filtering {command} with {first} and {second}")
filtered = first(command)
if filtered is None:
return None
return second(filtered)
return filter
T = TypeVar("T", bound=protocol.Command)
def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
def filter(command: protocol.Command) -> Optional[T]:
if isinstance(command, type):
return command
return None
return filter
def ALL(c):
return c
def NONE(_):
return None

View File

@ -6,6 +6,7 @@ import random
import time
import typing
from contextlib import asynccontextmanager
from hashlib import sha1
import anyio
from anyio.abc import TaskGroup
@ -13,7 +14,7 @@ from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from . import protocol, transport, _util
from . import _util, filters, protocol, transport
@asynccontextmanager
@ -21,13 +22,18 @@ async def create_apns_connection(
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
async with anyio.create_task_group() as tg:
conn = Connection(tg, certificate, private_key, token, courier)
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
) # Await connected for first time here, so that base token is set
yield conn
tg.cancel_scope.cancel() # Cancel the task group when the context manager exits
await conn.aclose() # Make sure to close the connection after the task group is cancelled
await (
conn.aclose()
) # Make sure to close the connection after the task group is cancelled
class Connection:
@ -37,26 +43,44 @@ class Connection:
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
self.certificate = certificate
self.private_key = private_key
self.base_token = token
self._base_token = token
self._filters: dict[str, int] = {} # topic -> use count
self._connected = anyio.Event() # Only use for base_token property
self._conn = None
self._tg = task_group
self._broadcast = _util.BroadcastStream[protocol.Command]()
self._reconnect_lock = anyio.Lock()
self._send_lock = anyio.Lock()
self.sandbox = sandbox
if courier is None:
# Pick a random courier server from 1 to 50
courier = f"{random.randint(1, 50)}-courier.push.apple.com"
courier = (
f"{random.randint(1, 50)}-courier.push.apple.com"
if not sandbox
else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com"
)
logging.debug(f"Using courier: {courier}")
self.courier = courier
self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._ping_task)
@property
async def base_token(self) -> bytes:
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
return self._base_token
async def _receive_task(self):
assert self._conn is not None
async for command in self._conn:
@ -68,8 +92,10 @@ class Connection:
while True:
await anyio.sleep(30)
logging.debug("Sending keepalive")
await self.send(protocol.KeepAliveCommand())
await self.receive(protocol.KeepAliveAck)
await self._send(protocol.KeepAliveCommand())
await self._receive(
filters.cmd(protocol.KeepAliveAck), backlog=False
) # Explicitly disable the backlog since we don't want to receive old acks
@_util.exponential_backoff
async def reconnect(self):
@ -77,8 +103,11 @@ class Connection:
if self._conn is not None:
logging.warning("Closing existing connection")
await self._conn.aclose()
self._conn = protocol.CommandStream(
await transport.create_courier_connection(courier=self.courier)
self._broadcast.backlog = [] # Clear the backlog
conn = protocol.CommandStream(
await transport.create_courier_connection(self.sandbox, self.courier)
)
cert = self.certificate.public_bytes(serialization.Encoding.DER)
nonce = (
@ -89,53 +118,150 @@ class Connection:
signature = b"\x01\x01" + self.private_key.sign(
nonce, padding.PKCS1v15(), hashes.SHA1()
)
await self._conn.send(
await conn.send(
protocol.ConnectCommand(
push_token=self.base_token,
push_token=self._base_token,
state=1,
flags=69,
flags=65, # 69
certificate=cert,
nonce=nonce,
signature=signature,
)
)
# Don't set self._conn until we've sent the connect command
self._conn = conn
self._tg.start_soon(self._receive_task)
ack = await self.receive(protocol.ConnectAck)
ack = await self._receive(
filters.chain(
filters.cmd(protocol.ConnectAck),
lambda c: (
c
if (
c.token == self._base_token
if self._base_token is not None
else True
)
else None
),
)
)
logging.debug(f"Connected with ack: {ack}")
assert ack.status == 0
if self.base_token is None:
self.base_token = ack.token
if self._base_token is None:
self._base_token = ack.token
else:
assert ack.token == self.base_token
assert ack.token == self._base_token
if not self._connected.is_set():
self._connected.set()
await self._update_filter()
async def aclose(self):
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running
T = typing.TypeVar("T", bound=protocol.Command)
T = typing.TypeVar("T")
async def receive_stream(
self, filter: typing.Type[T], max: int = -1
) -> typing.AsyncIterator[T]:
async with self._broadcast.open_stream() as stream:
@asynccontextmanager
async def _receive_stream(
self,
filter: filters.Filter[protocol.Command, T] = lambda c: c,
backlog: bool = True,
):
async with self._broadcast.open_stream(backlog) as stream:
yield _util.FilteredStream(stream, filter)
async def _receive(
self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
):
async with self._receive_stream(filter, backlog) as stream:
async for command in stream:
if isinstance(command, filter):
yield command
max -= 1
if max == 0:
break
return command
raise ValueError("Did not receive expected command")
async def receive(self, filter: typing.Type[T]) -> T:
async for command in self.receive_stream(filter, 1):
return command
raise ValueError("No matching command received")
async def send(self, command: protocol.Command):
async def _send(self, command: protocol.Command):
try:
assert self._conn is not None
await self._conn.send(command)
except Exception as e:
logging.warning(f"Error sending command, reconnecting")
async with self._send_lock:
assert self._conn is not None
await self._conn.send(command)
except Exception:
logging.warning("Error sending command, reconnecting")
await self.reconnect()
await self.send(command)
await self._send(command)
async def _update_filter(self):
await self._send(
protocol.FilterCommand(
token=await self.base_token,
enabled_topic_hashes=[
sha1(topic.encode()).digest() for topic in self._filters
],
)
)
@asynccontextmanager
async def _filter(self, topics: list[str]):
for topic in topics:
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
yield
for topic in topics:
self._filters[topic] -= 1
if self._filters[topic] == 0:
del self._filters[topic]
await self._update_filter()
async def mint_scoped_token(self, topic: str) -> bytes:
topic_hash = sha1(topic.encode()).digest()
await self._send(
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
)
ack = await self._receive(filters.cmd(protocol.ScopedTokenAck))
assert ack.status == 0
return ack.scoped_token
@asynccontextmanager
async def notification_stream(
self,
topic: str,
token: typing.Optional[bytes] = None,
filter: filters.Filter[
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
):
if token is None:
token = await self.base_token
async with self._filter([topic]), self._receive_stream(
filters.chain(
filters.chain(
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
),
lambda c: (c if c.topic == topic else None),
),
filter,
)
) as stream:
yield stream
async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
await self._send(
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
)
async def expect_notification(
self,
topic: str,
token: typing.Optional[bytes] = None,
filter: filters.Filter[
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
) -> protocol.SendMessageCommand:
async with self.notification_stream(topic, token, filter) as stream:
command = await stream.receive()
await self.ack(command)
return command

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
from hashlib import sha1
from typing import Optional, Union
from anyio.abc import ByteStream, ObjectStream
from anyio.abc import ObjectStream
from pypush.apns._protocol import command, fid
from pypush.apns.transport import Packet
@ -87,12 +87,7 @@ class FilterCommand(Command):
def _lookup_hashes(self, hashes: Optional[list[bytes]]):
return (
[
KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash
for hash in hashes
]
if hashes
else []
[KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else []
)
@property
@ -140,6 +135,7 @@ class KeepAliveAck(Command):
PacketType = Packet.Type.KeepAliveAck
unknown: Optional[int] = fid(1)
@command
@dataclass
class SetStateCommand(Command):
@ -182,7 +178,7 @@ class SendMessageCommand(Command):
) and not (self._token_topic_1 is not None and self._token_topic_2 is not None):
raise ValueError("topic, token, and outgoing must be set.")
if self.outgoing == True:
if self.outgoing is True:
assert self.topic and self.token
self._token_topic_1 = (
sha1(self.topic.encode()).digest()
@ -190,7 +186,7 @@ class SendMessageCommand(Command):
else self.topic
)
self._token_topic_2 = self.token
elif self.outgoing == False:
elif self.outgoing is False:
assert self.topic and self.token
self._token_topic_1 = self.token
self._token_topic_2 = (
@ -201,18 +197,14 @@ class SendMessageCommand(Command):
else:
assert self._token_topic_1 and self._token_topic_2
if len(self._token_topic_1) == 20: # SHA1 hash, topic
self.topic = (
KNOWN_TOPICS_LOOKUP[self._token_topic_1]
if self._token_topic_1 in KNOWN_TOPICS_LOOKUP
else self._token_topic_1
self.topic = KNOWN_TOPICS_LOOKUP.get(
self._token_topic_1, self._token_topic_1
)
self.token = self._token_topic_2
self.outgoing = True
else:
self.topic = (
KNOWN_TOPICS_LOOKUP[self._token_topic_2]
if self._token_topic_2 in KNOWN_TOPICS_LOOKUP
else self._token_topic_2
self.topic = KNOWN_TOPICS_LOOKUP.get(
self._token_topic_2, self._token_topic_2
)
self.token = self._token_topic_1
self.outgoing = False
@ -229,6 +221,27 @@ class SendMessageAck(Command):
unknown6: Optional[bytes] = fid(6, default=None)
@command
@dataclass
class ScopedTokenCommand(Command):
PacketType = Packet.Type.ScopedToken
token: bytes = fid(1)
topic: bytes = fid(2)
app_id: Optional[bytes] = fid(3, default=None)
@command
@dataclass
class ScopedTokenAck(Command):
PacketType = Packet.Type.ScopedTokenAck
status: int = fid(1)
scoped_token: bytes = fid(2)
topic: bytes = fid(3)
app_id: Optional[bytes] = fid(4, default=None)
@dataclass
class UnknownCommand(Command):
id: Packet.Type
@ -240,7 +253,7 @@ class UnknownCommand(Command):
def to_packet(self) -> Packet:
return Packet(id=self.id, fields=self.fields)
def __repr__(self):
if self.id.value in [29, 30, 32]:
return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])"
@ -259,6 +272,8 @@ def command_from_packet(packet: Packet) -> Command:
Packet.Type.SetState: SetStateCommand,
Packet.Type.SendMessage: SendMessageCommand,
Packet.Type.SendMessageAck: SendMessageAck,
Packet.Type.ScopedToken: ScopedTokenCommand,
Packet.Type.ScopedTokenAck: ScopedTokenAck,
# Add other mappings here...
}
command_class = command_classes.get(packet.id, None)

View File

@ -30,6 +30,8 @@ class Packet:
KeepAlive = 12
KeepAliveAck = 13
NoStorage = 14
ScopedToken = 17
ScopedTokenAck = 18
SetState = 20
UNKNOWN = "Unknown"
@ -38,20 +40,19 @@ class Packet:
obj = object.__new__(cls)
obj._value_ = value
return obj
@classmethod
def _missing_(cls, value):
# Handle unknown values
instance = cls.UNKNOWN
instance._value_ = value # Assign the unknown value
return instance
def __str__(self):
if self is Packet.Type.UNKNOWN:
return f"Unknown({self._value_})"
return self.name
id: Type
fields: list[Field]
@ -60,18 +61,25 @@ class Packet:
async def create_courier_connection(
sandbox: bool = False,
courier: str = "1-courier.push.apple.com",
) -> PacketStream:
context = ssl.create_default_context()
context.set_alpn_protocols(ALPN)
sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com"
# TODO: Verify courier certificate
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return PacketStream(
await anyio.connect_tcp(
courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False
courier,
COURIER_PORT,
ssl_context=context,
tls_standard_compatible=False,
tls_hostname=sni,
)
)

View File

@ -1,12 +1,17 @@
import contextlib
import logging
from asyncio import CancelledError
import anyio
import typer
from rich.logging import RichHandler
from typing_extensions import Annotated
from pypush import apns
from . import proxy as _proxy
logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s")
app = typer.Typer()
@ -22,12 +27,12 @@ def proxy(
Attach requires SIP to be disabled and to be running as root
"""
_proxy.main(attach)
with contextlib.suppress(CancelledError):
_proxy.main(attach)
@app.command()
def client(
def notifications(
topic: Annotated[str, typer.Argument(help="app topic to listen on")],
sandbox: Annotated[
bool, typer.Option("--sandbox/--production", help="APNs courier to use")
@ -36,8 +41,29 @@ def client(
"""
Connect to the APNs courier and listen for app notifications on the given topic
"""
typer.echo("Running APNs client")
raise NotImplementedError("Not implemented yet")
logging.getLogger("httpx").setLevel(logging.WARNING)
with contextlib.suppress(CancelledError):
anyio.run(notifications_async, topic, sandbox)
async def notifications_async(topic: str, sandbox: bool):
async with apns.create_apns_connection(
*await apns.activate(),
courier="1-courier.sandbox.push.apple.com"
if sandbox
else "1-courier.push.apple.com",
) as connection:
token = await connection.mint_scoped_token(topic)
async with connection.notification_stream(topic, token) as stream:
logging.info(
f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})"
)
logging.info(f"Token: {token.hex()}")
async for notification in stream:
await connection.ack(notification)
logging.info(notification.payload.decode())
def main():

View File

@ -1,6 +1,7 @@
import frida
import logging
import frida
def attach_to_apsd() -> frida.core.Session:
frida.kill("apsd")

View File

@ -2,7 +2,6 @@ import datetime
import logging
import ssl
import tempfile
from typing import Optional
import anyio
import anyio.abc
@ -12,11 +11,10 @@ from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.hazmat.primitives.serialization import Encoding
# from pypush import apns
from pypush.apns import transport
from pypush.apns import protocol
from pypush.apns import protocol, transport
from . import _frida
@ -71,7 +69,7 @@ async def handle(client: TLSStream):
else "1-courier.sandbox.push.apple.com"
)
name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}"
async with await transport.create_courier_connection(forward) as conn:
async with await transport.create_courier_connection(sandbox, forward) as conn:
logging.debug("Connected to courier")
async with anyio.create_task_group() as tg:
tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}")

View File

@ -0,0 +1,75 @@
Bag Attributes
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US
issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US
-----BEGIN CERTIFICATE-----
MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1
MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD
ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw
cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw
MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0
czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0
ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC
VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi
HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa
XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y
RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI
0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW
rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC
TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n
mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0
cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw
Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC
AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug
b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh
bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv
bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj
YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v
d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI
KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v
d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV
HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw
dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw
MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO
DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0
MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC
AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y
8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl
/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty
wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0
74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX
5s0ybZYoG5meGKUplwu7A2bfFw==
-----END CERTIFICATE-----
Bag Attributes
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
Key Attributes: <No Attributes>
-----BEGIN PRIVATE KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL
VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+
pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt
8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0
GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0
taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6
p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT
bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt
I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX
gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7
bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW
ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB
myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW
BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE
14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw
ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr
tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea
i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi
a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b
zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa
8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2
YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V
SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT
c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK
J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD
v77FY1M3FxGR6rNqPJQ9rRLFbA==
-----END PRIVATE KEY-----

View File

@ -1,18 +1,13 @@
import pytest
from pypush import apns
import asyncio
# from aioapns import *
import uuid
import anyio
# from pypush.apns import _util
# from pypush.apns import albert, lifecycle, protocol
from pypush import apns
import logging
import uuid
from pathlib import Path
import httpx
import pytest
from rich.logging import RichHandler
from pypush import apns
logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
@ -26,17 +21,43 @@ async def test_activate():
@pytest.mark.asyncio
async def test_lifecycle_2():
async with apns.create_apns_connection(
certificate, key, courier="localhost"
) as connection:
await connection.receive(
apns.protocol.ConnectAck
) # Just wait until the initial connection is established. Don't do this in real code plz.
async with apns.create_apns_connection(certificate, key) as _:
pass
ASSETS_DIR = Path(__file__).parent / "assets"
async def send_test_notification(device_token, payload=b"hello, world"):
async with httpx.AsyncClient(
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
) as client:
# Use the certificate and key from above
response = await client.post(
f"https://api.sandbox.push.apple.com/3/device/{device_token}",
content=payload,
headers={
"apns-topic": "dev.jjtech.pypush.tests",
"apns-push-type": "alert",
"apns-priority": "10",
},
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_shorthand():
async def test_scoped_token():
async with apns.create_apns_connection(
*await apns.activate(), courier="localhost"
*await apns.activate(), sandbox=True
) as connection:
await connection.receive(apns.protocol.ConnectAck)
token = await connection.mint_scoped_token("dev.jjtech.pypush.tests")
test_message = f"test-message-{uuid.uuid4().hex}"
await send_test_notification(token.hex(), test_message.encode())
await connection.expect_notification(
"dev.jjtech.pypush.tests",
token,
lambda c: c if c.payload == test_message.encode() else None,
)