diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index a0710d85..390499c0 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -1,5 +1,6 @@ import asyncio import time + from dataclasses import asdict, dataclass from typing import Awaitable, Callable, Dict, Optional, Union @@ -111,132 +112,100 @@ def asDict(self) -> dict[str, any]: } -class APNsService: - __slots__ = ("client",) - - def __init__( - self, - application_id: str = None, - creds: Credentials = None, - topic: str = None, - err_func: ErrFunc = None, - ): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - self.client = self._create_client( - creds=creds, application_id=application_id, topic=topic, err_func=err_func - ) +def _create_notification_request_from_args( + registration_id: str, + alert: Union[str, Alert], + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + aps_kwargs: dict = {}, + message_kwargs: dict = {}, + notification_request_kwargs: dict = {}, +): + if alert is None: + alert = Alert(body="") - def send_message( - self, - request: NotificationRequest, - ): - loop = asyncio.get_event_loop() - routine = self.client.send_notification(request) - res = loop.run_until_complete(routine) - return res - - def _create_notification_request_from_args( - self, - registration_id: str, - alert: Union[str, Alert], - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - aps_kwargs: dict = {}, - message_kwargs: dict = {}, - notification_request_kwargs: dict = {}, - ): - if alert is None: - alert = Alert(body="") - - if loc_key: - if isinstance(alert, str): - alert = Alert(body=alert) - alert.loc_key = loc_key - - if isinstance(alert, Alert): - alert = alert.asDict() - - notification_request_kwargs_out = notification_request_kwargs.copy() - - if expiration is not None: - notification_request_kwargs_out["time_to_live"] = expiration - int( - time.time() - ) - if priority is not None: - notification_request_kwargs_out["priority"] = priority - - if collapse_id is not None: - notification_request_kwargs_out["collapse_key"] = collapse_id - - request = NotificationRequest( - device_token=registration_id, - message={ - "aps": { - "alert": alert, - "badge": badge, - "sound": sound, - "thread-id": thread_id, - **aps_kwargs, - }, - **extra, - **message_kwargs, - }, - **notification_request_kwargs_out, - ) + if loc_key: + if isinstance(alert, str): + alert = Alert(body=alert) + alert.loc_key = loc_key - return request - - def _create_client( - self, - creds: Credentials = None, - application_id: str = None, - topic=None, - err_func: ErrFunc = None, - ) -> APNs: - use_sandbox = get_manager().get_apns_use_sandbox(application_id) - if topic is None: - topic = get_manager().get_apns_topic(application_id) - if creds is None: - creds = self._get_credentials(application_id) - - client = APNs( - **asdict(creds), - topic=topic, # Bundle ID - use_sandbox=use_sandbox, - err_func=err_func, + if isinstance(alert, Alert): + alert = alert.asDict() + + notification_request_kwargs_out = notification_request_kwargs.copy() + + if expiration is not None: + notification_request_kwargs_out["time_to_live"] = expiration - int( + time.time() ) - return client - - def _get_credentials(self, application_id): - if not get_manager().has_auth_token_creds(application_id): - # TLS certificate authentication - cert = get_manager().get_apns_certificate(application_id) - return CertificateCredentials( - client_cert=cert, - ) - else: - # Token authentication - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) + if priority is not None: + notification_request_kwargs_out["priority"] = priority + + if collapse_id is not None: + notification_request_kwargs_out["collapse_key"] = collapse_id + + request = NotificationRequest( + device_token=registration_id, + message={ + "aps": { + "alert": alert, + "badge": badge, + "sound": sound, + "thread-id": thread_id, + **aps_kwargs, + }, + **extra, + **message_kwargs, + }, + **notification_request_kwargs_out, + ) + return request -# Public interface + +def _create_client( + creds: Credentials = None, + application_id: str = None, + topic=None, + err_func: ErrFunc = None, +) -> APNs: + use_sandbox = get_manager().get_apns_use_sandbox(application_id) + if topic is None: + topic = get_manager().get_apns_topic(application_id) + if creds is None: + creds = _get_credentials(application_id) + + client = APNs( + **asdict(creds), + topic=topic, # Bundle ID + use_sandbox=use_sandbox, + err_func=err_func, + ) + return client + + +def _get_credentials(application_id): + if not get_manager().has_auth_token_creds(application_id): + # TLS certificate authentication + cert = get_manager().get_apns_certificate(application_id) + return CertificateCredentials( + client_cert=cert, + ) + else: + # Token authentication + keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) def apns_send_message( @@ -270,33 +239,28 @@ def apns_send_message( :param application_id: The application_id to use :param creds: The credentials to use """ + results = apns_send_bulk_message( + registration_ids=[registration_id], + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + err_func=err_func, + ) - try: - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - - request = apns_service._create_notification_request_from_args( - registration_id, - alert, - badge=badge, - sound=sound, - extra=extra, - expiration=expiration, - thread_id=thread_id, - loc_key=loc_key, - priority=priority, - collapse_id=collapse_id, - ) - res = apns_service.send_message(request) - if not res.is_successful: - if res.description == "Unregistered": - models.APNSDevice.objects.filter( - registration_id=registration_id - ).update(active=False) - raise APNSServerError(status=res.description) - except ConnectionError as e: - raise APNSServerError(status=e.__class__.__name__) + for result in results.values(): + if result == "Success": + return {"results": [result]} + else: + return {"results": [{"error": result}]} def apns_send_bulk_message( @@ -328,17 +292,17 @@ def apns_send_bulk_message( :param application_id: The application_id to use :param creds: The credentials to use """ - - topic = get_manager().get_apns_topic(application_id) - results: Dict[str, str] = {} - inactive_tokens = [] - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - for registration_id in registration_ids: - request = apns_service._create_notification_request_from_args( - registration_id, - alert, + try: + topic = get_manager().get_apns_topic(application_id) + results: Dict[str, str] = {} + inactive_tokens = [] + + responses = asyncio.run(_send_bulk_request( + registration_ids=registration_ids, + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, badge=badge, sound=sound, extra=extra, @@ -347,17 +311,79 @@ def apns_send_bulk_message( loc_key=loc_key, priority=priority, collapse_id=collapse_id, - ) + err_func=err_func, + )) + + results = {} + for registration_id, result in responses: + results[registration_id] = ( + "Success" if result.is_successful else result.description + ) + if not result.is_successful and result.description in ["Unregistered", "BadDeviceToken", + "DeviceTokenNotForTopic"]: + inactive_tokens.append(registration_id) + + if len(inactive_tokens) > 0: + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + + return results + + except ConnectionError as e: + raise APNSServerError(status=e.__class__.__name__) - result = apns_service.send_message(request) - results[registration_id] = ( - "Success" if result.is_successful else result.description - ) - if not result.is_successful and result.description == "Unregistered": - inactive_tokens.append(registration_id) - if len(inactive_tokens) > 0: - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( - active=False +async def _send_bulk_request( + registration_ids: list[str], + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, +): + client = _create_client( + creds=creds, application_id=application_id, topic=topic, err_func=err_func + ) + + requests = [_create_notification_request_from_args( + registration_id, + alert, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + ) for registration_id in registration_ids] + + send_requests = [_send_request(client, request) for request in requests] + return await asyncio.gather(*send_requests) + + +async def _send_request(apns, request): + try: + res = await asyncio.wait_for(apns.send_notification(request), timeout=1) + return request.device_token, res + except asyncio.TimeoutError: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="TimeoutError" + ) + except: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="CommunicationError" ) - return results