-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
122 lines (90 loc) · 3.07 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from gevent import monkey
monkey.patch_all()
from gevent.event import Event
from gevent.timeout import Timeout
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import scoped_session, sessionmaker
from gevent.event import Event
from gevent.timeout import Timeout
from flask import Flask, request, send_from_directory
from flask_cors import CORS
from flask_socketio import SocketIO
from database.database import Token, Base
from utils import get_config
app = Flask(__name__)
CORS(app) # Add CORS support to allow cross-origin requests
app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='gevent')
config = get_config()
HOST = config["MAIN"]["HOST"]
PORT = int(config["MAIN"]["PORT"])
DB_PATH = config["MAIN"]["DB_PATH"]
TIMEOUT = int(config["MAIN"]["TIMEOUT"])
# Set the directory to serve files from
FILES_PATH = 'resources'
engine = create_engine(DB_PATH,
connect_args={"check_same_thread": False},
poolclass=StaticPool)
Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
#message_received_event = Event()
received_messages = {}
events_by_sid = {}
# The server runs locally and serves a single client at a time
sid = None
def check_auth(token=None):
if token is None:
return False
# Validate token from the database
session = Session()
with session.begin():
token_record = session.query(Token).filter(Token.token == token).first()
if token_record is None:
return False
return True
@app.route('/resources/<path:path>')
def serve_files(path):
return send_from_directory(FILES_PATH, path)
@app.route('/gpt/send_message', methods=['POST'])
def send_message():
assert sid is not None, "No client connected"
message = request.form.get('message')
token = request.form.get('token')
if not check_auth(token):
return 'Unauthorized', 401
socketio.emit('message', message)
event = Event()
events_by_sid[sid] = event
try:
with Timeout(TIMEOUT):
event.wait()
received_message = received_messages.pop(sid, 'Error: No message received')
event.clear()
return received_message
except Timeout:
return 'Error: Timeout occurred while waiting for a response'
@socketio.on('connect')
def handle_connect():
token = request.args.get('token')
if not check_auth(token):
return 'Unauthorized', 401
global sid
sid = request.sid
print('Client connected:', request.sid)
@socketio.on('disconnect')
def handle_disconnect():
global sid
sid = None
print('Client disconnected:', request.sid)
@socketio.on('message')
def handle_message(txt):
print('received message: ' + txt)
received_messages[request.sid] = txt
event = events_by_sid.get(request.sid)
if event:
event.set()
if __name__ == '__main__':
print(f"Starting server at address: {HOST}:{PORT}")
socketio.run(app, host=HOST, port=PORT)