summaryrefslogtreecommitdiff
path: root/web-server/opendc/util
diff options
context:
space:
mode:
authorGeorgios Andreadis <info@gandreadis.com>2020-06-29 16:05:23 +0200
committerFabian Mastenbroek <mail.fabianm@gmail.com>2020-08-24 16:18:36 +0200
commit4f9a40abdc7836345113c047f27fcc96800cb3f5 (patch)
treee443d14e34a884b1a4d9c549f81d51202eddd5f7 /web-server/opendc/util
parentcd5f7bf3a72913e1602cb4c575e61ac7d5519be0 (diff)
Prepare web-server repository for monorepo
This change prepares the web-server Git repository for the monorepo residing at https://github.com/atlarge-research.com/opendc. To accomodate for this, we move all files into a web-server subdirectory.
Diffstat (limited to 'web-server/opendc/util')
-rw-r--r--web-server/opendc/util/__init__.py0
-rw-r--r--web-server/opendc/util/database.py93
-rw-r--r--web-server/opendc/util/exceptions.py65
-rw-r--r--web-server/opendc/util/parameter_checker.py78
-rw-r--r--web-server/opendc/util/path_parser.py39
-rw-r--r--web-server/opendc/util/rest.py141
6 files changed, 416 insertions, 0 deletions
diff --git a/web-server/opendc/util/__init__.py b/web-server/opendc/util/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/web-server/opendc/util/__init__.py
diff --git a/web-server/opendc/util/database.py b/web-server/opendc/util/database.py
new file mode 100644
index 00000000..50bc93a8
--- /dev/null
+++ b/web-server/opendc/util/database.py
@@ -0,0 +1,93 @@
+import json
+import urllib.parse
+from datetime import datetime
+
+from bson.json_util import dumps
+from pymongo import MongoClient
+
+DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S'
+CONNECTION_POOL = None
+
+
+class Database:
+ def __init__(self):
+ self.opendc_db = None
+
+ def init_database(self, user, password, database, host):
+ user = urllib.parse.quote_plus(user) # TODO: replace this with environment variable
+ password = urllib.parse.quote_plus(password) # TODO: same as above
+ database = urllib.parse.quote_plus(database)
+ host = urllib.parse.quote_plus(host)
+
+ client = MongoClient('mongodb://%s:%s@%s/default_db?authSource=%s' % (user, password, host, database))
+ self.opendc_db = client.opendc
+
+ def fetch_one(self, query, collection):
+ """Uses existing mongo connection to return a single (the first) document in a collection matching the given
+ query as a JSON object.
+
+ The query needs to be in json format, i.e.: `{'name': prefab_name}`.
+ """
+ bson = getattr(self.opendc_db, collection).find_one(query)
+
+ return self.convert_bson_to_json(bson)
+
+ def fetch_all(self, query, collection):
+ """Uses existing mongo connection to return all documents matching a given query, as a list of JSON objects.
+
+ The query needs to be in json format, i.e.: `{'name': prefab_name}`.
+ """
+ results = []
+ cursor = getattr(self.opendc_db, collection).find(query)
+ for doc in cursor:
+ results.append(self.convert_bson_to_json(doc))
+ return results
+
+ def insert(self, obj, collection):
+ """Updates an existing object."""
+ bson = getattr(self.opendc_db, collection).insert(obj)
+
+ return self.convert_bson_to_json(bson)
+
+ def update(self, _id, obj, collection):
+ """Updates an existing object."""
+ bson = getattr(self.opendc_db, collection).update({'_id': _id}, obj)
+
+ return self.convert_bson_to_json(bson)
+
+ def delete_one(self, query, collection):
+ """Deletes one object matching the given query.
+
+ The query needs to be in json format, i.e.: `{'name': prefab_name}`.
+ """
+ bson = getattr(self.opendc_db, collection).delete_one(query)
+
+ return self.convert_bson_to_json(bson)
+
+ def delete_all(self, query, collection):
+ """Deletes all objects matching the given query.
+
+ The query needs to be in json format, i.e.: `{'name': prefab_name}`.
+ """
+ bson = getattr(self.opendc_db, collection).delete_many(query)
+
+ return self.convert_bson_to_json(bson)
+
+ @staticmethod
+ def convert_bson_to_json(bson):
+ """Converts a BSON representation to JSON and returns the JSON representation."""
+ json_string = dumps(bson)
+ return json.loads(json_string)
+
+ @staticmethod
+ def datetime_to_string(datetime_to_convert):
+ """Return a database-compatible string representation of the given datetime object."""
+ return datetime_to_convert.strftime(DATETIME_STRING_FORMAT)
+
+ @staticmethod
+ def string_to_datetime(string_to_convert):
+ """Return a datetime corresponding to the given string representation."""
+ return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT)
+
+
+DB = Database()
diff --git a/web-server/opendc/util/exceptions.py b/web-server/opendc/util/exceptions.py
new file mode 100644
index 00000000..2563c419
--- /dev/null
+++ b/web-server/opendc/util/exceptions.py
@@ -0,0 +1,65 @@
+class RequestInitializationError(Exception):
+ """Raised when a Request cannot successfully be initialized"""
+
+
+class UnimplementedEndpointError(RequestInitializationError):
+ """Raised when a Request path does not point to a module."""
+
+
+class MissingRequestParameterError(RequestInitializationError):
+ """Raised when a Request does not contain one or more required parameters."""
+
+
+class UnsupportedMethodError(RequestInitializationError):
+ """Raised when a Request does not use a supported REST method.
+
+ The method must be in all-caps, supported by REST, and implemented by the module.
+ """
+
+
+class AuthorizationTokenError(RequestInitializationError):
+ """Raised when an authorization token is not correctly verified."""
+
+
+class ForeignKeyError(Exception):
+ """Raised when a foreign key constraint is not met."""
+
+
+class RowNotFoundError(Exception):
+ """Raised when a database row is not found."""
+ def __init__(self, table_name):
+ super(RowNotFoundError, self).__init__('Row in `{}` table not found.'.format(table_name))
+
+ self.table_name = table_name
+
+
+class ParameterError(Exception):
+ """Raised when a parameter is either missing or incorrectly typed."""
+
+
+class IncorrectParameterError(ParameterError):
+ """Raised when a parameter is of the wrong type."""
+ def __init__(self, parameter_name, parameter_location):
+ super(IncorrectParameterError,
+ self).__init__('Incorrectly typed `{}` {} parameter.'.format(parameter_name, parameter_location))
+
+ self.parameter_name = parameter_name
+ self.parameter_location = parameter_location
+
+
+class MissingParameterError(ParameterError):
+ """Raised when a parameter is missing."""
+ def __init__(self, parameter_name, parameter_location):
+ super(MissingParameterError,
+ self).__init__('Missing required `{}` {} parameter.'.format(parameter_name, parameter_location))
+
+ self.parameter_name = parameter_name
+ self.parameter_location = parameter_location
+
+
+class ClientError(Exception):
+ """Raised when a 4xx response is to be returned."""
+
+ def __init__(self, response):
+ super(ClientError, self).__init__(str(response))
+ self.response = response
diff --git a/web-server/opendc/util/parameter_checker.py b/web-server/opendc/util/parameter_checker.py
new file mode 100644
index 00000000..f55e780e
--- /dev/null
+++ b/web-server/opendc/util/parameter_checker.py
@@ -0,0 +1,78 @@
+from opendc.util import database, exceptions
+
+
+def _missing_parameter(params_required, params_actual, parent=''):
+ """Recursively search for the first missing parameter."""
+
+ for param_name in params_required:
+
+ if param_name not in params_actual:
+ return '{}.{}'.format(parent, param_name)
+
+ param_required = params_required.get(param_name)
+ param_actual = params_actual.get(param_name)
+
+ if isinstance(param_required, dict):
+
+ param_missing = _missing_parameter(param_required, param_actual, param_name)
+
+ if param_missing is not None:
+ return '{}.{}'.format(parent, param_missing)
+
+ return None
+
+
+def _incorrect_parameter(params_required, params_actual, parent=''):
+ """Recursively make sure each parameter is of the correct type."""
+
+ for param_name in params_required:
+
+ param_required = params_required.get(param_name)
+ param_actual = params_actual.get(param_name)
+
+ if isinstance(param_required, dict):
+
+ param_incorrect = _incorrect_parameter(param_required, param_actual, param_name)
+
+ if param_incorrect is not None:
+ return '{}.{}'.format(parent, param_incorrect)
+
+ else:
+
+ if param_required == 'datetime':
+ try:
+ database.string_to_datetime(param_actual)
+ except:
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required == 'int' and not isinstance(param_actual, int):
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required == 'string' and not isinstance(param_actual, str) and not isinstance(param_actual, int):
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required.startswith('list') and not isinstance(param_actual, list):
+ return '{}.{}'.format(parent, param_name)
+
+
+def _format_parameter(parameter):
+ """Format the output of a parameter check."""
+
+ parts = parameter.split('.')
+ inner = ['["{}"]'.format(x) for x in parts[2:]]
+ return parts[1] + ''.join(inner)
+
+
+def check(request, **kwargs):
+ """Return True if all required parameters are there."""
+
+ for location, params_required in kwargs.items():
+ params_actual = getattr(request, 'params_{}'.format(location))
+
+ missing_parameter = _missing_parameter(params_required, params_actual)
+ if missing_parameter is not None:
+ raise exceptions.MissingParameterError(_format_parameter(missing_parameter), location)
+
+ incorrect_parameter = _incorrect_parameter(params_required, params_actual)
+ if incorrect_parameter is not None:
+ raise exceptions.IncorrectParameterError(_format_parameter(incorrect_parameter), location)
diff --git a/web-server/opendc/util/path_parser.py b/web-server/opendc/util/path_parser.py
new file mode 100644
index 00000000..a8bbdeba
--- /dev/null
+++ b/web-server/opendc/util/path_parser.py
@@ -0,0 +1,39 @@
+import json
+import os
+
+
+def parse(version, endpoint_path):
+ """Map an HTTP endpoint path to an API path"""
+
+ # Get possible paths
+ with open(os.path.join(os.path.dirname(__file__), '..', 'api', '{}', 'paths.json').format(version)) as paths_file:
+ paths = json.load(paths_file)
+
+ # Find API path that matches endpoint_path
+ endpoint_path_parts = endpoint_path.strip('/').split('/')
+ paths_parts = [x.strip('/').split('/') for x in paths if len(x.strip('/').split('/')) == len(endpoint_path_parts)]
+ path = None
+
+ for path_parts in paths_parts:
+ found = True
+ for (endpoint_part, part) in zip(endpoint_path_parts, path_parts):
+ if not part.startswith('{') and endpoint_part != part:
+ found = False
+ break
+ if found:
+ path = path_parts
+
+ if path is None:
+ return None
+
+ # Extract path parameters
+ parameters = {}
+
+ for (name, value) in zip(path, endpoint_path_parts):
+ if name.startswith('{'):
+ try:
+ parameters[name.strip('{}')] = int(value)
+ except:
+ parameters[name.strip('{}')] = value
+
+ return '{}/{}'.format(version, '/'.join(path)), parameters
diff --git a/web-server/opendc/util/rest.py b/web-server/opendc/util/rest.py
new file mode 100644
index 00000000..dc5478de
--- /dev/null
+++ b/web-server/opendc/util/rest.py
@@ -0,0 +1,141 @@
+import importlib
+import json
+import os
+import sys
+
+from oauth2client import client, crypt
+
+from opendc.util import exceptions, parameter_checker
+from opendc.util.exceptions import ClientError
+
+
+class Request(object):
+ """WebSocket message to REST request mapping."""
+ def __init__(self, message=None):
+ """"Initialize a Request from a socket message."""
+
+ # Get the Request parameters from the message
+
+ if message is None:
+ return
+
+ try:
+ self.message = message
+
+ self.id = message['id']
+
+ self.path = message['path']
+ self.method = message['method']
+
+ self.params_body = message['parameters']['body']
+ self.params_path = message['parameters']['path']
+ self.params_query = message['parameters']['query']
+
+ self.token = message['token']
+
+ except KeyError as exception:
+ raise exceptions.MissingRequestParameterError(exception)
+
+ # Parse the path and import the appropriate module
+
+ try:
+ self.path = message['path'].strip('/')
+
+ module_base = 'opendc.api.{}.endpoint'
+ module_path = self.path.replace('{', '').replace('}', '').replace('/', '.')
+
+ self.module = importlib.import_module(module_base.format(module_path))
+ except ImportError as e:
+ print(e)
+ raise exceptions.UnimplementedEndpointError('Unimplemented endpoint: {}.'.format(self.path))
+
+ # Check the method
+
+ if self.method not in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']:
+ raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method))
+
+ if not hasattr(self.module, self.method):
+ raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(
+ self.path, self.method))
+
+ # Verify the user
+
+ if "OPENDC_FLASK_TESTING" in os.environ:
+ self.google_id = 'test'
+ return
+
+ try:
+ self.google_id = self._verify_token(self.token)
+ except crypt.AppIdentityError as e:
+ raise exceptions.AuthorizationTokenError(e)
+
+ def check_required_parameters(self, **kwargs):
+ """Raise an error if a parameter is missing or of the wrong type."""
+
+ try:
+ parameter_checker.check(self, **kwargs)
+ except exceptions.ParameterError as e:
+ raise ClientError(Response(400, str(e)))
+
+ def process(self):
+ """Process the Request and return a Response."""
+
+ method = getattr(self.module, self.method)
+
+ try:
+ response = method(self)
+ except ClientError as e:
+ e.response.id = self.id
+ return e.response
+
+ response.id = self.id
+
+ return response
+
+ def to_JSON(self):
+ """Return a JSON representation of this Request"""
+
+ self.message['id'] = 0
+ self.message['token'] = None
+
+ return json.dumps(self.message)
+
+ @staticmethod
+ def _verify_token(token):
+ """Return the ID of the signed-in user.
+
+ Or throw an Exception if the token is invalid.
+ """
+
+ try:
+ id_info = client.verify_id_token(token, os.environ['OPENDC_OAUTH_CLIENT_ID'])
+ except Exception as e:
+ print(e)
+ raise crypt.AppIdentityError('Exception caught trying to verify ID token: {}'.format(e))
+
+ if id_info['aud'] != os.environ['OPENDC_OAUTH_CLIENT_ID']:
+ raise crypt.AppIdentityError('Unrecognized client.')
+
+ if id_info['iss'] not in ['accounts.google.com', 'https://accounts.google.com']:
+ raise crypt.AppIdentityError('Wrong issuer.')
+
+ return id_info['sub']
+
+
+class Response(object):
+ """Response to websocket mapping"""
+ def __init__(self, status_code, status_description, content=None):
+ """Initialize a new Response."""
+
+ self.status = {'code': status_code, 'description': status_description}
+ self.content = content
+
+ def to_JSON(self):
+ """"Return a JSON representation of this Response"""
+
+ data = {'id': self.id, 'status': self.status}
+
+ if self.content is not None:
+ data['content'] = self.content
+
+ return json.dumps(data)