-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8cc5b97
commit b1bcea4
Showing
35 changed files
with
1,110 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""init users | ||
Revision ID: e8f17ffb729b | ||
Revises: | ||
Create Date: 2024-09-27 17:18:16.719229 | ||
""" | ||
from typing import Sequence, Union | ||
|
||
from alembic import op | ||
import sqlalchemy as sa | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = 'e8f17ffb729b' | ||
down_revision: Union[str, None] = None | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table('users', | ||
sa.Column('username', sa.String(), nullable=False), | ||
sa.Column('password', sa.String(), nullable=False), | ||
sa.Column('gauth', sa.String(), nullable=False), | ||
sa.Column('id', sa.UUID(), nullable=False), | ||
sa.Column('created_at', sa.DateTime(), nullable=False), | ||
sa.Column('updated_at', sa.DateTime(), nullable=False), | ||
sa.PrimaryKeyConstraint('id'), | ||
sa.UniqueConstraint('id'), | ||
sa.UniqueConstraint('username') | ||
) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_table('users') | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import uvicorn | ||
|
||
from src.core.config import settings | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run( | ||
host="0.0.0.0", | ||
port=settings.port, | ||
app="src.core.fastapi:app", | ||
reload=True if settings.ENVIRONMENT != "production" else False, | ||
workers=1, | ||
log_level="debug", | ||
access_log=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,49 @@ | ||
fastapi | ||
uvicorn | ||
sqlalchemy | ||
pydantic | ||
alembic | ||
python-dotenv | ||
PyJWT | ||
passlib[bcrypt] | ||
httpx # Optional for external API calls | ||
alembic==1.13.3 | ||
annotated-types==0.7.0 | ||
anyio==4.6.0 | ||
asyncpg==0.29.0 | ||
bcrypt==4.2.0 | ||
black==24.8.0 | ||
certifi==2024.8.30 | ||
click==8.1.7 | ||
docstring-to-markdown==0.15 | ||
ecdsa==0.19.0 | ||
fastapi==0.115.0 | ||
greenlet==3.1.1 | ||
h11==0.14.0 | ||
httpcore==1.0.5 | ||
httpx==0.27.2 | ||
idna==3.10 | ||
jedi==0.19.1 | ||
Mako==1.3.5 | ||
MarkupSafe==2.1.5 | ||
mypy-extensions==1.0.0 | ||
packaging==24.1 | ||
parso==0.8.4 | ||
passlib==1.7.4 | ||
pathspec==0.12.1 | ||
platformdirs==4.3.6 | ||
pluggy==1.5.0 | ||
pyasn1==0.6.1 | ||
pydantic==2.9.2 | ||
pydantic-settings==2.5.2 | ||
pydantic_core==2.23.4 | ||
PyJWT==2.9.0 | ||
pyotp==2.9.0 | ||
pypng==0.20220715.0 | ||
python-dotenv==1.0.1 | ||
python-jose==3.3.0 | ||
python-lsp-black==2.0.0 | ||
python-lsp-jsonrpc==1.1.2 | ||
python-lsp-server==1.12.0 | ||
python-multipart==0.0.10 | ||
qrcode==7.4.2 | ||
redis==5.0.8 | ||
rsa==4.9 | ||
six==1.16.0 | ||
sniffio==1.3.1 | ||
SQLAlchemy==2.0.35 | ||
starlette==0.38.6 | ||
typing_extensions==4.12.2 | ||
ujson==5.10.0 | ||
uvicorn==0.30.6 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import asyncio | ||
import base64 | ||
import io | ||
|
||
import qrcode | ||
from pyotp import random_base32, totp | ||
|
||
from src.core.config import settings | ||
|
||
from src.core.database import DBManager | ||
from src.core.exceptions import ( | ||
BadRequestException, | ||
CustomException, | ||
UnauthorizedException, | ||
) | ||
from src.core.redis.client import RedisManager | ||
from src.repository.jwt import JWTHandler | ||
from src.repository.password import PasswordHandler | ||
from src.crud.users import UserCRUD | ||
from src.schema.auth import Token | ||
from src.schema.user import UserOut, UserOutRegister | ||
|
||
|
||
class AuthController: | ||
user_crud = UserCRUD() | ||
password_handler = PasswordHandler | ||
jwt_handler = JWTHandler | ||
|
||
def __init__( | ||
self, | ||
db_session: DBManager, | ||
redis_session: RedisManager | None = None, | ||
user_crud: UserCRUD | None = None, | ||
): | ||
self.user_crud = user_crud or self.user_crud | ||
self.db_session = db_session | ||
self.redis_session = redis_session | ||
|
||
async def register(self, password: str, username: str) -> UserOutRegister: | ||
user = await self.user_crud.get_by_username( | ||
self.db_session, | ||
username=username, | ||
) | ||
|
||
if user: | ||
raise BadRequestException("User already exists with this username") | ||
|
||
password = self.password_handler.hash(password) | ||
user = await self.user_crud.create( | ||
db_session=self.db_session, | ||
username=username, | ||
password=password, | ||
gauth=str(random_base32()), | ||
) | ||
|
||
assert user is not None | ||
provisioning_uri = totp.TOTP(user.gauth).provisioning_uri() | ||
buffered = io.BytesIO() | ||
qrcode.make(provisioning_uri).save(buffered) | ||
return UserOutRegister( | ||
username=user.username, | ||
updated_at=user.updated_at, | ||
created_at=user.created_at, | ||
gauth=user.gauth, | ||
qr_img=base64.b64encode(buffered.getvalue()).decode(), | ||
) | ||
|
||
async def login(self, username: str, password: str, existing_session_id: str)-> Token: | ||
if not self.redis_session: | ||
raise CustomException("Database connection is not initialized") | ||
|
||
user = await self.user_crud.get_by_username(self.db_session, username=username) | ||
if (not user) or (not self.password_handler.verify(user.password, password)): | ||
raise BadRequestException("Invalid credentials") | ||
|
||
refresh_token = self.jwt_handler.encode_refresh_token( | ||
payload={"sub": "refresh_token", "verify": str(user.id)} | ||
) | ||
|
||
await self.redis_session.set( | ||
name=refresh_token, value=user.id, ex=self.jwt_handler.refresh_token_expire | ||
) | ||
|
||
token = Token( | ||
access_token=None, | ||
refresh_token=refresh_token, | ||
) | ||
# if user was verified | ||
session_id_redis = await self.redis_session.get(existing_session_id) | ||
if session_id_redis == str(user.id): | ||
token.access_token = self.jwt_handler.encode(payload={"user_id": str(user.id)}) | ||
return token | ||
|
||
return token | ||
|
||
async def logout(self, refresh_token) -> None: | ||
if not refresh_token: | ||
raise BadRequestException | ||
if not self.redis_session: | ||
raise CustomException("Database connection is not initialized") | ||
await self.redis_session.delete(refresh_token) | ||
return None | ||
|
||
async def me(self, user_id) -> UserOut: | ||
user = await self.user_crud.get_by_id(self.db_session, user_id) | ||
if not user: | ||
raise BadRequestException("Invalid credentials") | ||
return UserOut( | ||
username=user.username, | ||
updated_at=user.updated_at, | ||
created_at=user.created_at, | ||
) | ||
|
||
async def verify( | ||
self, | ||
refresh_token: str, | ||
session_id: str, | ||
code: str, | ||
settings=settings, | ||
) -> None: | ||
if not self.redis_session: | ||
raise CustomException("Database connection is not initialized") | ||
session_id_redis, user_id = await asyncio.gather( | ||
self.redis_session.get(session_id), self.redis_session.get(refresh_token) | ||
) | ||
if not user_id or len(str(user_id)) < 5: | ||
raise UnauthorizedException("Invalid Refresh Token") | ||
elif session_id_redis != user_id: | ||
user = await self.user_crud.get_by_id(self.db_session, user_id=user_id) | ||
assert user is not None | ||
if not totp.TOTP(user.gauth).verify(code): | ||
raise BadRequestException("Invalid Code") | ||
await self.redis_session.set( | ||
session_id, value=user_id, ex=(settings.SESSION_EXPIRE_MINUTES) * 60 | ||
) | ||
else: | ||
raise BadRequestException("Already Verified") | ||
return None | ||
|
||
async def refresh_token(self, old_refresh_token: str, session_id: str) -> Token | str: | ||
if not self.redis_session: | ||
raise CustomException("Database connection is not initialized") | ||
|
||
user_id, ttl, session_id = await asyncio.gather( | ||
self.redis_session.get(old_refresh_token), | ||
self.redis_session.ttl(old_refresh_token), | ||
self.redis_session.get(session_id), | ||
) | ||
if not user_id or len(str(user_id)) < 5: | ||
return UnauthorizedException("Invalid Refresh Token") | ||
|
||
if session_id != user_id: | ||
raise UnauthorizedException( | ||
"Please verify using 2 step authentication first" | ||
) | ||
|
||
access_token = self.jwt_handler.encode(payload={"user_id": str(user_id)}) | ||
refresh_token = self.jwt_handler.encode_refresh_token( | ||
payload={"sub": "refresh_token", "verify": str(user_id)} | ||
) | ||
|
||
await asyncio.gather( | ||
self.redis_session.set(refresh_token, user_id, ex=ttl), | ||
self.redis_session.delete(old_refresh_token), | ||
) | ||
return Token( | ||
access_token=access_token, | ||
refresh_token=refresh_token, | ||
) |
Oops, something went wrong.