summaryrefslogtreecommitdiff
path: root/opendc/util
diff options
context:
space:
mode:
authorleonoverweel <l.overweel@gmail.com>2017-01-24 12:05:15 +0100
committerleonoverweel <l.overweel@gmail.com>2017-01-24 12:05:15 +0100
commit86a50a4f6df9ece982743a3b7ca510846d248909 (patch)
tree79edc0478908b7fee9e5dca2088e562c7a62038b /opendc/util
Initial commit
Diffstat (limited to 'opendc/util')
-rw-r--r--opendc/util/__init__.py0
-rw-r--r--opendc/util/database.py82
-rw-r--r--opendc/util/exceptions.py61
-rw-r--r--opendc/util/parameter_checker.py90
-rw-r--r--opendc/util/rest.py137
5 files changed, 370 insertions, 0 deletions
diff --git a/opendc/util/__init__.py b/opendc/util/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/opendc/util/__init__.py
diff --git a/opendc/util/database.py b/opendc/util/database.py
new file mode 100644
index 00000000..16fff5f0
--- /dev/null
+++ b/opendc/util/database.py
@@ -0,0 +1,82 @@
+from datetime import datetime
+import json
+import sqlite3
+
+# Get keys from config file
+with open('/var/www/opendc.ewi.tudelft.nl/web-server/config/keys.json') as file:
+ KEYS = json.load(file)
+
+DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S'
+
+def execute(statement, t):
+ """Open a database connection and execute the statement."""
+
+ # Connect to the database
+ connection = sqlite3.connect(KEYS['DATABASE_LOCATION'])
+ cursor = connection.cursor()
+
+ # Turn on foreign key checks
+ cursor.execute('pragma foreign_keys=ON')
+
+ # Execute the statement
+ cursor.execute(statement, t)
+
+ # Get the id
+ database_id = cursor.execute('SELECT last_insert_rowid()').fetchone()[0]
+
+ # Disconnect from the database
+ connection.commit()
+ connection.close()
+
+ # Return the id
+ return database_id
+
+def fetchone(statement, t=None):
+ """Open a database connection and return the first row matched by the SELECT statement."""
+
+ # Connect to the database
+ connection = sqlite3.connect(KEYS['DATABASE_LOCATION'])
+ cursor = connection.cursor()
+
+ # Execute the SELECT statement
+
+ if t is not None:
+ cursor.execute(statement, t)
+ else:
+ cursor.execute(statement)
+
+ value = cursor.fetchone()
+
+ # Disconnect from the database and return
+ connection.close()
+ return value
+
+def fetchall(statement, t=None):
+ """Open a database connection and return all rows matched by the SELECT statement."""
+
+ # Connect to the database
+ connection = sqlite3.connect(KEYS['DATABASE_LOCATION'])
+ cursor = connection.cursor()
+
+ # Execute the SELECT statement
+
+ if t is not None:
+ cursor.execute(statement, t)
+ else:
+ cursor.execute(statement)
+
+ values = cursor.fetchall()
+
+ # Disconnect from the database and return
+ connection.close()
+ return values
+
+def datetime_to_string(datetime_to_convert):
+ """Return a database-compatible string representation of the given datetime object."""
+
+ return datetime_to_convert.strftime(DATETIME_STRING_FORMAT)
+
+def string_to_datetime(string_to_convert):
+ """Return a datetime corresponding to the given string representation."""
+
+ return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT)
diff --git a/opendc/util/exceptions.py b/opendc/util/exceptions.py
new file mode 100644
index 00000000..56a04ab9
--- /dev/null
+++ b/opendc/util/exceptions.py
@@ -0,0 +1,61 @@
+class RequestInitializationError(Exception):
+ """Raised when a Request cannot successfully be initialized"""
+
+class UnimplementedEndpointError(RequestInitializationError):
+ """Raised when a Request path does not point to a module."""
+
+class MissingRequestParameterError(RequestInitializationError):
+ """Raised when a Request does not contain one or more required parameters."""
+
+class UnsupportedMethodError(RequestInitializationError):
+ """Raised when a Request does not use a supported REST method.
+
+ The method must be in all-caps, supported by REST, and implemented by the module.
+ """
+
+class AuthorizationTokenError(RequestInitializationError):
+ """Raised when an authorization token is not correctly verified."""
+
+class ForeignKeyError(Exception):
+ """Raised when a foreign key constraint is not met."""
+
+class RowNotFoundError(Exception):
+ """Raised when a database row is not found."""
+
+ def __init__(self, table_name):
+ super(RowNotFoundError, self).__init__(
+ 'Row in `{}` table not found.'.format(table_name)
+ )
+
+ self.table_name = table_name
+
+class ParameterError(Exception):
+ """Raised when a paramter is either missing or incorrectly typed."""
+
+class IncorrectParameterError(ParameterError):
+ """Raised when a parameter is of the wrong type."""
+
+ def __init__(self, parameter_name, parameter_location):
+ super(IncorrectParameterError, self).__init__(
+ 'Incorrectly typed `{}` {} parameter.'.format(
+ parameter_name,
+ parameter_location
+ )
+ )
+
+ self.parameter_name = parameter_name
+ self.parameter_location = parameter_location
+
+class MissingParameterError(ParameterError):
+ """Raised when a parameter is missing."""
+
+ def __init__(self, parameter_name, parameter_location):
+ super(MissingParameterError, self).__init__(
+ 'Missing required `{}` {} parameter.'.format(
+ parameter_name,
+ parameter_location
+ )
+ )
+
+ self.parameter_name = parameter_name
+ self.parameter_location = parameter_location
diff --git a/opendc/util/parameter_checker.py b/opendc/util/parameter_checker.py
new file mode 100644
index 00000000..32cd6777
--- /dev/null
+++ b/opendc/util/parameter_checker.py
@@ -0,0 +1,90 @@
+from opendc.util import database, exceptions
+
+def _missing_parameter(params_required, params_actual, parent=''):
+ """Recursively search for the first missing parameter."""
+
+ for param_name in params_required:
+
+ if not param_name in params_actual:
+ return '{}.{}'.format(parent, param_name)
+
+ param_required = params_required.get(param_name)
+ param_actual = params_actual.get(param_name)
+
+ if isinstance(param_required, dict):
+
+ param_missing = _missing_parameter(
+ param_required,
+ param_actual,
+ param_name
+ )
+
+ if param_missing is not None:
+ return '{}.{}'.format(parent, param_missing)
+
+ return None
+
+def _incorrect_parameter(params_required, params_actual, parent=''):
+ """Recursively make sure each parameter is of the correct type."""
+
+ for param_name in params_required:
+
+ param_required = params_required.get(param_name)
+ param_actual = params_actual.get(param_name)
+
+ if isinstance(param_required, dict):
+
+ param_incorrect = _incorrect_parameter(
+ param_required,
+ param_actual,
+ param_name
+ )
+
+ if param_incorrect is not None:
+ return '{}.{}'.format(parent, param_incorrect)
+
+ else:
+
+ if param_required == 'datetime':
+ try:
+ database.string_to_datetime(param_actual)
+ except:
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required == 'int' and not isinstance(param_actual, int):
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required == 'string' and not isinstance(param_actual, basestring):
+ return '{}.{}'.format(parent, param_name)
+
+ if param_required.startswith('list') and not isinstance(param_actual, list):
+ return '{}.{}'.format(parent, param_name)
+
+def _format_parameter(parameter):
+ """Format the output of a parameter check."""
+
+ parts = parameter.split('.')
+ inner = ['["{}"]'.format(x) for x in parts[2:]]
+ return parts[1] + ''.join(inner)
+
+def check(request, **kwargs):
+ """Return True if all required parameters are there."""
+
+ for location, params_required in kwargs.iteritems():
+
+ params_actual = getattr(request, 'params_{}'.format(location))
+
+ missing_parameter = _missing_parameter(params_required, params_actual)
+ if missing_parameter is not None:
+ raise exceptions.MissingParameterError(
+ _format_parameter(missing_parameter),
+ location
+ )
+
+ incorrect_parameter = _incorrect_parameter(params_required, params_actual)
+ if incorrect_parameter is not None:
+ raise exceptions.IncorrectParameterError(
+ _format_parameter(incorrect_parameter),
+ location
+ )
+
diff --git a/opendc/util/rest.py b/opendc/util/rest.py
new file mode 100644
index 00000000..a52b0082
--- /dev/null
+++ b/opendc/util/rest.py
@@ -0,0 +1,137 @@
+import importlib
+import json
+import os
+import sys
+
+from oauth2client import client, crypt
+
+from opendc.util import exceptions, parameter_checker
+
+with open('/var/www/opendc.ewi.tudelft.nl/web-server/config/keys.json') as file:
+ KEYS = json.load(file)
+
+class Request(object):
+ """WebSocket message to REST request mapping."""
+
+ def __init__(self, message):
+ """"Initialize a Request from a socket message."""
+
+ # Get the Request parameters from the message
+
+ try:
+ self.message = message
+
+ self.id = message['id']
+
+ self.path = message['path']
+ self.method = message['method']
+
+ self.params_body = message['parameters']['body']
+ self.params_path = message['parameters']['path']
+ self.params_query = message['parameters']['query']
+
+ self.token = message['token']
+
+ except KeyError as exception:
+ raise exceptions.MissingRequestParameterError(exception)
+
+ # Parse the path and import the appropriate module
+
+ try:
+ self.path = message['path'].encode('ascii', 'ignore').strip('/')
+
+ module_base = 'opendc.api.{}.endpoint'
+ module_path = self.path.translate(None, '{}').replace('/', '.')
+
+ self.module = importlib.import_module(module_base.format(module_path))
+
+ except UnicodeError as e:
+ raise exceptions.UnimplementedEndpointError('Non-ASCII path')
+
+ except ImportError:
+ raise exceptions.UnimplementedEndpointError(
+ 'Unimplemented endpoint: {}.'.format(self.path)
+ )
+
+ # Check the method
+
+ if not self.method in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']:
+ raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method))
+
+ if not hasattr(self.module, self.method):
+ raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(self.path, self.method))
+
+ # Verify the user
+
+ try:
+ self.google_id = self._verify_token(self.token)
+
+ except crypt.AppIdentityError as e:
+ raise exceptions.AuthorizationTokenError(e.message)
+
+ def _verify_token(self, token):
+ """Return the ID of the signed-in user.
+
+ Or throw an Exception if the token is invalid.
+ """
+
+ try:
+ idinfo = client.verify_id_token(token, KEYS['OAUTH_CLIENT_ID'])
+ except Exception as e:
+ raise crypt.AppIdentityError('Exception caught trying to verify ID token: {}'.format(e))
+
+ if idinfo['aud'] != KEYS['OAUTH_CLIENT_ID']:
+ raise crypt.AppIdentityError('Unrecognized client.')
+
+ if idinfo['iss'] not in ['accounts.google.com', 'https://accounts.google.com']:
+ raise crypt.AppIdentityError('Wrong issuer.')
+
+ return idinfo['sub']
+
+ def check_required_parameters(self, **kwargs):
+ """Raise an error if a parameter is missing or of the wrong type."""
+
+ parameter_checker.check(self, **kwargs)
+
+ def process(self):
+ """Process the Request and return a Response."""
+
+ method = getattr(self.module, self.method)
+
+ response = method(self)
+ response.id = self.id
+
+ return response
+
+ def to_JSON(self):
+ """Return a JSON representation of this Request"""
+
+ self.message['id'] = 0
+ self.message['token'] = None
+
+ return json.dumps(self.message)
+
+class Response(object):
+ """Response to websocket mapping"""
+
+ def __init__(self, status_code, status_description, content=None):
+ """Initialize a new Response."""
+
+ self.status = {
+ 'code': status_code,
+ 'description': status_description
+ }
+ self.content = content
+
+ def to_JSON(self):
+ """"Return a JSON representation of this Response"""
+
+ data = {
+ 'id': self.id,
+ 'status': self.status
+ }
+
+ if self.content is not None:
+ data['content'] = self.content
+
+ return json.dumps(data)