diff --git a/advlabdb/database_import.py b/advlabdb/database_import.py index 4675286..bc18861 100644 --- a/advlabdb/database_import.py +++ b/advlabdb/database_import.py @@ -23,7 +23,7 @@ from .models import ( ) relative_db_dir = Path(environ["RELATIVE_DB_DIR"]) -relative_db_path = relative_db_dir / "adblab.db" +relative_db_path = relative_db_dir / "advlab.db" relative_db_bk_dir = relative_db_dir / "backups" relative_db_bk_dir.mkdir(exist_ok=True) @@ -33,12 +33,35 @@ def now(): return datetime.now().strftime("%d_%m_%Y_%H_%M_%S") +def is_null(entry): + return entry == "NULL" or entry == "" + + +def nullable(entry): + if is_null(entry): + return None + + return entry + + +def not_nullable(entry): + if is_null(entry): + raise DataBaseImportException("Unnullable entry is NULL!") + + return entry + + def importFromFile(filePath): if filePath[-4:] != ".txt": raise DataBaseImportException( "The import file has to be a text file with txt extension (.txt at the end of the filename)!" ) + if has_request_context(): + show = flash + else: + show = print + semesters = {} parts = {} students = {} @@ -48,12 +71,7 @@ def importFromFile(filePath): groupExperiments = {} appointments = {} - with open(filePath, "r") as f: # encoding="iso-8859-15" - if has_request_context(): - show = flash - else: - show = print - + with open(filePath, "r") as f: show("Reading file...") expectingTable = True @@ -64,39 +82,39 @@ def importFromFile(filePath): for line in f: line = line[:-1] - if not line: + if line == "": expectingTable = True continue if expectingTable: - if line[0] == "#": - expectingTable = False - tableName = line[1:] - - if tableName == "Semester": - activeDict = semesters - elif tableName == "Part": - activeDict = parts - elif tableName == "Student": - activeDict = students - elif tableName == "Group": - activeDict = groups - elif tableName == "PartStudent": - activeDict = partStudents - elif tableName == "Experiment": - activeDict = experiments - elif tableName == "GroupExperiment": - activeDict = groupExperiments - elif tableName == "Appointment": - activeDict = appointments - else: - raise DataBaseImportException(f"{tableName} is not a valid table name!") - - readHeader = True - continue - else: + if line[0] != "#": raise DataBaseImportException(f"Expected a Table name starting with # but got this line: {line}") + expectingTable = False + tableName = line[1:] + + if tableName == "Semester": + activeDict = semesters + elif tableName == "Part": + activeDict = parts + elif tableName == "Student": + activeDict = students + elif tableName == "Group": + activeDict = groups + elif tableName == "PartStudent": + activeDict = partStudents + elif tableName == "Experiment": + activeDict = experiments + elif tableName == "GroupExperiment": + activeDict = groupExperiments + elif tableName == "Appointment": + activeDict = appointments + else: + raise DataBaseImportException(f"{tableName} is not a valid table name!") + + readHeader = True + continue + cells = line.split("\t") if readHeader: @@ -109,24 +127,23 @@ def importFromFile(filePath): continue cellsLen = len(cells) - if cellsLen == len(activeDict["_header"]): - for i in range(cellsLen): - activeDict[activeDict["_header"][i]].append(cells[i]) - else: + if cellsLen != len(activeDict["_header"]): raise DataBaseImportException( f"The number of header cells is not equal to the number of row cells in row {cells}!" ) - db.session.rollback() - with db.session.begin(): + for i in range(cellsLen): + activeDict[activeDict["_header"][i]].append(cells[i]) + + try: # Semester show("Semester...") - if len(semesters["label"]) * len(semesters["year"]) != 1: + if len(semesters["label"]) != 1: raise DataBaseImportException("Only one semester is allowed in an import file!") - semesterLabel = semesters["label"][0] - semesterYear = int(semesters["year"][0]) + semesterLabel = not_nullable(semesters["label"][0]) + semesterYear = int(not_nullable(semesters["year"][0])) dbSemester = Semester.query.filter(Semester.label == semesterLabel, Semester.year == semesterYear).first() if not dbSemester: @@ -139,9 +156,9 @@ def importFromFile(filePath): dbParts = {} for i, id in enumerate(parts["id"]): - id = int(id) - partNumber = int(parts["number"][i]) - partProgramLabel = parts["program_label"][i] + id = int(not_nullable(id)) + partNumber = int(not_nullable(parts["number"][i])) + partProgramLabel = not_nullable(parts["program_label"][i]) dbPart = Part.query.filter( Part.number == partNumber, Part.program.has(Program.label == partProgramLabel), @@ -160,26 +177,26 @@ def importFromFile(filePath): dbStudents = {} for i, studentNumber in enumerate(students["student_number"]): - studentNumber = int(studentNumber) + studentNumber = int(not_nullable(studentNumber)) dbStudent = Student.query.filter(Student.student_number == studentNumber).first() if not dbStudent: dbStudent = Student( student_number=studentNumber, - first_name=students["first_name"][i], - last_name=students["last_name"][i], - uni_email=students["uni_email"][i], - contact_email=students["contact_email"][i] or None, - bachelor_thesis=students["bachelor_thesis"][i] or None, - bachelor_thesis_work_group=students["bachelor_thesis_work_group"][i] or None, - note=students["note"][i] or None, + first_name=not_nullable(students["first_name"][i]), + last_name=not_nullable(students["last_name"][i]), + uni_email=not_nullable(students["uni_email"][i]), + contact_email=nullable(students["contact_email"][i]), + bachelor_thesis=nullable(students["bachelor_thesis"][i]), + bachelor_thesis_work_group=nullable(students["bachelor_thesis_work_group"][i]), + note=nullable(students["note"][i]), ) db.session.add(dbStudent) else: - dbStudent.contact_email = students["contact_email"][i] or None - dbStudent.bachelor_thesis = students["bachelor_thesis"][i] or None - dbStudent.bachelor_thesis_work_group = students["bachelor_thesis_work_group"][i] or None - dbStudent.note = students["note"][i] or None + dbStudent.contact_email = nullable(students["contact_email"][i]) + dbStudent.bachelor_thesis = nullable(students["bachelor_thesis"][i]) + dbStudent.bachelor_thesis_work_group = nullable(students["bachelor_thesis_work_group"][i]) + dbStudent.note = nullable(students["note"][i]) dbStudents[studentNumber] = dbStudent @@ -188,10 +205,10 @@ def importFromFile(filePath): dbGroups = {} for i, id in enumerate(groups["id"]): - id = int(id) + id = int(not_nullable(id)) dbGroup = Group( - number=int(groups["number"][i]), - program=Program.query.filter(Program.label == groups["program_label"][i]).first(), + number=int(not_nullable(groups["number"][i])), + program=Program.query.filter(Program.label == not_nullable(groups["program_label"][i])).first(), semester=dbSemester, ) db.session.add(dbGroup) @@ -201,11 +218,11 @@ def importFromFile(filePath): show("PartStudent...") for i, studentNumber in enumerate(partStudents["student_number"]): - studentNumber = int(studentNumber) + studentNumber = int(not_nullable(studentNumber)) dbPartStudent = PartStudent( student=dbStudents[studentNumber], - part=dbParts[int(partStudents["part_id"][i])], - group=dbGroups[int(partStudents["group_id"][i])], + part=dbParts[int(not_nullable(partStudents["part_id"][i]))], + group=dbGroups[int(not_nullable(partStudents["group_id"][i]))], ) db.session.add(dbPartStudent) @@ -214,14 +231,17 @@ def importFromFile(filePath): dbSemesterExperiments = {} for i, id in enumerate(experiments["id"]): - id = int(id) - experimentNumber = int(experiments["number"][i]) - experimentProgram = Program.query.filter(Program.label == experiments["program_label"][i]).first() + id = int(not_nullable(id)) + experimentNumber = int(not_nullable(experiments["number"][i])) + experimentProgram = Program.query.filter( + Program.label == not_nullable(experiments["program_label"][i]) + ).first() dbExperiment = Experiment.query.filter( Experiment.number == experimentNumber, Experiment.program == experimentProgram ).first() if not dbExperiment: + # TODO: Check if experimentProgram is None raise DataBaseImportException( f"Experiment with number {experimentNumber} and program {experimentProgram.repr()} does not exist in the database. Please make sure that experiments are created from the web interface." ) @@ -242,10 +262,10 @@ def importFromFile(filePath): dbGroupExperiments = {} for i, id in enumerate(groupExperiments["id"]): - id = int(id) + id = int(not_nullable(id)) dbGroupExperiment = GroupExperiment( - semester_experiment=dbSemesterExperiments[int(groupExperiments["experiment_id"][i])], - group=dbGroups[int(groupExperiments["group_id"][i])], + semester_experiment=dbSemesterExperiments[int(not_nullable(groupExperiments["experiment_id"][i]))], + group=dbGroups[int(not_nullable(groupExperiments["group_id"][i]))], ) db.session.add(dbGroupExperiment) dbGroupExperiments[id] = dbGroupExperiment @@ -254,18 +274,19 @@ def importFromFile(filePath): show("Appointment...") for i, date in enumerate(appointments["date"]): - assistantEmail = appointments["assistant_email"][i] + date = not_nullable(date) + assistantEmail = not_nullable(appointments["assistant_email"][i]) assistant = Assistant.query.filter(Assistant.user.has(User.email == assistantEmail)).first() if assistant is None: raise DataBaseImportException( - f"Assistant with email {email} does not exist in the database! Please make sure that you create assistants in the web interface." + f"Assistant with email {assistantEmail} does not exist in the database! Please make sure that you create assistants in the web interface." ) dbAppointment = Appointment( date=datetime.strptime(date, "%d.%m.%Y").date(), - special=bool(int(appointments["special"][i])), - group_experiment=dbGroupExperiments[int(appointments["group_experiment_id"][i])], + special=bool(int(not_nullable(appointments["special"][i]))), + group_experiment=dbGroupExperiments[int(not_nullable(appointments["group_experiment_id"][i]))], assistant=assistant, ) db.session.add(dbAppointment) @@ -274,9 +295,13 @@ def importFromFile(filePath): dest = relative_db_bk_dir / f"before_{dbSemester.repr()}_import_{now()}.db" copy2(relative_db_path, dest) - show(f"Made a backup of the database before the import at {dest}") + show(f"Made a backup of the database before commiting the import at {dest}") - # Auto commit from the transaction context + db.session.commit() + except Exception as ex: + db.session.rollback() + + raise ex dest = relative_db_bk_dir / f"after_{dbSemester.repr()}_import_{now()}.db" copy2(relative_db_path, dest)