Skip to content

Commit

Permalink
Feat: check project token (#624)
Browse files Browse the repository at this point in the history
* feat(api): project token management

* tests + alembic revision

Signed-off-by: inimaz <93inigo93@gmail.com>

* temp

Signed-off-by: inimaz <93inigo93@gmail.com>

* feat(api): check project-token when calling experiments endpoints

* fix sql return + small refactor of project_has_access in project token service

* feat: create emission and create run need project token in header

* add unit test for when there is no token in header

Signed-off-by: inimaz <93inigo93@gmail.com>

* fix: x-api-token instead of project-token

---------

Signed-off-by: inimaz <93inigo93@gmail.com>
  • Loading branch information
inimaz authored Aug 2, 2024
1 parent 049b1bf commit 6d1a466
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

from carbonserver.api.domain.project_tokens import ProjectTokens
from carbonserver.api.infra.api_key_service import generate_api_key
from carbonserver.api.infra.database.sql_models import Emission as SqlModelEmission
from carbonserver.api.infra.database.sql_models import Experiment as SqlModelExperiment
from carbonserver.api.infra.database.sql_models import (
ProjectToken as SqlModelProjectToken,
)
from carbonserver.api.infra.database.sql_models import Run as SqlModelRun
from carbonserver.api.schemas import ProjectToken, ProjectTokenCreate


Expand Down Expand Up @@ -58,6 +61,73 @@ def list_project_tokens(self, project_id: str):
for project_token in db_project_tokens
]

def get_project_token_by_project_id_and_token(self, project_id: str, token: str):
with self.session_factory() as session:
db_project_token = (
session.query(SqlModelProjectToken)
.filter(
SqlModelProjectToken.project_id == project_id
and SqlModelProjectToken.token == token
)
.first()
)
return (
self.map_sql_to_schema(db_project_token) if db_project_token else None
)

def get_project_token_by_experiment_id_and_token(
self, experiment_id: str, token: str
):
with self.session_factory() as session:
db_project_token = (
session.query(SqlModelProjectToken)
.filter(SqlModelProjectToken.token == token)
.join(
SqlModelExperiment,
SqlModelProjectToken.project_id == SqlModelExperiment.project_id,
)
.filter(SqlModelExperiment.id == experiment_id)
.first()
)
return (
self.map_sql_to_schema(db_project_token) if db_project_token else None
)

def get_project_token_by_run_id_and_token(self, run_id: str, token: str):
with self.session_factory() as session:
db_project_token = (
session.query(SqlModelProjectToken)
.filter(SqlModelProjectToken.token == token)
.join(
SqlModelExperiment,
SqlModelProjectToken.project_id == SqlModelExperiment.project_id,
)
.join(SqlModelRun, SqlModelExperiment.id == SqlModelRun.experiment_id)
.filter(SqlModelRun.id == run_id)
.first()
)
return (
self.map_sql_to_schema(db_project_token) if db_project_token else None
)

def get_project_token_by_emission_id_and_token(self, emission_id: str, token: str):
with self.session_factory() as session:
db_project_token = (
session.query(SqlModelProjectToken)
.filter(SqlModelProjectToken.token == token)
.join(
SqlModelExperiment,
SqlModelProjectToken.project_id == SqlModelExperiment.project_id,
)
.join(SqlModelRun, SqlModelExperiment.id == SqlModelRun.experiment_id)
.join(SqlModelEmission, SqlModelRun.id == SqlModelEmission.run_id)
.filter(SqlModelEmission.id == emission_id)
.first()
)
return (
self.map_sql_to_schema(db_project_token) if db_project_token else None
)

@staticmethod
def map_sql_to_schema(project_token: SqlModelProjectToken) -> ProjectToken:
"""Convert a models.ProjectToken to a schemas.ProjectToken
Expand Down
14 changes: 12 additions & 2 deletions carbonserver/carbonserver/api/routers/emissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from container import ServerContainer
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Header, Query
from fastapi_pagination import Page, paginate
from fastapi_pagination.default import Page as BasePage
from fastapi_pagination.default import Params as BaseParams
from starlette import status

from carbonserver.api.dependencies import get_token_header
from carbonserver.api.schemas import Emission, EmissionCreate
from carbonserver.api.schemas import AccessLevel, Emission, EmissionCreate
from carbonserver.api.services.emissions_service import EmissionService
from carbonserver.api.services.project_token_service import ProjectTokenService

# T, Params and Page are needed to override default pagination of get_emissions_from_run
T = TypeVar("T")
Expand Down Expand Up @@ -45,7 +46,16 @@ def add_emission(
emission_service: EmissionService = Depends(
Provide[ServerContainer.emission_service]
),
project_token_service: ProjectTokenService = Depends(
Provide[ServerContainer.project_token_service]
),
x_api_token: str = Header(None), # Capture the x-api-token from the headers
) -> UUID:
project_token_service.project_token_has_access(
AccessLevel.WRITE.value,
run_id=emission.run_id,
project_token=x_api_token,
)
return emission_service.add_emission(emission)


Expand Down
14 changes: 12 additions & 2 deletions carbonserver/carbonserver/api/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import dateutil.relativedelta
from container import ServerContainer
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Header
from starlette import status

from carbonserver.api.dependencies import get_token_header
from carbonserver.api.errors import EmptyResultException
from carbonserver.api.schemas import Empty, Run, RunCreate, RunReport
from carbonserver.api.schemas import AccessLevel, Empty, Run, RunCreate, RunReport
from carbonserver.api.services.project_token_service import ProjectTokenService
from carbonserver.api.services.run_service import RunService
from carbonserver.api.usecases.run.experiment_sum_by_run import (
ExperimentSumsByRunUsecase,
Expand All @@ -34,7 +35,16 @@
def add_run(
run: RunCreate,
run_service: RunService = Depends(Provide[ServerContainer.run_service]),
project_token_service: ProjectTokenService = Depends(
Provide[ServerContainer.project_token_service]
),
x_api_token: str = Header(None), # Capture the x-api-token from the headers
) -> Run:
project_token_service.project_token_has_access(
AccessLevel.WRITE.value,
experiment_id=run.experiment_id,
project_token=x_api_token,
)
return run_service.add_run(run)


Expand Down
102 changes: 101 additions & 1 deletion carbonserver/carbonserver/api/services/project_token_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi import HTTPException

from carbonserver.api.infra.repositories.repository_projects_tokens import (
SqlAlchemyRepository as ProjectTokensSqlRepository,
)
from carbonserver.api.schemas import ProjectTokenCreate
from carbonserver.api.schemas import AccessLevel, ProjectToken, ProjectTokenCreate


class ProjectTokenService:
Expand All @@ -16,3 +18,101 @@ def delete_project_token(self, project_id, token_id):

def list_tokens_from_project(self, project_id):
return self._repository.list_project_tokens(project_id)

def project_token_has_access(
self,
desired_access: int,
project_token: str,
project_id=None,
experiment_id=None,
run_id=None,
emission_id=None,
):
"""
Check if the project token has access to the project_id, experiment_id, run_id or emission_id with the desired_access.
"""
if not project_token:
raise HTTPException(
status_code=403,
detail="Not allowed to perform this action. Missing project token",
)
if project_id:
self._project_token_has_access_to_project_id(
desired_access, project_id, project_token
)
elif experiment_id:
self._project_token_has_access_to_experiment_id(
desired_access, experiment_id, project_token
)
elif run_id:
self._project_token_has_access_to_run_id(
desired_access, run_id, project_token
)
elif emission_id:
self._project_token_has_access_to_emission_id(
desired_access, emission_id, project_token
)
else:
raise HTTPException(
status_code=400,
detail="Not allowed to perform this action. Missing project_id, experiment_id, run_id or emission_id",
)

def _project_token_has_access_to_project_id(
self, desired_access: int, project_id, project_token: str
):
# Verify that the project token is valid and has access to do the action
full_project_token = self._repository.get_project_token_by_project_id_and_token(
project_id, project_token
)
self._has_access(desired_access, full_project_token)

def _project_token_has_access_to_experiment_id(
self, desired_access: int, experiment_id, project_token: str
):
"""
Check if the project token has access to the experiment_id with the desired_access.
Example: desired_access = AccessLevel.READ.value but the project_token has AccessLevel.WRITE.value ==> has_access = False because WRITE access is not READ
Example2: desired_access = AccessLevel.WRITE.value and the project_token has AccessLevel.READ_WRITE.value ==> has_access = TRUE because READ_WRITE access contains WRITE access
"""
# Verify that the project token is valid and has access to do the action

full_project_token = (
self._repository.get_project_token_by_experiment_id_and_token(
experiment_id, project_token
)
)
self._has_access(desired_access, full_project_token)

def _project_token_has_access_to_run_id(
self, desired_access: int, run_id, project_token: str
):
# Verify that the project token is valid and has access to do the action
full_project_token = self._repository.get_project_token_by_run_id_and_token(
run_id, project_token
)
self._has_access(desired_access, full_project_token)

def _project_token_has_access_to_emission_id(
self, desired_access: int, emission_id, project_token: str
):
# Verify that the project token is valid and has access to do the action
full_project_token = (
self._repository.get_project_token_by_emission_id_and_token(
emission_id, project_token
)
)
self._has_access(desired_access, full_project_token)

def _has_access(self, desired_access: int, full_project_token: ProjectToken | None):
if full_project_token:
has_access = (
desired_access == full_project_token.access
or full_project_token.access == AccessLevel.READ_WRITE.value
)
else:
has_access = False
if not has_access:
raise HTTPException(
status_code=403, detail="Not allowed to perform this action"
)
38 changes: 36 additions & 2 deletions carbonserver/tests/api/integration/test_api_black_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

experiment_id = project_id = user_id = api_key = org_id = None
org_name = org_description = org_new_id = None
project_token_id = PROJECT_TOKEN = None
emission_id = None
USER_PASSWORD = "Secret1!îstring"
USER_EMAIL = "user@integration.test"
Expand Down Expand Up @@ -243,6 +244,23 @@ def test_api20_experiment_list():
assert is_key_value_exist(r.json(), "id", experiment_id)


def test_api21_create_api_project_token():
# This project token is needed to create emissions/runs
global PROJECT_TOKEN
global project_token_id
assert project_id is not None
payload = {
"name": "Project token for test_api_black_box",
"access": 2,
}
r = requests.post(
url=URL + f"/projects/{project_id}/api-tokens", json=payload, timeout=2
)
tc.assertEqual(r.status_code, 201)
PROJECT_TOKEN = r.json()["token"]
project_token_id = r.json()["id"]


def send_run(experiment_id: str):
assert experiment_id is not None
payload = {
Expand All @@ -262,7 +280,12 @@ def send_run(experiment_id: str):
"ram_total_size": 16948.22,
"tracking_mode": "Machine",
}
r = requests.post(url=URL + "/runs/", json=payload, timeout=2)
r = requests.post(
url=URL + "/runs/",
json=payload,
timeout=2,
headers={"x-api-token": PROJECT_TOKEN},
)
tc.assertEqual(r.status_code, 201)
return r.json()

Expand Down Expand Up @@ -330,7 +353,12 @@ def add_emission(run_id: str):
"ram_energy": default_emission["ram_energy"],
"energy_consumed": default_emission["energy_consumed"],
}
r = requests.post(url=URL + "/emissions/", json=payload, timeout=2)
r = requests.post(
url=URL + "/emissions/",
json=payload,
timeout=2,
headers={"x-api-token": PROJECT_TOKEN},
)
tc.assertEqual(r.status_code, 201)
return r.json()

Expand Down Expand Up @@ -524,3 +552,9 @@ def test_api33_project_read_last_run():
assert len(r.json()) > 0
assert r.json()["id"] == run_id_2
assert r.json()["experiment_id"] == experiment_id


def test_api34_project_api_token_delete():
url = f"{URL}/projects/{project_id}/api-tokens/{project_token_id}"
r = requests.delete(url, timeout=2)
tc.assertEqual(r.status_code, 204)
Loading

0 comments on commit 6d1a466

Please sign in to comment.