From 69c60b0090aba04b5e9f6a8f977e31fd86d14b91 Mon Sep 17 00:00:00 2001 From: Morgan Pretty Date: Fri, 14 Mar 2025 10:42:05 +1100 Subject: [PATCH] Track current db tasks and cancel when suspending --- SessionUtilitiesKit/Database/Storage.swift | 29 ++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/SessionUtilitiesKit/Database/Storage.swift b/SessionUtilitiesKit/Database/Storage.swift index b81fa1edf..8a2da7c24 100644 --- a/SessionUtilitiesKit/Database/Storage.swift +++ b/SessionUtilitiesKit/Database/Storage.swift @@ -91,6 +91,9 @@ open class Storage { /// This property gets set the first time we successfully write to the database public private(set) var hasSuccessfullyWritten: Bool = false + /// This property keeps track of all current database tasks and can be used when suspending the database to explicitly + /// cancel any currently running tasks + @ThreadSafeObject private var currentTasks: Set> = [] // MARK: - Initialization @@ -483,7 +486,12 @@ open class Storage { guard !isSuspended else { return } isSuspended = true - Log.info(.storage, "Database access suspended.") + Log.info(.storage, "Database access suspended - cancelling \(currentTasks.count) running task(s).") + + /// Before triggering an `interrupt` (which will forcibly kill in-progress database queries) we want to try to cancel all + /// database tasks to give them a small chance to resolve cleanly before we take a brute-force approach + currentTasks.forEach { $0.cancel() } + _currentTasks.performUpdate { _ in [] } /// Interrupt any open transactions (if this function is called then we are expecting that all processes have finished running /// and don't actually want any more transactions to occur) @@ -660,6 +668,7 @@ open class Storage { let syncQueue = DispatchQueue(label: "com.session.performOperation.syncQueue") let semaphore: DispatchSemaphore = DispatchSemaphore(value: 0) var operationResult: Result? + var operationTask: Task<(), Never>? let logErrorIfNeeded: (Result) -> Result = { result in switch result { case .success: break @@ -673,6 +682,7 @@ open class Storage { func completeOperation(with result: Result) { syncQueue.sync { guard operationResult == nil else { return } + info.storage?.removeTask(operationTask) operationResult = result semaphore.signal() @@ -741,6 +751,10 @@ open class Storage { } } + /// Store the task in case we want to + info.storage?.addTask(task) + operationTask = task + /// For the `async` operation the returned value should be ignored so just return the `invalidQueryResult` error guard !info.isAsync else { return (.failure(StorageError.invalidQueryResult), task) @@ -782,13 +796,24 @@ open class Storage { holder.task = task } } - .handleEvents(receiveCancel: { holder.task?.cancel() }) + .handleEvents(receiveCancel: { [weak self] in + holder.task?.cancel() + self?.removeTask(holder.task) + }) .eraseToAnyPublisher() } } // MARK: - Functions + private func addTask(_ task: Task<(), Never>) { + _currentTasks.performUpdate { $0.inserting(task) } + } + + private func removeTask(_ task: Task<(), Never>?) { + _currentTasks.performUpdate { $0.removing(task) } + } + @discardableResult public func write( fileName file: String = #file, functionName funcN: String = #function,