diff options
| author | Georgios Andreadis <info@gandreadis.com> | 2020-06-29 16:05:23 +0200 |
|---|---|---|
| committer | Fabian Mastenbroek <mail.fabianm@gmail.com> | 2020-08-24 16:18:36 +0200 |
| commit | 4f9a40abdc7836345113c047f27fcc96800cb3f5 (patch) | |
| tree | e443d14e34a884b1a4d9c549f81d51202eddd5f7 /web-server/opendc/util | |
| parent | cd5f7bf3a72913e1602cb4c575e61ac7d5519be0 (diff) | |
Prepare web-server repository for monorepo
This change prepares the web-server Git repository for the monorepo residing at
https://github.com/atlarge-research.com/opendc. To accomodate for this, we
move all files into a web-server subdirectory.
Diffstat (limited to 'web-server/opendc/util')
| -rw-r--r-- | web-server/opendc/util/__init__.py | 0 | ||||
| -rw-r--r-- | web-server/opendc/util/database.py | 93 | ||||
| -rw-r--r-- | web-server/opendc/util/exceptions.py | 65 | ||||
| -rw-r--r-- | web-server/opendc/util/parameter_checker.py | 78 | ||||
| -rw-r--r-- | web-server/opendc/util/path_parser.py | 39 | ||||
| -rw-r--r-- | web-server/opendc/util/rest.py | 141 |
6 files changed, 416 insertions, 0 deletions
diff --git a/web-server/opendc/util/__init__.py b/web-server/opendc/util/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/web-server/opendc/util/__init__.py diff --git a/web-server/opendc/util/database.py b/web-server/opendc/util/database.py new file mode 100644 index 00000000..50bc93a8 --- /dev/null +++ b/web-server/opendc/util/database.py @@ -0,0 +1,93 @@ +import json +import urllib.parse +from datetime import datetime + +from bson.json_util import dumps +from pymongo import MongoClient + +DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S' +CONNECTION_POOL = None + + +class Database: + def __init__(self): + self.opendc_db = None + + def init_database(self, user, password, database, host): + user = urllib.parse.quote_plus(user) # TODO: replace this with environment variable + password = urllib.parse.quote_plus(password) # TODO: same as above + database = urllib.parse.quote_plus(database) + host = urllib.parse.quote_plus(host) + + client = MongoClient('mongodb://%s:%s@%s/default_db?authSource=%s' % (user, password, host, database)) + self.opendc_db = client.opendc + + def fetch_one(self, query, collection): + """Uses existing mongo connection to return a single (the first) document in a collection matching the given + query as a JSON object. + + The query needs to be in json format, i.e.: `{'name': prefab_name}`. + """ + bson = getattr(self.opendc_db, collection).find_one(query) + + return self.convert_bson_to_json(bson) + + def fetch_all(self, query, collection): + """Uses existing mongo connection to return all documents matching a given query, as a list of JSON objects. + + The query needs to be in json format, i.e.: `{'name': prefab_name}`. + """ + results = [] + cursor = getattr(self.opendc_db, collection).find(query) + for doc in cursor: + results.append(self.convert_bson_to_json(doc)) + return results + + def insert(self, obj, collection): + """Updates an existing object.""" + bson = getattr(self.opendc_db, collection).insert(obj) + + return self.convert_bson_to_json(bson) + + def update(self, _id, obj, collection): + """Updates an existing object.""" + bson = getattr(self.opendc_db, collection).update({'_id': _id}, obj) + + return self.convert_bson_to_json(bson) + + def delete_one(self, query, collection): + """Deletes one object matching the given query. + + The query needs to be in json format, i.e.: `{'name': prefab_name}`. + """ + bson = getattr(self.opendc_db, collection).delete_one(query) + + return self.convert_bson_to_json(bson) + + def delete_all(self, query, collection): + """Deletes all objects matching the given query. + + The query needs to be in json format, i.e.: `{'name': prefab_name}`. + """ + bson = getattr(self.opendc_db, collection).delete_many(query) + + return self.convert_bson_to_json(bson) + + @staticmethod + def convert_bson_to_json(bson): + """Converts a BSON representation to JSON and returns the JSON representation.""" + json_string = dumps(bson) + return json.loads(json_string) + + @staticmethod + def datetime_to_string(datetime_to_convert): + """Return a database-compatible string representation of the given datetime object.""" + return datetime_to_convert.strftime(DATETIME_STRING_FORMAT) + + @staticmethod + def string_to_datetime(string_to_convert): + """Return a datetime corresponding to the given string representation.""" + return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT) + + +DB = Database() diff --git a/web-server/opendc/util/exceptions.py b/web-server/opendc/util/exceptions.py new file mode 100644 index 00000000..2563c419 --- /dev/null +++ b/web-server/opendc/util/exceptions.py @@ -0,0 +1,65 @@ +class RequestInitializationError(Exception): + """Raised when a Request cannot successfully be initialized""" + + +class UnimplementedEndpointError(RequestInitializationError): + """Raised when a Request path does not point to a module.""" + + +class MissingRequestParameterError(RequestInitializationError): + """Raised when a Request does not contain one or more required parameters.""" + + +class UnsupportedMethodError(RequestInitializationError): + """Raised when a Request does not use a supported REST method. + + The method must be in all-caps, supported by REST, and implemented by the module. + """ + + +class AuthorizationTokenError(RequestInitializationError): + """Raised when an authorization token is not correctly verified.""" + + +class ForeignKeyError(Exception): + """Raised when a foreign key constraint is not met.""" + + +class RowNotFoundError(Exception): + """Raised when a database row is not found.""" + def __init__(self, table_name): + super(RowNotFoundError, self).__init__('Row in `{}` table not found.'.format(table_name)) + + self.table_name = table_name + + +class ParameterError(Exception): + """Raised when a parameter is either missing or incorrectly typed.""" + + +class IncorrectParameterError(ParameterError): + """Raised when a parameter is of the wrong type.""" + def __init__(self, parameter_name, parameter_location): + super(IncorrectParameterError, + self).__init__('Incorrectly typed `{}` {} parameter.'.format(parameter_name, parameter_location)) + + self.parameter_name = parameter_name + self.parameter_location = parameter_location + + +class MissingParameterError(ParameterError): + """Raised when a parameter is missing.""" + def __init__(self, parameter_name, parameter_location): + super(MissingParameterError, + self).__init__('Missing required `{}` {} parameter.'.format(parameter_name, parameter_location)) + + self.parameter_name = parameter_name + self.parameter_location = parameter_location + + +class ClientError(Exception): + """Raised when a 4xx response is to be returned.""" + + def __init__(self, response): + super(ClientError, self).__init__(str(response)) + self.response = response diff --git a/web-server/opendc/util/parameter_checker.py b/web-server/opendc/util/parameter_checker.py new file mode 100644 index 00000000..f55e780e --- /dev/null +++ b/web-server/opendc/util/parameter_checker.py @@ -0,0 +1,78 @@ +from opendc.util import database, exceptions + + +def _missing_parameter(params_required, params_actual, parent=''): + """Recursively search for the first missing parameter.""" + + for param_name in params_required: + + if param_name not in params_actual: + return '{}.{}'.format(parent, param_name) + + param_required = params_required.get(param_name) + param_actual = params_actual.get(param_name) + + if isinstance(param_required, dict): + + param_missing = _missing_parameter(param_required, param_actual, param_name) + + if param_missing is not None: + return '{}.{}'.format(parent, param_missing) + + return None + + +def _incorrect_parameter(params_required, params_actual, parent=''): + """Recursively make sure each parameter is of the correct type.""" + + for param_name in params_required: + + param_required = params_required.get(param_name) + param_actual = params_actual.get(param_name) + + if isinstance(param_required, dict): + + param_incorrect = _incorrect_parameter(param_required, param_actual, param_name) + + if param_incorrect is not None: + return '{}.{}'.format(parent, param_incorrect) + + else: + + if param_required == 'datetime': + try: + database.string_to_datetime(param_actual) + except: + return '{}.{}'.format(parent, param_name) + + if param_required == 'int' and not isinstance(param_actual, int): + return '{}.{}'.format(parent, param_name) + + if param_required == 'string' and not isinstance(param_actual, str) and not isinstance(param_actual, int): + return '{}.{}'.format(parent, param_name) + + if param_required.startswith('list') and not isinstance(param_actual, list): + return '{}.{}'.format(parent, param_name) + + +def _format_parameter(parameter): + """Format the output of a parameter check.""" + + parts = parameter.split('.') + inner = ['["{}"]'.format(x) for x in parts[2:]] + return parts[1] + ''.join(inner) + + +def check(request, **kwargs): + """Return True if all required parameters are there.""" + + for location, params_required in kwargs.items(): + params_actual = getattr(request, 'params_{}'.format(location)) + + missing_parameter = _missing_parameter(params_required, params_actual) + if missing_parameter is not None: + raise exceptions.MissingParameterError(_format_parameter(missing_parameter), location) + + incorrect_parameter = _incorrect_parameter(params_required, params_actual) + if incorrect_parameter is not None: + raise exceptions.IncorrectParameterError(_format_parameter(incorrect_parameter), location) diff --git a/web-server/opendc/util/path_parser.py b/web-server/opendc/util/path_parser.py new file mode 100644 index 00000000..a8bbdeba --- /dev/null +++ b/web-server/opendc/util/path_parser.py @@ -0,0 +1,39 @@ +import json +import os + + +def parse(version, endpoint_path): + """Map an HTTP endpoint path to an API path""" + + # Get possible paths + with open(os.path.join(os.path.dirname(__file__), '..', 'api', '{}', 'paths.json').format(version)) as paths_file: + paths = json.load(paths_file) + + # Find API path that matches endpoint_path + endpoint_path_parts = endpoint_path.strip('/').split('/') + paths_parts = [x.strip('/').split('/') for x in paths if len(x.strip('/').split('/')) == len(endpoint_path_parts)] + path = None + + for path_parts in paths_parts: + found = True + for (endpoint_part, part) in zip(endpoint_path_parts, path_parts): + if not part.startswith('{') and endpoint_part != part: + found = False + break + if found: + path = path_parts + + if path is None: + return None + + # Extract path parameters + parameters = {} + + for (name, value) in zip(path, endpoint_path_parts): + if name.startswith('{'): + try: + parameters[name.strip('{}')] = int(value) + except: + parameters[name.strip('{}')] = value + + return '{}/{}'.format(version, '/'.join(path)), parameters diff --git a/web-server/opendc/util/rest.py b/web-server/opendc/util/rest.py new file mode 100644 index 00000000..dc5478de --- /dev/null +++ b/web-server/opendc/util/rest.py @@ -0,0 +1,141 @@ +import importlib +import json +import os +import sys + +from oauth2client import client, crypt + +from opendc.util import exceptions, parameter_checker +from opendc.util.exceptions import ClientError + + +class Request(object): + """WebSocket message to REST request mapping.""" + def __init__(self, message=None): + """"Initialize a Request from a socket message.""" + + # Get the Request parameters from the message + + if message is None: + return + + try: + self.message = message + + self.id = message['id'] + + self.path = message['path'] + self.method = message['method'] + + self.params_body = message['parameters']['body'] + self.params_path = message['parameters']['path'] + self.params_query = message['parameters']['query'] + + self.token = message['token'] + + except KeyError as exception: + raise exceptions.MissingRequestParameterError(exception) + + # Parse the path and import the appropriate module + + try: + self.path = message['path'].strip('/') + + module_base = 'opendc.api.{}.endpoint' + module_path = self.path.replace('{', '').replace('}', '').replace('/', '.') + + self.module = importlib.import_module(module_base.format(module_path)) + except ImportError as e: + print(e) + raise exceptions.UnimplementedEndpointError('Unimplemented endpoint: {}.'.format(self.path)) + + # Check the method + + if self.method not in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']: + raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method)) + + if not hasattr(self.module, self.method): + raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format( + self.path, self.method)) + + # Verify the user + + if "OPENDC_FLASK_TESTING" in os.environ: + self.google_id = 'test' + return + + try: + self.google_id = self._verify_token(self.token) + except crypt.AppIdentityError as e: + raise exceptions.AuthorizationTokenError(e) + + def check_required_parameters(self, **kwargs): + """Raise an error if a parameter is missing or of the wrong type.""" + + try: + parameter_checker.check(self, **kwargs) + except exceptions.ParameterError as e: + raise ClientError(Response(400, str(e))) + + def process(self): + """Process the Request and return a Response.""" + + method = getattr(self.module, self.method) + + try: + response = method(self) + except ClientError as e: + e.response.id = self.id + return e.response + + response.id = self.id + + return response + + def to_JSON(self): + """Return a JSON representation of this Request""" + + self.message['id'] = 0 + self.message['token'] = None + + return json.dumps(self.message) + + @staticmethod + def _verify_token(token): + """Return the ID of the signed-in user. + + Or throw an Exception if the token is invalid. + """ + + try: + id_info = client.verify_id_token(token, os.environ['OPENDC_OAUTH_CLIENT_ID']) + except Exception as e: + print(e) + raise crypt.AppIdentityError('Exception caught trying to verify ID token: {}'.format(e)) + + if id_info['aud'] != os.environ['OPENDC_OAUTH_CLIENT_ID']: + raise crypt.AppIdentityError('Unrecognized client.') + + if id_info['iss'] not in ['accounts.google.com', 'https://accounts.google.com']: + raise crypt.AppIdentityError('Wrong issuer.') + + return id_info['sub'] + + +class Response(object): + """Response to websocket mapping""" + def __init__(self, status_code, status_description, content=None): + """Initialize a new Response.""" + + self.status = {'code': status_code, 'description': status_description} + self.content = content + + def to_JSON(self): + """"Return a JSON representation of this Response""" + + data = {'id': self.id, 'status': self.status} + + if self.content is not None: + data['content'] = self.content + + return json.dumps(data) |
