From 18257799c75b3855a666c07e22584fd2cac31e4e Mon Sep 17 00:00:00 2001 From: Mo8it Date: Mon, 12 Jul 2021 13:06:44 +0200 Subject: [PATCH] Moved get_query and get_count_query to parent classs --- advlabdb/customClasses.py | 14 ++++++ advlabdb/modelViews.py | 103 ++++++++------------------------------ 2 files changed, 36 insertions(+), 81 deletions(-) diff --git a/advlabdb/customClasses.py b/advlabdb/customClasses.py index 2d23d06..971d1ff 100644 --- a/advlabdb/customClasses.py +++ b/advlabdb/customClasses.py @@ -29,9 +29,23 @@ class SecureModelView(ModelView): create_template = "admin_create.html" edit_template = "admin_edit.html" + queryFilter = None + def is_accessible(self): return adminViewIsAccessible() def inaccessible_callback(self, name, **kwargs): # Redirect to login page if user doesn't have access return redirect(url_for("security.login", next=request.url)) + + def get_query(self): + if self.queryFilter: + return super().get_query().filter(self.queryFilter()) + else: + return super().get_query() + + def get_count_query(self): + if self.queryFilter: + return super().get_count_query().filter(self.queryFilter()) + else: + return super().get_count_query() diff --git a/advlabdb/modelViews.py b/advlabdb/modelViews.py index 35acdd5..de8762d 100644 --- a/advlabdb/modelViews.py +++ b/advlabdb/modelViews.py @@ -188,15 +188,7 @@ class PartView(SecureModelView): column_details_list = ["label", "semester", "part_students", "groups"] form_columns = ["label", "semester"] - def get_query(self): - return super().get_query().filter(Part.id.in_([part.id for part in userActiveSemester().parts])) - - def get_count_query(self): - return ( - self.session.query(func.count("*")) - .select_from(self.model) - .filter(Part.id.in_([part.id for part in userActiveSemester().parts])) - ) + queryFilter = lambda self: Part.id.in_([part.id for part in userActiveSemester().parts]) class StudentView(SecureModelView): @@ -225,18 +217,14 @@ class StudentView(SecureModelView): ] -def partQueryFactory(): - return Part.query.filter(Part.id.in_([part.id for part in userActiveSemester().parts])) +partQueryFactory = lambda: Part.query.filter(Part.id.in_([part.id for part in userActiveSemester().parts])) - -def groupQueryFactory(): - return Group.query.filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) +groupQueryFactory = lambda: Group.query.filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) class PartStudentView(SecureModelView): class CreateForm(Form): - def studentQueryFactory(): - return Student.query + studentQueryFactory = lambda: Student.query student = QuerySelectField( "Student", query_factory=studentQueryFactory, validators=[DataRequired()], allow_blank=True, blank_text="-" @@ -255,6 +243,8 @@ class PartStudentView(SecureModelView): column_filters = ["part", "student", "group"] + queryFilter = lambda self: PartStudent.part_id.in_([part.id for part in userActiveSemester().parts]) + partGroupPartMismatchException = "Student's part and group's part do not match!" def create_form(self, obj=None): @@ -271,21 +261,12 @@ class PartStudentView(SecureModelView): else: return super().handle_view_exception(exc) - def get_query(self): - return super().get_query().filter(PartStudent.part_id.in_([part.id for part in userActiveSemester().parts])) - - def get_count_query(self): - return ( - self.session.query(func.count("*")) - .select_from(self.model) - .filter(PartStudent.part_id.in_([part.id for part in userActiveSemester().parts])) - ) - class GroupView(SecureModelView): class CreateForm(Form): - def partStudentsQueryFactory(): - return PartStudent.query.filter(PartStudent.part_id.in_([part.id for part in userActiveSemester().parts])) + partStudentsQueryFactory = lambda: PartStudent.query.filter( + PartStudent.part_id.in_([part.id for part in userActiveSemester().parts]) + ) part = QuerySelectField( "Part", query_factory=partQueryFactory, validators=[DataRequired()], allow_blank=True, blank_text="-" @@ -300,6 +281,8 @@ class GroupView(SecureModelView): column_list = ["number", "part", "part_students", "group_experiments"] column_filters = ["number", "part"] + queryFilter = lambda self: Group.part_id.in_([part.id for part in userActiveSemester().parts]) + partStudentPartPartMismatchException = "Group's part and student's part do not match!" def create_model(self, form): @@ -339,16 +322,6 @@ class GroupView(SecureModelView): form = self.CreateForm return form(get_form_data(), obj=obj) - def get_query(self): - return super().get_query().filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) - - def get_count_query(self): - return ( - self.session.query(func.count("*")) - .select_from(self.model) - .filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) - ) - class ExperimentView(SecureModelView): can_view_details = True @@ -372,15 +345,7 @@ class ExperimentView(SecureModelView): class SemesterExperimentView(SecureModelView): column_list = ["experiment", "semester", "assistants"] - def get_query(self): - return super().get_query().filter(SemesterExperiment.semester == userActiveSemester()) - - def get_count_query(self): - return ( - self.session.query(func.count("*")) - .select_from(self.model) - .filter(SemesterExperiment.semester == userActiveSemester()) - ) + queryFilter = lambda self: SemesterExperiment.semester == userActiveSemester() class AssistantView(SecureModelView): @@ -400,13 +365,13 @@ class AssistantView(SecureModelView): class GroupExperimentView(SecureModelView): class CreateForm(Form): - def semesterExperimentQueryFactory(): - return SemesterExperiment.query.filter(SemesterExperiment.semester == userActiveSemester()) + semesterExperimentQueryFactory = lambda: SemesterExperiment.query.filter( + SemesterExperiment.semester == userActiveSemester() + ) - def assistantQueryFactory(): - return Assistant.query.filter( - Assistant.user_id.in_([user.id for user in User.query.filter(User.active == True)]) - ) + assistantQueryFactory = lambda: Assistant.query.filter( + Assistant.user_id.in_([user.id for user in User.query.filter(User.active == True)]) + ) group = QuerySelectField( "Group", query_factory=groupQueryFactory, validators=[DataRequired()], allow_blank=True, blank_text="-" @@ -445,6 +410,10 @@ class GroupExperimentView(SecureModelView): column_list = ["group", "semester_experiment", "appointments", "experiment_marks"] column_filters = ["group", "semester_experiment.experiment", "appointments"] + queryFilter = lambda self: GroupExperiment.group_id.in_( + [g.id for g in Group.query.filter(Group.part_id.in_([part.id for part in userActiveSemester().parts]))] + ) + def create_model(self, form): try: model = GroupExperiment.checkAndInit( @@ -492,34 +461,6 @@ class GroupExperimentView(SecureModelView): self.after_model_change(form, model, True) return model - def get_query(self): - return ( - super() - .get_query() - .filter( - GroupExperiment.group_id.in_( - [ - g.id - for g in Group.query.filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) - ] - ) - ) - ) - - def get_count_query(self): - return ( - self.session.query(func.count("*")) - .select_from(self.model) - .filter( - GroupExperiment.group_id.in_( - [ - g.id - for g in Group.query.filter(Group.part_id.in_([part.id for part in userActiveSemester().parts])) - ] - ) - ) - ) - class AppointmentView(SecureModelView): column_list = ["date", "special", "group_experiment", "assistant"]