Track current db tasks and cancel when suspending

pull/1061/head
Morgan Pretty 2 months ago
parent 846aa695c2
commit 69c60b0090

@ -91,6 +91,9 @@ open class Storage {
/// This property gets set the first time we successfully write to the database /// This property gets set the first time we successfully write to the database
public private(set) var hasSuccessfullyWritten: Bool = false public private(set) var hasSuccessfullyWritten: Bool = false
/// This property keeps track of all current database tasks and can be used when suspending the database to explicitly
/// cancel any currently running tasks
@ThreadSafeObject private var currentTasks: Set<Task<(), Never>> = []
// MARK: - Initialization // MARK: - Initialization
@ -483,7 +486,12 @@ open class Storage {
guard !isSuspended else { return } guard !isSuspended else { return }
isSuspended = true isSuspended = true
Log.info(.storage, "Database access suspended.") Log.info(.storage, "Database access suspended - cancelling \(currentTasks.count) running task(s).")
/// Before triggering an `interrupt` (which will forcibly kill in-progress database queries) we want to try to cancel all
/// database tasks to give them a small chance to resolve cleanly before we take a brute-force approach
currentTasks.forEach { $0.cancel() }
_currentTasks.performUpdate { _ in [] }
/// Interrupt any open transactions (if this function is called then we are expecting that all processes have finished running /// Interrupt any open transactions (if this function is called then we are expecting that all processes have finished running
/// and don't actually want any more transactions to occur) /// and don't actually want any more transactions to occur)
@ -660,6 +668,7 @@ open class Storage {
let syncQueue = DispatchQueue(label: "com.session.performOperation.syncQueue") let syncQueue = DispatchQueue(label: "com.session.performOperation.syncQueue")
let semaphore: DispatchSemaphore = DispatchSemaphore(value: 0) let semaphore: DispatchSemaphore = DispatchSemaphore(value: 0)
var operationResult: Result<T, Error>? var operationResult: Result<T, Error>?
var operationTask: Task<(), Never>?
let logErrorIfNeeded: (Result<T, Error>) -> Result<T, Error> = { result in let logErrorIfNeeded: (Result<T, Error>) -> Result<T, Error> = { result in
switch result { switch result {
case .success: break case .success: break
@ -673,6 +682,7 @@ open class Storage {
func completeOperation(with result: Result<T, Error>) { func completeOperation(with result: Result<T, Error>) {
syncQueue.sync { syncQueue.sync {
guard operationResult == nil else { return } guard operationResult == nil else { return }
info.storage?.removeTask(operationTask)
operationResult = result operationResult = result
semaphore.signal() semaphore.signal()
@ -741,6 +751,10 @@ open class Storage {
} }
} }
/// Store the task in case we want to
info.storage?.addTask(task)
operationTask = task
/// For the `async` operation the returned value should be ignored so just return the `invalidQueryResult` error /// For the `async` operation the returned value should be ignored so just return the `invalidQueryResult` error
guard !info.isAsync else { guard !info.isAsync else {
return (.failure(StorageError.invalidQueryResult), task) return (.failure(StorageError.invalidQueryResult), task)
@ -782,13 +796,24 @@ open class Storage {
holder.task = task holder.task = task
} }
} }
.handleEvents(receiveCancel: { holder.task?.cancel() }) .handleEvents(receiveCancel: { [weak self] in
holder.task?.cancel()
self?.removeTask(holder.task)
})
.eraseToAnyPublisher() .eraseToAnyPublisher()
} }
} }
// MARK: - Functions // MARK: - Functions
private func addTask(_ task: Task<(), Never>) {
_currentTasks.performUpdate { $0.inserting(task) }
}
private func removeTask(_ task: Task<(), Never>?) {
_currentTasks.performUpdate { $0.removing(task) }
}
@discardableResult public func write<T>( @discardableResult public func write<T>(
fileName file: String = #file, fileName file: String = #file,
functionName funcN: String = #function, functionName funcN: String = #function,

Loading…
Cancel
Save