secure all required endpoints with JWT

This commit is contained in:
ІО-23 Шмуляр Олег 2024-12-29 18:21:37 +02:00
parent 729498db11
commit f78fa33397
Signed by: hasslesstech
GPG Key ID: 09745A46126DDD4C
2 changed files with 133 additions and 109 deletions

View File

@ -1,20 +1,30 @@
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required, JWTManager
import time import time
import json import json
import uuid import uuid
import datetime
import sys
import os
from hashlib import sha256
from marshmallow import Schema, fields from marshmallow import Schema, fields
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
app = Flask(__name__) app = Flask(__name__)
app.config.from_pyfile('config.py', silent=True) app.config.from_pyfile('config.py', silent=True)
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY")
#print(app.config["JWT_SECRET_KEY"], file = sys.stderr)
db = SQLAlchemy(app) db = SQLAlchemy(app)
jwt = JWTManager(app)
class UserModel(db.Model): class UserModel(db.Model):
__tablename__ = "user" __tablename__ = "user"
uuid = db.Column(db.String(32), unique=True, primary_key=True, nullable=False) uuid = db.Column(db.String(32), unique=True, primary_key=True, nullable=False)
name = db.Column(db.String(64), nullable=False) name = db.Column(db.String(64), nullable=False)
password = db.Column(db.String(64), nullable=False)
bal_uuid = db.Column(db.String(32), db.ForeignKey('balance.uuid')) bal_uuid = db.Column(db.String(32), db.ForeignKey('balance.uuid'))
class CategoryModel(db.Model): class CategoryModel(db.Model):
@ -39,6 +49,7 @@ class BalanceModel(db.Model):
class UserSchema(Schema): class UserSchema(Schema):
uuid = fields.Str() uuid = fields.Str()
name = fields.Str() name = fields.Str()
password = fields.Str()
bal_uuid = fields.Str() bal_uuid = fields.Str()
class CategorySchema(Schema): class CategorySchema(Schema):
@ -49,7 +60,7 @@ class RecordSchema(Schema):
uuid = fields.Str() uuid = fields.Str()
user_uuid = fields.Str() user_uuid = fields.Str()
cat_uuid = fields.Str() cat_uuid = fields.Str()
date = fields.Date() date = fields.Date('iso')
amount = fields.Integer() amount = fields.Integer()
class BalanceSchema(Schema): class BalanceSchema(Schema):
@ -88,39 +99,34 @@ def ep_reset():
return {}, 200 return {}, 200
@app.route("/users", methods = ["GET"]) @app.route("/users", methods = ["GET"])
@jwt_required()
def ep_users_get(): def ep_users_get():
result = db.session.query(UserModel).all() result = db.session.query(UserModel).all()
return users_schema.dumps(result) return users_schema.dumps(result)
@app.route("/user/<user_id>", methods = ["GET"])
def ep_user_get(user_id):
result = db.session.query(UserModel).filter(UserModel.uuid == user_id).all()
if len(result) == 1:
return user_schema.dumps(result[0]), 200
else:
return {}, 404
@app.route("/user", methods = ["POST"]) @app.route("/user", methods = ["POST"])
def ep_user_post(): def ep_user_post():
body = request.json name = request.json.get('name', None)
password = request.json.get('password', None)
if not body: pass_hashed = sha256(password.encode("UTF-8")).digest().hex()
return {}, 403
if 'uuid' in body:
return {}, 403
b = BalanceModel(uuid=uuid.uuid4().hex, value=0) b = BalanceModel(uuid=uuid.uuid4().hex, value=0)
body.update({'uuid': uuid.uuid4().hex})
body.update({'bal_uuid': b.uuid}) u_uuid = uuid.uuid4().hex
bal_uuid = b.uuid
try: try:
_ = user_schema.load(body) _ = user_schema.load({'uuid': u_uuid, 'name': name,
except ValidationError as e: 'password': pass_hashed, 'bal_uuid': bal_uuid})
except Exception as e:
print(e, file=sys.stderr)
return {}, 403 return {}, 403
u = UserModel(**body) u = UserModel(uuid=u_uuid, name=name,
password=pass_hashed, bal_uuid=bal_uuid)
at = create_access_token(identity = json.dumps({'uuid': u.uuid, 'bal_uuid': u.bal_uuid}))
try: try:
db.session.add(b) db.session.add(b)
@ -130,29 +136,43 @@ def ep_user_post():
db.session.rollback() db.session.rollback()
return {}, 403 return {}, 403
return jsonify(user_schema.load(body)), 200 return jsonify(access_token = at, uuid = u.uuid), 200
@app.route("/user/<user_id>", methods = ["DELETE"]) @app.route("/user", methods = ["GET"])
def ep_user_delete(user_id): def ep_user_get():
try: name = request.json.get('user', None)
result = db.session.query(UserModel).filter(UserModel.uuid == user_id).all() password = request.json.get('password', None)
except Exception as e: pass_hashed = sha256(password.encode("UTF-8")).digest().hex()
return {}, 403
if len(result) == 0: u = db.session.query(UserModel).filter(UserModel.name == name).one_or_none()
if not u:
return {}, 404 return {}, 404
if u.password != pass_hashed:
return {"message": "Wrong password."}, 401
at = create_access_token(identity = json.dumps({'uuid': u.uuid, 'bal_uuid': u.bal_uuid}))
return jsonify(access_token = at, uuid = u.uuid), 200
@app.route("/user", methods = ["DELETE"])
@jwt_required()
def ep_user_delete():
current_user = json.loads(get_jwt_identity())
try: try:
db.session.query(UserModel).filter(UserModel.uuid == user_id).delete() db.session.query(UserModel).filter(UserModel.uuid == current_user['uuid']).delete()
db.session.query(BalanceModel).filter(BalanceModel.uuid == result[0].bal_uuid).delete() db.session.query(BalanceModel).filter(BalanceModel.uuid == current_user['bal_uuid']).delete()
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return {}, 403 return {}, 403
return user_schema.dumps(result[0]), 200 return {"message": "Success."}, 200
@app.route("/category", methods = ["GET"]) @app.route("/category", methods = ["GET"])
@jwt_required()
def ep_category_get(): def ep_category_get():
body = request.json body = request.json
@ -167,6 +187,7 @@ def ep_category_get():
return {}, 403 return {}, 403
@app.route("/category", methods = ["POST"]) @app.route("/category", methods = ["POST"])
@jwt_required()
def ep_category_post(): def ep_category_post():
body = request.json body = request.json
@ -195,6 +216,7 @@ def ep_category_post():
return jsonify(category_schema.load(body)), 200 return jsonify(category_schema.load(body)), 200
@app.route("/category", methods = ["DELETE"]) @app.route("/category", methods = ["DELETE"])
@jwt_required()
def ep_category_delete(): def ep_category_delete():
body = request.json body = request.json
@ -220,15 +242,18 @@ def ep_category_delete():
return category_schema.dumps(result[0]), 200 return category_schema.dumps(result[0]), 200
@app.route("/record/<record_id>", methods = ["GET"]) @app.route("/record/<record_id>", methods = ["GET"])
@jwt_required()
def ep_record_get(record_id): def ep_record_get(record_id):
result = db.session.query(RecordModel).filter(RecordModel.uuid == record_id).all() current_user = json.loads(get_jwt_identity())
result = db.session.query(RecordModel).filter(RecordModel.uuid == record_id).one_or_none()
if len(result) == 1: if result and result.user_uuid == current_user['uuid']:
return user_schema.dumps(result[0]), 200 return user_schema.dumps(result), 200
else: else:
return {}, 404 return {}, 403
@app.route("/record", methods = ["GET"]) @app.route("/record", methods = ["GET"])
@jwt_required()
def ep_record_get_filtered(): def ep_record_get_filtered():
r = db.session.query(RecordModel) r = db.session.query(RecordModel)
@ -248,132 +273,130 @@ def ep_record_get_filtered():
r = r.filter(RecordModel.cat_uuid == request.json['cat_uuid']) r = r.filter(RecordModel.cat_uuid == request.json['cat_uuid'])
filtered = True filtered = True
if filtered: if filtered:
return records_schema.dumps(r.all()) return records_schema.dumps(r.all())
else: else:
return [], 403 return [], 403
@app.route("/record/<record_id>", methods = ["DELETE"]) @app.route("/record/<record_id>", methods = ["DELETE"])
@jwt_required()
def ep_record_del(record_id): def ep_record_del(record_id):
current_user = json.loads(get_jwt_identity())
try: try:
result = db.session.query(RecordModel).filter(RecordModel.uuid == record_id).all() result = db.session.query(RecordModel).filter(RecordModel.uuid == record_id).one_or_none()
except Exception as e: except Exception as e:
return {}, 403 return {}, 403
if len(result) == 0: if result and result.user_uuid == current_user['uuid']:
return {}, 404
db.session.query(RecordModel).filter(RecordModel.uuid == record_id).delete() db.session.query(RecordModel).filter(RecordModel.uuid == record_id).delete()
db.session.commit() db.session.commit()
else:
return {}, 401
return record_schema.dumps(result[0]), 200 return record_schema.dumps(result), 200
@app.route("/record", methods = ["POST"]) @app.route("/record", methods = ["POST"])
@jwt_required()
def ep_record_post(): def ep_record_post():
body = request.json current_user = json.loads(get_jwt_identity())
if not body: amount = request.json.get('amount', None)
category = request.json.get('category', None)
if not all([amount, category]):
return {}, 403 return {}, 403
if 'uuid' in body: u_uuid = uuid.uuid4().hex
return {}, 403
body.update({'uuid': uuid.uuid4().hex}) r_time = time.strftime("%Y-%m-%d")
# backward compatibility with lab2 DB model
if 'cat_id' in body:
body.update({'cat_uuid': body['cat_id']})
del body['cat_id']
if 'user_id' in body:
body.update({'user_uuid': body['user_id']})
del body['user_id']
try: try:
_ = record_schema.load(body) #_ = record_schema.load(amount=amount, cat_uuid=category, uuid=u_uuid,
# user_uuid=current_user['uuid'], date=datetime.datetime.now())
_ = record_schema.load({'amount': amount, "cat_uuid": category, "uuid": u_uuid,
'user_uuid': current_user['uuid'], 'date': r_time})
except Exception as e: except Exception as e:
return {}, 403 print(e, file = sys.stderr)
return {}, 400
r = RecordModel(**body) r = RecordModel(amount=amount, cat_uuid=category, uuid=u_uuid, user_uuid=current_user['uuid'], date=r_time)
b_id = db.session \
.query(UserModel) \
.filter(UserModel.uuid == body['user_uuid']) \
.all()[0] \
.bal_uuid
v = db.session \ v = db.session \
.query(BalanceModel) \ .query(BalanceModel) \
.filter(BalanceModel.uuid == b_id) \ .filter(BalanceModel.uuid == current_user['bal_uuid']) \
.all()[0] \ .one_or_none() \
.value .value
BalanceModel.metadata.tables.get("balance").update().where(BalanceModel.metadata.tables.get("balance").c.uuid == b_id).values(value = v-body['amount']) BalanceModel \
.metadata.tables.get("balance") \
.update() \
.where(BalanceModel.metadata.tables.get("balance").c.uuid == current_user['bal_uuid']) \
.values(value = v - amount)
try: try:
db.session.add(r) db.session.add(r)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
print(e, file = sys.stderr)
db.session.rollback() db.session.rollback()
return {}, 403 return {}, 403
return {'uuid': r.uuid}, 200
return jsonify(record_schema.load(body)), 200
@app.route("/balance_up", methods = ["POST"]) @app.route("/balance_up", methods = ["POST"])
@jwt_required()
def ep_balance_up(): def ep_balance_up():
body = request.json current_user = json.loads(get_jwt_identity())
if 'user_id' in body: amount = request.json.get('amount', None)
body.update({'user_uuid': body['user_id']})
del body['user_id']
if 'user_uuid' not in body:
return {}, 403
try: try:
b_id = db.session \
.query(UserModel) \
.filter(UserModel.uuid == body['user_uuid']) \
.all()[0] \
.bal_uuid
v = db.session \ v = db.session \
.query(BalanceModel) \ .query(BalanceModel) \
.filter(BalanceModel.uuid == b_id) \ .filter(BalanceModel.uuid == current_user['bal_uuid']) \
.all()[0] \ .one_or_none() \
.value .value
BalanceModel.metadata.tables.get("balance").update().where(BalanceModel.metadata.tables.get("balance").c.uuid == b_id).values(value = v + body['amount']) BalanceModel.metadata.tables.get("balance").update().where(BalanceModel.metadata.tables.get("balance").c.uuid == current_user['bal_uuid']).values(value = v + amount)
except Exception as e: except Exception as e:
return {}, 403 print(e, file = sys.stderr)
return {}, 407
return {}, 200 return {}, 200
@app.route("/balance", methods = ["GET"]) @app.route("/balance", methods = ["GET"])
@jwt_required()
def ep_balance_get(): def ep_balance_get():
body = request.json current_user = json.loads(get_jwt_identity())
result = db.session.query(BalanceModel).filter(BalanceModel.uuid == current_user['bal_uuid']).one_or_none()
if 'user_id' in body: return balance_schema.dumps(result), 200
body.update({'user_uuid': body['user_id']})
del body['user_id']
if 'user_uuid' not in body: @jwt.expired_token_loader
return {}, 403 def expired_token_callback(jwt_header, jwt_payload):
return (
jsonify({"message": "The token has expired.", "error": "token_expired"}),
401,
)
try: @jwt.invalid_token_loader
b_id = db.session \ def invalid_token_callback(error):
.query(UserModel) \ return (
.filter(UserModel.uuid == body['user_uuid']) \ jsonify(
.all()[0] \ {"message": "Signature verification failed.", "error": "invalid_token"}
.bal_uuid ),
401,
)
result = db.session.query(BalanceModel).filter(BalanceModel.uuid == b_id).all() @jwt.unauthorized_loader
except Exception as e: def missing_token_callback(error):
return {}, 403 return (
jsonify(
if len(result) == 1: {
return user_schema.dumps(result[0]), 200 "description": "Request does not contain an access token.",
else: "error": "authorization_required",
return {}, 404 }
),
401,
)

View File

@ -14,6 +14,7 @@ services:
- ./app:/app/app - ./app:/app/app
env_file: env_file:
- db.env - db.env
- app.env
depends_on: depends_on:
- db - db