Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/api/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from datetime import datetime, timezone
from typing import Any

import strawberry
Expand Down Expand Up @@ -107,6 +108,7 @@ async def get_context(
lastname=lastname,
title="User",
avatar_url=picture,
last_online=datetime.now(timezone.utc),
)
session.add(new_user)
await session.commit()
Expand Down Expand Up @@ -143,8 +145,9 @@ async def get_context(
await session.refresh(db_user)

if db_user:
db_user.last_online = datetime.now(timezone.utc)
session.add(db_user)
try:

await _update_user_root_locations(
session,
db_user,
Expand Down
3 changes: 3 additions & 0 deletions backend/api/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from api.decorators.pagination import apply_pagination, paginated_query

__all__ = ["apply_pagination", "paginated_query"]
40 changes: 40 additions & 0 deletions backend/api/decorators/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from functools import wraps
from typing import Any, Callable, TypeVar

from sqlalchemy import Select

T = TypeVar("T")


def apply_pagination(
query: Select[Any],
limit: int | None = None,
offset: int | None = None,
) -> Select[Any]:
if offset is not None:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
return query


def paginated_query(
limit_param: str = "limit",
offset_param: str = "offset",
):
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
limit = kwargs.get(limit_param)
offset = kwargs.get(offset_param)

result = await func(*args, **kwargs)

if isinstance(result, Select):
return apply_pagination(result, limit=limit, offset=offset)

return result

return wrapper

return decorator
2 changes: 2 additions & 0 deletions backend/api/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CreatePatientInput:
)
properties: list[PropertyValueInput] | None = None
state: PatientState | None = None
description: str | None = None


@strawberry.input
Expand All @@ -101,6 +102,7 @@ class UpdatePatientInput:
team_ids: list[strawberry.ID] | None = strawberry.UNSET
properties: list[PropertyValueInput] | None = None
checksum: str | None = None
description: str | None = None


@strawberry.input
Expand Down
2 changes: 2 additions & 0 deletions backend/api/resolvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import strawberry

from .audit import AuditQuery
from .location import LocationMutation, LocationQuery, LocationSubscription
from .patient import PatientMutation, PatientQuery, PatientSubscription
from .property import PropertyDefinitionMutation, PropertyDefinitionQuery
Expand All @@ -14,6 +15,7 @@ class Query(
LocationQuery,
PropertyDefinitionQuery,
UserQuery,
AuditQuery,
):
pass

Expand Down
89 changes: 89 additions & 0 deletions backend/api/resolvers/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
from datetime import datetime
from typing import Any

import strawberry
from api.audit import AuditLogger
from api.context import Info
from api.types.audit import AuditLogType
from config import INFLUXDB_BUCKET, INFLUXDB_ORG, LOGGER

logger = logging.getLogger(LOGGER)


@strawberry.type
class AuditQuery:
@strawberry.field
async def audit_logs(
self,
info: Info,
case_id: strawberry.ID,
limit: int | None = None,
offset: int | None = None,
) -> list[AuditLogType]:
client = AuditLogger._get_client()
if not client:
logger.warning(
"InfluxDB client not available for audit log query"
)
return []

try:
query_api = client.query_api()

limit_clause = f"LIMIT {limit}" if limit else ""
offset_clause = f"OFFSET {offset}" if offset else ""

query = f'''
from(bucket: "{INFLUXDB_BUCKET}")
|> range(start: 0)
|> filter(fn: (r) => r["_measurement"] == "activity")
|> filter(fn: (r) => r["case_id"] == "{case_id}")
|> sort(columns: ["_time"], desc: true)
{offset_clause}
{limit_clause}
'''

result = query_api.query(org=INFLUXDB_ORG, query=query)

audit_logs: list[AuditLogType] = []
seen_combinations: set[tuple[str, datetime]] = set()

for table in result:
record_data: dict[str, Any] = {}
timestamp: datetime | None = None

for record in table.records:
if timestamp is None:
timestamp = record.get_time()

field = record.get_field()
value = record.get_value()

if field == "context":
record_data["context"] = value
elif field == "count":
record_data["count"] = value

case_id_value = record.values.get("case_id", "")
activity = record.values.get("activity", "")
user_id = record.values.get("user_id")

if timestamp and case_id_value and activity:
key = (case_id_value, activity, timestamp)
if key not in seen_combinations:
seen_combinations.add(key)
audit_logs.append(
AuditLogType(
case_id=case_id_value,
activity=activity,
user_id=user_id,
timestamp=timestamp,
context=record_data.get("context"),
)
)

return sorted(audit_logs, key=lambda x: x.timestamp, reverse=True)
except Exception as e:
logger.error(f"Error querying audit logs: {e}")
return []
67 changes: 58 additions & 9 deletions backend/api/resolvers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Generic, TypeVar

import strawberry
Expand All @@ -10,11 +12,30 @@
notify_entity_update,
)
from api.services.subscription import create_redis_subscription
from database import models
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession

logger = logging.getLogger(__name__)

ModelType = TypeVar("ModelType")


async def update_user_last_online(db: AsyncSession, user_id: str | None) -> None:
if not user_id:
return
try:
await db.execute(
update(models.User)
.where(models.User.id == user_id)
.values(last_online=datetime.now(timezone.utc))
)
await db.commit()
except Exception as e:
logger.warning(f"Failed to update last_online for user {user_id}: {e}")
await db.rollback()


class BaseQueryResolver(Generic[ModelType]):
def __init__(self, model: type[ModelType]):
self.model = model
Expand Down Expand Up @@ -102,9 +123,18 @@ class BaseSubscriptionResolver:
async def entity_created(
info: Info, entity_name: str
) -> AsyncGenerator[strawberry.ID, None]:
async for entity_id in create_redis_subscription(
f"{entity_name}_created"
):
if info.context.user:
await update_user_last_online(info.context.db, info.context.user.id)
channel = f"{entity_name}_created"
logger.info(
f"[SUBSCRIPTION] Initializing entity_created subscription: "
f"entity_name={entity_name}, channel={channel}"
)
async for entity_id in create_redis_subscription(channel):
logger.info(
f"[SUBSCRIPTION] BaseSubscriptionResolver received entity_created event: "
f"entity_name={entity_name}, entity_id={entity_id}, channel={channel}"
)
yield entity_id

@staticmethod
Expand All @@ -113,16 +143,35 @@ async def entity_updated(
entity_name: str,
entity_id: strawberry.ID | None = None,
) -> AsyncGenerator[strawberry.ID, None]:
async for updated_id in create_redis_subscription(
f"{entity_name}_updated", entity_id
):
if info.context.user:
await update_user_last_online(info.context.db, info.context.user.id)
channel = f"{entity_name}_updated"
logger.info(
f"[SUBSCRIPTION] Initializing entity_updated subscription: "
f"entity_name={entity_name}, entity_id={entity_id}, channel={channel}"
)
async for updated_id in create_redis_subscription(channel, str(entity_id) if entity_id else None):
logger.info(
f"[SUBSCRIPTION] BaseSubscriptionResolver received entity_updated event: "
f"entity_name={entity_name}, updated_id={updated_id}, "
f"filter_entity_id={entity_id}, channel={channel}"
)
yield updated_id

@staticmethod
async def entity_deleted(
info: Info, entity_name: str
) -> AsyncGenerator[strawberry.ID, None]:
async for entity_id in create_redis_subscription(
f"{entity_name}_deleted"
):
if info.context.user:
await update_user_last_online(info.context.db, info.context.user.id)
channel = f"{entity_name}_deleted"
logger.info(
f"[SUBSCRIPTION] Initializing entity_deleted subscription: "
f"entity_name={entity_name}, channel={channel}"
)
async for entity_id in create_redis_subscription(channel):
logger.info(
f"[SUBSCRIPTION] BaseSubscriptionResolver received entity_deleted event: "
f"entity_name={entity_name}, entity_id={entity_id}, channel={channel}"
)
yield entity_id
5 changes: 5 additions & 0 deletions backend/api/resolvers/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import strawberry
from api.audit import audit_log
from api.context import Info
from api.decorators.pagination import apply_pagination
from api.inputs import CreateLocationNodeInput, LocationType, UpdateLocationNodeInput
from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver
from api.services.authorization import AuthorizationService
Expand Down Expand Up @@ -67,6 +68,8 @@ async def location_nodes(
parent_id: strawberry.ID | None = None,
recursive: bool = False,
order_by_name: bool = False,
limit: int | None = None,
offset: int | None = None,
) -> list[LocationNodeType]:
db = info.context.db

Expand Down Expand Up @@ -118,6 +121,8 @@ async def location_nodes(
if order_by_name:
query = query.order_by(models.LocationNode.title)

query = apply_pagination(query, limit=limit, offset=offset)

result = await db.execute(query)
return result.scalars().all()

Expand Down
Loading
Loading