Source code for wcraas_storage.wcraas_storage
# -*- coding: utf-8 -*-
"""The WCraaS Storage module is responsible for providing storage services for the platform."""
import asyncio
import json
import logging
import aio_pika
from aio_pika import connect_robust, IncomingMessage, ExchangeType
from aio_pika.patterns import RPC
from aio_pika.pool import Pool
from typing import Dict, List
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from wcraas_common import AMQPConfig, WcraasWorker
from wcraas_common.decorator import is_rpc, consume
from wcraas_storage.config import MongoConfig
[docs]class StorageWorker(WcraasWorker):
def __init__(
self,
amqp: AMQPConfig,
mongo: MongoConfig,
mapping: Dict[str, str],
loglevel: int,
*args,
**kwargs,
):
super().__init__(amqp, loglevel)
self.mongo = mongo
self.logger = logging.getLogger("wcraas.storage")
self.logger.setLevel(loglevel)
self._db = AsyncIOMotorClient(mongo.host, mongo.port)[mongo.db]
self._close = asyncio.Event()
self.mapping = mapping
def _discover_callable(self, *args, **kwargs):
"""
Overwride common module's callable check to exclude AsyncIOMotorDatabase.
AsyncIOMotorDatabase instances return (and create) a new collection refference
when their `__getattr__` is invoked and will thus always pass `_discover`.
"""
for attr in super()._discover_callable(*args, **kwargs):
if isinstance(attr, AsyncIOMotorDatabase):
continue
yield attr
[docs] def get_queue_by_collection(self, collection: str) -> str:
"""
Return the queue that corresponds to the given collection.
:param collection: The collection with which to determine the queue.
:type collection: string
"""
for k, v in self.mapping.items():
if v == collection:
return k
raise KeyError
[docs] async def store(self, message: IncomingMessage) -> None:
"""
AMQP consumer function, that inserts an `IncomingMessage`'s json-loaded body in a MongoDB collection based on the source exchange.
:param message: The message that trigered the consume callback.
:type message: aio_pika.IncomingMessage
"""
async with message.process():
try:
result = await self._db[self.mapping[message.exchange]].insert_one(
json.loads(message.body)
)
self.logger.info(result)
except Exception as err:
self.logger.error(err)
[docs] @is_rpc("list_collections")
async def list_collections(self) -> Dict[str, List[Dict[str, str]]]:
"""
AMQP function that lists available collections in selected MongoDB.
"""
return {
"data": [
{
"name": collection["name"],
"type": collection["type"],
"queue": self.get_queue_by_collection(collection["name"]),
"count": await self._db[collection["name"]].estimated_document_count(),
}
for collection in (await self._db.list_collections())
]
}
[docs] async def start(self) -> None:
"""
Asynchronous runtime for the worker, responsible of managing and maintaining async context open.
"""
async with self._amqp_pool.acquire() as sub_channel:
await sub_channel.set_qos(prefetch_count=1)
# Not using the common module's `start_consume` because the consumer function is the same for
# multiple consume queues and dynamically creating new function names in the class in order to
# decorate them (which would enable using `start_consume`) would be too much of a hack.
for queue_name, collection in self.mapping.items():
await self.register_consumer(sub_channel, self.store, queue_name)
self.logger.warning(f"Registered {queue_name} ...")
await self.start_rpc()
[docs] def run(self) -> None:
"""
Helper function implementing the synchronous boilerplate for initilization and teardown.
"""
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(self.start())
except KeyboardInterrupt:
self.logger.info("[x] Received ^C ! Exiting ...")
finally:
self._close.set()
loop.shutdown_asyncgens()