Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions dj_gui_api_server/DJConnector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import datajoint as dj

class DJConnector():
"""
Attempt to authenticate against database with given username and address

Parameters:
database_address (string): Address of database
username (string): Username of user
password (string): Password of user

Returns:
dict(result=True): If successful
dict(result=False, error=<error-message>): If failed
"""
@staticmethod
def attempt_login(database_address, username, password):
dj.config['database.host'] = database_address
Expand All @@ -11,5 +23,30 @@ def attempt_login(database_address, username, password):
try:
dj.conn(reset=True)
return dict(result=True)
except Exception as e:
return dict(result=False, error=e)

"""
List all schemas under the database

Parameters:
database_address (string): Address of database
username (string): Username of user
password (string): Password of user

Returns:
dict(result=True, schemas=(list(str))): If successful
dict(result=False, error=<error-message>): If failed
"""
@staticmethod
def list_schemas(database_address, username, password):
dj.config['database.host'] = database_address
dj.config['database.user'] = username
dj.config['database.password'] = password

# Attempt to connect return true if successful, false is failed
try:
schemas = dj.list_schemas()
return dict(result=True, schemas=schemas)
except Exception as e:
return dict(result=False, error=e)
56 changes: 55 additions & 1 deletion dj_gui_api_server/DJGUIAPIServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,42 @@
import jwt
app = Flask(__name__)

"""
Protected route function decrator

Parameters:
function: function to decreate, typically routes

Returns:
Return function output if jwt authecation is successful, otherwise return error message
"""
def protected_route(function):
def wrapper():
try:
jwt_payload = jwt.decode(request.headers.get('Authorization')[7:], os.environ['PUBLIC_KEY'], algorithm='RS256')
return function(jwt_payload)
except Exception as e:
return dict(error=str(e))
return wrapper

"""
Route to check if the server is alive or not
"""
@app.route('/api')
def hello_world():
return 'Hello, World!'

"""
# Login route which uses datajoint login

Parameters:
(html:POST:body): json with keys {databaseAddress: string, username: string, password: string}

Returns:
dict(jwt=<JWT_TOKEN>): If sucessfully authenticated against the database
or
dict(error=<error_message>): With error message of why it failed
"""
@app.route('/api/login', methods=['POST'])
def login():
# Check if request.json has the correct fields
Expand All @@ -20,10 +52,32 @@ def login():
attempt_connection_result = DJConnector.attempt_login(request.json['databaseAddress'], request.json['username'], request.json['password'])
if attempt_connection_result['result']:
# Generate JWT key and send it back
encoded_jwt = jwt.encode(request.json, os.environ['PRIVATE_KEY'].encode(), algorithm='RS256')
encoded_jwt = jwt.encode(request.json, os.environ['PRIVATE_KEY'], algorithm='RS256')
return dict(jwt=encoded_jwt.decode())
else:
return dict(error=str(attempt_connection_result['error']))

"""
# API route for fetching schema

Parameters:
(html:POST:body): json with keys {}

Returns:
dict(schemas=<schemas>): If sucessfuly send back a list of schemas names
or
dict(error=<error_message>): With error message of why it failed
"""
@app.route('/api/list_schemas', methods=['GET'])
@protected_route
def list_schemas(jwt_payload):
print(jwt_payload, flush=True)
# Get all the schemas
result = DJConnector.list_schemas(jwt_payload['databaseAddress'], jwt_payload['username'], jwt_payload['password'])
if result['result']:
return dict(schemas=result['schemas'])
else:
return dict(error=result['error'])

if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)