Skip to content

Commit c44f0dd

Browse files
author
Vianpyro
committed
Refactor database connection handling to use context manager for cursor operations
1 parent 91aca10 commit c44f0dd

File tree

8 files changed

+82
-99
lines changed

8 files changed

+82
-99
lines changed

db.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import contextmanager
2+
13
import pymysql.cursors
24
from flask import current_app
35

@@ -10,3 +12,13 @@ def get_db_connection():
1012
database=current_app.config["MYSQL_DB"],
1113
cursorclass=pymysql.cursors.DictCursor,
1214
)
15+
16+
17+
@contextmanager
18+
def database_cursor():
19+
db = get_db_connection()
20+
try:
21+
yield db.cursor()
22+
db.commit()
23+
finally:
24+
db.close()

routes/authentication.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from flask import Blueprint, jsonify, request
77
from pymysql import MySQLError
88

9-
from db import get_db_connection
9+
from db import database_cursor, get_db_connection
1010
from jwt_helper import (
1111
TokenError,
1212
extract_token_from_header,
@@ -42,6 +42,27 @@ def validate_password(password):
4242
)
4343

4444

45+
def get_person_by_email(email):
46+
with database_cursor() as cursor:
47+
cursor.callproc("login_person", (email,))
48+
return cursor.fetchone()
49+
50+
51+
def verify_password(password, stored_password, salt):
52+
seasoned_password = password.encode("utf-8") + salt + PEPPER
53+
try:
54+
return ph.verify(stored_password, seasoned_password)
55+
except exceptions.VerifyMismatchError:
56+
return False
57+
58+
59+
def update_last_login(person_id):
60+
db = get_db_connection()
61+
with database_cursor() as cursor:
62+
cursor.callproc("update_last_login", (person_id,))
63+
db.commit()
64+
65+
4566
@authentication_blueprint.route("/register", methods=["POST"])
4667
def register():
4768
data = request.get_json()
@@ -59,7 +80,7 @@ def register():
5980
hashed_password, salt = hash_password_with_salt_and_pepper(password)
6081

6182
db = get_db_connection()
62-
with db.cursor() as cursor:
83+
with database_cursor() as cursor:
6384
try:
6485
cursor.callproc(
6586
"register_person", (name, email, hashed_password, salt, language_code)
@@ -74,7 +95,6 @@ def register():
7495
else:
7596
return jsonify(message="An error occurred during registration"), 500
7697

77-
db.close()
7898
return jsonify(message="User created successfully"), 201
7999

80100

@@ -87,30 +107,26 @@ def login():
87107
if not email or not password:
88108
return jsonify(message="Email and password are required"), 400
89109

90-
db = get_db_connection()
91-
with db.cursor() as cursor:
92-
cursor.callproc("login_person", (email,))
93-
person = cursor.fetchone()
94-
95-
if not person:
96-
return jsonify(message="Invalid credentials"), 401
97-
98-
person_id = person["person_id"]
99-
stored_password = person["hashed_password"]
100-
salt = person["salt"]
101-
seasoned_password = password.encode("utf-8") + salt + PEPPER
110+
person = get_person_by_email(email)
102111

103-
try:
104-
ph.verify(stored_password, seasoned_password)
105-
access_token = generate_access_token(person_id)
106-
refresh_token = generate_refresh_token(person_id)
107-
return jsonify(
108-
message="Login successful",
109-
access_token=access_token,
110-
refresh_token=refresh_token,
111-
)
112-
except exceptions.VerifyMismatchError:
112+
try:
113+
if not person or not verify_password(
114+
password, person["hashed_password"], person["salt"]
115+
):
113116
return jsonify(message="Invalid credentials"), 401
117+
except Exception as e:
118+
return jsonify(message="An error occurred", error=str(e)), 500
119+
120+
person_id = person["person_id"]
121+
access_token = generate_access_token(person_id)
122+
refresh_token = generate_refresh_token(person_id)
123+
update_last_login(person_id)
124+
125+
return jsonify(
126+
message="Login successful",
127+
access_token=access_token,
128+
refresh_token=refresh_token,
129+
)
114130

115131

116132
@authentication_blueprint.route("/refresh", methods=["POST"])

routes/comment.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,45 @@
11
from flask import Blueprint, jsonify
22

3-
from db import get_db_connection
3+
from db import database_cursor, get_db_connection
44

55
comment_blueprint = Blueprint("comment", __name__)
66

77

88
@comment_blueprint.route("/all", methods=["GET"])
99
def get_all_comments():
10-
db = get_db_connection()
11-
with db.cursor() as cursor:
10+
with database_cursor() as cursor:
1211
cursor.callproc("get_all_comments")
1312
comments = cursor.fetchall()
14-
db.close()
1513
return jsonify(comments)
1614

1715

1816
@comment_blueprint.route("/<int:comment_id>", methods=["GET"])
1917
def get_comment_by_id(comment_id):
20-
db = get_db_connection()
21-
with db.cursor() as cursor:
18+
with database_cursor() as cursor:
2219
cursor.callproc("get_comment_by_id", (comment_id,))
2320
comment = cursor.fetchone()
24-
db.close()
2521
return jsonify(comment)
2622

2723

2824
@comment_blueprint.route("/person/<int:person_id>", methods=["GET"])
2925
def get_comments_by_person(person_id):
30-
db = get_db_connection()
31-
with db.cursor() as cursor:
26+
with database_cursor() as cursor:
3227
cursor.callproc("get_all_comments_by_person", (person_id,))
3328
comments = cursor.fetchall()
34-
db.close()
3529
return jsonify(comments)
3630

3731

3832
@comment_blueprint.route("/recipe/<int:recipe_id>", methods=["GET"])
3933
def get_comments_by_recipe(recipe_id):
40-
db = get_db_connection()
41-
with db.cursor() as cursor:
34+
with database_cursor() as cursor:
4235
cursor.callproc("get_all_comments_by_recipe", (recipe_id,))
4336
comments = cursor.fetchall()
44-
db.close()
4537
return jsonify(comments)
4638

4739

4840
@comment_blueprint.route("/count/recipe/<int:recipe_id>", methods=["GET"])
4941
def get_comment_count_by_recipe(recipe_id):
50-
db = get_db_connection()
51-
with db.cursor() as cursor:
42+
with database_cursor() as cursor:
5243
cursor.callproc("get_comment_count_by_recipe", (recipe_id,))
5344
count = cursor.fetchone()
54-
db.close()
5545
return jsonify(count)

routes/ingredient.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
from flask import Blueprint, jsonify
22

3-
from db import get_db_connection
3+
from db import database_cursor, get_db_connection
44

55
ingredient_blueprint = Blueprint("ingredient", __name__)
66

77

88
@ingredient_blueprint.route("/all", methods=["GET"])
99
def get_all_ingredients():
10-
db = get_db_connection()
11-
with db.cursor() as cursor:
10+
with database_cursor() as cursor:
1211
cursor.callproc("get_all_ingredients")
1312
ingredients = cursor.fetchall()
14-
db.close()
1513
return jsonify(ingredients)
1614

1715

1816
@ingredient_blueprint.route("/<int:ingredient_id>", methods=["GET"])
1917
def get_ingredient_by_id(ingredient_id):
20-
db = get_db_connection()
21-
with db.cursor() as cursor:
18+
with database_cursor() as cursor:
2219
cursor.callproc("get_ingredient_by_id", (ingredient_id,))
2320
ingredient = cursor.fetchone()
24-
db.close()
2521
return jsonify(ingredient)

routes/language.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
from flask import Blueprint, jsonify
22

3-
from db import get_db_connection
3+
from db import database_cursor, get_db_connection
44

55
language_blueprint = Blueprint("language", __name__)
66

77

88
@language_blueprint.route("/all", methods=["GET"])
99
def get_all_languages():
10-
db = get_db_connection()
11-
with db.cursor() as cursor:
10+
with database_cursor() as cursor:
1211
cursor.callproc("get_all_languages")
1312
languages = cursor.fetchall()
14-
db.close()
1513
return jsonify(languages)
1614

1715

1816
@language_blueprint.route("/<int:language_id>", methods=["GET"])
1917
def get_language_by_id(language_id):
20-
db = get_db_connection()
21-
with db.cursor() as cursor:
18+
with database_cursor() as cursor:
2219
cursor.callproc("get_language_by_id", (language_id,))
2320
language = cursor.fetchone()
24-
db.close()
2521
return jsonify(language)
2622

2723

@@ -30,29 +26,23 @@ def get_language_by_language_code(language_code):
3026
if len(language_code) != 2:
3127
return jsonify({"error": "Invalid language code"}), 400
3228

33-
db = get_db_connection()
34-
with db.cursor() as cursor:
29+
with database_cursor() as cursor:
3530
cursor.callproc("get_language_by_code", (language_code,))
3631
language = cursor.fetchone()
37-
db.close()
3832
return jsonify(language)
3933

4034

4135
@language_blueprint.route("/used", methods=["GET"])
4236
def get_used_languages():
43-
db = get_db_connection()
44-
with db.cursor() as cursor:
37+
with database_cursor() as cursor:
4538
cursor.callproc("get_languages_with_users")
4639
languages = cursor.fetchall()
47-
db.close()
4840
return jsonify(languages)
4941

5042

5143
@language_blueprint.route("/stats", methods=["GET"])
5244
def get_language_stats():
53-
db = get_db_connection()
54-
with db.cursor() as cursor:
45+
with database_cursor() as cursor:
5546
cursor.callproc("get_languages_usage_statistics")
5647
stats = cursor.fetchall()
57-
db.close()
5848
return jsonify(stats)

routes/person.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
from flask import Blueprint, jsonify
22

3-
from db import get_db_connection
3+
from db import database_cursor, get_db_connection
44

55
person_blueprint = Blueprint("person", __name__)
66

77

88
@person_blueprint.route("/all", methods=["GET"])
99
def get_all_persons():
10-
db = get_db_connection()
11-
with db.cursor() as cursor:
10+
with database_cursor() as cursor:
1211
cursor.callproc("get_all_persons")
1312
persons = cursor.fetchall()
14-
db.close()
1513
return jsonify(persons)
1614

1715

1816
@person_blueprint.route("/<int:person_id>", methods=["GET"])
1917
def get_person_by_id(person_id):
20-
db = get_db_connection()
21-
with db.cursor() as cursor:
18+
with database_cursor() as cursor:
2219
cursor.callproc("get_person_by_id", (person_id,))
2320
person = cursor.fetchone()
24-
db.close()
2521
return jsonify(person)

routes/picture.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from flask import Blueprint, jsonify, request, send_from_directory
44

55
from config import PICTURE_FOLDER
6-
from db import get_db_connection
6+
from db import database_cursor, get_db_connection
77
from jwt_helper import token_required
88

99
picture_blueprint = Blueprint("picture", __name__)
@@ -23,34 +23,28 @@ def extract_file_extension(filename):
2323
def get_all_pictures():
2424
picture_type = request.args.get("type", None)
2525

26-
db = get_db_connection()
27-
with db.cursor() as cursor:
26+
with database_cursor() as cursor:
2827
cursor.callproc(
2928
"get_all_pictures" if picture_type is None else "get_pictures_by_type",
3029
(picture_type,),
3130
)
3231
pictures = cursor.fetchall()
33-
db.close()
3432
return jsonify(pictures)
3533

3634

3735
@picture_blueprint.route("/<int:picture_id>", methods=["GET"])
3836
def get_picture_by_id(picture_id):
39-
db = get_db_connection()
40-
with db.cursor() as cursor:
37+
with database_cursor() as cursor:
4138
cursor.callproc("get_picture_by_id", (picture_id,))
4239
picture = cursor.fetchone()
43-
db.close()
4440
return jsonify(picture)
4541

4642

4743
@picture_blueprint.route("/author/<int:author_id>", methods=["GET"])
4844
def get_pictures_by_author(author_id):
49-
db = get_db_connection()
50-
with db.cursor() as cursor:
45+
with database_cursor() as cursor:
5146
cursor.callproc("get_pictures_by_author", (author_id,))
5247
picture = cursor.fetchall()
53-
db.close()
5448
return jsonify(picture)
5549

5650

@@ -84,10 +78,9 @@ def upload_picture():
8478
return jsonify({"error": f"Invalid picture type: {picture_type}"}), 400
8579

8680
db = get_db_connection()
87-
with db.cursor() as cursor:
81+
with database_cursor() as cursor:
8882
cursor.callproc(procedure, (hexname, request.person_id))
8983
db.commit()
90-
db.close()
9184

9285
fullpath = os.path.normpath(os.path.join(PICTURE_FOLDER, hexname))
9386
if not fullpath.startswith(PICTURE_FOLDER):

0 commit comments

Comments
 (0)