Skip to content

Commit ed8a851

Browse files
Web application.
Model downloaded from the Internet if not existing in the local disk.
1 parent f8e7447 commit ed8a851

35 files changed

+118
-11331
lines changed

common/languages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
LANGUAGES = ["Assembly", "C", "C++", "C#", "CSS", "Go", "HTML", "Java", "JavaScript", "Kotlin",
2+
"Matlab", "Perl", "PHP", "Python", "R", "Ruby", "Scala", "SQL", "Swift", "TypeScript",
3+
"Unix Shell"]

common/model.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
1+
from io import BytesIO
12
from typing import Dict
3+
import requests
4+
from zipfile import ZipFile
5+
26
import numpy as np
37
import tensorflow as tf
8+
from keras import Model
9+
10+
MODEL_PATH = '../common/model/BRNN'
11+
MODEL_URL = 'https://reflection.uniovi.es/bigcode/download/2024/plangrec/BRNN.zip'
12+
13+
14+
def download_and_load_model(model_path: str, model_url: str) -> Model:
15+
import os
16+
import sys
17+
if not os.path.exists(model_path):
18+
print(f"Model not found in '{model_path}'.")
19+
print(f"Downloading model from '{model_url}'. It may take some minutes...")
20+
response = requests.get(model_url)
21+
if response.status_code == 200:
22+
# Extract the zip file content
23+
with ZipFile(BytesIO(response.content)) as zip_file:
24+
# Create the directory for extraction if it doesn't exist
25+
os.makedirs(model_path, exist_ok=True)
26+
# Extract all contents to the specified path
27+
zip_file.extractall(model_path)
28+
print(f"Model successfully downloaded and extracted to '{model_path}'.\n")
29+
else:
30+
print(f"Failed to download file. Status code: {response.status_code}")
31+
sys.exit(-1)
32+
return tf.keras.models.load_model(model_path)
33+
434

535
# Loads the model into memory at startup to go faster upon prediction
6-
model = tf.keras.models.load_model('../model/BRNN')
36+
model = download_and_load_model(MODEL_PATH, MODEL_URL)
737

838

939
def parse_line(line, allow_short_lines: bool):
@@ -46,7 +76,7 @@ def soft_voting(predictions):
4676

4777

4878
def predict(source_code: str) -> Dict[str, float]:
49-
from configuration import LANGUAGES
79+
from languages import LANGUAGES
5080
global model
5181
lines = source_code.split('\n')
5282
result = {}
@@ -60,4 +90,3 @@ def predict(source_code: str) -> Dict[str, float]:
6090
result[LANGUAGES[i]] = round(p*100, 2)
6191
return result
6292

63-

desktop-app/configuration.py renamed to desktop-app/example_code.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
from tkinter import Tk, END, E, W
2-
from tkinter.ttk import Button, Combobox, Entry, LabelFrame, Treeview, Style
3-
4-
from model import predict
5-
6-
LANGUAGES = ["Assembly", "C", "C++", "C#", "CSS", "Go", "HTML", "Java", "JavaScript", "Kotlin",
7-
"Matlab", "Perl", "PHP", "Python", "R", "Ruby", "Scala", "SQL", "Swift", "TypeScript",
8-
"Unix Shell"]
9-
101
PROGRAM_EXAMPLES = {
112
"Assembly": "ADD CX, [BX+SI*2+10]",
123
"C": "int *numbers = malloc(sizeof(int));",

desktop-app/main.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from tkinter import Tk, END, E, W, ttk, StringVar, BooleanVar, NORMAL, DISABLED, Text, Event
22
from tkinter.ttk import Button, Combobox, LabelFrame, Treeview, Style
3-
4-
from configuration import PROGRAM_EXAMPLES, LANGUAGES
5-
from model import predict
3+
from example_code import PROGRAM_EXAMPLES
4+
from languages import LANGUAGES
65

76

87
def sort_number(table: Treeview, column_name: str, ascending: bool) -> None:
@@ -27,6 +26,16 @@ def sort_string(table: Treeview, column_name: str, asc: bool) -> None:
2726
table.heading(column_name, command=lambda: sort_string(table, column_name, not asc))
2827

2928

29+
def import_model() -> None:
30+
"""Imports the model component from a sibling directory"""
31+
import os
32+
import sys
33+
if not os.path.exists("../common/model.py"):
34+
print("You need to include the common directory as a sibling directory of 'web-api'.", file=sys.stderr)
35+
sys.exit(-1)
36+
sys.path.append('../common')
37+
38+
3039
def main() -> None:
3140
"""Creates the main window and runs the application"""
3241
# Main window
@@ -76,9 +85,10 @@ def select_lang(_: Event) -> None:
7685
cmb.bind("<<ComboboxSelected>>", select_lang)
7786

7887
def get_prediction() -> None:
88+
import model
7989
for item in results_tree.get_children():
8090
results_tree.delete(item)
81-
predictions = sorted(predict(text.get("1.0", END)).items(), key=lambda x: x[1])
91+
predictions = sorted(model.predict(text.get("1.0", END)).items(), key=lambda x: x[1])
8292
for data in predictions:
8393
results_tree.insert(parent="", index=0, values=data)
8494

@@ -136,4 +146,5 @@ def text_changed() -> None:
136146

137147

138148
if __name__ == "__main__":
149+
import_model()
139150
main()

web-api/main.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import sys
22
import logging
3+
4+
from flask_cors import CORS
5+
36
from services.prediction import brnn_predict_blueprint
47

58
from flask import Flask, request, jsonify, Blueprint
@@ -23,8 +26,21 @@ def register_imported_blueprints(app: Flask) -> int:
2326
return count
2427

2528

29+
def import_model() -> None:
30+
"""Imports the model component from a sibling directory"""
31+
import os
32+
import sys
33+
if not os.path.exists("../common/model.py"):
34+
print("You need to include the common directory as a sibling directory of 'web-api'.", file=sys.stderr)
35+
sys.exit(-1)
36+
sys.path.append('../common')
37+
import model # loads the model into memory
38+
39+
2640
def main() -> None:
41+
import_model()
2742
app = Flask(__name__)
43+
CORS(app)
2844
register_imported_blueprints(app)
2945
app.run(debug=True) # debug must be False to trace programs in PyCharm
3046

web-api/services/prediction.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11

22
from flask import Blueprint, jsonify, request, Response
3-
4-
from model import predict
5-
63
brnn_predict_blueprint = Blueprint(f"brnn_predict", __name__)
74

85

@@ -11,7 +8,8 @@ def get_predict() -> Response:
118
"""One mandatory 'source code' parameter must be passed
129
http://127.0.0.1:5000/BRNN/predict?source_code=code
1310
"""
11+
import model
1412
source_code = request.args.get('source_code')
15-
predictions = predict(source_code)
13+
predictions = model.predict(source_code)
1614
return jsonify(predictions)
1715

0 commit comments

Comments
 (0)