Skip to content

Commit

Permalink
implement authentication with JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
sepehrsh79 committed Sep 27, 2024
1 parent 8cc5b97 commit b1bcea4
Show file tree
Hide file tree
Showing 35 changed files with 1,110 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

tmux.sh
53 changes: 38 additions & 15 deletions alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import asyncio
import json
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import MetaData, pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config

from alembic import context
from src.core.config import settings
from src.core.database.base import SQLBase
from src.core.database.parser import JSONDecoder, JSONEncoder
from src.models import * # noqa

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand All @@ -14,16 +21,19 @@
if config.config_file_name is not None:
fileConfig(config.config_file_name)


config.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = None
target_metadata = SQLBase.metadata

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
rollback = int(context.get_x_argument(as_dictionary=True).get("rollback", "0"))


def run_migrations_offline() -> None:
Expand All @@ -41,7 +51,7 @@ def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
Expand All @@ -50,26 +60,39 @@ def run_migrations_offline() -> None:
context.run_migrations()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore

with context.begin_transaction() as _: # noqa
context.run_migrations()
live_meta = MetaData()
live_meta.reflect(connection)
if rollback:
connection.rollback()

In this scenario we need to create an Engine
and associate a connection with the context.

async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
json_serializer=lambda x: json.dumps(x, cls=JSONEncoder),
json_deserializer=lambda x: json.loads(x, cls=JSONDecoder),
)

with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""

with context.begin_transaction():
context.run_migrations()
asyncio.run(run_async_migrations())


if context.is_offline_mode():
Expand Down
40 changes: 40 additions & 0 deletions alembic/versions/e8f17ffb729b_init_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""init users
Revision ID: e8f17ffb729b
Revises:
Create Date: 2024-09-27 17:18:16.719229
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = 'e8f17ffb729b'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('username', sa.String(), nullable=False),
sa.Column('password', sa.String(), nullable=False),
sa.Column('gauth', sa.String(), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('username')
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('users')
# ### end Alembic commands ###
14 changes: 14 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import uvicorn

from src.core.config import settings

if __name__ == "__main__":
uvicorn.run(
host="0.0.0.0",
port=settings.port,
app="src.core.fastapi:app",
reload=True if settings.ENVIRONMENT != "production" else False,
workers=1,
log_level="debug",
access_log=True,
)
58 changes: 49 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,49 @@
fastapi
uvicorn
sqlalchemy
pydantic
alembic
python-dotenv
PyJWT
passlib[bcrypt]
httpx # Optional for external API calls
alembic==1.13.3
annotated-types==0.7.0
anyio==4.6.0
asyncpg==0.29.0
bcrypt==4.2.0
black==24.8.0
certifi==2024.8.30
click==8.1.7
docstring-to-markdown==0.15
ecdsa==0.19.0
fastapi==0.115.0
greenlet==3.1.1
h11==0.14.0
httpcore==1.0.5
httpx==0.27.2
idna==3.10
jedi==0.19.1
Mako==1.3.5
MarkupSafe==2.1.5
mypy-extensions==1.0.0
packaging==24.1
parso==0.8.4
passlib==1.7.4
pathspec==0.12.1
platformdirs==4.3.6
pluggy==1.5.0
pyasn1==0.6.1
pydantic==2.9.2
pydantic-settings==2.5.2
pydantic_core==2.23.4
PyJWT==2.9.0
pyotp==2.9.0
pypng==0.20220715.0
python-dotenv==1.0.1
python-jose==3.3.0
python-lsp-black==2.0.0
python-lsp-jsonrpc==1.1.2
python-lsp-server==1.12.0
python-multipart==0.0.10
qrcode==7.4.2
redis==5.0.8
rsa==4.9
six==1.16.0
sniffio==1.3.1
SQLAlchemy==2.0.35
starlette==0.38.6
typing_extensions==4.12.2
ujson==5.10.0
uvicorn==0.30.6
File renamed without changes.
169 changes: 169 additions & 0 deletions src/controllers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import asyncio
import base64
import io

import qrcode
from pyotp import random_base32, totp

from src.core.config import settings

from src.core.database import DBManager
from src.core.exceptions import (
BadRequestException,
CustomException,
UnauthorizedException,
)
from src.core.redis.client import RedisManager
from src.repository.jwt import JWTHandler
from src.repository.password import PasswordHandler
from src.crud.users import UserCRUD
from src.schema.auth import Token
from src.schema.user import UserOut, UserOutRegister


class AuthController:
user_crud = UserCRUD()
password_handler = PasswordHandler
jwt_handler = JWTHandler

def __init__(
self,
db_session: DBManager,
redis_session: RedisManager | None = None,
user_crud: UserCRUD | None = None,
):
self.user_crud = user_crud or self.user_crud
self.db_session = db_session
self.redis_session = redis_session

async def register(self, password: str, username: str) -> UserOutRegister:
user = await self.user_crud.get_by_username(
self.db_session,
username=username,
)

if user:
raise BadRequestException("User already exists with this username")

password = self.password_handler.hash(password)
user = await self.user_crud.create(
db_session=self.db_session,
username=username,
password=password,
gauth=str(random_base32()),
)

assert user is not None
provisioning_uri = totp.TOTP(user.gauth).provisioning_uri()
buffered = io.BytesIO()
qrcode.make(provisioning_uri).save(buffered)
return UserOutRegister(
username=user.username,
updated_at=user.updated_at,
created_at=user.created_at,
gauth=user.gauth,
qr_img=base64.b64encode(buffered.getvalue()).decode(),
)

async def login(self, username: str, password: str, existing_session_id: str)-> Token:
if not self.redis_session:
raise CustomException("Database connection is not initialized")

user = await self.user_crud.get_by_username(self.db_session, username=username)
if (not user) or (not self.password_handler.verify(user.password, password)):
raise BadRequestException("Invalid credentials")

refresh_token = self.jwt_handler.encode_refresh_token(
payload={"sub": "refresh_token", "verify": str(user.id)}
)

await self.redis_session.set(
name=refresh_token, value=user.id, ex=self.jwt_handler.refresh_token_expire
)

token = Token(
access_token=None,
refresh_token=refresh_token,
)
# if user was verified
session_id_redis = await self.redis_session.get(existing_session_id)
if session_id_redis == str(user.id):
token.access_token = self.jwt_handler.encode(payload={"user_id": str(user.id)})
return token

return token

async def logout(self, refresh_token) -> None:
if not refresh_token:
raise BadRequestException
if not self.redis_session:
raise CustomException("Database connection is not initialized")
await self.redis_session.delete(refresh_token)
return None

async def me(self, user_id) -> UserOut:
user = await self.user_crud.get_by_id(self.db_session, user_id)
if not user:
raise BadRequestException("Invalid credentials")
return UserOut(
username=user.username,
updated_at=user.updated_at,
created_at=user.created_at,
)

async def verify(
self,
refresh_token: str,
session_id: str,
code: str,
settings=settings,
) -> None:
if not self.redis_session:
raise CustomException("Database connection is not initialized")
session_id_redis, user_id = await asyncio.gather(
self.redis_session.get(session_id), self.redis_session.get(refresh_token)
)
if not user_id or len(str(user_id)) < 5:
raise UnauthorizedException("Invalid Refresh Token")
elif session_id_redis != user_id:
user = await self.user_crud.get_by_id(self.db_session, user_id=user_id)
assert user is not None
if not totp.TOTP(user.gauth).verify(code):
raise BadRequestException("Invalid Code")
await self.redis_session.set(
session_id, value=user_id, ex=(settings.SESSION_EXPIRE_MINUTES) * 60
)
else:
raise BadRequestException("Already Verified")
return None

async def refresh_token(self, old_refresh_token: str, session_id: str) -> Token | str:
if not self.redis_session:
raise CustomException("Database connection is not initialized")

user_id, ttl, session_id = await asyncio.gather(
self.redis_session.get(old_refresh_token),
self.redis_session.ttl(old_refresh_token),
self.redis_session.get(session_id),
)
if not user_id or len(str(user_id)) < 5:
return UnauthorizedException("Invalid Refresh Token")

if session_id != user_id:
raise UnauthorizedException(
"Please verify using 2 step authentication first"
)

access_token = self.jwt_handler.encode(payload={"user_id": str(user_id)})
refresh_token = self.jwt_handler.encode_refresh_token(
payload={"sub": "refresh_token", "verify": str(user_id)}
)

await asyncio.gather(
self.redis_session.set(refresh_token, user_id, ex=ttl),
self.redis_session.delete(old_refresh_token),
)
return Token(
access_token=access_token,
refresh_token=refresh_token,
)
Loading

0 comments on commit b1bcea4

Please sign in to comment.