Fixed a few TODOs and cleaned up some logic

• Refactored the internal logic for the Dependencies class to resolve some multithreading issues
• Fixed a crash caused by using the wrong config when updating userGroups
• Fixed an issue where a group conversation you had been kicked from wasn't showing the "kicked" copy as it's snippet
• Fixed a copy TODO
• Removed some duplicate code from the Edit Group screen
pull/894/head
Morgan Pretty 6 months ago
parent c1b11a6196
commit 46533cbe1a

@ -273,7 +273,7 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
identifier: "Invite by id",
label: "Invite by id"
),
onTap: { [weak self] in self?.inviteById() }
onTap: { [weak self] in self?.inviteById(currentGroupName: state.group.name) }
)
)
].compactMap { $0 }
@ -438,42 +438,10 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
throw UserListError.error("groupAddMemberMaximum".localized())
}
/// Show a toast that we have sent the invitations
self?.showToast(
text: (selectedMemberInfo.count == 1 ?
"groupInviteSending".localized() :
"groupInviteSending".localized()
),
backgroundColor: .backgroundSecondary
self?.addMembers(
currentGroupName: currentGroupName,
memberInfo: selectedMemberInfo.map { ($0.profileId, $0.profile) }
)
/// Actually trigger the sending
MessageSender
.addGroupMembers(
groupSessionId: threadId,
members: selectedMemberInfo.map { ($0.profileId, $0.profile) },
allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite],
using: dependencies
)
.sinkUntilComplete(
receiveCompletion: { result in
switch result {
case .finished: break
case .failure:
viewModel?.showToast(
text: GroupInviteMemberJob.failureMessage(
groupName: currentGroupName,
memberIds: selectedMemberInfo.map { $0.profileId },
profileInfo: selectedMemberInfo
.reduce(into: [:]) { result, next in
result[next.profileId] = next.profile
}
),
backgroundColor: .backgroundSecondary
)
}
}
)
}
case .standard: // Assume it's a legacy group
@ -506,7 +474,7 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
)
}
private func inviteById() {
private func inviteById(currentGroupName: String) {
// Convenience functions to avoid duplicate code
func showError(_ errorString: String) {
let modal: ConfirmationModal = ConfirmationModal(
@ -520,25 +488,6 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
)
self.transitionToScreen(modal, transitionType: .present)
}
func inviteMember(_ accountId: String, _ modal: UIViewController) {
guard !currentMemberIds.contains(accountId) else {
// FIXME: Localise this
return showError("This Account ID or ONS belongs to an existing member")
}
MessageSender.addGroupMembers(
groupSessionId: threadId,
members: [(accountId, nil)],
allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite],
using: dependencies
).sinkUntilComplete()
modal.dismiss(animated: true) { [weak self] in
self?.showToast(
text: "groupInviteSending".localized(),
backgroundColor: .backgroundSecondary
)
}
}
let currentMemberIds: Set<String> = (tableData
.first(where: { $0.model == .members })?
@ -573,16 +522,22 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
cancelStyle: .alert_text,
dismissOnConfirm: false,
onConfirm: { [weak self, dependencies] modal in
// FIXME: Consolidate this with the logic in `NewDMVC`
switch Result(catching: { try SessionId(from: self?.inviteByIdValue) }) {
case .success(let sessionId) where sessionId.prefix == .standard: inviteMember(sessionId.hexString, modal)
case .success: return showError("accountIdErrorInvalid".localized())
switch (self?.inviteByIdValue, try? SessionId(from: self?.inviteByIdValue)) {
case (_, .some(let sessionId)) where sessionId.prefix == .standard:
guard !currentMemberIds.contains(sessionId.hexString) else {
return showError("This Account ID or ONS belongs to an existing member")
}
case .failure:
guard let inviteByIdValue: String = self?.inviteByIdValue else {
return showError("accountIdErrorInvalid".localized())
modal.dismiss(animated: true) {
self?.addMembers(
currentGroupName: currentGroupName,
memberInfo: [(sessionId.hexString, nil)]
)
}
case (.none, _), (_, .some): return showError("accountIdErrorInvalid".localized())
case (.some(let inviteByIdValue), _):
// This could be an ONS name
let viewController = ModalActivityIndicatorViewController() { modalActivityIndicator in
SnodeAPI
@ -605,8 +560,17 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
}
},
receiveValue: { sessionIdHexString in
guard !currentMemberIds.contains(sessionIdHexString) else {
return showError("This Account ID or ONS belongs to an existing member")
}
modalActivityIndicator.dismiss {
inviteMember(sessionIdHexString, modal)
modal.dismiss(animated: true) {
self?.addMembers(
currentGroupName: currentGroupName,
memberInfo: [(sessionIdHexString, nil)]
)
}
}
}
)
@ -621,6 +585,48 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl
)
}
private func addMembers(
currentGroupName: String,
memberInfo: [(id: String, profile: Profile?)]
) {
/// Show a toast that we have sent the invitations
self.showToast(
text: (memberInfo.count == 1 ?
"groupInviteSending".localized() :
"groupInviteSending".localized()
),
backgroundColor: .backgroundSecondary
)
/// Actually trigger the sending
MessageSender
.addGroupMembers(
groupSessionId: threadId,
members: memberInfo,
allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite],
using: dependencies
)
.sinkUntilComplete(
receiveCompletion: { [weak self] result in
switch result {
case .finished: break
case .failure:
self?.showToast(
text: GroupInviteMemberJob.failureMessage(
groupName: currentGroupName,
memberIds: memberInfo.map { $0.id },
profileInfo: memberInfo
.reduce(into: [:]) { result, next in
result[next.id] = next.profile
}
),
backgroundColor: .backgroundSecondary
)
}
}
)
}
private func resendInvitation(memberId: String) {
MessageSender.resendInvitation(
groupSessionId: threadId,

@ -218,6 +218,10 @@ public class ConversationViewModel: OWSAudioPlayerDelegate, NavigatableStateHold
).populatingCurrentUserBlindedIds(
currentUserBlinded15SessionIdForThisThread: initialData?.blinded15SessionId?.hexString,
currentUserBlinded25SessionIdForThisThread: initialData?.blinded25SessionId?.hexString,
wasKickedFromGroup: (
threadVariant == .group &&
LibSession.wasKickedFromGroup(groupSessionId: SessionId(.group, hex: threadId), using: dependencies)
),
using: dependencies
)
)
@ -288,6 +292,13 @@ public class ConversationViewModel: OWSAudioPlayerDelegate, NavigatableStateHold
db,
currentUserBlinded15SessionIdForThisThread: self?.threadData.currentUserBlinded15SessionId,
currentUserBlinded25SessionIdForThisThread: self?.threadData.currentUserBlinded25SessionId,
wasKickedFromGroup: (
viewModel.threadVariant == .group &&
LibSession.wasKickedFromGroup(
groupSessionId: SessionId(.group, hex: viewModel.threadId),
using: dependencies
)
),
using: dependencies
)
}

@ -978,7 +978,7 @@ class ThreadSettingsViewModel: SessionTableViewModel, NavigatableStateHolder, Ob
self?.updatedName != current &&
self?.updatedName?.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty == false
},
cancelTitle: "Reset",//.localized(),
cancelTitle: "remove".localized(),
cancelStyle: .danger,
cancelEnabled: .bool(current?.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty == false),
hasCloseButton: true,

@ -374,6 +374,13 @@ public class HomeViewModel: NavigatableStateHolder {
currentUserBlinded25SessionIdForThisThread: groupedOldData[viewModel.threadId]?
.first?
.currentUserBlinded25SessionId,
wasKickedFromGroup: (
viewModel.threadVariant == .group &&
LibSession.wasKickedFromGroup(
groupSessionId: SessionId(.group, hex: viewModel.threadId),
using: dependencies
)
),
using: dependencies
)
}

@ -166,6 +166,13 @@ class MessageRequestsViewModel: SessionTableViewModel, NavigatableStateHolder, O
.first?
.id
.currentUserBlinded25SessionId,
wasKickedFromGroup: (
viewModel.threadVariant == .group &&
LibSession.wasKickedFromGroup(
groupSessionId: SessionId(.group, hex: viewModel.threadId),
using: dependencies
)
),
using: dependencies
),
accessibility: Accessibility(

@ -73,7 +73,7 @@ public class NotificationPresenter: NSObject, UNUserNotificationCenterDelegate,
)
// While batch processing, some of the necessary changes have not been commited.
let rawMessageText = interaction.previewText(db, using: dependencies)
let rawMessageText: String = Interaction.notificationPreviewText(db, interaction: interaction, using: dependencies)
// iOS strips anything that looks like a printf formatting character from
// the notification body, so if we want to dispay a literal "%" in a notification

@ -588,9 +588,17 @@ public final class FullConversationCell: UITableViewCell, SwipeActionOptimisticC
cellViewModel: SessionThreadViewModel,
textColor: UIColor,
using dependencies: Dependencies
) -> NSMutableAttributedString {
) -> NSAttributedString {
guard cellViewModel.wasKickedFromGroup != true else {
return NSAttributedString(
string: "groupRemovedYou"
.put(key: "group_name", value: cellViewModel.displayName)
.localizedDeformatted()
)
}
// If we don't have an interaction then do nothing
guard cellViewModel.interactionId != nil else { return NSMutableAttributedString() }
guard cellViewModel.interactionId != nil else { return NSAttributedString() }
let result = NSMutableAttributedString()
@ -635,21 +643,24 @@ public final class FullConversationCell: UITableViewCell, SwipeActionOptimisticC
}
let previewText: String = {
if cellViewModel.interactionVariant == .infoGroupCurrentUserErrorLeaving {
return "groupLeaveErrorFailed"
.put(key: "group_name", value: cellViewModel.displayName)
.localized()
switch cellViewModel.interactionVariant {
case .infoGroupCurrentUserErrorLeaving:
return "groupLeaveErrorFailed"
.put(key: "group_name", value: cellViewModel.displayName)
.localized()
default:
return Interaction.previewText(
variant: (cellViewModel.interactionVariant ?? .standardIncoming),
body: cellViewModel.interactionBody,
threadContactDisplayName: cellViewModel.threadContactName(),
authorDisplayName: cellViewModel.authorName(for: cellViewModel.threadVariant),
attachmentDescriptionInfo: cellViewModel.interactionAttachmentDescriptionInfo,
attachmentCount: cellViewModel.interactionAttachmentCount,
isOpenGroupInvitation: (cellViewModel.interactionIsOpenGroupInvitation == true),
using: dependencies
)
}
return Interaction.previewText(
variant: (cellViewModel.interactionVariant ?? .standardIncoming),
body: cellViewModel.interactionBody,
threadContactDisplayName: cellViewModel.threadContactName(),
authorDisplayName: cellViewModel.authorName(for: cellViewModel.threadVariant),
attachmentDescriptionInfo: cellViewModel.interactionAttachmentDescriptionInfo,
attachmentCount: cellViewModel.interactionAttachmentCount,
isOpenGroupInvitation: (cellViewModel.interactionIsOpenGroupInvitation == true),
using: dependencies
)
}()
result.append(NSAttributedString(

@ -1025,20 +1025,23 @@ public extension Interaction {
}
}
/// Use the `Interaction.previewText` method directly where possible rather than this method as it
/// makes it's own database queries
func previewText(_ db: Database, using dependencies: Dependencies) -> String {
switch variant {
/// Use the `Interaction.previewText` method directly where possible rather than this one to avoid database queries
static func notificationPreviewText(
_ db: Database,
interaction: Interaction,
using dependencies: Dependencies
) -> String {
switch interaction.variant {
case .standardIncoming, .standardOutgoing:
return Interaction.previewText(
variant: self.variant,
body: self.body,
attachmentDescriptionInfo: try? attachments
variant: interaction.variant,
body: interaction.body,
attachmentDescriptionInfo: try? interaction.attachments
.select(.id, .variant, .contentType, .sourceFilename)
.asRequest(of: Attachment.DescriptionInfo.self)
.fetchOne(db),
attachmentCount: try? attachments.fetchCount(db),
isOpenGroupInvitation: linkPreview
attachmentCount: try? interaction.attachments.fetchCount(db),
isOpenGroupInvitation: interaction.linkPreview
.filter(LinkPreview.Columns.variant == LinkPreview.Variant.openGroupInvitation)
.isNotEmpty(db),
using: dependencies
@ -1048,15 +1051,15 @@ public extension Interaction {
// Note: These should only occur in 'contact' threads so the `threadId`
// is the contact id
return Interaction.previewText(
variant: self.variant,
body: self.body,
authorDisplayName: Profile.displayName(db, id: threadId, using: dependencies),
variant: interaction.variant,
body: interaction.body,
authorDisplayName: Profile.displayName(db, id: interaction.threadId, using: dependencies),
using: dependencies
)
default: return Interaction.previewText(
variant: self.variant,
body: self.body,
variant: interaction.variant,
body: interaction.body,
using: dependencies
)
}

@ -700,8 +700,9 @@ extension MessageReceiver {
/// that if the user doesn't delete the group and links a new device, the group will have the same name as on the current device
if !LibSession.wasKickedFromGroup(groupSessionId: groupSessionId, using: dependencies) {
dependencies.mutate(cache: .libSession) { cache in
let config: LibSession.Config? = cache.config(for: .groupInfo, sessionId: groupSessionId)
let groupName: String? = try? LibSession.groupName(in: config)
let groupInfoConfig: LibSession.Config? = cache.config(for: .groupInfo, sessionId: groupSessionId)
let userGroupsConfig: LibSession.Config? = cache.config(for: .userGroups, sessionId: userSessionId)
let groupName: String? = try? LibSession.groupName(in: groupInfoConfig)
switch groupName {
case .none: Log.warn(.messageReceiver, "Failed to update group name before being kicked.")
@ -713,7 +714,7 @@ extension MessageReceiver {
name: name
)
],
in: config,
in: userGroupsConfig,
using: dependencies
)
}

@ -83,6 +83,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat
case currentUserBlinded15SessionId
case currentUserBlinded25SessionId
case recentReactionEmoji
case wasKickedFromGroup
}
public var differenceIdentifier: String { threadId }
@ -131,11 +132,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat
)
case .group:
let groupSessionId: SessionId = SessionId(.group, hex: threadId)
guard !LibSession.wasKickedFromGroup(groupSessionId: groupSessionId, using: dependencies) else {
return false
}
guard wasKickedFromGroup != true else { return false }
guard threadIsMessageRequest == false else { return true }
return (
@ -190,6 +187,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat
public let currentUserBlinded15SessionId: String?
public let currentUserBlinded25SessionId: String?
public let recentReactionEmoji: [String]?
public let wasKickedFromGroup: Bool?
// UI specific logic
@ -471,6 +469,7 @@ public extension SessionThreadViewModel {
self.currentUserBlinded15SessionId = nil
self.currentUserBlinded25SessionId = nil
self.recentReactionEmoji = nil
self.wasKickedFromGroup = false
}
}
@ -535,14 +534,16 @@ public extension SessionThreadViewModel {
currentUserSessionId: self.currentUserSessionId,
currentUserBlinded15SessionId: self.currentUserBlinded15SessionId,
currentUserBlinded25SessionId: self.currentUserBlinded25SessionId,
recentReactionEmoji: (recentReactionEmoji ?? self.recentReactionEmoji)
recentReactionEmoji: (recentReactionEmoji ?? self.recentReactionEmoji),
wasKickedFromGroup: self.wasKickedFromGroup
)
}
func populatingCurrentUserBlindedIds(
_ db: Database? = nil,
currentUserBlinded15SessionIdForThisThread: String? = nil,
currentUserBlinded25SessionIdForThisThread: String? = nil,
currentUserBlinded15SessionIdForThisThread: String?,
currentUserBlinded25SessionIdForThisThread: String?,
wasKickedFromGroup: Bool,
using dependencies: Dependencies
) -> SessionThreadViewModel {
return SessionThreadViewModel(
@ -618,7 +619,8 @@ public extension SessionThreadViewModel {
using: dependencies
)?.hexString
),
recentReactionEmoji: self.recentReactionEmoji
recentReactionEmoji: self.recentReactionEmoji,
wasKickedFromGroup: wasKickedFromGroup
)
}
}

@ -66,7 +66,8 @@ public class NSENotificationPresenter: NotificationsManagerType {
.localized()
}
let snippet: String = (interaction.previewText(db, using: dependencies)
let snippet: String = (Interaction
.notificationPreviewText(db, interaction: interaction, using: dependencies)
.filteredForDisplay
.nullIfEmpty?
.replacingMentions(for: thread.id, using: dependencies))

@ -1,4 +1,4 @@
// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
// Copyright © 2024 Rangeproof Pty Ltd. All rights reserved.
//
// stringlint:disable
@ -6,50 +6,18 @@ import Foundation
import Combine
public class Dependencies {
static let userInfoKey: CodingUserInfoKey = CodingUserInfoKey(rawValue: "io.oxen.dependencies.codingOptions")!
static let userInfoKey: CodingUserInfoKey = CodingUserInfoKey(rawValue: "session.dependencies.codingOptions")!
private static var _isRTLRetriever: Atomic<(Bool, () -> Bool)> = Atomic((false, { false }))
private static var singletonInstances: Atomic<[String: Any]> = Atomic([:])
private static var cacheInstances: Atomic<[String: Atomic<MutableCacheType>]> = Atomic([:])
private static var userDefaultsInstances: Atomic<[String: (any UserDefaultsType)]> = Atomic([:])
private static var featureInstances: Atomic<[String: (any FeatureType)]> = Atomic([:])
private var featureChangeSubject: PassthroughSubject<(String, String?, Any?), Never> = PassthroughSubject()
private let featureChangeSubject: PassthroughSubject<(String, String?, Any?), Never> = PassthroughSubject()
private var storage: Atomic<DependencyStorage> = Atomic(DependencyStorage())
// MARK: - Subscript Access
public subscript<S>(singleton singleton: SingletonConfig<S>) -> S {
guard let value: S = (Dependencies.singletonInstances.wrappedValue[singleton.identifier] as? S) else {
let value: S = singleton.createInstance(self)
Dependencies.singletonInstances.mutate { $0[singleton.identifier] = value }
return value
}
return value
}
public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I {
getValueSettingIfNull(cache: cache)
}
public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType {
guard let value: UserDefaultsType = Dependencies.userDefaultsInstances.wrappedValue[defaults.identifier] else {
let value: UserDefaultsType = defaults.createInstance(self)
Dependencies.userDefaultsInstances.mutate { $0[defaults.identifier] = value }
return value
}
return value
}
public subscript<T: FeatureOption>(feature feature: FeatureConfig<T>) -> T {
guard let value: Feature<T> = (Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature<T>) else {
let value: Feature<T> = feature.createInstance(self)
Dependencies.featureInstances.mutate { $0[feature.identifier] = value }
return value.currentValue(using: self)
}
return value.currentValue(using: self)
}
public subscript<S>(singleton singleton: SingletonConfig<S>) -> S { getOrCreate(singleton) }
public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I { getOrCreate(cache).immutable(cache: cache, using: self) }
public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType { getOrCreate(defaults) }
public subscript<T: FeatureOption>(feature feature: FeatureConfig<T>) -> T { getOrCreate(feature).currentValue(using: self) }
// MARK: - Global Values, Timing and Async Handling
@ -85,20 +53,15 @@ public class Dependencies {
cache: CacheConfig<M, I>,
_ mutation: (inout M) -> R
) -> R {
/// The cast from `Atomic<MutableCacheType>` to `Atomic<M>` always fails so we need to do some
/// stuffing around to ensure we have the right types - since we call `createInstance` multiple times in
/// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored
/// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some
/// random instance that will go out of memory as soon as the mutation is completed
getValueSettingIfNull(cache: cache)
let cacheWrapper: Atomic<MutableCacheType> = (
Dependencies.cacheInstances.wrappedValue[cache.identifier] ??
Atomic(cache.mutableInstance(cache.createInstance(self))) // Should never be called
)
return cacheWrapper.mutate { erasedValue in
var value: M = ((erasedValue as? M) ?? cache.createInstance(self))
return getOrCreate(cache).mutate { erasedValue in
guard var value: M = (erasedValue as? M) else {
/// This code path should never happen (and is essentially invalid if it does) but in order to avoid neeing to return
/// a nullable type or force-casting this is how we need to do things)
Log.critical("Failed to convert erased cache value for '\(cache.identifier)' to expected type: \(M.self)")
var fallbackValue: M = cache.createInstance(self)
return mutation(&fallbackValue)
}
return mutation(&value)
}
}
@ -107,20 +70,15 @@ public class Dependencies {
cache: CacheConfig<M, I>,
_ mutation: (inout M) throws -> R
) throws -> R {
/// The cast from `Atomic<MutableCacheType>` to `Atomic<M>` always fails so we need to do some
/// stuffing around to ensure we have the right types - since we call `createInstance` multiple times in
/// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored
/// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some
/// random instance that will go out of memory as soon as the mutation is completed
getValueSettingIfNull(cache: cache)
let cacheWrapper: Atomic<MutableCacheType> = (
Dependencies.cacheInstances.wrappedValue[cache.identifier] ??
Atomic(cache.mutableInstance(cache.createInstance(self))) // Should never be called
)
return try cacheWrapper.mutate { erasedValue in
var value: M = ((erasedValue as? M) ?? cache.createInstance(self))
return try getOrCreate(cache).mutate { erasedValue in
guard var value: M = (erasedValue as? M) else {
/// This code path should never happen (and is essentially invalid if it does) but in order to avoid neeing to return
/// a nullable type or force-casting this is how we need to do things)
Log.critical("Failed to convert erased cache value for '\(cache.identifier)' to expected type: \(M.self)")
var fallbackValue: M = cache.createInstance(self)
return try mutation(&fallbackValue)
}
return try mutation(&value)
}
}
@ -138,41 +96,29 @@ public class Dependencies {
public func popRandomElement<T>(_ elements: inout Set<T>) -> T? {
return elements.popRandomElement()
}
// MARK: - Instance upserting
@discardableResult private func getValueSettingIfNull<M, I>(cache: CacheConfig<M, I>) -> I {
guard let value: M = (Dependencies.cacheInstances.wrappedValue[cache.identifier]?.wrappedValue as? M) else {
let value: M = cache.createInstance(self)
let mutableInstance: MutableCacheType = cache.mutableInstance(value)
Dependencies.cacheInstances.mutate { $0[cache.identifier] = Atomic(mutableInstance) }
return cache.immutableInstance(value)
}
return cache.immutableInstance(value)
}
// MARK: - Instance replacing
public func warmCache<M, I>(cache: CacheConfig<M, I>) {
_ = getValueSettingIfNull(cache: cache)
_ = getOrCreate(cache)
}
public func set<S>(singleton: SingletonConfig<S>, to instance: S) {
Dependencies.singletonInstances.mutate {
$0[singleton.identifier] = instance
threadSafeChange(for: singleton.identifier) {
setValue(instance, typedStorage: .singleton(instance), key: singleton.identifier)
}
}
public func set<M, I>(cache: CacheConfig<M, I>, to instance: M) {
Dependencies.cacheInstances.mutate {
$0[cache.identifier] = Atomic(cache.mutableInstance(instance))
threadSafeChange(for: cache.identifier) {
let value: Atomic<MutableCacheType> = Atomic(cache.mutableInstance(instance))
setValue(value, typedStorage: .cache(value), key: cache.identifier)
}
}
public func remove<M, I>(cache: CacheConfig<M, I>) {
Dependencies.cacheInstances.mutate {
$0[cache.identifier] = nil
threadSafeChange(for: cache.identifier) {
removeValue(cache.identifier)
}
}
@ -181,6 +127,17 @@ public class Dependencies {
}
}
// MARK: - Cache Management
private extension Atomic<MutableCacheType> {
func immutable<M, I>(cache: CacheConfig<M, I>, using dependencies: Dependencies) -> I {
return cache.immutableInstance(
(self.wrappedValue as? M) ??
cache.createInstance(dependencies)
)
}
}
// MARK: - Feature Management
public extension Dependencies {
@ -215,30 +172,26 @@ public extension Dependencies {
}
func set<T: FeatureOption>(feature: FeatureConfig<T>, to updatedFeature: T?) {
let value: Feature<T> = {
guard let value: Feature<T> = (Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature<T>) else {
let value: Feature<T> = feature.createInstance(self)
Dependencies.featureInstances.mutate { $0[feature.identifier] = value }
return value
}
return value
}()
threadSafeChange(for: feature.identifier) {
/// Update the cached & in-memory values
let instance: Feature<T> = (
getValue(feature.identifier) ??
feature.createInstance(self)
)
instance.setValue(to: updatedFeature, using: self)
setValue(instance, typedStorage: .feature(instance), key: feature.identifier)
}
value.setValue(to: updatedFeature, using: self)
/// Notify observers
featureChangeSubject.send((feature.identifier, feature.groupIdentifier, updatedFeature))
}
func reset<T: FeatureOption>(feature: FeatureConfig<T>) {
/// Reset the cached value
switch Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature<T> {
case .none: break
case .some(let value): value.setValue(to: nil, using: self)
}
/// Reset the in-memory value
Dependencies.featureInstances.mutate {
$0[feature.identifier] = nil
threadSafeChange(for: feature.identifier) {
/// Reset the cached and in-memory values
let instance: Feature<T>? = getValue(feature.identifier)
instance?.setValue(to: nil, using: self)
removeValue(feature.identifier)
}
/// Notify observers
@ -309,6 +262,171 @@ public extension Dependencies {
}
}
// MARK: - DependenciesError
public enum DependenciesError: Error {
case missingDependencies
}
// MARK: - Storage Management
private extension Dependencies {
struct DependencyStorage {
var initializationTracker: [String: DispatchGroup] = [:]
var instances: [String: Value] = [:]
enum Value {
case singleton(Any)
case cache(Atomic<MutableCacheType>)
case userDefaults(UserDefaultsType)
case feature(any FeatureType)
func value<T>(as type: T.Type) -> T? {
switch self {
case .singleton(let value): return value as? T
case .cache(let value): return value as? T
case .userDefaults(let value): return value as? T
case .feature(let value): return value as? T
}
}
}
}
private func getOrCreate<S>(_ singleton: SingletonConfig<S>) -> S {
return getOrCreateInstance(
identifier: singleton.identifier,
constructor: .singleton { singleton.createInstance(self) }
)
}
private func getOrCreate<M, I>(_ cache: CacheConfig<M, I>) -> Atomic<MutableCacheType> {
return getOrCreateInstance(
identifier: cache.identifier,
constructor: .cache { Atomic(cache.mutableInstance(cache.createInstance(self))) }
)
}
private func getOrCreate(_ defaults: UserDefaultsConfig) -> UserDefaultsType {
return getOrCreateInstance(
identifier: defaults.identifier,
constructor: .userDefaults { defaults.createInstance(self) }
)
}
private func getOrCreate<T: FeatureOption>(_ feature: FeatureConfig<T>) -> Feature<T> {
return getOrCreateInstance(
identifier: feature.identifier,
constructor: .feature { feature.createInstance(self) }
)
}
// MARK: - Instance upserting
/// Retrieves the current instance or, if one doesn't exist, uses the `StorageHelper.Info<Value>` to create a new instance
/// and store it
private func getOrCreateInstance<Value>(
identifier: String,
constructor: DependencyStorage.Constructor<Value>
) -> Value {
/// If we already have an instance then just return that
if let existingValue: Value = getValue(identifier) {
return existingValue
}
return threadSafeChange(for: identifier) {
/// Now that we are within a synchronized group, check to make sure an instance wasn't created while we were waiting to
/// enter the group
if let existingValue: Value = getValue(identifier) {
return existingValue
}
let result: (typedStorage: DependencyStorage.Value, value: Value) = constructor.create()
setValue(result.value, typedStorage: result.typedStorage, key: identifier)
return result.value
}
}
/// Convenience method to retrieve the existing dependency instance from memory in a thread-safe way
private func getValue<T>(_ key: String) -> T? {
guard let typedValue: DependencyStorage.Value = storage.wrappedValue.instances[key] else { return nil }
guard let result: T = typedValue.value(as: T.self) else {
/// If there is a value stored for the key, but it's not the right type then something has gone wrong, and we should log
Log.critical("Failed to convert stored dependency '\(key)' to expected type: \(T.self)")
return nil
}
return result
}
/// Convenience method to store a dependency instance in memory in a thread-safe way
@discardableResult private func setValue<T>(_ value: T, typedStorage: DependencyStorage.Value, key: String) -> T {
storage.mutate { $0.instances[key] = typedStorage }
return value
}
/// Convenience method to remove a dependency instance from memory in a thread-safe way
private func removeValue(_ key: String) {
storage.mutate { $0.instances.removeValue(forKey: key) }
}
/// This function creates a `DispatchGroup` for the given identifier which allows us to block instance creation on a per-identifier basis
/// and avoid situations where multithreading could result in multiple instances of the same dependency being created concurrently
///
/// **Note:** This `DispatchGroup` is an additional mechanism on top of the `Atomic<T>` because the interface is a little simpler
/// and we don't need to wrap every instance within `Atomic<T>` this way
@discardableResult private func threadSafeChange<T>(for identifier: String, change: () -> T) -> T {
let group: DispatchGroup = storage.mutate { storage in
if let existing = storage.initializationTracker[identifier] {
return existing
}
let group = DispatchGroup()
storage.initializationTracker[identifier] = group
return group
}
group.enter()
defer { group.leave() }
return change()
}
}
// MARK: - DSL
private extension Dependencies.DependencyStorage {
struct Constructor<T> {
let create: () -> (typedStorage: Dependencies.DependencyStorage.Value, value: T)
static func singleton(_ constructor: @escaping () -> T) -> Constructor<T> {
return Constructor {
let instance: T = constructor()
return (.singleton(instance), instance)
}
}
static func cache(_ constructor: @escaping () -> T) -> Constructor<T> where T: Atomic<MutableCacheType> {
return Constructor {
let instance: T = constructor()
return (.cache(instance), instance)
}
}
static func userDefaults(_ constructor: @escaping () -> T) -> Constructor<T> where T == UserDefaultsType {
return Constructor {
let instance: T = constructor()
return (.userDefaults(instance), instance)
}
}
static func feature(_ constructor: @escaping () -> T) -> Constructor<T> where T: FeatureType {
return Constructor {
let instance: T = constructor()
return (.feature(instance), instance)
}
}
}
}

Loading…
Cancel
Save