diff --git a/.dockerignore b/.dockerignore index 696735b..edc6e56 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,7 @@ # Ignore generated files -**/*.pyc \ No newline at end of file +**/*.pyc +config.py +**/__pycache__ +orcidflask/saml +.DS_Store +.git diff --git a/.gitignore b/.gitignore index 40d224f..0da24da 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -config.py +./config.py __pycache__ orcidflask/__pycache__ diff --git a/Dockerfile b/Dockerfile index 9903efb..7a53c5c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,11 +9,14 @@ COPY *.py ./ COPY requirements.txt . COPY migrations ./migrations COPY orcidflask/*.py ./orcidflask/ -COPY orcidflask/templates ./orcidflask/templates/ +COPY orcidflask/registration ./orcidflask/registration +COPY orcidflask/api ./orcidflask/api +COPY orcidflask/db ./orcidflask/db + RUN pip install -r requirements.txt ENV FLASK_APP=orcidflask ENV ORCIDFLASK_SETTINGS=/opt/orcid_integration/config.py -CMD [ "gunicorn", "-b", "0.0.0.0:8080", "orcidflask:app" ] \ No newline at end of file +CMD [ "gunicorn", "-b", "0.0.0.0:8080", "orcidflask:create_app()" ] \ No newline at end of file diff --git a/example.docker-compose.yml b/example.docker-compose.yml index 617d286..d41c8f2 100644 --- a/example.docker-compose.yml +++ b/example.docker-compose.yml @@ -1,17 +1,22 @@ version: "2" services: db: - image: postgres + image: postgres:16.2 environment: - POSTGRES_USER=${POSTGRES_USER} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_DB=${POSTGRES_DB} + user: "1000:1000" volumes: - ./data:/var/lib/postgresql/data + restart: always flask-app: + # Use the tagged image in production + #image: ghcr.io/gwu-libraries/orcid-integration-flask-app:1.1 build: context: . dockerfile: Dockerfile + image: flask-app ports: - 8080:8080 links: @@ -28,11 +33,35 @@ services: volumes: - ./orcidflask/saml:/opt/orcid_integration/orcidflask/saml - ./config.py:/opt/orcid_integration/config.py - - ./orcidflask/db:/opt/orcid_integration/orcidflask/db + - ./certs/db-encrypt.key:/opt/orcid_integration/db-encrypt.key + # Uncomment for use in development #- .:/opt/orcid_integration restart: always + token-api: + # Use tagged image in production + #image: ghcr.io/gwu-libraries/orcid-integration-flask-app:1.1 + image: flask-app + ports: + - 8081:8081 + links: + - db:db + environment: + - POSTGRES_USER=${POSTGRES_USER} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_DB_HOST=${POSTGRES_DB_HOST} + - POSTGRES_PORT=${POSTGRES_PORT} + - DB_ENCRYPTION_FILE=${DB_ENCRYPTION_FILE} + volumes: + - ./orcidflask/saml:/opt/orcid_integration/orcidflask/saml + - ./config.py:/opt/orcid_integration/config.py + - ./certs/db-encrypt.key:/opt/orcid_integration/db-encrypt.key + # Uncomment for development + #- .:/opt/orcid_integration + command: gunicorn -b 0.0.0.0:8081 "orcidflask:create_app('api')" + restart: always nginx-proxy: - image: nginxproxy/nginx-proxy:1.5 + image: nginxproxy/nginx-proxy:1.6.2 environment: - LOG_JSON=true ports: @@ -42,5 +71,5 @@ services: - /var/run/docker.sock:/tmp/docker.sock:ro # Note that the nginxproxy image require cert & key to reside in the same directory # And to follow certain naming conventions - - /etc/ssl/certs:/etc/nginx/certs + - ./certs:/etc/nginx/certs restart: always diff --git a/example.env b/example.env index 01d7715..b0de346 100644 --- a/example.env +++ b/example.env @@ -3,7 +3,7 @@ POSTGRES_PASSWORD=orcidpass POSTGRES_DB=orcidig POSTGRES_DB_HOST=db POSTGRES_PORT=5432 -DB_ENCRYPTION_FILE=/opt/orcid_integration/orcidflask/db/db-encrypt.key +DB_ENCRYPTION_FILE=/opt/orcid_integration/db-encrypt.key # Values are sandbox or prod ORCID_SERVER=sandbox VIRTUAL_HOST= \ No newline at end of file diff --git a/migrations/versions/aa4b644dbbe2_adding_api.py b/migrations/versions/aa4b644dbbe2_adding_api.py new file mode 100644 index 0000000..9cb7fd6 --- /dev/null +++ b/migrations/versions/aa4b644dbbe2_adding_api.py @@ -0,0 +1,35 @@ +"""Adding API + +Revision ID: aa4b644dbbe2 +Revises: ac9a61050c66 +Create Date: 2025-04-30 13:19:30.743197 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'aa4b644dbbe2' +down_revision = 'ac9a61050c66' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('api_key', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('api_key', sa.String(length=36), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('userId', sa.String(length=80), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('api_key') + # ### end Alembic commands ### diff --git a/migrations/versions/ac9a61050c66_initial_migration.py b/migrations/versions/ac9a61050c66_initial_migration.py index 113afe1..8424062 100644 --- a/migrations/versions/ac9a61050c66_initial_migration.py +++ b/migrations/versions/ac9a61050c66_initial_migration.py @@ -7,7 +7,7 @@ """ from alembic import op import sqlalchemy as sa -from orcidflask.models import EncryptedValue +from orcidflask.db.models import EncryptedValue # revision identifiers, used by Alembic. diff --git a/orcid_utils.py b/orcid_utils.py index 61f7072..5137b0e 100644 --- a/orcid_utils.py +++ b/orcid_utils.py @@ -11,7 +11,7 @@ def prepare_token_payload(code: str): 'client_secret': app.config['CLIENT_SECRET'], 'grant_type': 'authorization_code', 'code': code, - 'redirect_uri': url_for('orcid_redirect', _external=True, _scheme='https')} + 'redirect_uri': url_for('registration.orcid_redirect', _external=True, _scheme='https')} def extract_saml_user_data(session, populate=True): ''' diff --git a/orcidflask/__init__.py b/orcidflask/__init__.py index 089cb01..c99ad8b 100644 --- a/orcidflask/__init__.py +++ b/orcidflask/__init__.py @@ -1,60 +1,84 @@ from flask import Flask -from flask_sqlalchemy import SQLAlchemy +from flask.cli import with_appcontext from flask_migrate import Migrate import os import click from orcid_utils import load_encryption_key, new_encryption_key +from orcidflask.registration.views import registration +from orcidflask.api.views import api import json +from orcidflask.db import db +from datetime import datetime as dt -app = Flask(__name__) -# load default configs from default_settings.py -app.config.from_object('orcidflask.default_settings') -# load sensitive config settings -app.config.from_envvar('ORCIDFLASK_SETTINGS') -# Set the ORCID URL based on the setting in default_settings.py -if os.getenv('ORCID_SERVER') == 'sandbox': - base_url = 'https://sandbox.orcid.org' -else: - base_url = 'https://orcid.org' -# Personal attributes from SAML metadata definitions -app.config['orcid_auth_url'] = base_url + '/oauth/authorize?client_id={orcid_client_id}&response_type=code&scope={scopes}&redirect_uri={redirect_uri}&family_names={lastname}&given_names={firstname}&email={emailaddress}' -app.config['orcid_register_url'] = base_url + '/oauth/authorize?client_id={orcid_client_id}&response_type=code&scope={scopes}&redirect_uri={redirect_uri}&family_names={lastname}&given_names={firstname}&email={emailaddress}&show_login=false' -app.config['orcid_token_url'] = base_url + '/oauth/token' -app.config['SAML_PATH'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'saml') -app.config["SESSION_COOKIE_DOMAIN"] = app.config["SERVER_NAME"] -app.secret_key = app.config['SECRET_KEY'] -postgres_user = os.getenv('POSTGRES_USER') -postgres_pwd = os.getenv('POSTGRES_PASSWORD') -postgres_db_host = os.getenv('POSTGRES_DB_HOST') -postgres_port = os.getenv('POSTGRES_PORT') -postgres_db = os.getenv('POSTGRES_DB') -app.config['SQLALCHEMY_DATABASE_URI'] = f'postgresql://{postgres_user}:{postgres_pwd}@{postgres_db_host}:{postgres_port}/{postgres_db}' -app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False -db_key_file = os.getenv('DB_ENCRYPTION_FILE') -app.config['db_encryption_key'] = load_encryption_key(db_key_file) -db = SQLAlchemy(app) -migrate = Migrate(app, db) +migrate = Migrate() -import orcidflask.views -from orcidflask.models import Token +def create_app(blueprint: str='registration'): + ''' + Application factory. The argument should be either "registration" or "api", depending on the version of the application to be started. For invocation, see https://flask.palletsprojects.com/en/stable/cli/ + ''' + app = Flask(__name__) + # load default configs from default_settings.py + app.config.from_object('orcidflask.default_settings') + # load sensitive config settings + app.config.from_envvar('ORCIDFLASK_SETTINGS') + # Set the ORCID URL based on the setting in default_settings.py + if os.getenv('ORCID_SERVER') == 'sandbox': + base_url = 'https://sandbox.orcid.org' + else: + base_url = 'https://orcid.org' + # Personal attributes from SAML metadata definitions + app.config['orcid_auth_url'] = base_url + '/oauth/authorize?client_id={orcid_client_id}&response_type=code&scope={scopes}&redirect_uri={redirect_uri}&family_names={lastname}&given_names={firstname}&email={emailaddress}' + app.config['orcid_register_url'] = base_url + '/oauth/authorize?client_id={orcid_client_id}&response_type=code&scope={scopes}&redirect_uri={redirect_uri}&family_names={lastname}&given_names={firstname}&email={emailaddress}&show_login=false' + app.config['orcid_token_url'] = base_url + '/oauth/token' + app.config['SAML_PATH'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'saml') + app.config["SESSION_COOKIE_DOMAIN"] = app.config["SERVER_NAME"] + app.secret_key = app.config['SECRET_KEY'] + postgres_user = os.getenv('POSTGRES_USER') + postgres_pwd = os.getenv('POSTGRES_PASSWORD') + postgres_db_host = os.getenv('POSTGRES_DB_HOST') + postgres_port = os.getenv('POSTGRES_PORT') + postgres_db = os.getenv('POSTGRES_DB') + app.config['SQLALCHEMY_DATABASE_URI'] = f'postgresql://{postgres_user}:{postgres_pwd}@{postgres_db_host}:{postgres_port}/{postgres_db}' + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + db_key_file = os.getenv('DB_ENCRYPTION_FILE') + if not os.getenv('TESTING'): + app.config['db_encryption_key'] = load_encryption_key(db_key_file) + db.init_app(app) + migrate.init_app(app, db) + + if blueprint == 'registration': + app.register_blueprint(registration) + else: + app.register_blueprint(api) -@app.cli.command('create-secret-key') + app.cli.add_command(reset_db) + app.cli.add_command(create_secret_key) + app.cli.add_command(serialize_db) + app.cli.add_command(create_api_key) + return app + +from orcidflask.db.models import Token, APIKey, generate_key + +@click.command('create-secret-key') @click.argument('file') +@with_appcontext def create_secret_key(file): ''' Creates a new database encryption key and saves to the provided file path. Will not overwrite the existing file, if it exists. ''' - store_encryption_key(file) + new_encryption_key(file) -@app.cli.command('reset-db') +@click.command('reset-db') +@with_appcontext def reset_db(): ''' Resets the associated database by dropping all tables. Warning: for development purposes only. Do not run on a production instance without first backing up the database, as this command will result in the loss of all data. ''' db.drop_all() -@app.cli.command('serialize-db') +@click.command('serialize-db') @click.argument('file', type=click.File('w')) +@with_appcontext def serialize_db(file): ''' Serializes the database as a JSON dump. Argument should be the path to a file, preferably in a volume mapped to the container, such as /opt/orcid_integration/data @@ -64,4 +88,15 @@ def serialize_db(file): # convert to Python dicts records = [record.to_dict() for record in records] json.dump(records, file) - \ No newline at end of file + + +@click.command('create-api-key') +@click.argument('userid') +@with_appcontext +def create_api_key(userid: str): + '''userId should be an email address identifying the user for whom the key is being created.''' + api_key_str = generate_key() + api_key = APIKey(userId=userid, timestamp=dt.now(), api_key=api_key_str) + db.session.add(api_key) + db.session.commit() + print(f'API key created for user {userid} is {api_key_str}. Please pass this key as a request header when making an API call: Authorization: Apikey YOUR_API_KEY') diff --git a/orcidflask/api/__init__.py b/orcidflask/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/orcidflask/api/views.py b/orcidflask/api/views.py new file mode 100644 index 0000000..e1a1a85 --- /dev/null +++ b/orcidflask/api/views.py @@ -0,0 +1,32 @@ +from flask import request, Blueprint, jsonify +from orcidflask.db.models import Token, APIKey +import re + +auth_pattern = re.compile(r'Apikey ([A-Za-z0-9\-]{36})') + +api = Blueprint('api', __name__, url_prefix='/api') + +def is_valid(api_key): + return APIKey.check_api_key(api_key) + +@api.route('/get-token') +def get_token(): + '''GET request should include Authorization: Apikey header (with a valid API key) and an orcid URL parameter with the ORCiD of the user whose token is to be retrieved.''' + api_key = auth_pattern.match(request.headers.get('Authorization', '')) + if not api_key: + return {'message': 'Please provide a valid API key in the Authorization header of your request.'}, 403 + api_key = api_key.group(1) + if not is_valid(api_key): + return {'message': 'API key has not been registered. Please have the application administrator create an API key for you.'}, 403 + orcid = request.args.get('orcid') + if not orcid: + return {'message': 'Please provide a valid ORCiD as a URL parameter, e.g., "?orcid=0000-0000-0000-0000"'}, 422 + access_token = Token.query.filter_by(orcid=orcid).order_by(Token.timestamp.desc()).first() + if not access_token: + return jsonify({'error': f'Entry not found for ORCiD {orcid}. Has the user registered with the GW ORCiD integration app?'}) + return jsonify({'orcid': orcid, + 'access_token': access_token.to_dict()['access_token']}) + + + + diff --git a/orcidflask/db/.gitignore b/orcidflask/db/.gitignore deleted file mode 100644 index 5e7d273..0000000 --- a/orcidflask/db/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/orcidflask/db/__init__.py b/orcidflask/db/__init__.py new file mode 100644 index 0000000..04a1a91 --- /dev/null +++ b/orcidflask/db/__init__.py @@ -0,0 +1,4 @@ +from flask import Flask +from flask_sqlalchemy import SQLAlchemy + +db = SQLAlchemy() \ No newline at end of file diff --git a/orcidflask/models.py b/orcidflask/db/models.py similarity index 66% rename from orcidflask/models.py rename to orcidflask/db/models.py index eeabb00..6a45fff 100644 --- a/orcidflask/models.py +++ b/orcidflask/db/models.py @@ -1,13 +1,15 @@ -from orcidflask import db, app +from . import db from sqlalchemy.sql import func from sqlalchemy import TypeDecorator from cryptography.fernet import Fernet +from flask import current_app +from uuid import uuid1 def fernet_encrypt(data): ''' Encrypts data using the Fernet algorithm with the key set in the app's config object ''' - fernet = Fernet(app.config['db_encryption_key']) + fernet = Fernet(current_app.config['db_encryption_key']) return fernet.encrypt(data.encode()) @@ -15,7 +17,7 @@ def fernet_decrypt(data): ''' Decrypts data using the Fernet algorithm with the key set in the app's config object ''' - fernet = Fernet(app.config['db_encryption_key']) + fernet = Fernet(current_app.config['db_encryption_key']) return fernet.decrypt(data).decode() class EncryptedValue(TypeDecorator): @@ -26,6 +28,13 @@ def process_bind_param(self, value, dialect): def process_result_value(self, value, dialect): return fernet_decrypt(value) + + +def generate_key(): + ''' + Generates a unique indentifier for use as an API key (using the system time) + ''' + return str(uuid1()) class Token(db.Model): @@ -51,3 +60,15 @@ def to_dict(self): # Convert timestamp to string record['timestamp'] = record['timestamp'].isoformat() return record + +class APIKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + api_key = db.Column(db.String(36), unique=True, nullable=False) + timestamp = db.Column(db.DateTime(timezone=True), server_default=func.now()) + # Email address of the user for whom the API key was created + userId = db.Column(db.String(80), unique=False, nullable=False) + + @classmethod + def check_api_key(cls, api_key: str): + '''Checks whether the given key exists in the database.''' + return cls.query.filter_by(api_key=api_key).first() \ No newline at end of file diff --git a/orcidflask/registration/.DS_Store b/orcidflask/registration/.DS_Store new file mode 100644 index 0000000..9f9809a Binary files /dev/null and b/orcidflask/registration/.DS_Store differ diff --git a/orcidflask/registration/__init__.py b/orcidflask/registration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/orcidflask/static/css/main.css b/orcidflask/registration/static/css/main.css similarity index 100% rename from orcidflask/static/css/main.css rename to orcidflask/registration/static/css/main.css diff --git a/orcidflask/templates/base.html b/orcidflask/registration/templates/base.html similarity index 72% rename from orcidflask/templates/base.html rename to orcidflask/registration/templates/base.html index 5eca9b9..6557e85 100644 --- a/orcidflask/templates/base.html +++ b/orcidflask/registration/templates/base.html @@ -4,7 +4,7 @@ {% block title %} {% endblock %} - Orcid Integration - +
diff --git a/orcidflask/templates/index.html b/orcidflask/registration/templates/index.html similarity index 100% rename from orcidflask/templates/index.html rename to orcidflask/registration/templates/index.html diff --git a/orcidflask/templates/oauth_error.html b/orcidflask/registration/templates/oauth_error.html similarity index 100% rename from orcidflask/templates/oauth_error.html rename to orcidflask/registration/templates/oauth_error.html diff --git a/orcidflask/templates/orcid-success.html b/orcidflask/registration/templates/orcid-success.html similarity index 100% rename from orcidflask/templates/orcid-success.html rename to orcidflask/registration/templates/orcid-success.html diff --git a/orcidflask/templates/orcid_denied.html b/orcidflask/registration/templates/orcid_denied.html similarity index 100% rename from orcidflask/templates/orcid_denied.html rename to orcidflask/registration/templates/orcid_denied.html diff --git a/orcidflask/templates/orcid_login.html b/orcidflask/registration/templates/orcid_login.html similarity index 100% rename from orcidflask/templates/orcid_login.html rename to orcidflask/registration/templates/orcid_login.html diff --git a/orcidflask/templates/orcid_success.html b/orcidflask/registration/templates/orcid_success.html similarity index 100% rename from orcidflask/templates/orcid_success.html rename to orcidflask/registration/templates/orcid_success.html diff --git a/orcidflask/views.py b/orcidflask/registration/views.py similarity index 73% rename from orcidflask/views.py rename to orcidflask/registration/views.py index 91cafb6..7cd5207 100644 --- a/orcidflask/views.py +++ b/orcidflask/registration/views.py @@ -1,13 +1,17 @@ -from flask import request, url_for, redirect, session, render_template, make_response -from orcidflask import app, db -from orcidflask.models import Token +from flask import request, url_for, redirect, session, render_template, make_response, Blueprint +from orcidflask.db import db +from flask import current_app +from orcidflask.db.models import Token from saml_utils import * from orcid_utils import * from onelogin.saml2.utils import OneLogin_Saml2_Utils import requests from requests.exceptions import HTTPError -@app.route('/', methods=['GET', 'POST']) +registration = Blueprint('registration', __name__, template_folder='templates', static_folder='static', url_prefix='/') + +@registration.route('/', methods=['GET', 'POST']) +#@app.route('/', methods=['GET', 'POST']) def index(): ''' Route handles the SSO process @@ -24,7 +28,9 @@ def index(): # Initiating the SSO process if 'sso' in request.args: # Redirect to ORCID login upon successful SSO - return redirect(auth.login(return_to=url_for('orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) + #return redirect(auth.login(return_to=url_for('orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) + + return redirect(auth.login(return_to=url_for('registration.orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) # Initiating the SLO process elif 'slo' in request.args: metadata = get_metadata_from_session(session) @@ -52,6 +58,7 @@ def index(): # Get the reason for auth failure if exists elif auth.get_settings().is_debug_active(): error_reason = auth.get_last_error_reason() + current_app.logger.error(error_reason) # Handle logout elif 'sls' in request.args: @@ -72,19 +79,22 @@ def index(): # Redirect for login if no params provided else: # Remove the scopes param in order to solicit scopes from users - return redirect(auth.login(return_to=url_for('orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) + return redirect(auth.login(return_to=url_for('registration.orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) + #return redirect(auth.login(return_to=url_for('orcid_login', scopes='/read-limited /activities/update', register=register, _external=True, _scheme='https'))) # Redirect from logout process - return redirect(app.config['SLO_REDIRECT']) + return redirect(current_app.config['SLO_REDIRECT']) -@app.route('/attrs/') +@registration.route('/attrs/') +#@app.route('/attrs/') def attrs(): attributes, paint_logout = get_attributes(session) return render_template('attrs.html', paint_logout=paint_logout, attributes=attributes) -@app.route('/metadata/') +@registration.route('/metadata/') +#@app.route('/metadata/') def metadata(): auth, auth_req = init_saml_auth(request) settings = auth.get_settings() @@ -99,7 +109,8 @@ def metadata(): return resp -@app.route('/orcid', methods=('GET', 'POST')) +@registration.route('/orcid', methods=['GET', 'POST']) +#@app.route('/orcid', methods=('GET', 'POST')) def orcid_login(): ''' Should render homepage and if behind SSO, retrieve netID from SAML and store in a session variable. @@ -109,21 +120,23 @@ def orcid_login(): register = request.args.get('register') # If no SAML attributes, redirect for SSO if not session.get('samlNameId'): - return redirect(url_for('index', _external=True, _scheme='https')) + #return redirect(url_for('registration.index', _external=True, _scheme='https')) + return redirect(url_for('registration.index', _external=True, _scheme='https')) # If the scopes param is part of the request, we're not using the form elif scopes or request.method == 'POST': # Get the scopes from the form is not part of the URL if not scopes: scopes = ' '.join(request.form.keys()) # Get user data from SAML for registration form - saml_user_data = extract_saml_user_data(session, populate=app.config['PREFILL_REGISTRATION']) + saml_user_data = extract_saml_user_data(session, populate=current_app.config['PREFILL_REGISTRATION']) if register == 'True': - orcid_auth_url = app.config['orcid_register_url'] + orcid_auth_url = current_app.config['orcid_register_url'] else: - orcid_auth_url = app.config['orcid_auth_url'] - return redirect(orcid_auth_url.format(orcid_client_id=app.config['CLIENT_ID'], + orcid_auth_url = current_app.config['orcid_auth_url'] + return redirect(orcid_auth_url.format(orcid_client_id=current_app.config['CLIENT_ID'], scopes=scopes, - redirect_uri=url_for('orcid_redirect', + redirect_uri=url_for('registration.orcid_redirect', + #redirect_uri=url_for('orcid_redirect', _scheme='https', _external=True), **saml_user_data)) @@ -131,24 +144,24 @@ def orcid_login(): else: return render_template('orcid_login.html') -@app.route('/orcid-redirect') +@registration.route('/orcid-redirect') def orcid_redirect(): ''' Redirect route that retrieves the one-time code from ORCID after user logs in and approves. ''' # Redirect here for access denied page if request.args.get('error') == 'access_denied': - return redirect(app.config['ORCID_FAILURE_URL']) + return redirect(current_app.config['ORCID_FAILURE_URL']) elif request.args.get('error'): - app.logger.error(f'OAuth Error {request.args.get("error")};') + current_app.logger.error(f'OAuth Error {request.args.get("error")};') return render_template('oauth_error.html') orcid_code = request.args.get('code') headers = {'Accept': 'application/json', 'Content-Type': 'application/x-www-form-urlencoded'} try: - response = requests.post(app.config['orcid_token_url'], + response = requests.post(current_app.config['orcid_token_url'], headers=headers, data=prepare_token_payload(orcid_code)) response.raise_for_status() @@ -173,4 +186,4 @@ def orcid_redirect(): # return success page - testing only #return render_template('orcid_success.html', saml_id=saml_id, orcid_auth={k: v for k,v in orcid_auth.items() if not k.endswith('token')}) - return redirect(app.config['ORCID_SUCCESS_URL']) \ No newline at end of file + return redirect(current_app.config['ORCID_SUCCESS_URL']) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 16a1cef..41de51d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,14 @@ -cryptography==38.0.1 -Flask==2.3.3 -Flask-SQLAlchemy==2.5.1 +cryptography==43.0.1 +Flask==3.1.0 +Flask-SQLAlchemy==3.1.1 Flask-Migrate==4.0.7 -gunicorn==20.1.0 -psycopg2-binary==2.9.3 +gunicorn==23.0.0 +lxml==5.3.0 # Need to pin version -- see below +psycopg2-binary==2.9.8 python3-saml==1.16.0 -requests==2.25.1 +pytest==8.3.5 +requests==2.30.0 +responses==0.25.7 six==1.15.0 -SQLAlchemy==1.4.40 +SQLAlchemy==2.0.40 +xmlsec==1.3.14 # Need to pin version to avoid conflict with lxml diff --git a/tests.docker-compose.yml b/tests.docker-compose.yml new file mode 100644 index 0000000..cc06039 --- /dev/null +++ b/tests.docker-compose.yml @@ -0,0 +1,31 @@ +version: "2" +services: + db: + image: postgres:16.2 + environment: + - POSTGRES_USER=testuser + - POSTGRES_PASSWORD=testpass + - POSTGRES_DB=testdb + flask-app: + build: + context: . + dockerfile: Dockerfile + ports: + - 8080:8080 + links: + - db:db + environment: + - ORCIDFLASK_SETTINGS=/opt/orcid_integration/tests/config.py + - POSTGRES_USER=testuser + - POSTGRES_PASSWORD=testpass + - POSTGRES_DB=testdb + - POSTGRES_DB_HOST=db + - POSTGRES_PORT=5432 + - TESTING=true + tty: true + #command: "pytest -s" + command: "bash" + volumes: + - ./tests:/opt/orcid_integration/tests + - ./:/opt/orcid_integration + \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_test_saml.py b/tests/_test_saml.py new file mode 100644 index 0000000..b383a33 --- /dev/null +++ b/tests/_test_saml.py @@ -0,0 +1,40 @@ +import pytest +from onelogin.saml2.utils import OneLogin_Saml2_Utils +from onelogin.saml2.xml_utils import OneLogin_Saml2_XML +from onelogin.saml2.auth import OneLogin_Saml2_Auth +from onelogin.saml2.constants import OneLogin_Saml2_Constants +from flask import session +from freezegun import freeze_time +from lxml import etree +#from signxml import XMLSigner, namespaces +import uuid +import xmlsec +#from signxml.algorithms import CanonicalizationMethod + +@pytest.fixture() +def saml_response(): + with open('tests/test_saml_response.txt') as f: + return f.read() + +@pytest.fixture() +def saml_relay_state(): + return "http://localhost/orcid?scopes=/read-limited+/activities/update®ister=False" + +@pytest.fixture() +def saml_request(saml_relay_state, saml_response): + return { + 'https': 'off', + 'http_host': 'http://localhost:4000/api/saml', + 'server_port': None, + 'script_name': '', + 'get_data': {'acs': None}, + # Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144 + 'lowercase_urlencoding': True, + 'post_data': {'SAMLResponse': saml_response, + 'RelayState': saml_relay_state} + } + +@freeze_time('2025-04-15 13:25:00') +def test_acs(client, saml_response, saml_relay_state): + response = client.post(query_string='acs', base_url='http://localhost:8000/', data=saml_response, content_type='application/x-www-form-urlencoded') + print(response.text) \ No newline at end of file diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 0000000..82f4119 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,9 @@ +SECRET_KEY = b'random string' +CLIENT_ID = 'APP-1234567890' +CLIENT_SECRET = 'xxxxxxxxxxxxxxxxxxxxxxxxxx' +SERVER_NAME = 'localhost' +SLO_REDIRECT = 'http://localhost/orcid' +ORCID_SUCCESS_URL = 'http://localhost/orcid-connected' +ORCID_FAILURE_URL = 'https://localhost/orcid-disconnected' +# Set to true to prepopulate ORCID's registration for with values from the SAML IdP +PREFILL_REGISTRATION = True \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c755204 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,105 @@ +import pytest +import os +from orcid_utils import create_encryption_key +import json +from orcidflask import create_app +from orcidflask.db import db +from orcidflask.db.models import generate_key + +@pytest.fixture(scope='module') +def encryption_key(): + return create_encryption_key() + + +@pytest.fixture(scope='module') +def saml_settings(): + return { + 'strict': True, + 'debug': True, + 'sp': { + 'entityId': 'http://localhost/metadata/', + 'assertionConsumerService': { + 'url': 'http://localhost:8000/?acs', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' + }, + 'singleLogoutService': { + 'url': 'http://localhost:8000/?sls', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' + }, + 'NameIDFormat': 'urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified', + 'x509cert': 'MIICcDCCAdmgAwIBAgIBADANBgkqhkiG9w0BAQ0FADBVMQswCQYDVQQGEwJ1czEdMBsGA1UECAwURGlzdHJpY3Qgb2YgQ29sdW1iaWExDDAKBgNVBAoMA0dXVTEZMBcGA1UEAwwQaHR0cDovL2xvY2FsaG9zdDAeFw0yNTAzMTkxNTQ5MDdaFw0yNjAzMTkxNTQ5MDdaMFUxCzAJBgNVBAYTAnVzMR0wGwYDVQQIDBREaXN0cmljdCBvZiBDb2x1bWJpYTEMMAoGA1UECgwDR1dVMRkwFwYDVQQDDBBodHRwOi8vbG9jYWxob3N0MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC3/iGSSDB3YFu5Gvko8Hzyph+2FQ3nrHSKcMPjER5a5GX5I881VBLkfxLPRNLDlKPgTBou1jlVdmC09tEJyMdlK9V8NlgPIdKR09p4yopfFEZcDBJfjk8XL6qXMiRXkcV6UP68QRasNZ2hApqSIGm1UMwyE1teqtEZKHVoPZExBwIDAQABo1AwTjAdBgNVHQ4EFgQU/stj8DDeO15sEGHOeO+KP7j+XU4wHwYDVR0jBBgwFoAU/stj8DDeO15sEGHOeO+KP7j+XU4wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQ0FAAOBgQA/sXiknflk7r0nsCg36mNijrUgneevdS2O9vXjSJaMCVDXgXBcXvoVJKkOlFoKzpWiX74TGe1kqcH+mphkt50PSZ5qpa7rtWYSdlKs9QMq4l3orud7vxllD9R4Wi9NNP4z4DNFwwCfkxn9lz3Kq+NRsqWy3JGOC/9nRXg35z+CAA==', + 'privateKey': 'MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBALf+IZJIMHdgW7ka+SjwfPKmH7YVDeesdIpww+MRHlrkZfkjzzVUEuR/Es9E0sOUo+BMGi7WOVV2YLT20QnIx2Ur1Xw2WA8h0pHT2njKil8URlwMEl+OTxcvqpcyJFeRxXpQ/rxBFqw1naECmpIgabVQzDITW16q0RkodWg9kTEHAgMBAAECgYBBoJO46abf7a7Jx6U3xQ/MPRTyjW/4QrsO5kn4pBJ/uRfmVa+DBgn3FpxO8e17dXk+d+ae7iplIWQ9KAxHwSXdhXoiZ89dh4iufNHj7WLsAjJhuCCeeqtaIXw5gDrSo9ulHKxWKqj9V6pPL0reWcg38D50EzW5mqlapEHxgk+GAQJBAN1fgpjBS8BGYk6ENNddWsL8pY48WlKJn64ZDJyGyr8Vt/0buMj9LtOzbh5WDXIXw8wEvgx6LyUdPyqCbpXWFdcCQQDUxcpIwwk4IIh5PhOYr6FCwxlgjFrLK76Bf7k7gPUYRMZxIhACLyE1UxT/qt++5O+sfA9uVCqcn9Fup5EANPhRAkBrRfM1LsYUgIb24V3x1w06W8+mI1zpjkNQzFauKytoeY/VGW/sBbSBZfvAu5Z8aUO6Q7oMtdDOvWN0qAwKk9m1AkAzc8D+52sLT5Kw/vnuKkpswpEYb9hk2ScwWZqJcR3TyI3UPdBxNsRpCLZDPSbuGp56r2Vr4J6NUXhrscm2qxiBAkBbYjXCh9VM3O9mXfRJT+QyORQB38GBvBAbzdh2PTOtNatSvF3eUo2pQlrxF77GzNKTx99dq0KVrBKZuL23bIyP' + }, + 'idp': { + 'entityId': 'https://saml.example.com/entityid', + 'singleSignOnService': { + 'url': 'http://localhost:4000/api/saml/sso', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' + }, + 'singleLogoutService': { + 'url': 'http://localhost:4000/api/saml', + 'responseUrl': 'http://localhost:4000/api/saml', + 'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' + }, + 'x509cert': 'MIIEBTCCAu2gAwIBAgIUVuhF7lX6n9Z83MUAIXuHJxa7TaIwDQYJKoZIhvcNAQELBQAwgZAxCzAJBgNVBAYTAlVTMR0wGwYDVQQIDBREaXN0cmljdCBvZiBDb2x1bWJpYTETMBEGA1UEBwwKV2FzaGluZ3RvbjEMMAoGA1UECgwDR1dVMQwwCgYDVQQLDANMQUkxEjAQBgNVBAMMCWxvY2FsaG9zdDEdMBsGCSqGSIb3DQEJARYOZHNtaXRoQGd3dS5lZHUwIBcNMjUwNDE1MTI1MjUyWhgPMzAyNDA4MTYxMjUyNTJaMIGQMQswCQYDVQQGEwJVUzEdMBsGA1UECAwURGlzdHJpY3Qgb2YgQ29sdW1iaWExEzARBgNVBAcMCldhc2hpbmd0b24xDDAKBgNVBAoMA0dXVTEMMAoGA1UECwwDTEFJMRIwEAYDVQQDDAlsb2NhbGhvc3QxHTAbBgkqhkiG9w0BCQEWDmRzbWl0aEBnd3UuZWR1MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuuU+IEOV4YD8M9DJ8yCBtcO9VfeFiS3Aj9vwNyPHQBXvXqgeMGvswZfj5ZJjpNEYAvELJuDV/19EeDBy5YBXq9WNdOg+V9d07HMnkb2nycZ3IKsOu20hYJJFHTnFJJhpx0snV86w2bLYrFDJS9nW8mBXfYEkTH/wSiWNfBXzeRKDwKEsch1k7xNzsXu+AyACXfhvNlWb6DrxFcKK2x/cUGNItPCJbAMqXc6WHmVkVb2GZCUP0jiwr594l0NlH8iyum4yrg6XzuCsNtqDVpL4i7O9ic76PLozHWdJcUN0Fs6bKiCaXT9g5dRptSwFpLnEDaB6MzeZcbSOdGRw711m2wIDAQABo1MwUTAdBgNVHQ4EFgQUbhztZb1kfwFOybEYRnwlYDCOOzQwHwYDVR0jBBgwFoAUbhztZb1kfwFOybEYRnwlYDCOOzQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAnTf1c4LedIsBEylIqYChEZUkUQ6BgPsMt+ZBfl/BMi2mrowJPWZ2Wx6N8h7ZycQ7D2fsEpQ4hAMqcX8RnFuqSwY5awE/MrRpaiR/8o5EvEB2ZePAybybrU22/aTLq93AtAyyuJTZ6DMp1aCuMkJmoLGwr+JasiZUSIYIGxwAj52FJsHgwL7rT9ME06AcUoAwK4KZE9FGvX3TQ+ogz6+1QAddDy/h8fiiAkuh7Xf1W2SrLGpkI7KRTjbz6b4fSY5j+S7nTCqERQxdgTR4g7qMuysxp5s6EEd1ReoIOanrUNHFHg9iSJY2myhgWbT09m9BD03zrVbSW8kYdGsObytsbQ==' + } + } + +@pytest.fixture(scope='module') +def saml_settings_path(tmpdir_factory, saml_settings): + settings_dir = tmpdir_factory.mktemp('saml') + settings_file = settings_dir.join('settings.json') + with open(settings_file, 'w') as f: + json.dump(saml_settings, f) + return str(settings_dir) + + + +@pytest.fixture(scope='module') +def test_app(saml_settings_path, encryption_key): + app = create_app() + app.config.update({'db_encryption_key': encryption_key, + 'SAML_PATH': saml_settings_path, + 'TESTING': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False + }) + yield app + + +@pytest.fixture(scope='module') +def client(test_app): + return test_app.test_client() + +@pytest.fixture(scope='module') +def runner(test_app): + return test_app.test_cli_runner() + +@pytest.fixture(scope='module') +def database(test_app): + + with test_app.app_context(): + db.drop_all() + db.create_all() + + yield + + db.session.remove() + db.drop_all() + +@pytest.fixture(scope='module') +def test_app_api(encryption_key): + app = create_app('api') + app.config.update({'db_encryption_key': encryption_key, + 'SAML_PATH': saml_settings_path, + 'TESTING': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False + }) + yield app + +@pytest.fixture(scope='module') +def client_api(test_app_api): + return test_app_api.test_client() + +@pytest.fixture(scope='module') +def api_key(): + return generate_key() \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..f08c5c6 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,50 @@ +import pytest +from orcidflask.db.models import Token, APIKey, generate_key +from orcidflask.db import db +from datetime import datetime as dt +from uuid import uuid1 + + +@pytest.fixture() +def orcid(): + return '0000-0000-0000-0000' + +@pytest.fixture() +def user_id(): + return 'dsmith@gwu.edu' + +@pytest.fixture() +def orcid_registration(test_app, orcid, user_id, database): + with test_app.app_context(): + db.session.add(Token(userId=user_id, access_token='test-token', refresh_token='test-refresh-token', + expires_in=1, token_scope='test-scope', orcid=orcid)) + db.session.commit() + db.session.close() + + +@pytest.fixture() +def api_registration(test_app, user_id, api_key, database): + with test_app.app_context(): + db.session.add(APIKey(userId=user_id, timestamp=dt.now(), api_key=api_key)) + db.session.commit() + db.session.close() + +def test_get_token(client_api, orcid, api_key, orcid_registration, api_registration, database): + headers = {'Authorization': f'Apikey {api_key}'} + resp = client_api.get('api/get-token', headers=headers, query_string={'orcid': orcid}) + assert resp.json['orcid'] == orcid + assert resp.json['access_token'] == 'test-token' + + +def test_get_token_bad_key(client_api, orcid): + headers = {'Authorization': f'Apikey {uuid1()}'} + resp = client_api.get('api/get-token', headers=headers, query_string={'orcid': orcid}) + assert resp.status_code == 403 + assert resp.json['message'] == 'API key has not been registered. Please have the application administrator create an API key for you.' + +def test_get_token_bad_orcid(client_api, api_key): + bad_orcid = '0000-1111-0000-1111' + headers = {'Authorization': f'Apikey {api_key}'} + resp = client_api.get('api/get-token', headers=headers, query_string={'orcid': bad_orcid}) + assert resp.status_code == 200 + assert 'error' in resp.json and resp.json['error'] == f'Entry not found for ORCiD {bad_orcid}. Has the user registered with the GW ORCiD integration app?' \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..4fb8c95 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,36 @@ +import pytest +from orcidflask.db.models import Token +from orcidflask.db import db +from sqlalchemy import inspect +import json + +@pytest.fixture() +def sample_data(test_app): + with test_app.app_context(): + db.session.add(Token(userId = 'testId2', access_token = 'test-token', refresh_token = 'test-refresh-token', + expires_in = 1, token_scope = 'test-scope', orcid = 'test-orcid')) + db.session.commit() + db.session.close() + +@pytest.fixture() +def token_file(tmp_path_factory): + return tmp_path_factory.mktemp('tmp') / 'tokens.json' + +def test_token_dump(test_app, sample_data, runner, token_file, database): + # Create the Token table + with test_app.app_context(): + runner.invoke(args=['serialize-db', token_file]) + with open(token_file) as f: + token_dump = json.load(f) + assert len(token_dump) == 1 + assert token_dump[0]['access_token'] == 'test-token' + +''' This test causes pytest to hang!! + def test_drop_table(runner, test_app): + # Run the command that drops the table + # Confirm that the table has been dropped + with test_app.app_context(): + inspector = inspect(db.engines[None]) + runner.invoke(args='reset-db') + assert not inspector.has_table('token') +''' diff --git a/tests/test_orcid_oauth.py b/tests/test_orcid_oauth.py new file mode 100644 index 0000000..b9b766c --- /dev/null +++ b/tests/test_orcid_oauth.py @@ -0,0 +1,69 @@ +import pytest +from flask import url_for +from urllib.parse import urlparse, parse_qs +import responses +from responses import matchers +import requests +from orcidflask.db.models import Token +import datetime + +@pytest.fixture() +def user_attributes(): + return {'samlUserdata': { + 'emailaddress': ['test@example.com'], + 'firstname': ['Test'], + 'lastname': ['User'] + }, + 'samlNameId': 'testId' + } + +@pytest.fixture() +def auth_code(): + return 'test_code' + +@pytest.fixture() +def redirect_url(test_app): + with test_app.app_context(): + return url_for('registration.orcid_redirect', _external=True, _scheme='https') + +def test_orcid_login(client, user_attributes, test_app): + with client.session_transaction() as session: + session['samlUserdata'] = user_attributes.get('samlUserdata') + session['samlNameId'] = user_attributes.get('samlNameId') + response = client.get('/orcid', query_string={'scopes': '/read-limited /activities/update', 'register': 'True'}) + assert response.status_code == 302 + redirect = urlparse(response.location) + query = parse_qs(redirect.query) + if test_app.config['PREFILL_REGISTRATION']: + assert redirect.path == urlparse(test_app.config['orcid_register_url']).path + assert query['family_names'] == user_attributes['samlUserdata']['lastname'] + assert query['given_names'] == user_attributes['samlUserdata']['firstname'] + assert query['email'] == user_attributes['samlUserdata']['emailaddress'] + else: + assert redirect.path == urlparse(test_app.config['orcid_auth_url']).path + assert query['client_id'][0] == test_app.config['CLIENT_ID'] + +@responses.activate +def test_orcid_redirect(client, test_app, auth_code, redirect_url, user_attributes, database): + orcid_resp_mock = responses.Response( + method='POST', + url=test_app.config['orcid_token_url'], + json={ 'access_token': 'test-access-token', + 'refresh_token': 'test-refresh-token', + 'expires_in': 631138518, + 'scope': '/read-limited /activities/update', + 'orcid': '0000-0000-0000-0000' }, + match=[matchers.urlencoded_params_matcher({ 'client_id': test_app.config['CLIENT_ID'], + 'client_secret': test_app.config['CLIENT_SECRET'], + 'grant_type': 'authorization_code', + 'code': auth_code, + 'redirect_uri': redirect_url}) + ], + ) + responses.add(orcid_resp_mock) + orcid_resp = client.get('/orcid-redirect', query_string={'code': auth_code}) + assert orcid_resp.status_code == 302 + db_state = [record.to_dict() for record in Token.query.all()] + assert db_state[0]['userId'] == user_attributes.get('samlNameId') + assert db_state[0]['access_token'] == 'test-access-token' + assert db_state[0]['orcid'] == '0000-0000-0000-0000'