Updated the rest of the SnodeAPI static variables to be atomic to prevent threading crashes

pull/672/head
Morgan Pretty 3 years ago
parent 8109a326cf
commit 6099bd94d8

@ -202,7 +202,7 @@ public final class MessageSender {
recipient: message.recipient!,
data: base64EncodedData,
ttl: message.ttl,
timestampMs: UInt64(messageSendTimestamp + SnodeAPI.clockOffset)
timestampMs: UInt64(messageSendTimestamp + SnodeAPI.clockOffset.wrappedValue)
)
SnodeAPI

@ -108,7 +108,7 @@ public enum OnionRequestAPI: OnionRequestAPIType {
else {
SNLog("Populating guard snode cache.")
// Sync on LokiAPI.workQueue
var unusedSnodes = SnodeAPI.snodePool.subtracting(reusableGuardSnodes)
var unusedSnodes = SnodeAPI.snodePool.wrappedValue.subtracting(reusableGuardSnodes)
let reusableGuardSnodeCount = UInt(reusableGuardSnodes.count)
guard unusedSnodes.count >= (targetGuardSnodeCount - reusableGuardSnodeCount) else {
@ -156,7 +156,7 @@ public enum OnionRequestAPI: OnionRequestAPIType {
let reusableGuardSnodes = reusablePaths.map { $0[0] }
let promise: Promise<[[Snode]]> = getGuardSnodes(reusing: reusableGuardSnodes)
.map2 { guardSnodes -> [[Snode]] in
var unusedSnodes = SnodeAPI.snodePool
var unusedSnodes = SnodeAPI.snodePool.wrappedValue
.subtracting(guardSnodes)
.subtracting(reusablePaths.flatMap { $0 })
let reusableGuardSnodeCount = UInt(reusableGuardSnodes.count)
@ -285,7 +285,7 @@ public enum OnionRequestAPI: OnionRequestAPIType {
var path = oldPaths[pathIndex]
guard let snodeIndex = path.firstIndex(of: snode) else { return }
path.remove(at: snodeIndex)
let unusedSnodes = SnodeAPI.snodePool.subtracting(oldPaths.flatMap { $0 })
let unusedSnodes = SnodeAPI.snodePool.wrappedValue.subtracting(oldPaths.flatMap { $0 })
guard !unusedSnodes.isEmpty else { throw OnionRequestAPIError.insufficientSnodes }
// randomElement() uses the system's default random generator, which is cryptographically secure
path.append(unusedSnodes.randomElement()!)
@ -672,7 +672,7 @@ public enum OnionRequestAPI: OnionRequestAPIType {
if let timestamp = body["t"] as? Int64 {
let offset = timestamp - Int64(floor(Date().timeIntervalSince1970 * 1000))
SnodeAPI.clockOffset = offset
SnodeAPI.clockOffset.mutate { $0 = offset }
}
guard 200...299 ~= statusCode else {

@ -9,20 +9,20 @@ import SessionUtilitiesKit
public final class SnodeAPI {
private static let sodium = Sodium()
private static var hasLoadedSnodePool = false
private static var loadedSwarms: Set<String> = []
private static var getSnodePoolPromise: Promise<Set<Snode>>?
private static var hasLoadedSnodePool: Atomic<Bool> = Atomic(false)
private static var loadedSwarms: Atomic<Set<String>> = Atomic([])
private static var getSnodePoolPromise: Atomic<Promise<Set<Snode>>?> = Atomic(nil)
/// - Note: Should only be accessed from `Threading.workQueue` to avoid race conditions.
internal static var snodeFailureCount: [Snode: UInt] = [:]
internal static var snodeFailureCount: Atomic<[Snode: UInt]> = Atomic([:])
/// - Note: Should only be accessed from `Threading.workQueue` to avoid race conditions.
internal static var snodePool: Set<Snode> = []
internal static var snodePool: Atomic<Set<Snode>> = Atomic([])
/// The offset between the user's clock and the Service Node's clock. Used in cases where the
/// user's clock is incorrect.
///
/// - Note: Should only be accessed from `Threading.workQueue` to avoid race conditions.
public static var clockOffset: Int64 = 0
public static var clockOffset: Atomic<Int64> = Atomic(0)
/// - Note: Should only be accessed from `Threading.workQueue` to avoid race conditions.
public static var swarmCache: Atomic<[String: Set<Snode>]> = Atomic([:])
@ -48,20 +48,20 @@ public final class SnodeAPI {
// MARK: Snode Pool Interaction
private static var hasInsufficientSnodes: Bool { snodePool.count < minSnodePoolCount }
private static var hasInsufficientSnodes: Bool { snodePool.wrappedValue.count < minSnodePoolCount }
private static func loadSnodePoolIfNeeded() {
guard !hasLoadedSnodePool else { return }
guard !hasLoadedSnodePool.wrappedValue else { return }
Storage.shared.read { db in
snodePool = ((try? Snode.fetchSet(db)) ?? Set())
snodePool.mutate { $0 = ((try? Snode.fetchSet(db)) ?? Set()) }
}
hasLoadedSnodePool = true
hasLoadedSnodePool.mutate { $0 = true }
}
private static func setSnodePool(to newValue: Set<Snode>, db: Database? = nil) {
snodePool = newValue
snodePool.mutate { $0 = newValue }
if let db: Database = db {
_ = try? Snode.deleteAll(db)
@ -79,13 +79,13 @@ public final class SnodeAPI {
#if DEBUG
dispatchPrecondition(condition: .onQueue(Threading.workQueue))
#endif
var snodePool = SnodeAPI.snodePool
var snodePool = SnodeAPI.snodePool.wrappedValue
snodePool.remove(snode)
setSnodePool(to: snodePool)
}
@objc public static func clearSnodePool() {
snodePool.removeAll()
snodePool.mutate { $0.removeAll() }
Threading.workQueue.async {
setSnodePool(to: [])
@ -94,14 +94,14 @@ public final class SnodeAPI {
// MARK: Swarm Interaction
private static func loadSwarmIfNeeded(for publicKey: String) {
guard !loadedSwarms.contains(publicKey) else { return }
guard !loadedSwarms.wrappedValue.contains(publicKey) else { return }
let updatedCacheForKey: Set<Snode> = Storage.shared
.read { db in try Snode.fetchSet(db, publicKey: publicKey) }
.defaulting(to: [])
swarmCache.mutate { $0[publicKey] = updatedCacheForKey }
loadedSwarms.insert(publicKey)
loadedSwarms.mutate { $0.insert(publicKey) }
}
private static func setSwarm(to newValue: Set<Snode>, for publicKey: String, persist: Bool = true) {
@ -232,7 +232,7 @@ public final class SnodeAPI {
}
private static func getSnodePoolFromSnode() -> Promise<Set<Snode>> {
var snodePool = SnodeAPI.snodePool
var snodePool = SnodeAPI.snodePool.wrappedValue
var snodes: Set<Snode> = []
(0..<3).forEach { _ in
guard let snode = snodePool.randomElement() else { return }
@ -301,13 +301,13 @@ public final class SnodeAPI {
let hasSnodePoolExpired = given(Storage.shared[.lastSnodePoolRefreshDate]) {
now.timeIntervalSince($0) > 2 * 60 * 60
}.defaulting(to: true)
let snodePool: Set<Snode> = SnodeAPI.snodePool
let snodePool: Set<Snode> = SnodeAPI.snodePool.wrappedValue
guard hasInsufficientSnodes || hasSnodePoolExpired else {
return Promise.value(snodePool)
}
if let getSnodePoolPromise = getSnodePoolPromise { return getSnodePoolPromise }
if let getSnodePoolPromise = getSnodePoolPromise.wrappedValue { return getSnodePoolPromise }
let promise: Promise<Set<Snode>>
if snodePool.count < minSnodePoolCount {
@ -319,7 +319,7 @@ public final class SnodeAPI {
}
}
getSnodePoolPromise = promise
getSnodePoolPromise.mutate { $0 = promise }
promise.map2 { snodePool -> Set<Snode> in
guard !snodePool.isEmpty else { throw SnodeAPIError.snodePoolUpdatingFailed }
@ -342,10 +342,10 @@ public final class SnodeAPI {
return promise
}
promise.done2 { _ in
getSnodePoolPromise = nil
getSnodePoolPromise.mutate { $0 = nil }
}
promise.catch2 { _ in
getSnodePoolPromise = nil
getSnodePoolPromise.mutate { $0 = nil }
}
return promise
@ -545,7 +545,7 @@ public final class SnodeAPI {
let lastHash = SnodeReceivedMessageInfo.fetchLastNotExpired(for: snode, namespace: namespace, associatedWith: publicKey)?.hash ?? ""
// Construct signature
let timestamp = UInt64(Int64(floor(Date().timeIntervalSince1970 * 1000)) + SnodeAPI.clockOffset)
let timestamp = UInt64(Int64(floor(Date().timeIntervalSince1970 * 1000)) + SnodeAPI.clockOffset.wrappedValue)
let ed25519PublicKey = userED25519KeyPair.publicKey.toHexString()
let namespaceVerificationString = (namespace == defaultNamespace ? "" : String(namespace))
@ -644,7 +644,7 @@ public final class SnodeAPI {
}
// Construct signature
let timestamp = UInt64(Int64(floor(Date().timeIntervalSince1970 * 1000)) + SnodeAPI.clockOffset)
let timestamp = UInt64(Int64(floor(Date().timeIntervalSince1970 * 1000)) + SnodeAPI.clockOffset.wrappedValue)
let ed25519PublicKey = userED25519KeyPair.publicKey.toHexString()
guard
@ -1026,9 +1026,9 @@ public final class SnodeAPI {
dispatchPrecondition(condition: .onQueue(Threading.workQueue))
#endif
func handleBadSnode() {
let oldFailureCount = SnodeAPI.snodeFailureCount[snode] ?? 0
let oldFailureCount = (SnodeAPI.snodeFailureCount.wrappedValue[snode] ?? 0)
let newFailureCount = oldFailureCount + 1
SnodeAPI.snodeFailureCount[snode] = newFailureCount
SnodeAPI.snodeFailureCount.mutate { $0[snode] = newFailureCount }
SNLog("Couldn't reach snode at: \(snode); setting failure count to \(newFailureCount).")
if newFailureCount >= SnodeAPI.snodeFailureThreshold {
SNLog("Failure threshold reached for: \(snode); dropping it.")
@ -1036,8 +1036,8 @@ public final class SnodeAPI {
SnodeAPI.dropSnodeFromSwarmIfNeeded(snode, publicKey: publicKey)
}
SnodeAPI.dropSnodeFromSnodePool(snode)
SNLog("Snode pool count: \(snodePool.count).")
SnodeAPI.snodeFailureCount[snode] = 0
SNLog("Snode pool count: \(snodePool.wrappedValue.count).")
SnodeAPI.snodeFailureCount.mutate { $0[snode] = 0 }
}
}

Loading…
Cancel
Save