From b93ec3be04f3370ac5b7d285295a0e142038fe82 Mon Sep 17 00:00:00 2001 From: bemusementpark Date: Fri, 2 Aug 2024 12:22:25 +0930 Subject: [PATCH] Optimise Snode and Snode.Version --- .../securesms/database/LokiAPIDatabase.kt | 35 ++---------- .../org/session/libsession/snode/SnodeAPI.kt | 34 +++++------- .../org/session/libsession/utilities/Util.kt | 33 ++---------- .../org/session/libsignal/utilities/Snode.kt | 54 ++++++++++++++++++- 4 files changed, 73 insertions(+), 83 deletions(-) diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/LokiAPIDatabase.kt b/app/src/main/java/org/thoughtcrime/securesms/database/LokiAPIDatabase.kt index f1f999242c..ff22ae14e3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/LokiAPIDatabase.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/LokiAPIDatabase.kt @@ -166,8 +166,6 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database( const val RESET_SEQ_NO = "UPDATE $lastMessageServerIDTable SET $lastMessageServerID = 0;" - const val EMPTY_VERSION = "0.0.0" - // endregion } @@ -175,15 +173,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database( val database = databaseHelper.readableDatabase return database.get(snodePoolTable, "${Companion.dummyKey} = ?", wrap("dummy_key")) { cursor -> val snodePoolAsString = cursor.getString(cursor.getColumnIndexOrThrow(snodePool)) - snodePoolAsString.split(", ").mapNotNull { snodeAsString -> - val components = snodeAsString.split("-") - val address = components[0] - val port = components.getOrNull(1)?.toIntOrNull() ?: return@mapNotNull null - val ed25519Key = components.getOrNull(2) ?: return@mapNotNull null - val x25519Key = components.getOrNull(3) ?: return@mapNotNull null - val version = components.getOrNull(4) ?: EMPTY_VERSION - Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version) - } + snodePoolAsString.split(", ").mapNotNull(::Snode) }?.toSet() ?: setOf() } @@ -231,18 +221,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database( val database = databaseHelper.readableDatabase fun get(indexPath: String): Snode? { return database.get(onionRequestPathTable, "${Companion.indexPath} = ?", wrap(indexPath)) { cursor -> - val snodeAsString = cursor.getString(cursor.getColumnIndexOrThrow(snode)) - val components = snodeAsString.split("-") - val address = components[0] - val port = components.getOrNull(1)?.toIntOrNull() - val ed25519Key = components.getOrNull(2) - val x25519Key = components.getOrNull(3) - val version = components.getOrNull(4) ?: EMPTY_VERSION - if (port != null && ed25519Key != null && x25519Key != null) { - Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version) - } else { - null - } + Snode(cursor.getString(cursor.getColumnIndexOrThrow(snode))) } } val result = mutableListOf>() @@ -276,15 +255,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database( val database = databaseHelper.readableDatabase return database.get(swarmTable, "${Companion.swarmPublicKey} = ?", wrap(publicKey)) { cursor -> val swarmAsString = cursor.getString(cursor.getColumnIndexOrThrow(swarm)) - swarmAsString.split(", ").mapNotNull { targetAsString -> - val components = targetAsString.split("-") - val address = components[0] - val port = components.getOrNull(1)?.toIntOrNull() ?: return@mapNotNull null - val ed25519Key = components.getOrNull(2) ?: return@mapNotNull null - val x25519Key = components.getOrNull(3) ?: return@mapNotNull null - val version = components.getOrNull(4) ?: EMPTY_VERSION - Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version) - } + swarmAsString.split(", ").mapNotNull(::Snode) }?.toSet() } diff --git a/libsession/src/main/java/org/session/libsession/snode/SnodeAPI.kt b/libsession/src/main/java/org/session/libsession/snode/SnodeAPI.kt index 034ef0e7d2..9ceefe8386 100644 --- a/libsession/src/main/java/org/session/libsession/snode/SnodeAPI.kt +++ b/libsession/src/main/java/org/session/libsession/snode/SnodeAPI.kt @@ -18,6 +18,7 @@ import nl.komponents.kovenant.task import org.session.libsession.messaging.MessagingModuleConfiguration import org.session.libsession.messaging.utilities.MessageWrapper import org.session.libsession.messaging.utilities.SodiumUtilities.sodium +import org.session.libsession.utilities.toByteArray import org.session.libsignal.crypto.getRandomElement import org.session.libsignal.database.LokiAPIDatabaseProtocol import org.session.libsignal.protos.SignalServiceProtos @@ -94,8 +95,6 @@ object SnodeAPI { const val KEY_ED25519 = "pubkey_ed25519" const val KEY_VERSION = "storage_server_version" - const val EMPTY_VERSION = "0.0.0" - // Error sealed class Error(val description: String) : Exception(description) { object Generic : Error("An error occurred.") @@ -191,7 +190,7 @@ object SnodeAPI { val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String val version = (rawSnodeAsJSON?.get(KEY_VERSION) as? ArrayList<*>) ?.filterIsInstance() // get the array as Integers - ?.joinToString(separator = ".") // turn it int a version string + ?.let(Snode::Version) // turn it int a version if (address != null && port != null && ed25519Key != null && x25519Key != null && address != "0.0.0.0" && version != null) { @@ -696,7 +695,7 @@ object SnodeAPI { getSingleTargetSnode(publicKey).bind { snode -> retryIfNeeded(maxRetryCount) { val signature = ByteArray(Sign.BYTES) - val verificationData = (Snode.Method.DeleteMessage.rawValue + serverHashes.fold("") { a, v -> a + v }).toByteArray() + val verificationData = sequenceOf(Snode.Method.DeleteMessage.rawValue).plus(serverHashes).toByteArray() sodium.cryptoSignDetached(signature, verificationData, verificationData.size.toLong(), userED25519KeyPair.secretKey.asBytes) val deleteMessageParams = mapOf( "pubkey" to userPublicKey, @@ -719,7 +718,7 @@ object SnodeAPI { val signature = json["signature"] as String val snodePublicKey = Key.fromHexString(hexSnodePublicKey) // The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] ) - val message = (userPublicKey + serverHashes.fold("") { a, v -> a + v } + hashes.fold("") { a, v -> a + v }).toByteArray() + val message = sequenceOf(userPublicKey).plus(serverHashes).plus(hashes).toByteArray() sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes) } } @@ -733,11 +732,10 @@ object SnodeAPI { } // Parsing - private fun parseSnodes(rawResponse: Any): List { - val json = rawResponse as? Map<*, *> - val rawSnodes = json?.get("snodes") as? List<*> - if (rawSnodes != null) { - return rawSnodes.mapNotNull { rawSnode -> + private fun parseSnodes(rawResponse: Any): List = + (rawResponse as? Map<*, *>) + ?.run { get("snodes") as? List<*> } + ?.mapNotNull { rawSnode -> val rawSnodeAsJSON = rawSnode as? Map<*, *> val address = rawSnodeAsJSON?.get("ip") as? String val portAsString = rawSnodeAsJSON?.get("port") as? String @@ -746,17 +744,12 @@ object SnodeAPI { val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String if (address != null && port != null && ed25519Key != null && x25519Key != null && address != "0.0.0.0") { - Snode("https://$address", port, Snode.KeySet(ed25519Key, x25519Key), EMPTY_VERSION) + Snode("https://$address", port, Snode.KeySet(ed25519Key, x25519Key), Snode.Version.ZERO) } else { Log.d("Loki", "Failed to parse snode from: ${rawSnode?.prettifiedDescription()}.") null } - } - } else { - Log.d("Loki", "Failed to parse snodes from: ${rawResponse.prettifiedDescription()}.") - return listOf() - } - } + } ?: listOf().also { Log.d("Loki", "Failed to parse snodes from: ${rawResponse.prettifiedDescription()}.") } fun deleteAllMessages(): Promise, Exception> { return retryIfNeeded(maxRetryCount) { @@ -796,8 +789,7 @@ object SnodeAPI { getSingleTargetSnode(userPublicKey).bind { snode -> retryIfNeeded(maxRetryCount) { // "expire" || expiry || messages[0] || ... || messages[N] - val verificationData = - (Snode.Method.Expire.rawValue + updatedExpiryMsWithNetworkOffset + serverHashes.fold("") { a, v -> a + v }).toByteArray() + val verificationData = sequenceOf(Snode.Method.Expire.rawValue, "$updatedExpiryMsWithNetworkOffset").plus(serverHashes).toByteArray() val signature = ByteArray(Sign.BYTES) sodium.cryptoSignDetached( signature, @@ -828,7 +820,7 @@ object SnodeAPI { val signature = json["signature"] as String val snodePublicKey = Key.fromHexString(hexSnodePublicKey) // The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] ) - val message = (userPublicKey + serverHashes.fold("") { a, v -> a + v } + hashes.fold("") { a, v -> a + v }).toByteArray() + val message = sequenceOf(userPublicKey).plus(serverHashes).plus(hashes).toByteArray() if (sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes)) { hashes to expiryApplied } else listOf() to 0L @@ -922,7 +914,7 @@ object SnodeAPI { val signature = json["signature"] as String val snodePublicKey = Key.fromHexString(hexSnodePublicKey) // The signature looks like ( PUBKEY_HEX || TIMESTAMP || DELETEDHASH[0] || ... || DELETEDHASH[N] ) - val message = (userPublicKey + timestamp.toString() + hashes.joinToString(separator = "")).toByteArray() + val message = sequenceOf(userPublicKey, "$timestamp").plus(hashes).toByteArray() sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes) } } diff --git a/libsession/src/main/java/org/session/libsession/utilities/Util.kt b/libsession/src/main/java/org/session/libsession/utilities/Util.kt index 929f53e305..d47754b7ed 100644 --- a/libsession/src/main/java/org/session/libsession/utilities/Util.kt +++ b/libsession/src/main/java/org/session/libsession/utilities/Util.kt @@ -366,34 +366,6 @@ object Util { val digitGroups = (Math.log10(sizeBytes.toDouble()) / Math.log10(1024.0)).toInt() return DecimalFormat("#,##0.#").format(sizeBytes / Math.pow(1024.0, digitGroups.toDouble())) + " " + units[digitGroups] } - - /** - * Compares two version strings (for example "1.8.0") - * - * @param version1 the first version string to compare. - * @param version2 the second version string to compare. - * @return an integer indicating the result of the comparison: - * - 0 if the versions are equal - * - a positive number if version1 is greater than version2 - * - a negative number if version1 is less than version2 - */ - @JvmStatic - fun compareVersions(version1: String, version2: String): Int { - val parts1 = version1.split(".").map { it.toIntOrNull() ?: 0 } - val parts2 = version2.split(".").map { it.toIntOrNull() ?: 0 } - - val maxLength = maxOf(parts1.size, parts2.size) - val paddedParts1 = parts1 + List(maxLength - parts1.size) { 0 } - val paddedParts2 = parts2 + List(maxLength - parts2.size) { 0 } - - for (i in 0 until maxLength) { - val compare = paddedParts1[i].compareTo(paddedParts2[i]) - if (compare != 0) { - return compare - } - } - return 0 - } } fun T.runIf(condition: Boolean, block: T.() -> R): R where T: R = if (condition) block() else this @@ -440,3 +412,8 @@ fun Iterable.associateByNotNull( inline fun Iterable.groupByNotNull(keySelector: (E) -> K?): Map> = LinkedHashMap>().also { forEach { e -> keySelector(e)?.let { k -> it.getOrPut(k) { mutableListOf() } += e } } } + +fun Sequence.toByteArray(): ByteArray = ByteArrayOutputStream().use { output -> + forEach { it.byteInputStream().use { input -> input.copyTo(output) } } + output.toByteArray() +} diff --git a/libsignal/src/main/java/org/session/libsignal/utilities/Snode.kt b/libsignal/src/main/java/org/session/libsignal/utilities/Snode.kt index f6b11754ad..cc123a8527 100644 --- a/libsignal/src/main/java/org/session/libsignal/utilities/Snode.kt +++ b/libsignal/src/main/java/org/session/libsignal/utilities/Snode.kt @@ -1,9 +1,21 @@ package org.session.libsignal.utilities -class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val version: String) { +import android.annotation.SuppressLint + +fun Snode(string: String): Snode? { + val components = string.split("-") + val address = components[0] + val port = components.getOrNull(1)?.toIntOrNull() ?: return null + val ed25519Key = components.getOrNull(2) ?: return null + val x25519Key = components.getOrNull(3) ?: return null + val version = components.getOrNull(4)?.let(Snode::Version) ?: Snode.Version.ZERO + return Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version) +} + +class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val version: Version) { val ip: String get() = address.removePrefix("https://") - public enum class Method(val rawValue: String) { + enum class Method(val rawValue: String) { GetSwarm("get_snodes_for_pubkey"), Retrieve("retrieve"), SendMessage("store"), @@ -32,4 +44,42 @@ class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val v } override fun toString(): String { return "$address:$port" } + + companion object { + private val CACHE = mutableMapOf() + + @SuppressLint("NotConstructor") + fun Version(value: String) = CACHE.getOrElse(value) { + Snode.Version(value) + } + } + + @JvmInline + value class Version(val value: ULong) { + companion object { + val ZERO = Version(0UL) + private const val MASK_BITS = 16 + private const val MASK = 0xFFFFUL + + private fun Sequence.foldToVersionAsULong() = take(4).foldIndexed(0UL) { i, acc, it -> + it and MASK shl (3 - i) * MASK_BITS or acc + } + } + + constructor(parts: List): this( + parts.asSequence() + .map { it.toByte().toULong() } + .foldToVersionAsULong() + ) + + constructor(value: Int): this(value.toULong()) + + internal constructor(value: String): this( + value.splitToSequence(".") + .map { it.toULongOrNull() ?: 0UL } + .foldToVersionAsULong() + ) + + operator fun compareTo(other: Version): Int = value.compareTo(other.value) + } }