mirror of https://github.com/oxen-io/session-ios
				
				
				
			
			You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			286 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Swift
		
	
			
		
		
	
	
			286 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Swift
		
	
// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
 | 
						|
 | 
						|
import Foundation
 | 
						|
import GRDB
 | 
						|
 | 
						|
// MARK: - Migration Safe Functions
 | 
						|
 | 
						|
public extension MutablePersistableRecord where Self: TableRecord & EncodableRecord & Codable {
 | 
						|
    func migrationSafeInsert(
 | 
						|
        _ db: Database,
 | 
						|
        onConflict conflictResolution: Database.ConflictResolution? = nil
 | 
						|
    ) throws {
 | 
						|
        var record = try MigrationSafeMutableRecord(db, originalRecord: self)
 | 
						|
        try record.insert(db, onConflict: conflictResolution)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func migrationSafeInserted(
 | 
						|
        _ db: Database,
 | 
						|
        onConflict conflictResolution: Database.ConflictResolution? = nil
 | 
						|
    ) throws -> Self {
 | 
						|
        let record = try MigrationSafeMutableRecord(db, originalRecord: self)
 | 
						|
        let updatedRecord = try record.inserted(db, onConflict: conflictResolution)
 | 
						|
        return updatedRecord.originalRecord
 | 
						|
    }
 | 
						|
    
 | 
						|
    func migrationSafeSave(
 | 
						|
        _ db: Database,
 | 
						|
        onConflict conflictResolution: Database.ConflictResolution? = nil
 | 
						|
    ) throws {
 | 
						|
        var record = try MigrationSafeMutableRecord(db, originalRecord: self)
 | 
						|
        try record.save(db, onConflict: conflictResolution)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func migrationSafeSaved(
 | 
						|
        _ db: Database,
 | 
						|
        onConflict conflictResolution: Database.ConflictResolution? = nil
 | 
						|
    ) throws -> Self {
 | 
						|
        let record = try MigrationSafeMutableRecord(db, originalRecord: self)
 | 
						|
        let updatedRecord = try record.saved(db, onConflict: conflictResolution)
 | 
						|
        return updatedRecord.originalRecord
 | 
						|
    }
 | 
						|
    
 | 
						|
    func migrationSafeUpsert(_ db: Database) throws {
 | 
						|
        var record = try MigrationSafeMutableRecord(db, originalRecord: self)
 | 
						|
        try record.upsert(db)
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
public extension MutablePersistableRecord where Self: FetchableRecord {
 | 
						|
    @discardableResult func upserted(_ db: Database) throws -> Self {
 | 
						|
        var mutableSelf: Self = self
 | 
						|
        try mutableSelf.upsert(db)
 | 
						|
        return mutableSelf
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
// MARK: - MigrationSafeMutableRecord
 | 
						|
 | 
						|
private class MigrationSafeRecord<T: PersistableRecord & Encodable>: MigrationSafeMutableRecord<T> {}
 | 
						|
 | 
						|
private class MigrationSafeMutableRecord<T: MutablePersistableRecord & Encodable>: MutablePersistableRecord & Encodable {
 | 
						|
    public static var databaseTableName: String { T.databaseTableName }
 | 
						|
    
 | 
						|
    fileprivate var originalRecord: T
 | 
						|
    private let availableColumnNames: [String]
 | 
						|
    
 | 
						|
    init(_ db: Database, originalRecord: T) throws {
 | 
						|
        // Check the current columns in the database and filter out any properties on the object which
 | 
						|
        // don't exist in the dictionary
 | 
						|
        self.originalRecord = originalRecord
 | 
						|
        self.availableColumnNames = try db.columns(in: Self.databaseTableName).map(\.name)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(to encoder: Encoder) throws {
 | 
						|
        let filteredEncoder: FilteredEncoder = FilteredEncoder(
 | 
						|
            originalEncoder: encoder,
 | 
						|
            availableKeys: availableColumnNames
 | 
						|
        )
 | 
						|
        try originalRecord.encode(to: filteredEncoder)
 | 
						|
    }
 | 
						|
    
 | 
						|
    // MARK: - Persistence Callbacks
 | 
						|
    
 | 
						|
    func willInsert(_ db: Database) throws {
 | 
						|
        try originalRecord.willInsert(db)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func aroundInsert(_ db: Database, insert: () throws -> InsertionSuccess) throws {
 | 
						|
        try originalRecord.aroundInsert(db, insert: insert)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func didInsert(_ inserted: InsertionSuccess) {
 | 
						|
        originalRecord.didInsert(inserted)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func willUpdate(_ db: Database, columns: Set<String>) throws {
 | 
						|
        try originalRecord.willUpdate(db, columns: columns)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func aroundUpdate(_ db: Database, columns: Set<String>, update: () throws -> PersistenceSuccess) throws {
 | 
						|
        try originalRecord.aroundUpdate(db, columns: columns, update: update)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func didUpdate(_ updated: PersistenceSuccess) {
 | 
						|
        originalRecord.didUpdate(updated)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func willSave(_ db: Database) throws {
 | 
						|
        try originalRecord.willSave(db)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func aroundSave(_ db: Database, save: () throws -> PersistenceSuccess) throws {
 | 
						|
        try originalRecord.aroundSave(db, save: save)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func didSave(_ saved: PersistenceSuccess) {
 | 
						|
        originalRecord.didSave(saved)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func willDelete(_ db: Database) throws {
 | 
						|
        try originalRecord.willDelete(db)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func aroundDelete(_ db: Database, delete: () throws -> Bool) throws {
 | 
						|
        try originalRecord.aroundDelete(db, delete: delete)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func didDelete(deleted: Bool) {
 | 
						|
        originalRecord.didDelete(deleted: deleted)
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
// MARK: - FilteredEncoder
 | 
						|
 | 
						|
private class FilteredEncoder: Encoder {
 | 
						|
    let originalEncoder: Encoder
 | 
						|
    let availableKeys: [String]
 | 
						|
    
 | 
						|
    init(originalEncoder: Encoder, availableKeys: [String]) {
 | 
						|
        self.originalEncoder = originalEncoder
 | 
						|
        self.availableKeys = availableKeys
 | 
						|
    }
 | 
						|
    
 | 
						|
    var codingPath: [CodingKey] { originalEncoder.codingPath }
 | 
						|
    var userInfo: [CodingUserInfoKey: Any] { originalEncoder.userInfo }
 | 
						|
    
 | 
						|
    func container<Key>(keyedBy type: Key.Type) -> KeyedEncodingContainer<Key> where Key: CodingKey {
 | 
						|
        let container = originalEncoder.container(keyedBy: type)
 | 
						|
        let filteredContainer = FilteredKeyedEncodingContainer(
 | 
						|
            availableKeys: availableKeys,
 | 
						|
            originalContainer: container
 | 
						|
        )
 | 
						|
        
 | 
						|
        return KeyedEncodingContainer(filteredContainer)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func unkeyedContainer() -> UnkeyedEncodingContainer { originalEncoder.unkeyedContainer() }
 | 
						|
    func singleValueContainer() -> SingleValueEncodingContainer { originalEncoder.singleValueContainer() }
 | 
						|
}
 | 
						|
 | 
						|
// MARK: - FilteredKeyedEncodingContainer
 | 
						|
 | 
						|
private class FilteredKeyedEncodingContainer<Key: CodingKey>: KeyedEncodingContainerProtocol {
 | 
						|
    let codingPath: [CodingKey]
 | 
						|
    let availableKeys: [String]
 | 
						|
    var originalContainer: KeyedEncodingContainer<Key>
 | 
						|
    
 | 
						|
    init(availableKeys: [String], originalContainer: KeyedEncodingContainer<Key>) {
 | 
						|
        self.availableKeys = availableKeys
 | 
						|
        self.codingPath = originalContainer.codingPath
 | 
						|
        self.originalContainer = originalContainer
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encodeNil(forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encodeNil(forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Bool, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: String, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Double, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Float, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Int, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Int8, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Int16, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Int32, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: Int64, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: UInt, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: UInt8, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: UInt16, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: UInt32, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode(_ value: UInt64, forKey key: Key) throws {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func encode<T>(_ value: T, forKey key: Key) throws where T: Encodable {
 | 
						|
        guard availableKeys.contains(key.stringValue) else { return }
 | 
						|
        
 | 
						|
        try originalContainer.encode(value, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func nestedContainer<NestedKey>(keyedBy keyType: NestedKey.Type, forKey key: Key) -> KeyedEncodingContainer<NestedKey> where NestedKey: CodingKey {
 | 
						|
        return originalContainer.nestedContainer(keyedBy: keyType, forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func nestedUnkeyedContainer(forKey key: Key) -> UnkeyedEncodingContainer {
 | 
						|
        return originalContainer.nestedUnkeyedContainer(forKey: key)
 | 
						|
    }
 | 
						|
    
 | 
						|
    func superEncoder() -> Encoder {
 | 
						|
        return originalContainer.superEncoder()
 | 
						|
    }
 | 
						|
    
 | 
						|
    func superEncoder(forKey key: Key) -> Encoder {
 | 
						|
        return originalContainer.superEncoder(forKey: key)
 | 
						|
    }
 | 
						|
}
 |