diff --git a/Session/Closed Groups/EditGroupViewModel.swift b/Session/Closed Groups/EditGroupViewModel.swift index d10ec4c72..40e96ded2 100644 --- a/Session/Closed Groups/EditGroupViewModel.swift +++ b/Session/Closed Groups/EditGroupViewModel.swift @@ -273,7 +273,7 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl identifier: "Invite by id", label: "Invite by id" ), - onTap: { [weak self] in self?.inviteById() } + onTap: { [weak self] in self?.inviteById(currentGroupName: state.group.name) } ) ) ].compactMap { $0 } @@ -438,42 +438,10 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl throw UserListError.error("groupAddMemberMaximum".localized()) } - /// Show a toast that we have sent the invitations - self?.showToast( - text: (selectedMemberInfo.count == 1 ? - "groupInviteSending".localized() : - "groupInviteSending".localized() - ), - backgroundColor: .backgroundSecondary + self?.addMembers( + currentGroupName: currentGroupName, + memberInfo: selectedMemberInfo.map { ($0.profileId, $0.profile) } ) - - /// Actually trigger the sending - MessageSender - .addGroupMembers( - groupSessionId: threadId, - members: selectedMemberInfo.map { ($0.profileId, $0.profile) }, - allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite], - using: dependencies - ) - .sinkUntilComplete( - receiveCompletion: { result in - switch result { - case .finished: break - case .failure: - viewModel?.showToast( - text: GroupInviteMemberJob.failureMessage( - groupName: currentGroupName, - memberIds: selectedMemberInfo.map { $0.profileId }, - profileInfo: selectedMemberInfo - .reduce(into: [:]) { result, next in - result[next.profileId] = next.profile - } - ), - backgroundColor: .backgroundSecondary - ) - } - } - ) } case .standard: // Assume it's a legacy group @@ -506,7 +474,7 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl ) } - private func inviteById() { + private func inviteById(currentGroupName: String) { // Convenience functions to avoid duplicate code func showError(_ errorString: String) { let modal: ConfirmationModal = ConfirmationModal( @@ -520,25 +488,6 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl ) self.transitionToScreen(modal, transitionType: .present) } - func inviteMember(_ accountId: String, _ modal: UIViewController) { - guard !currentMemberIds.contains(accountId) else { - // FIXME: Localise this - return showError("This Account ID or ONS belongs to an existing member") - } - - MessageSender.addGroupMembers( - groupSessionId: threadId, - members: [(accountId, nil)], - allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite], - using: dependencies - ).sinkUntilComplete() - modal.dismiss(animated: true) { [weak self] in - self?.showToast( - text: "groupInviteSending".localized(), - backgroundColor: .backgroundSecondary - ) - } - } let currentMemberIds: Set = (tableData .first(where: { $0.model == .members })? @@ -573,16 +522,22 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl cancelStyle: .alert_text, dismissOnConfirm: false, onConfirm: { [weak self, dependencies] modal in - // FIXME: Consolidate this with the logic in `NewDMVC` - switch Result(catching: { try SessionId(from: self?.inviteByIdValue) }) { - case .success(let sessionId) where sessionId.prefix == .standard: inviteMember(sessionId.hexString, modal) - case .success: return showError("accountIdErrorInvalid".localized()) + switch (self?.inviteByIdValue, try? SessionId(from: self?.inviteByIdValue)) { + case (_, .some(let sessionId)) where sessionId.prefix == .standard: + guard !currentMemberIds.contains(sessionId.hexString) else { + return showError("This Account ID or ONS belongs to an existing member") + } - case .failure: - guard let inviteByIdValue: String = self?.inviteByIdValue else { - return showError("accountIdErrorInvalid".localized()) + modal.dismiss(animated: true) { + self?.addMembers( + currentGroupName: currentGroupName, + memberInfo: [(sessionId.hexString, nil)] + ) } + + case (.none, _), (_, .some): return showError("accountIdErrorInvalid".localized()) + case (.some(let inviteByIdValue), _): // This could be an ONS name let viewController = ModalActivityIndicatorViewController() { modalActivityIndicator in SnodeAPI @@ -605,8 +560,17 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl } }, receiveValue: { sessionIdHexString in + guard !currentMemberIds.contains(sessionIdHexString) else { + return showError("This Account ID or ONS belongs to an existing member") + } + modalActivityIndicator.dismiss { - inviteMember(sessionIdHexString, modal) + modal.dismiss(animated: true) { + self?.addMembers( + currentGroupName: currentGroupName, + memberInfo: [(sessionIdHexString, nil)] + ) + } } } ) @@ -621,6 +585,48 @@ class EditGroupViewModel: SessionTableViewModel, NavigatableStateHolder, Editabl ) } + private func addMembers( + currentGroupName: String, + memberInfo: [(id: String, profile: Profile?)] + ) { + /// Show a toast that we have sent the invitations + self.showToast( + text: (memberInfo.count == 1 ? + "groupInviteSending".localized() : + "groupInviteSending".localized() + ), + backgroundColor: .backgroundSecondary + ) + + /// Actually trigger the sending + MessageSender + .addGroupMembers( + groupSessionId: threadId, + members: memberInfo, + allowAccessToHistoricMessages: dependencies[feature: .updatedGroupsAllowHistoricAccessOnInvite], + using: dependencies + ) + .sinkUntilComplete( + receiveCompletion: { [weak self] result in + switch result { + case .finished: break + case .failure: + self?.showToast( + text: GroupInviteMemberJob.failureMessage( + groupName: currentGroupName, + memberIds: memberInfo.map { $0.id }, + profileInfo: memberInfo + .reduce(into: [:]) { result, next in + result[next.id] = next.profile + } + ), + backgroundColor: .backgroundSecondary + ) + } + } + ) + } + private func resendInvitation(memberId: String) { MessageSender.resendInvitation( groupSessionId: threadId, diff --git a/Session/Conversations/ConversationViewModel.swift b/Session/Conversations/ConversationViewModel.swift index 840785c8d..d384d1768 100644 --- a/Session/Conversations/ConversationViewModel.swift +++ b/Session/Conversations/ConversationViewModel.swift @@ -218,6 +218,10 @@ public class ConversationViewModel: OWSAudioPlayerDelegate, NavigatableStateHold ).populatingCurrentUserBlindedIds( currentUserBlinded15SessionIdForThisThread: initialData?.blinded15SessionId?.hexString, currentUserBlinded25SessionIdForThisThread: initialData?.blinded25SessionId?.hexString, + wasKickedFromGroup: ( + threadVariant == .group && + LibSession.wasKickedFromGroup(groupSessionId: SessionId(.group, hex: threadId), using: dependencies) + ), using: dependencies ) ) @@ -288,6 +292,13 @@ public class ConversationViewModel: OWSAudioPlayerDelegate, NavigatableStateHold db, currentUserBlinded15SessionIdForThisThread: self?.threadData.currentUserBlinded15SessionId, currentUserBlinded25SessionIdForThisThread: self?.threadData.currentUserBlinded25SessionId, + wasKickedFromGroup: ( + viewModel.threadVariant == .group && + LibSession.wasKickedFromGroup( + groupSessionId: SessionId(.group, hex: viewModel.threadId), + using: dependencies + ) + ), using: dependencies ) } diff --git a/Session/Conversations/Settings/ThreadSettingsViewModel.swift b/Session/Conversations/Settings/ThreadSettingsViewModel.swift index 03bc4f354..85403511d 100644 --- a/Session/Conversations/Settings/ThreadSettingsViewModel.swift +++ b/Session/Conversations/Settings/ThreadSettingsViewModel.swift @@ -978,7 +978,7 @@ class ThreadSettingsViewModel: SessionTableViewModel, NavigatableStateHolder, Ob self?.updatedName != current && self?.updatedName?.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty == false }, - cancelTitle: "Reset",//.localized(), + cancelTitle: "remove".localized(), cancelStyle: .danger, cancelEnabled: .bool(current?.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty == false), hasCloseButton: true, diff --git a/Session/Home/HomeViewModel.swift b/Session/Home/HomeViewModel.swift index de59be678..dedd064fc 100644 --- a/Session/Home/HomeViewModel.swift +++ b/Session/Home/HomeViewModel.swift @@ -374,6 +374,13 @@ public class HomeViewModel: NavigatableStateHolder { currentUserBlinded25SessionIdForThisThread: groupedOldData[viewModel.threadId]? .first? .currentUserBlinded25SessionId, + wasKickedFromGroup: ( + viewModel.threadVariant == .group && + LibSession.wasKickedFromGroup( + groupSessionId: SessionId(.group, hex: viewModel.threadId), + using: dependencies + ) + ), using: dependencies ) } diff --git a/Session/Home/Message Requests/MessageRequestsViewModel.swift b/Session/Home/Message Requests/MessageRequestsViewModel.swift index 2e783dffd..c72d4eb7f 100644 --- a/Session/Home/Message Requests/MessageRequestsViewModel.swift +++ b/Session/Home/Message Requests/MessageRequestsViewModel.swift @@ -166,6 +166,13 @@ class MessageRequestsViewModel: SessionTableViewModel, NavigatableStateHolder, O .first? .id .currentUserBlinded25SessionId, + wasKickedFromGroup: ( + viewModel.threadVariant == .group && + LibSession.wasKickedFromGroup( + groupSessionId: SessionId(.group, hex: viewModel.threadId), + using: dependencies + ) + ), using: dependencies ), accessibility: Accessibility( diff --git a/Session/Notifications/NotificationPresenter.swift b/Session/Notifications/NotificationPresenter.swift index 363dc96cc..898601618 100644 --- a/Session/Notifications/NotificationPresenter.swift +++ b/Session/Notifications/NotificationPresenter.swift @@ -73,7 +73,7 @@ public class NotificationPresenter: NSObject, UNUserNotificationCenterDelegate, ) // While batch processing, some of the necessary changes have not been commited. - let rawMessageText = interaction.previewText(db, using: dependencies) + let rawMessageText: String = Interaction.notificationPreviewText(db, interaction: interaction, using: dependencies) // iOS strips anything that looks like a printf formatting character from // the notification body, so if we want to dispay a literal "%" in a notification diff --git a/Session/Shared/FullConversationCell.swift b/Session/Shared/FullConversationCell.swift index 48ee0d207..f50399009 100644 --- a/Session/Shared/FullConversationCell.swift +++ b/Session/Shared/FullConversationCell.swift @@ -588,9 +588,17 @@ public final class FullConversationCell: UITableViewCell, SwipeActionOptimisticC cellViewModel: SessionThreadViewModel, textColor: UIColor, using dependencies: Dependencies - ) -> NSMutableAttributedString { + ) -> NSAttributedString { + guard cellViewModel.wasKickedFromGroup != true else { + return NSAttributedString( + string: "groupRemovedYou" + .put(key: "group_name", value: cellViewModel.displayName) + .localizedDeformatted() + ) + } + // If we don't have an interaction then do nothing - guard cellViewModel.interactionId != nil else { return NSMutableAttributedString() } + guard cellViewModel.interactionId != nil else { return NSAttributedString() } let result = NSMutableAttributedString() @@ -635,21 +643,24 @@ public final class FullConversationCell: UITableViewCell, SwipeActionOptimisticC } let previewText: String = { - if cellViewModel.interactionVariant == .infoGroupCurrentUserErrorLeaving { - return "groupLeaveErrorFailed" - .put(key: "group_name", value: cellViewModel.displayName) - .localized() + switch cellViewModel.interactionVariant { + case .infoGroupCurrentUserErrorLeaving: + return "groupLeaveErrorFailed" + .put(key: "group_name", value: cellViewModel.displayName) + .localized() + + default: + return Interaction.previewText( + variant: (cellViewModel.interactionVariant ?? .standardIncoming), + body: cellViewModel.interactionBody, + threadContactDisplayName: cellViewModel.threadContactName(), + authorDisplayName: cellViewModel.authorName(for: cellViewModel.threadVariant), + attachmentDescriptionInfo: cellViewModel.interactionAttachmentDescriptionInfo, + attachmentCount: cellViewModel.interactionAttachmentCount, + isOpenGroupInvitation: (cellViewModel.interactionIsOpenGroupInvitation == true), + using: dependencies + ) } - return Interaction.previewText( - variant: (cellViewModel.interactionVariant ?? .standardIncoming), - body: cellViewModel.interactionBody, - threadContactDisplayName: cellViewModel.threadContactName(), - authorDisplayName: cellViewModel.authorName(for: cellViewModel.threadVariant), - attachmentDescriptionInfo: cellViewModel.interactionAttachmentDescriptionInfo, - attachmentCount: cellViewModel.interactionAttachmentCount, - isOpenGroupInvitation: (cellViewModel.interactionIsOpenGroupInvitation == true), - using: dependencies - ) }() result.append(NSAttributedString( diff --git a/SessionMessagingKit/Database/Models/Interaction.swift b/SessionMessagingKit/Database/Models/Interaction.swift index 3066260a9..90244cdd3 100644 --- a/SessionMessagingKit/Database/Models/Interaction.swift +++ b/SessionMessagingKit/Database/Models/Interaction.swift @@ -1025,20 +1025,23 @@ public extension Interaction { } } - /// Use the `Interaction.previewText` method directly where possible rather than this method as it - /// makes it's own database queries - func previewText(_ db: Database, using dependencies: Dependencies) -> String { - switch variant { + /// Use the `Interaction.previewText` method directly where possible rather than this one to avoid database queries + static func notificationPreviewText( + _ db: Database, + interaction: Interaction, + using dependencies: Dependencies + ) -> String { + switch interaction.variant { case .standardIncoming, .standardOutgoing: return Interaction.previewText( - variant: self.variant, - body: self.body, - attachmentDescriptionInfo: try? attachments + variant: interaction.variant, + body: interaction.body, + attachmentDescriptionInfo: try? interaction.attachments .select(.id, .variant, .contentType, .sourceFilename) .asRequest(of: Attachment.DescriptionInfo.self) .fetchOne(db), - attachmentCount: try? attachments.fetchCount(db), - isOpenGroupInvitation: linkPreview + attachmentCount: try? interaction.attachments.fetchCount(db), + isOpenGroupInvitation: interaction.linkPreview .filter(LinkPreview.Columns.variant == LinkPreview.Variant.openGroupInvitation) .isNotEmpty(db), using: dependencies @@ -1048,15 +1051,15 @@ public extension Interaction { // Note: These should only occur in 'contact' threads so the `threadId` // is the contact id return Interaction.previewText( - variant: self.variant, - body: self.body, - authorDisplayName: Profile.displayName(db, id: threadId, using: dependencies), + variant: interaction.variant, + body: interaction.body, + authorDisplayName: Profile.displayName(db, id: interaction.threadId, using: dependencies), using: dependencies ) default: return Interaction.previewText( - variant: self.variant, - body: self.body, + variant: interaction.variant, + body: interaction.body, using: dependencies ) } diff --git a/SessionMessagingKit/Sending & Receiving/Message Handling/MessageReceiver+Groups.swift b/SessionMessagingKit/Sending & Receiving/Message Handling/MessageReceiver+Groups.swift index 371c5dff1..41d609ab6 100644 --- a/SessionMessagingKit/Sending & Receiving/Message Handling/MessageReceiver+Groups.swift +++ b/SessionMessagingKit/Sending & Receiving/Message Handling/MessageReceiver+Groups.swift @@ -700,8 +700,9 @@ extension MessageReceiver { /// that if the user doesn't delete the group and links a new device, the group will have the same name as on the current device if !LibSession.wasKickedFromGroup(groupSessionId: groupSessionId, using: dependencies) { dependencies.mutate(cache: .libSession) { cache in - let config: LibSession.Config? = cache.config(for: .groupInfo, sessionId: groupSessionId) - let groupName: String? = try? LibSession.groupName(in: config) + let groupInfoConfig: LibSession.Config? = cache.config(for: .groupInfo, sessionId: groupSessionId) + let userGroupsConfig: LibSession.Config? = cache.config(for: .userGroups, sessionId: userSessionId) + let groupName: String? = try? LibSession.groupName(in: groupInfoConfig) switch groupName { case .none: Log.warn(.messageReceiver, "Failed to update group name before being kicked.") @@ -713,7 +714,7 @@ extension MessageReceiver { name: name ) ], - in: config, + in: userGroupsConfig, using: dependencies ) } diff --git a/SessionMessagingKit/Shared Models/SessionThreadViewModel.swift b/SessionMessagingKit/Shared Models/SessionThreadViewModel.swift index a778db810..e2431448d 100644 --- a/SessionMessagingKit/Shared Models/SessionThreadViewModel.swift +++ b/SessionMessagingKit/Shared Models/SessionThreadViewModel.swift @@ -83,6 +83,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat case currentUserBlinded15SessionId case currentUserBlinded25SessionId case recentReactionEmoji + case wasKickedFromGroup } public var differenceIdentifier: String { threadId } @@ -131,11 +132,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat ) case .group: - let groupSessionId: SessionId = SessionId(.group, hex: threadId) - - guard !LibSession.wasKickedFromGroup(groupSessionId: groupSessionId, using: dependencies) else { - return false - } + guard wasKickedFromGroup != true else { return false } guard threadIsMessageRequest == false else { return true } return ( @@ -190,6 +187,7 @@ public struct SessionThreadViewModel: FetchableRecordWithRowId, Decodable, Equat public let currentUserBlinded15SessionId: String? public let currentUserBlinded25SessionId: String? public let recentReactionEmoji: [String]? + public let wasKickedFromGroup: Bool? // UI specific logic @@ -471,6 +469,7 @@ public extension SessionThreadViewModel { self.currentUserBlinded15SessionId = nil self.currentUserBlinded25SessionId = nil self.recentReactionEmoji = nil + self.wasKickedFromGroup = false } } @@ -535,14 +534,16 @@ public extension SessionThreadViewModel { currentUserSessionId: self.currentUserSessionId, currentUserBlinded15SessionId: self.currentUserBlinded15SessionId, currentUserBlinded25SessionId: self.currentUserBlinded25SessionId, - recentReactionEmoji: (recentReactionEmoji ?? self.recentReactionEmoji) + recentReactionEmoji: (recentReactionEmoji ?? self.recentReactionEmoji), + wasKickedFromGroup: self.wasKickedFromGroup ) } func populatingCurrentUserBlindedIds( _ db: Database? = nil, - currentUserBlinded15SessionIdForThisThread: String? = nil, - currentUserBlinded25SessionIdForThisThread: String? = nil, + currentUserBlinded15SessionIdForThisThread: String?, + currentUserBlinded25SessionIdForThisThread: String?, + wasKickedFromGroup: Bool, using dependencies: Dependencies ) -> SessionThreadViewModel { return SessionThreadViewModel( @@ -618,7 +619,8 @@ public extension SessionThreadViewModel { using: dependencies )?.hexString ), - recentReactionEmoji: self.recentReactionEmoji + recentReactionEmoji: self.recentReactionEmoji, + wasKickedFromGroup: wasKickedFromGroup ) } } diff --git a/SessionNotificationServiceExtension/NSENotificationPresenter.swift b/SessionNotificationServiceExtension/NSENotificationPresenter.swift index da13a1982..8ba3ac15f 100644 --- a/SessionNotificationServiceExtension/NSENotificationPresenter.swift +++ b/SessionNotificationServiceExtension/NSENotificationPresenter.swift @@ -66,7 +66,8 @@ public class NSENotificationPresenter: NotificationsManagerType { .localized() } - let snippet: String = (interaction.previewText(db, using: dependencies) + let snippet: String = (Interaction + .notificationPreviewText(db, interaction: interaction, using: dependencies) .filteredForDisplay .nullIfEmpty? .replacingMentions(for: thread.id, using: dependencies)) diff --git a/SessionUtilitiesKit/Dependency Injection/Dependencies.swift b/SessionUtilitiesKit/Dependency Injection/Dependencies.swift index 8caa7384f..38ac52b61 100644 --- a/SessionUtilitiesKit/Dependency Injection/Dependencies.swift +++ b/SessionUtilitiesKit/Dependency Injection/Dependencies.swift @@ -1,4 +1,4 @@ -// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved. +// Copyright © 2024 Rangeproof Pty Ltd. All rights reserved. // // stringlint:disable @@ -6,50 +6,18 @@ import Foundation import Combine public class Dependencies { - static let userInfoKey: CodingUserInfoKey = CodingUserInfoKey(rawValue: "io.oxen.dependencies.codingOptions")! + static let userInfoKey: CodingUserInfoKey = CodingUserInfoKey(rawValue: "session.dependencies.codingOptions")! private static var _isRTLRetriever: Atomic<(Bool, () -> Bool)> = Atomic((false, { false })) - private static var singletonInstances: Atomic<[String: Any]> = Atomic([:]) - private static var cacheInstances: Atomic<[String: Atomic]> = Atomic([:]) - private static var userDefaultsInstances: Atomic<[String: (any UserDefaultsType)]> = Atomic([:]) - private static var featureInstances: Atomic<[String: (any FeatureType)]> = Atomic([:]) - private var featureChangeSubject: PassthroughSubject<(String, String?, Any?), Never> = PassthroughSubject() + private let featureChangeSubject: PassthroughSubject<(String, String?, Any?), Never> = PassthroughSubject() + private var storage: Atomic = Atomic(DependencyStorage()) // MARK: - Subscript Access - public subscript(singleton singleton: SingletonConfig) -> S { - guard let value: S = (Dependencies.singletonInstances.wrappedValue[singleton.identifier] as? S) else { - let value: S = singleton.createInstance(self) - Dependencies.singletonInstances.mutate { $0[singleton.identifier] = value } - return value - } - - return value - } - - public subscript(cache cache: CacheConfig) -> I { - getValueSettingIfNull(cache: cache) - } - - public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType { - guard let value: UserDefaultsType = Dependencies.userDefaultsInstances.wrappedValue[defaults.identifier] else { - let value: UserDefaultsType = defaults.createInstance(self) - Dependencies.userDefaultsInstances.mutate { $0[defaults.identifier] = value } - return value - } - - return value - } - - public subscript(feature feature: FeatureConfig) -> T { - guard let value: Feature = (Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature) else { - let value: Feature = feature.createInstance(self) - Dependencies.featureInstances.mutate { $0[feature.identifier] = value } - return value.currentValue(using: self) - } - - return value.currentValue(using: self) - } + public subscript(singleton singleton: SingletonConfig) -> S { getOrCreate(singleton) } + public subscript(cache cache: CacheConfig) -> I { getOrCreate(cache).immutable(cache: cache, using: self) } + public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType { getOrCreate(defaults) } + public subscript(feature feature: FeatureConfig) -> T { getOrCreate(feature).currentValue(using: self) } // MARK: - Global Values, Timing and Async Handling @@ -85,20 +53,15 @@ public class Dependencies { cache: CacheConfig, _ mutation: (inout M) -> R ) -> R { - /// The cast from `Atomic` to `Atomic` always fails so we need to do some - /// stuffing around to ensure we have the right types - since we call `createInstance` multiple times in - /// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored - /// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some - /// random instance that will go out of memory as soon as the mutation is completed - getValueSettingIfNull(cache: cache) - - let cacheWrapper: Atomic = ( - Dependencies.cacheInstances.wrappedValue[cache.identifier] ?? - Atomic(cache.mutableInstance(cache.createInstance(self))) // Should never be called - ) - - return cacheWrapper.mutate { erasedValue in - var value: M = ((erasedValue as? M) ?? cache.createInstance(self)) + return getOrCreate(cache).mutate { erasedValue in + guard var value: M = (erasedValue as? M) else { + /// This code path should never happen (and is essentially invalid if it does) but in order to avoid neeing to return + /// a nullable type or force-casting this is how we need to do things) + Log.critical("Failed to convert erased cache value for '\(cache.identifier)' to expected type: \(M.self)") + var fallbackValue: M = cache.createInstance(self) + return mutation(&fallbackValue) + } + return mutation(&value) } } @@ -107,20 +70,15 @@ public class Dependencies { cache: CacheConfig, _ mutation: (inout M) throws -> R ) throws -> R { - /// The cast from `Atomic` to `Atomic` always fails so we need to do some - /// stuffing around to ensure we have the right types - since we call `createInstance` multiple times in - /// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored - /// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some - /// random instance that will go out of memory as soon as the mutation is completed - getValueSettingIfNull(cache: cache) - - let cacheWrapper: Atomic = ( - Dependencies.cacheInstances.wrappedValue[cache.identifier] ?? - Atomic(cache.mutableInstance(cache.createInstance(self))) // Should never be called - ) - - return try cacheWrapper.mutate { erasedValue in - var value: M = ((erasedValue as? M) ?? cache.createInstance(self)) + return try getOrCreate(cache).mutate { erasedValue in + guard var value: M = (erasedValue as? M) else { + /// This code path should never happen (and is essentially invalid if it does) but in order to avoid neeing to return + /// a nullable type or force-casting this is how we need to do things) + Log.critical("Failed to convert erased cache value for '\(cache.identifier)' to expected type: \(M.self)") + var fallbackValue: M = cache.createInstance(self) + return try mutation(&fallbackValue) + } + return try mutation(&value) } } @@ -138,41 +96,29 @@ public class Dependencies { public func popRandomElement(_ elements: inout Set) -> T? { return elements.popRandomElement() } - - // MARK: - Instance upserting - - @discardableResult private func getValueSettingIfNull(cache: CacheConfig) -> I { - guard let value: M = (Dependencies.cacheInstances.wrappedValue[cache.identifier]?.wrappedValue as? M) else { - let value: M = cache.createInstance(self) - let mutableInstance: MutableCacheType = cache.mutableInstance(value) - Dependencies.cacheInstances.mutate { $0[cache.identifier] = Atomic(mutableInstance) } - return cache.immutableInstance(value) - } - - return cache.immutableInstance(value) - } // MARK: - Instance replacing public func warmCache(cache: CacheConfig) { - _ = getValueSettingIfNull(cache: cache) + _ = getOrCreate(cache) } public func set(singleton: SingletonConfig, to instance: S) { - Dependencies.singletonInstances.mutate { - $0[singleton.identifier] = instance + threadSafeChange(for: singleton.identifier) { + setValue(instance, typedStorage: .singleton(instance), key: singleton.identifier) } } public func set(cache: CacheConfig, to instance: M) { - Dependencies.cacheInstances.mutate { - $0[cache.identifier] = Atomic(cache.mutableInstance(instance)) + threadSafeChange(for: cache.identifier) { + let value: Atomic = Atomic(cache.mutableInstance(instance)) + setValue(value, typedStorage: .cache(value), key: cache.identifier) } } public func remove(cache: CacheConfig) { - Dependencies.cacheInstances.mutate { - $0[cache.identifier] = nil + threadSafeChange(for: cache.identifier) { + removeValue(cache.identifier) } } @@ -181,6 +127,17 @@ public class Dependencies { } } +// MARK: - Cache Management + +private extension Atomic { + func immutable(cache: CacheConfig, using dependencies: Dependencies) -> I { + return cache.immutableInstance( + (self.wrappedValue as? M) ?? + cache.createInstance(dependencies) + ) + } +} + // MARK: - Feature Management public extension Dependencies { @@ -215,30 +172,26 @@ public extension Dependencies { } func set(feature: FeatureConfig, to updatedFeature: T?) { - let value: Feature = { - guard let value: Feature = (Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature) else { - let value: Feature = feature.createInstance(self) - Dependencies.featureInstances.mutate { $0[feature.identifier] = value } - return value - } - - return value - }() + threadSafeChange(for: feature.identifier) { + /// Update the cached & in-memory values + let instance: Feature = ( + getValue(feature.identifier) ?? + feature.createInstance(self) + ) + instance.setValue(to: updatedFeature, using: self) + setValue(instance, typedStorage: .feature(instance), key: feature.identifier) + } - value.setValue(to: updatedFeature, using: self) + /// Notify observers featureChangeSubject.send((feature.identifier, feature.groupIdentifier, updatedFeature)) } func reset(feature: FeatureConfig) { - /// Reset the cached value - switch Dependencies.featureInstances.wrappedValue[feature.identifier] as? Feature { - case .none: break - case .some(let value): value.setValue(to: nil, using: self) - } - - /// Reset the in-memory value - Dependencies.featureInstances.mutate { - $0[feature.identifier] = nil + threadSafeChange(for: feature.identifier) { + /// Reset the cached and in-memory values + let instance: Feature? = getValue(feature.identifier) + instance?.setValue(to: nil, using: self) + removeValue(feature.identifier) } /// Notify observers @@ -309,6 +262,171 @@ public extension Dependencies { } } +// MARK: - DependenciesError + public enum DependenciesError: Error { case missingDependencies } + +// MARK: - Storage Management + +private extension Dependencies { + struct DependencyStorage { + var initializationTracker: [String: DispatchGroup] = [:] + var instances: [String: Value] = [:] + + enum Value { + case singleton(Any) + case cache(Atomic) + case userDefaults(UserDefaultsType) + case feature(any FeatureType) + + func value(as type: T.Type) -> T? { + switch self { + case .singleton(let value): return value as? T + case .cache(let value): return value as? T + case .userDefaults(let value): return value as? T + case .feature(let value): return value as? T + } + } + } + } + + private func getOrCreate(_ singleton: SingletonConfig) -> S { + return getOrCreateInstance( + identifier: singleton.identifier, + constructor: .singleton { singleton.createInstance(self) } + ) + } + + private func getOrCreate(_ cache: CacheConfig) -> Atomic { + return getOrCreateInstance( + identifier: cache.identifier, + constructor: .cache { Atomic(cache.mutableInstance(cache.createInstance(self))) } + ) + } + + private func getOrCreate(_ defaults: UserDefaultsConfig) -> UserDefaultsType { + return getOrCreateInstance( + identifier: defaults.identifier, + constructor: .userDefaults { defaults.createInstance(self) } + ) + } + + private func getOrCreate(_ feature: FeatureConfig) -> Feature { + return getOrCreateInstance( + identifier: feature.identifier, + constructor: .feature { feature.createInstance(self) } + ) + } + + // MARK: - Instance upserting + + /// Retrieves the current instance or, if one doesn't exist, uses the `StorageHelper.Info` to create a new instance + /// and store it + private func getOrCreateInstance( + identifier: String, + constructor: DependencyStorage.Constructor + ) -> Value { + /// If we already have an instance then just return that + if let existingValue: Value = getValue(identifier) { + return existingValue + } + + return threadSafeChange(for: identifier) { + /// Now that we are within a synchronized group, check to make sure an instance wasn't created while we were waiting to + /// enter the group + if let existingValue: Value = getValue(identifier) { + return existingValue + } + + let result: (typedStorage: DependencyStorage.Value, value: Value) = constructor.create() + setValue(result.value, typedStorage: result.typedStorage, key: identifier) + return result.value + } + } + + /// Convenience method to retrieve the existing dependency instance from memory in a thread-safe way + private func getValue(_ key: String) -> T? { + guard let typedValue: DependencyStorage.Value = storage.wrappedValue.instances[key] else { return nil } + guard let result: T = typedValue.value(as: T.self) else { + /// If there is a value stored for the key, but it's not the right type then something has gone wrong, and we should log + Log.critical("Failed to convert stored dependency '\(key)' to expected type: \(T.self)") + return nil + } + + return result + } + + /// Convenience method to store a dependency instance in memory in a thread-safe way + @discardableResult private func setValue(_ value: T, typedStorage: DependencyStorage.Value, key: String) -> T { + storage.mutate { $0.instances[key] = typedStorage } + return value + } + + /// Convenience method to remove a dependency instance from memory in a thread-safe way + private func removeValue(_ key: String) { + storage.mutate { $0.instances.removeValue(forKey: key) } + } + + /// This function creates a `DispatchGroup` for the given identifier which allows us to block instance creation on a per-identifier basis + /// and avoid situations where multithreading could result in multiple instances of the same dependency being created concurrently + /// + /// **Note:** This `DispatchGroup` is an additional mechanism on top of the `Atomic` because the interface is a little simpler + /// and we don't need to wrap every instance within `Atomic` this way + @discardableResult private func threadSafeChange(for identifier: String, change: () -> T) -> T { + let group: DispatchGroup = storage.mutate { storage in + if let existing = storage.initializationTracker[identifier] { + return existing + } + + let group = DispatchGroup() + storage.initializationTracker[identifier] = group + return group + } + group.enter() + defer { group.leave() } + + return change() + } +} + +// MARK: - DSL + +private extension Dependencies.DependencyStorage { + struct Constructor { + let create: () -> (typedStorage: Dependencies.DependencyStorage.Value, value: T) + + static func singleton(_ constructor: @escaping () -> T) -> Constructor { + return Constructor { + let instance: T = constructor() + + return (.singleton(instance), instance) + } + } + + static func cache(_ constructor: @escaping () -> T) -> Constructor where T: Atomic { + return Constructor { + let instance: T = constructor() + + return (.cache(instance), instance) + } + } + + static func userDefaults(_ constructor: @escaping () -> T) -> Constructor where T == UserDefaultsType { + return Constructor { + let instance: T = constructor() + + return (.userDefaults(instance), instance) + } + } + + static func feature(_ constructor: @escaping () -> T) -> Constructor where T: FeatureType { + return Constructor { + let instance: T = constructor() + + return (.feature(instance), instance) + } + } + } +}