diff --git a/advlabdb/adminModelViews.py b/advlabdb/adminModelViews.py index 7eaef29..0f8e850 100644 --- a/advlabdb/adminModelViews.py +++ b/advlabdb/adminModelViews.py @@ -348,7 +348,7 @@ class SemesterView(SecureAdminModelView): ] def customCreateModel(self, form): - return Semester.customInitFromOldSemester( + return Semester.initFromOldSemester( label=form.label.data, year=form.year.data, oldSemester=userActiveSemester(), diff --git a/advlabdb/database_import.py b/advlabdb/database_import.py index 02d8ed1..fd10ac4 100644 --- a/advlabdb/database_import.py +++ b/advlabdb/database_import.py @@ -202,7 +202,7 @@ def importFromFile(filePath): for i, studentNumber in enumerate(partStudents["student_number"]): studentNumber = int(studentNumber) - dbPartStudent = PartStudent.customInit( + dbPartStudent = PartStudent( student=dbStudents[studentNumber], part=dbParts[int(partStudents["part_id"][i])], group=dbGroups[int(partStudents["group_id"][i])], diff --git a/advlabdb/models.py b/advlabdb/models.py index b4f5716..ec4dfa2 100644 --- a/advlabdb/models.py +++ b/advlabdb/models.py @@ -74,15 +74,15 @@ class PartStudent(db.Model): ) def check(group, part): - if group and group.program != part.program: + if group is not None and group.program != part.program: raise DataBaseException( f"Group's program {group.program} and student part's program {part.program} do not match!" ) - def customInit(student, part, group=None): - PartStudent.check(group, part) + def __init__(self, *args, **kwargs): + PartStudent.check(kwargs.get("group"), kwargs["part"]) - return PartStudent(student=student, part=part, group=group) + super().__init__(*args, **kwargs) def customUpdate(self, group, final_part_mark): Part.check(group, self.part) @@ -442,7 +442,7 @@ class Semester(db.Model): __table_args__ = (db.UniqueConstraint(label, year),) - def customInitFromOldSemester(label, year, oldSemester, transferParts, transferAssistants): + def initFromOldSemester(label, year, oldSemester, transferParts, transferAssistants): semester = Semester(label=label, year=year) if transferParts: diff --git a/advlabdb/scripts/test/test_database.py b/advlabdb/scripts/test/test_database.py index d33bd68..0af700e 100644 --- a/advlabdb/scripts/test/test_database.py +++ b/advlabdb/scripts/test/test_database.py @@ -49,9 +49,9 @@ def main(): db.session.add(student2) db.session.add(student3) - ps1 = PartStudent.customInit(student=student1, part=part1) - ps2 = PartStudent.customInit(student=student2, part=part1) - ps3 = PartStudent.customInit(student=student3, part=part2) + ps1 = PartStudent(student=student1, part=part1) + ps2 = PartStudent(student=student2, part=part1) + ps3 = PartStudent(student=student3, part=part2) db.session.add(ps1) db.session.add(ps2)