diff options
| author | leonoverweel <l.overweel@gmail.com> | 2017-01-24 12:05:15 +0100 |
|---|---|---|
| committer | leonoverweel <l.overweel@gmail.com> | 2017-01-24 12:05:15 +0100 |
| commit | 86a50a4f6df9ece982743a3b7ca510846d248909 (patch) | |
| tree | 79edc0478908b7fee9e5dca2088e562c7a62038b /opendc/util | |
Initial commit
Diffstat (limited to 'opendc/util')
| -rw-r--r-- | opendc/util/__init__.py | 0 | ||||
| -rw-r--r-- | opendc/util/database.py | 82 | ||||
| -rw-r--r-- | opendc/util/exceptions.py | 61 | ||||
| -rw-r--r-- | opendc/util/parameter_checker.py | 90 | ||||
| -rw-r--r-- | opendc/util/rest.py | 137 |
5 files changed, 370 insertions, 0 deletions
diff --git a/opendc/util/__init__.py b/opendc/util/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/opendc/util/__init__.py diff --git a/opendc/util/database.py b/opendc/util/database.py new file mode 100644 index 00000000..16fff5f0 --- /dev/null +++ b/opendc/util/database.py @@ -0,0 +1,82 @@ +from datetime import datetime +import json +import sqlite3 + +# Get keys from config file +with open('/var/www/opendc.ewi.tudelft.nl/web-server/config/keys.json') as file: + KEYS = json.load(file) + +DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S' + +def execute(statement, t): + """Open a database connection and execute the statement.""" + + # Connect to the database + connection = sqlite3.connect(KEYS['DATABASE_LOCATION']) + cursor = connection.cursor() + + # Turn on foreign key checks + cursor.execute('pragma foreign_keys=ON') + + # Execute the statement + cursor.execute(statement, t) + + # Get the id + database_id = cursor.execute('SELECT last_insert_rowid()').fetchone()[0] + + # Disconnect from the database + connection.commit() + connection.close() + + # Return the id + return database_id + +def fetchone(statement, t=None): + """Open a database connection and return the first row matched by the SELECT statement.""" + + # Connect to the database + connection = sqlite3.connect(KEYS['DATABASE_LOCATION']) + cursor = connection.cursor() + + # Execute the SELECT statement + + if t is not None: + cursor.execute(statement, t) + else: + cursor.execute(statement) + + value = cursor.fetchone() + + # Disconnect from the database and return + connection.close() + return value + +def fetchall(statement, t=None): + """Open a database connection and return all rows matched by the SELECT statement.""" + + # Connect to the database + connection = sqlite3.connect(KEYS['DATABASE_LOCATION']) + cursor = connection.cursor() + + # Execute the SELECT statement + + if t is not None: + cursor.execute(statement, t) + else: + cursor.execute(statement) + + values = cursor.fetchall() + + # Disconnect from the database and return + connection.close() + return values + +def datetime_to_string(datetime_to_convert): + """Return a database-compatible string representation of the given datetime object.""" + + return datetime_to_convert.strftime(DATETIME_STRING_FORMAT) + +def string_to_datetime(string_to_convert): + """Return a datetime corresponding to the given string representation.""" + + return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT) diff --git a/opendc/util/exceptions.py b/opendc/util/exceptions.py new file mode 100644 index 00000000..56a04ab9 --- /dev/null +++ b/opendc/util/exceptions.py @@ -0,0 +1,61 @@ +class RequestInitializationError(Exception): + """Raised when a Request cannot successfully be initialized""" + +class UnimplementedEndpointError(RequestInitializationError): + """Raised when a Request path does not point to a module.""" + +class MissingRequestParameterError(RequestInitializationError): + """Raised when a Request does not contain one or more required parameters.""" + +class UnsupportedMethodError(RequestInitializationError): + """Raised when a Request does not use a supported REST method. + + The method must be in all-caps, supported by REST, and implemented by the module. + """ + +class AuthorizationTokenError(RequestInitializationError): + """Raised when an authorization token is not correctly verified.""" + +class ForeignKeyError(Exception): + """Raised when a foreign key constraint is not met.""" + +class RowNotFoundError(Exception): + """Raised when a database row is not found.""" + + def __init__(self, table_name): + super(RowNotFoundError, self).__init__( + 'Row in `{}` table not found.'.format(table_name) + ) + + self.table_name = table_name + +class ParameterError(Exception): + """Raised when a paramter is either missing or incorrectly typed.""" + +class IncorrectParameterError(ParameterError): + """Raised when a parameter is of the wrong type.""" + + def __init__(self, parameter_name, parameter_location): + super(IncorrectParameterError, self).__init__( + 'Incorrectly typed `{}` {} parameter.'.format( + parameter_name, + parameter_location + ) + ) + + self.parameter_name = parameter_name + self.parameter_location = parameter_location + +class MissingParameterError(ParameterError): + """Raised when a parameter is missing.""" + + def __init__(self, parameter_name, parameter_location): + super(MissingParameterError, self).__init__( + 'Missing required `{}` {} parameter.'.format( + parameter_name, + parameter_location + ) + ) + + self.parameter_name = parameter_name + self.parameter_location = parameter_location diff --git a/opendc/util/parameter_checker.py b/opendc/util/parameter_checker.py new file mode 100644 index 00000000..32cd6777 --- /dev/null +++ b/opendc/util/parameter_checker.py @@ -0,0 +1,90 @@ +from opendc.util import database, exceptions + +def _missing_parameter(params_required, params_actual, parent=''): + """Recursively search for the first missing parameter.""" + + for param_name in params_required: + + if not param_name in params_actual: + return '{}.{}'.format(parent, param_name) + + param_required = params_required.get(param_name) + param_actual = params_actual.get(param_name) + + if isinstance(param_required, dict): + + param_missing = _missing_parameter( + param_required, + param_actual, + param_name + ) + + if param_missing is not None: + return '{}.{}'.format(parent, param_missing) + + return None + +def _incorrect_parameter(params_required, params_actual, parent=''): + """Recursively make sure each parameter is of the correct type.""" + + for param_name in params_required: + + param_required = params_required.get(param_name) + param_actual = params_actual.get(param_name) + + if isinstance(param_required, dict): + + param_incorrect = _incorrect_parameter( + param_required, + param_actual, + param_name + ) + + if param_incorrect is not None: + return '{}.{}'.format(parent, param_incorrect) + + else: + + if param_required == 'datetime': + try: + database.string_to_datetime(param_actual) + except: + return '{}.{}'.format(parent, param_name) + + if param_required == 'int' and not isinstance(param_actual, int): + return '{}.{}'.format(parent, param_name) + + if param_required == 'string' and not isinstance(param_actual, basestring): + return '{}.{}'.format(parent, param_name) + + if param_required.startswith('list') and not isinstance(param_actual, list): + return '{}.{}'.format(parent, param_name) + +def _format_parameter(parameter): + """Format the output of a parameter check.""" + + parts = parameter.split('.') + inner = ['["{}"]'.format(x) for x in parts[2:]] + return parts[1] + ''.join(inner) + +def check(request, **kwargs): + """Return True if all required parameters are there.""" + + for location, params_required in kwargs.iteritems(): + + params_actual = getattr(request, 'params_{}'.format(location)) + + missing_parameter = _missing_parameter(params_required, params_actual) + if missing_parameter is not None: + raise exceptions.MissingParameterError( + _format_parameter(missing_parameter), + location + ) + + incorrect_parameter = _incorrect_parameter(params_required, params_actual) + if incorrect_parameter is not None: + raise exceptions.IncorrectParameterError( + _format_parameter(incorrect_parameter), + location + ) + diff --git a/opendc/util/rest.py b/opendc/util/rest.py new file mode 100644 index 00000000..a52b0082 --- /dev/null +++ b/opendc/util/rest.py @@ -0,0 +1,137 @@ +import importlib +import json +import os +import sys + +from oauth2client import client, crypt + +from opendc.util import exceptions, parameter_checker + +with open('/var/www/opendc.ewi.tudelft.nl/web-server/config/keys.json') as file: + KEYS = json.load(file) + +class Request(object): + """WebSocket message to REST request mapping.""" + + def __init__(self, message): + """"Initialize a Request from a socket message.""" + + # Get the Request parameters from the message + + try: + self.message = message + + self.id = message['id'] + + self.path = message['path'] + self.method = message['method'] + + self.params_body = message['parameters']['body'] + self.params_path = message['parameters']['path'] + self.params_query = message['parameters']['query'] + + self.token = message['token'] + + except KeyError as exception: + raise exceptions.MissingRequestParameterError(exception) + + # Parse the path and import the appropriate module + + try: + self.path = message['path'].encode('ascii', 'ignore').strip('/') + + module_base = 'opendc.api.{}.endpoint' + module_path = self.path.translate(None, '{}').replace('/', '.') + + self.module = importlib.import_module(module_base.format(module_path)) + + except UnicodeError as e: + raise exceptions.UnimplementedEndpointError('Non-ASCII path') + + except ImportError: + raise exceptions.UnimplementedEndpointError( + 'Unimplemented endpoint: {}.'.format(self.path) + ) + + # Check the method + + if not self.method in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']: + raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method)) + + if not hasattr(self.module, self.method): + raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(self.path, self.method)) + + # Verify the user + + try: + self.google_id = self._verify_token(self.token) + + except crypt.AppIdentityError as e: + raise exceptions.AuthorizationTokenError(e.message) + + def _verify_token(self, token): + """Return the ID of the signed-in user. + + Or throw an Exception if the token is invalid. + """ + + try: + idinfo = client.verify_id_token(token, KEYS['OAUTH_CLIENT_ID']) + except Exception as e: + raise crypt.AppIdentityError('Exception caught trying to verify ID token: {}'.format(e)) + + if idinfo['aud'] != KEYS['OAUTH_CLIENT_ID']: + raise crypt.AppIdentityError('Unrecognized client.') + + if idinfo['iss'] not in ['accounts.google.com', 'https://accounts.google.com']: + raise crypt.AppIdentityError('Wrong issuer.') + + return idinfo['sub'] + + def check_required_parameters(self, **kwargs): + """Raise an error if a parameter is missing or of the wrong type.""" + + parameter_checker.check(self, **kwargs) + + def process(self): + """Process the Request and return a Response.""" + + method = getattr(self.module, self.method) + + response = method(self) + response.id = self.id + + return response + + def to_JSON(self): + """Return a JSON representation of this Request""" + + self.message['id'] = 0 + self.message['token'] = None + + return json.dumps(self.message) + +class Response(object): + """Response to websocket mapping""" + + def __init__(self, status_code, status_description, content=None): + """Initialize a new Response.""" + + self.status = { + 'code': status_code, + 'description': status_description + } + self.content = content + + def to_JSON(self): + """"Return a JSON representation of this Response""" + + data = { + 'id': self.id, + 'status': self.status + } + + if self.content is not None: + data['content'] = self.content + + return json.dumps(data) |
