Skip to content

Commit

Permalink
add backend utilities for milestonegroup statistics computation
Browse files Browse the repository at this point in the history
  • Loading branch information
MaHaWo committed Dec 3, 2024
1 parent 7f3d9fb commit 49e026c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 74 deletions.
33 changes: 17 additions & 16 deletions mondey_backend/src/mondey_backend/models/milestones.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datetime

from pydantic import BaseModel
from sqlalchemy.orm import Mapped
from sqlmodel import Field
from sqlmodel import SQLModel
Expand Down Expand Up @@ -194,27 +193,29 @@ class MilestoneAnswerSessionPublic(SQLModel):
answers: dict[int, MilestoneAnswerPublic]


class MilestoneAgeScore(BaseModel):
milestone_id: int
age_months: int
class Statistics(SQLModel):
avg_score: float
stddev_score: float
age_months: int
expected_score: float


class MilestoneAgeScores(BaseModel):
scores: list[MilestoneAgeScore]
class MilestoneAgeScores(SQLModel, table=True):
milestone_id: int = Field(primary_key=True, default=None)
scores: list[Statistics]
expected_age: int
created_at: datetime.datetime = Field(
sa_column_kwargs={
"server_default": text("CURRENT_TIMESTAMP"),
}
)


class MilestoneGroupStatistics(SQLModel):
session_id: int = Field(
default=None, foreign_key="milestoneanswersession.id", primary_key=True
)
group_id: int = Field(
default=None, foreign_key="milestonegroup.id", primary_key=True
class MilestoneGroupAgeScores(SQLModel, table=True):
milestonegroup_id: int = Field(primary_key=True, default=None)
scores: list[Statistics]
created_at: datetime.datetime = Field(
sa_column_kwargs={
"server_default": text("CURRENT_TIMESTAMP"),
}
)
child_id: int = Field(default=None, foreign_key="child.id", primary_key=True)
age_months: int
avg_score: float
stddev_score: float
90 changes: 32 additions & 58 deletions mondey_backend/src/mondey_backend/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
from ..models.milestones import AgeInterval
from ..models.milestones import Milestone
from ..models.milestones import MilestoneAdmin
from ..models.milestones import MilestoneAgeScore
from ..models.milestones import MilestoneAgeScores
from ..models.milestones import MilestoneAnswer
from ..models.milestones import MilestoneAnswerSession
from ..models.milestones import MilestoneGroup
from ..models.milestones import MilestoneGroupAdmin
from ..models.milestones import MilestoneGroupStatistics
from ..models.milestones import MilestoneGroupAgeScores
from ..models.milestones import MilestoneGroupText
from ..models.milestones import MilestoneText
from ..models.milestones import Statistics
from ..models.questions import ChildQuestion
from ..models.questions import ChildQuestionAdmin
from ..models.questions import ChildQuestionText
Expand Down Expand Up @@ -202,8 +202,6 @@ def get_db_child(
def _get_answer_session_child_ages_in_months(session: SessionDep) -> dict[int, int]:
answer_sessions = session.exec(select(MilestoneAnswerSession)).all()

print(answer_sessions)

return {
answer_session.id: get_child_age_in_months( # type: ignore
get(session, Child, answer_session.child_id), answer_session.created_at
Expand All @@ -217,18 +215,7 @@ def _get_expected_age_from_scores(scores: np.ndarray) -> int:
return np.argmax(scores >= 3.0)


def _calculate_statistics_for(
data: Sequence[int | float], **statfuncs
) -> dict[str, float | np.ndarray | tuple]:
result = {}
for name, func in statfuncs.items():
with np.errstate(invalid="ignore"):
stat = func(data)
result[name] = stat
return result


def _get_score_statistics_by_age(
def _get_statistics_by_age(
answers: Sequence[MilestoneAnswer], child_ages: dict[int, int]
) -> tuple[np.ndarray, np.ndarray]:
max_age_months = 72
Expand Down Expand Up @@ -276,67 +263,54 @@ def calculate_milestone_statistics_by_age(
col(MilestoneAnswer.milestone_id) == milestone_id
)
).all()
avg, stddev = _get_score_statistics_by_age(answers, child_ages)
avg, stddev = _get_statistics_by_age(answers, child_ages)
expected_age = _get_expected_age_from_scores(avg)

return MilestoneAgeScores(
milestone_id=milestone_id,
expected_age=expected_age,
created_at=datetime.datetime.now(),
scores=[
MilestoneAgeScore(
milestone_id=milestone_id,
Statistics(
age_months=age,
avg_score=avg[age],
stddev_score=stddev[age],
expected_score=(
4 if age >= expected_age else 1
), # FIXME: don´t know what this is supposed to mean
), # TODO: placeholder algorithm? how does the model behind this look like really?
)
for age in range(0, len(avg))
],
)


def calculate_milestonegroup_statistics(
def calculate_milestonegroup_statistics_by_age(
session: SessionDep,
mid: int,
age: int,
age_lower: int,
age_upper: int,
) -> MilestoneGroupStatistics:
milestonegroup = get(session, MilestoneGroup, mid)
answers = []
for milestone in milestonegroup.milestones:
# we want something that is relevant for the age of the child at hand. Hence we filter by age here. Is this what they want?
# FIXME: 11-25-2024: I think this is not what we want and it should be filtered by the age of the child at the time of the answer session?
# this however should already be handled by the answersession itself?
# dazed and confused....
# At any rate the above comment is obsolete.
m_answers = [
answer.answer
for answer in session.exec(
select(MilestoneAnswer)
.where(col(MilestoneAnswer.milestone_id) == milestone.id)
.where(age_lower <= milestone.expected_age_months <= age_upper)
).all()
]
answers.extend(m_answers)

answers = np.array(answers) + 1 # convert 0-3 answer index to 1-4 score

result = _calculate_statistics_for(
answers,
mean=np.mean,
std=lambda a: np.std(a, correction=1),
)
milestonegroup_id,
answers: Sequence[MilestoneAnswer] | None = None,
) -> MilestoneGroupAgeScores:
child_ages = _get_answer_session_child_ages_in_months(session)

mg_score = MilestoneGroupStatistics(
age_months=age,
group_id=milestonegroup.id,
avg_score=np.nan_to_num(result["mean"]),
stddev_score=np.nan_to_num(result["std"]),
)
if answers is None:
answers = session.exec(
select(MilestoneAnswer).where(
col(MilestoneAnswer.milestone_group_id) == milestonegroup_id
)
).all()

return mg_score
avg, stddev = _get_statistics_by_age(answers, child_ages)
return MilestoneGroupAgeScores(
milestonegroup_id=milestonegroup_id,
scores=[
Statistics(
age_months=age,
avg_score=avg[age],
stddev_score=stddev[age],
)
for age in range(0, len(avg))
],
created_at=datetime.datetime.now(),
)


def child_image_path(child_id: int | None) -> pathlib.Path:
Expand Down

0 comments on commit 49e026c

Please sign in to comment.