summaryrefslogtreecommitdiff
path: root/opendc-web/opendc-web-api/opendc/util
diff options
context:
space:
mode:
authorFabian Mastenbroek <mail.fabianm@gmail.com>2021-05-15 13:09:06 +0200
committerFabian Mastenbroek <mail.fabianm@gmail.com>2021-05-18 15:46:40 +0200
commit2281d3265423d01e60f8cc088de5a5730bb8a910 (patch)
tree8dc81338cfd30845717f1b9025176d26c82fe930 /opendc-web/opendc-web-api/opendc/util
parent05d2318538eba71ac0555dc5ec146499d9cb0592 (diff)
api: Migrate to Flask Restful
This change updates the API to use Flask Restful instead of our own in-house REST library. This change reduces the maintenance effort and allows us to drastically simplify the API implementation needed for the OpenDC v2 API.
Diffstat (limited to 'opendc-web/opendc-web-api/opendc/util')
-rw-r--r--opendc-web/opendc-web-api/opendc/util/__init__.py0
-rw-r--r--opendc-web/opendc-web-api/opendc/util/auth.py253
-rw-r--r--opendc-web/opendc-web-api/opendc/util/database.py77
-rw-r--r--opendc-web/opendc-web-api/opendc/util/exceptions.py64
-rw-r--r--opendc-web/opendc-web-api/opendc/util/json.py12
-rw-r--r--opendc-web/opendc-web-api/opendc/util/parameter_checker.py85
-rw-r--r--opendc-web/opendc-web-api/opendc/util/path_parser.py36
-rw-r--r--opendc-web/opendc-web-api/opendc/util/rest.py109
8 files changed, 0 insertions, 636 deletions
diff --git a/opendc-web/opendc-web-api/opendc/util/__init__.py b/opendc-web/opendc-web-api/opendc/util/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/opendc-web/opendc-web-api/opendc/util/__init__.py
+++ /dev/null
diff --git a/opendc-web/opendc-web-api/opendc/util/auth.py b/opendc-web/opendc-web-api/opendc/util/auth.py
deleted file mode 100644
index 810b582a..00000000
--- a/opendc-web/opendc-web-api/opendc/util/auth.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright (c) 2021 AtLarge Research
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-import json
-import time
-from functools import wraps
-
-import urllib3
-from flask import request, _request_ctx_stack
-from jose import jwt, JWTError
-from werkzeug.local import LocalProxy
-
-current_user = LocalProxy(lambda: getattr(_request_ctx_stack.top, 'current_user', None))
-
-
-class AuthError(Exception):
- """
- This error is thrown when the request failed to authorize.
- """
-
- def __init__(self, error, status_code):
- Exception.__init__(self, error)
- self.error = error
- self.status_code = status_code
-
-
-class AuthManager:
- """
- This class handles the authorization of requests.
- """
-
- def __init__(self, alg, issuer, audience):
- self._alg = alg
- self._issuer = issuer
- self._audience = audience
-
- def require(self, f):
- """Determines if the Access Token is valid
- """
-
- @wraps(f)
- def decorated(*args, **kwargs):
- token = _get_token()
- try:
- header = jwt.get_unverified_header(token)
- except JWTError as e:
- raise AuthError({"code": "invalid_token",
- "description": str(e)}, 401)
-
- alg = header.get('alg', None)
- if alg != self._alg.algorithm:
- raise AuthError({"code": "invalid_header",
- "description": f"Signature algorithm of {alg} is not supported. Expected the ID token "
- f"to be signed with {self._alg.algorithm}"}, 401)
-
- kid = header.get('kid', None)
- try:
- secret_or_certificate = self._alg.get_key(key_id=kid)
- except TokenValidationError as e:
- raise AuthError({"code": "invalid_header",
- "description": str(e)}, 401)
- try:
- payload = jwt.decode(token,
- key=secret_or_certificate,
- algorithms=[self._alg.algorithm],
- audience=self._audience,
- issuer=self._issuer)
- _request_ctx_stack.top.current_user = payload
- return f(*args, **kwargs)
- except jwt.ExpiredSignatureError:
- raise AuthError({"code": "token_expired",
- "description": "token is expired"}, 401)
- except jwt.JWTClaimsError:
- raise AuthError({"code": "invalid_claims",
- "description":
- "incorrect claims,"
- "please check the audience and issuer"}, 401)
- except Exception as e:
- print(e)
- raise AuthError({"code": "invalid_header",
- "description":
- "Unable to parse authentication"
- " token."}, 401)
-
- return decorated
-
-
-def _get_token():
- """
- Obtain the Access Token from the Authorization Header
- """
- auth = request.headers.get("Authorization", None)
- if not auth:
- raise AuthError({"code": "authorization_header_missing",
- "description":
- "Authorization header is expected"}, 401)
-
- parts = auth.split()
-
- if parts[0].lower() != "bearer":
- raise AuthError({"code": "invalid_header",
- "description":
- "Authorization header must start with"
- " Bearer"}, 401)
- if len(parts) == 1:
- raise AuthError({"code": "invalid_header",
- "description": "Token not found"}, 401)
- if len(parts) > 2:
- raise AuthError({"code": "invalid_header",
- "description":
- "Authorization header must be"
- " Bearer token"}, 401)
-
- token = parts[1]
- return token
-
-
-class SymmetricJwtAlgorithm:
- """Verifier for HMAC signatures, which rely on shared secrets.
- Args:
- shared_secret (str): The shared secret used to decode the token.
- algorithm (str, optional): The expected signing algorithm. Defaults to "HS256".
- """
-
- def __init__(self, shared_secret, algorithm="HS256"):
- self.algorithm = algorithm
- self._shared_secret = shared_secret
-
- # pylint: disable=W0613
- def get_key(self, key_id=None):
- """
- Obtain the key for this algorithm.
- :param key_id: The identifier of the key.
- :return: The JWK key.
- """
- return self._shared_secret
-
-
-class AsymmetricJwtAlgorithm:
- """Verifier for RSA signatures, which rely on public key certificates.
- Args:
- jwks_url (str): The url where the JWK set is located.
- algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
- """
-
- def __init__(self, jwks_url, algorithm="RS256"):
- self.algorithm = algorithm
- self._fetcher = JwksFetcher(jwks_url)
-
- def get_key(self, key_id=None):
- """
- Obtain the key for this algorithm.
- :param key_id: The identifier of the key.
- :return: The JWK key.
- """
- return self._fetcher.get_key(key_id)
-
-
-class TokenValidationError(Exception):
- """
- Error thrown when the token cannot be validated
- """
-
-
-class JwksFetcher:
- """Class that fetches and holds a JSON web key set.
- This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it.
- Args:
- jwks_url (str): The url where the JWK set is located.
- cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
- """
- CACHE_TTL = 600 # 10 min cache lifetime
-
- def __init__(self, jwks_url, cache_ttl=CACHE_TTL):
- self._jwks_url = jwks_url
- self._http = urllib3.PoolManager()
- self._cache_value = {}
- self._cache_date = 0
- self._cache_ttl = cache_ttl
- self._cache_is_fresh = False
-
- def _fetch_jwks(self, force=False):
- """Attempts to obtain the JWK set from the cache, as long as it's still valid.
- When not, it will perform a network request to the jwks_url to obtain a fresh result
- and update the cache value with it.
- Args:
- force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False.
- """
- has_expired = self._cache_date + self._cache_ttl < time.time()
-
- if not force and not has_expired:
- # Return from cache
- self._cache_is_fresh = False
- return self._cache_value
-
- # Invalidate cache and fetch fresh data
- self._cache_value = {}
- response = self._http.request('GET', self._jwks_url)
-
- if response.status == 200:
- # Update cache
- jwks = json.loads(response.data.decode('utf-8'))
- self._cache_value = self._parse_jwks(jwks)
- self._cache_is_fresh = True
- self._cache_date = time.time()
- return self._cache_value
-
- @staticmethod
- def _parse_jwks(jwks):
- """Converts a JWK string representation into a binary certificate in PEM format.
- """
- keys = {}
-
- for key in jwks['keys']:
- keys[key["kid"]] = key
- return keys
-
- def get_key(self, key_id):
- """Obtains the JWK associated with the given key id.
- Args:
- key_id (str): The id of the key to fetch.
- Returns:
- the JWK associated with the given key id.
-
- Raises:
- TokenValidationError: when a key with that id cannot be found
- """
- keys = self._fetch_jwks()
-
- if keys and key_id in keys:
- return keys[key_id]
-
- if not self._cache_is_fresh:
- keys = self._fetch_jwks(force=True)
- if keys and key_id in keys:
- return keys[key_id]
- raise TokenValidationError(f"RSA Public Key with ID {key_id} was not found.")
diff --git a/opendc-web/opendc-web-api/opendc/util/database.py b/opendc-web/opendc-web-api/opendc/util/database.py
deleted file mode 100644
index dd26533d..00000000
--- a/opendc-web/opendc-web-api/opendc/util/database.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import urllib.parse
-from datetime import datetime
-
-from pymongo import MongoClient
-
-DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S'
-CONNECTION_POOL = None
-
-
-class Database:
- """Object holding functionality for database access."""
- def __init__(self):
- self.opendc_db = None
-
- def initialize_database(self, user, password, database, host):
- """Initializes the database connection."""
-
- user = urllib.parse.quote_plus(user)
- password = urllib.parse.quote_plus(password)
- database = urllib.parse.quote_plus(database)
- host = urllib.parse.quote_plus(host)
-
- client = MongoClient('mongodb://%s:%s@%s/default_db?authSource=%s' % (user, password, host, database))
- self.opendc_db = client.opendc
-
- def fetch_one(self, query, collection):
- """Uses existing mongo connection to return a single (the first) document in a collection matching the given
- query as a JSON object.
-
- The query needs to be in json format, i.e.: `{'name': prefab_name}`.
- """
- return getattr(self.opendc_db, collection).find_one(query)
-
- def fetch_all(self, query, collection):
- """Uses existing mongo connection to return all documents matching a given query, as a list of JSON objects.
-
- The query needs to be in json format, i.e.: `{'name': prefab_name}`.
- """
- cursor = getattr(self.opendc_db, collection).find(query)
- return list(cursor)
-
- def insert(self, obj, collection):
- """Updates an existing object."""
- bson = getattr(self.opendc_db, collection).insert(obj)
-
- return bson
-
- def update(self, _id, obj, collection):
- """Updates an existing object."""
- return getattr(self.opendc_db, collection).update({'_id': _id}, obj)
-
- def delete_one(self, query, collection):
- """Deletes one object matching the given query.
-
- The query needs to be in json format, i.e.: `{'name': prefab_name}`.
- """
- getattr(self.opendc_db, collection).delete_one(query)
-
- def delete_all(self, query, collection):
- """Deletes all objects matching the given query.
-
- The query needs to be in json format, i.e.: `{'name': prefab_name}`.
- """
- getattr(self.opendc_db, collection).delete_many(query)
-
- @staticmethod
- def datetime_to_string(datetime_to_convert):
- """Return a database-compatible string representation of the given datetime object."""
- return datetime_to_convert.strftime(DATETIME_STRING_FORMAT)
-
- @staticmethod
- def string_to_datetime(string_to_convert):
- """Return a datetime corresponding to the given string representation."""
- return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT)
-
-
-DB = Database()
diff --git a/opendc-web/opendc-web-api/opendc/util/exceptions.py b/opendc-web/opendc-web-api/opendc/util/exceptions.py
deleted file mode 100644
index 7724a407..00000000
--- a/opendc-web/opendc-web-api/opendc/util/exceptions.py
+++ /dev/null
@@ -1,64 +0,0 @@
-class RequestInitializationError(Exception):
- """Raised when a Request cannot successfully be initialized"""
-
-
-class UnimplementedEndpointError(RequestInitializationError):
- """Raised when a Request path does not point to a module."""
-
-
-class MissingRequestParameterError(RequestInitializationError):
- """Raised when a Request does not contain one or more required parameters."""
-
-
-class UnsupportedMethodError(RequestInitializationError):
- """Raised when a Request does not use a supported REST method.
-
- The method must be in all-caps, supported by REST, and implemented by the module.
- """
-
-
-class AuthorizationTokenError(RequestInitializationError):
- """Raised when an authorization token is not correctly verified."""
-
-
-class ForeignKeyError(Exception):
- """Raised when a foreign key constraint is not met."""
-
-
-class RowNotFoundError(Exception):
- """Raised when a database row is not found."""
- def __init__(self, table_name):
- super(RowNotFoundError, self).__init__('Row in `{}` table not found.'.format(table_name))
-
- self.table_name = table_name
-
-
-class ParameterError(Exception):
- """Raised when a parameter is either missing or incorrectly typed."""
-
-
-class IncorrectParameterError(ParameterError):
- """Raised when a parameter is of the wrong type."""
- def __init__(self, parameter_name, parameter_location):
- super(IncorrectParameterError,
- self).__init__('Incorrectly typed `{}` {} parameter.'.format(parameter_name, parameter_location))
-
- self.parameter_name = parameter_name
- self.parameter_location = parameter_location
-
-
-class MissingParameterError(ParameterError):
- """Raised when a parameter is missing."""
- def __init__(self, parameter_name, parameter_location):
- super(MissingParameterError,
- self).__init__('Missing required `{}` {} parameter.'.format(parameter_name, parameter_location))
-
- self.parameter_name = parameter_name
- self.parameter_location = parameter_location
-
-
-class ClientError(Exception):
- """Raised when a 4xx response is to be returned."""
- def __init__(self, response):
- super(ClientError, self).__init__(str(response))
- self.response = response
diff --git a/opendc-web/opendc-web-api/opendc/util/json.py b/opendc-web/opendc-web-api/opendc/util/json.py
deleted file mode 100644
index 2ef4f965..00000000
--- a/opendc-web/opendc-web-api/opendc/util/json.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import flask
-from bson.objectid import ObjectId
-
-
-class JSONEncoder(flask.json.JSONEncoder):
- """
- A customized JSON encoder to handle unsupported types.
- """
- def default(self, o):
- if isinstance(o, ObjectId):
- return str(o)
- return flask.json.JSONEncoder.default(self, o)
diff --git a/opendc-web/opendc-web-api/opendc/util/parameter_checker.py b/opendc-web/opendc-web-api/opendc/util/parameter_checker.py
deleted file mode 100644
index 14dd1dc0..00000000
--- a/opendc-web/opendc-web-api/opendc/util/parameter_checker.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from opendc.util import exceptions
-from opendc.util.database import Database
-
-
-def _missing_parameter(params_required, params_actual, parent=''):
- """Recursively search for the first missing parameter."""
-
- for param_name in params_required:
-
- if param_name not in params_actual:
- return '{}.{}'.format(parent, param_name)
-
- param_required = params_required.get(param_name)
- param_actual = params_actual.get(param_name)
-
- if isinstance(param_required, dict):
-
- param_missing = _missing_parameter(param_required, param_actual, param_name)
-
- if param_missing is not None:
- return '{}.{}'.format(parent, param_missing)
-
- return None
-
-
-def _incorrect_parameter(params_required, params_actual, parent=''):
- """Recursively make sure each parameter is of the correct type."""
-
- for param_name in params_required:
-
- param_required = params_required.get(param_name)
- param_actual = params_actual.get(param_name)
-
- if isinstance(param_required, dict):
-
- param_incorrect = _incorrect_parameter(param_required, param_actual, param_name)
-
- if param_incorrect is not None:
- return '{}.{}'.format(parent, param_incorrect)
-
- else:
-
- if param_required == 'datetime':
- try:
- Database.string_to_datetime(param_actual)
- except:
- return '{}.{}'.format(parent, param_name)
-
- type_pairs = [
- ('int', (int,)),
- ('float', (float, int)),
- ('bool', (bool,)),
- ('string', (str, int)),
- ('list', (list,)),
- ]
-
- for str_type, actual_types in type_pairs:
- if param_required == str_type and all(not isinstance(param_actual, t)
- for t in actual_types):
- return '{}.{}'.format(parent, param_name)
-
- return None
-
-
-def _format_parameter(parameter):
- """Format the output of a parameter check."""
-
- parts = parameter.split('.')
- inner = ['["{}"]'.format(x) for x in parts[2:]]
- return parts[1] + ''.join(inner)
-
-
-def check(request, **kwargs):
- """Check if all required parameters are there."""
-
- for location, params_required in kwargs.items():
- params_actual = getattr(request, 'params_{}'.format(location))
-
- missing_parameter = _missing_parameter(params_required, params_actual)
- if missing_parameter is not None:
- raise exceptions.MissingParameterError(_format_parameter(missing_parameter), location)
-
- incorrect_parameter = _incorrect_parameter(params_required, params_actual)
- if incorrect_parameter is not None:
- raise exceptions.IncorrectParameterError(_format_parameter(incorrect_parameter), location)
diff --git a/opendc-web/opendc-web-api/opendc/util/path_parser.py b/opendc-web/opendc-web-api/opendc/util/path_parser.py
deleted file mode 100644
index c8452f20..00000000
--- a/opendc-web/opendc-web-api/opendc/util/path_parser.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import json
-import os
-
-
-def parse(version, endpoint_path):
- """Map an HTTP endpoint path to an API path"""
-
- # Get possible paths
- with open(os.path.join(os.path.dirname(__file__), '..', 'api', '{}', 'paths.json').format(version)) as paths_file:
- paths = json.load(paths_file)
-
- # Find API path that matches endpoint_path
- endpoint_path_parts = endpoint_path.strip('/').split('/')
- paths_parts = [x.strip('/').split('/') for x in paths if len(x.strip('/').split('/')) == len(endpoint_path_parts)]
- path = None
-
- for path_parts in paths_parts:
- found = True
- for (endpoint_part, part) in zip(endpoint_path_parts, path_parts):
- if not part.startswith('{') and endpoint_part != part:
- found = False
- break
- if found:
- path = path_parts
-
- if path is None:
- return None
-
- # Extract path parameters
- parameters = {}
-
- for (name, value) in zip(path, endpoint_path_parts):
- if name.startswith('{'):
- parameters[name.strip('{}')] = value
-
- return '{}/{}'.format(version, '/'.join(path)), parameters
diff --git a/opendc-web/opendc-web-api/opendc/util/rest.py b/opendc-web/opendc-web-api/opendc/util/rest.py
deleted file mode 100644
index 63d063b3..00000000
--- a/opendc-web/opendc-web-api/opendc/util/rest.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import importlib
-import json
-
-from opendc.util import exceptions, parameter_checker
-from opendc.util.exceptions import ClientError
-from opendc.util.auth import current_user
-
-
-class Request:
- """WebSocket message to REST request mapping."""
- def __init__(self, message=None):
- """"Initialize a Request from a socket message."""
-
- # Get the Request parameters from the message
-
- if message is None:
- return
-
- try:
- self.message = message
-
- self.id = message['id']
-
- self.path = message['path']
- self.method = message['method']
-
- self.params_body = message['parameters']['body']
- self.params_path = message['parameters']['path']
- self.params_query = message['parameters']['query']
-
- self.token = message['token']
-
- except KeyError as exception:
- raise exceptions.MissingRequestParameterError(exception)
-
- # Parse the path and import the appropriate module
-
- try:
- self.path = message['path'].strip('/')
-
- module_base = 'opendc.api.{}.endpoint'
- module_path = self.path.replace('{', '').replace('}', '').replace('/', '.')
-
- self.module = importlib.import_module(module_base.format(module_path))
- except ImportError as e:
- print(e)
- raise exceptions.UnimplementedEndpointError('Unimplemented endpoint: {}.'.format(self.path))
-
- # Check the method
-
- if self.method not in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']:
- raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method))
-
- if not hasattr(self.module, self.method):
- raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(
- self.path, self.method))
-
- self.current_user = current_user
-
- def check_required_parameters(self, **kwargs):
- """Raise an error if a parameter is missing or of the wrong type."""
-
- try:
- parameter_checker.check(self, **kwargs)
- except exceptions.ParameterError as e:
- raise ClientError(Response(400, str(e)))
-
- def process(self):
- """Process the Request and return a Response."""
-
- method = getattr(self.module, self.method)
-
- try:
- response = method(self)
- except ClientError as e:
- e.response.id = self.id
- return e.response
-
- response.id = self.id
-
- return response
-
- def to_JSON(self):
- """Return a JSON representation of this Request"""
-
- self.message['id'] = 0
- self.message['token'] = None
-
- return json.dumps(self.message)
-
-
-class Response:
- """Response to websocket mapping"""
- def __init__(self, status_code, status_description, content=None):
- """Initialize a new Response."""
-
- self.id = 0
- self.status = {'code': status_code, 'description': status_description}
- self.content = content
-
- def to_JSON(self):
- """"Return a JSON representation of this Response"""
-
- data = {'id': self.id, 'status': self.status}
-
- if self.content is not None:
- data['content'] = self.content
-
- return json.dumps(data, default=str)