Optimise Snode and Snode.Version

pull/1593/head
bemusementpark 9 months ago
parent a56e1d0b91
commit b93ec3be04

@ -166,8 +166,6 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database(
const val RESET_SEQ_NO = "UPDATE $lastMessageServerIDTable SET $lastMessageServerID = 0;" const val RESET_SEQ_NO = "UPDATE $lastMessageServerIDTable SET $lastMessageServerID = 0;"
const val EMPTY_VERSION = "0.0.0"
// endregion // endregion
} }
@ -175,15 +173,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database(
val database = databaseHelper.readableDatabase val database = databaseHelper.readableDatabase
return database.get(snodePoolTable, "${Companion.dummyKey} = ?", wrap("dummy_key")) { cursor -> return database.get(snodePoolTable, "${Companion.dummyKey} = ?", wrap("dummy_key")) { cursor ->
val snodePoolAsString = cursor.getString(cursor.getColumnIndexOrThrow(snodePool)) val snodePoolAsString = cursor.getString(cursor.getColumnIndexOrThrow(snodePool))
snodePoolAsString.split(", ").mapNotNull { snodeAsString -> snodePoolAsString.split(", ").mapNotNull(::Snode)
val components = snodeAsString.split("-")
val address = components[0]
val port = components.getOrNull(1)?.toIntOrNull() ?: return@mapNotNull null
val ed25519Key = components.getOrNull(2) ?: return@mapNotNull null
val x25519Key = components.getOrNull(3) ?: return@mapNotNull null
val version = components.getOrNull(4) ?: EMPTY_VERSION
Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version)
}
}?.toSet() ?: setOf() }?.toSet() ?: setOf()
} }
@ -231,18 +221,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database(
val database = databaseHelper.readableDatabase val database = databaseHelper.readableDatabase
fun get(indexPath: String): Snode? { fun get(indexPath: String): Snode? {
return database.get(onionRequestPathTable, "${Companion.indexPath} = ?", wrap(indexPath)) { cursor -> return database.get(onionRequestPathTable, "${Companion.indexPath} = ?", wrap(indexPath)) { cursor ->
val snodeAsString = cursor.getString(cursor.getColumnIndexOrThrow(snode)) Snode(cursor.getString(cursor.getColumnIndexOrThrow(snode)))
val components = snodeAsString.split("-")
val address = components[0]
val port = components.getOrNull(1)?.toIntOrNull()
val ed25519Key = components.getOrNull(2)
val x25519Key = components.getOrNull(3)
val version = components.getOrNull(4) ?: EMPTY_VERSION
if (port != null && ed25519Key != null && x25519Key != null) {
Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version)
} else {
null
}
} }
} }
val result = mutableListOf<List<Snode>>() val result = mutableListOf<List<Snode>>()
@ -276,15 +255,7 @@ class LokiAPIDatabase(context: Context, helper: SQLCipherOpenHelper) : Database(
val database = databaseHelper.readableDatabase val database = databaseHelper.readableDatabase
return database.get(swarmTable, "${Companion.swarmPublicKey} = ?", wrap(publicKey)) { cursor -> return database.get(swarmTable, "${Companion.swarmPublicKey} = ?", wrap(publicKey)) { cursor ->
val swarmAsString = cursor.getString(cursor.getColumnIndexOrThrow(swarm)) val swarmAsString = cursor.getString(cursor.getColumnIndexOrThrow(swarm))
swarmAsString.split(", ").mapNotNull { targetAsString -> swarmAsString.split(", ").mapNotNull(::Snode)
val components = targetAsString.split("-")
val address = components[0]
val port = components.getOrNull(1)?.toIntOrNull() ?: return@mapNotNull null
val ed25519Key = components.getOrNull(2) ?: return@mapNotNull null
val x25519Key = components.getOrNull(3) ?: return@mapNotNull null
val version = components.getOrNull(4) ?: EMPTY_VERSION
Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version)
}
}?.toSet() }?.toSet()
} }

@ -18,6 +18,7 @@ import nl.komponents.kovenant.task
import org.session.libsession.messaging.MessagingModuleConfiguration import org.session.libsession.messaging.MessagingModuleConfiguration
import org.session.libsession.messaging.utilities.MessageWrapper import org.session.libsession.messaging.utilities.MessageWrapper
import org.session.libsession.messaging.utilities.SodiumUtilities.sodium import org.session.libsession.messaging.utilities.SodiumUtilities.sodium
import org.session.libsession.utilities.toByteArray
import org.session.libsignal.crypto.getRandomElement import org.session.libsignal.crypto.getRandomElement
import org.session.libsignal.database.LokiAPIDatabaseProtocol import org.session.libsignal.database.LokiAPIDatabaseProtocol
import org.session.libsignal.protos.SignalServiceProtos import org.session.libsignal.protos.SignalServiceProtos
@ -94,8 +95,6 @@ object SnodeAPI {
const val KEY_ED25519 = "pubkey_ed25519" const val KEY_ED25519 = "pubkey_ed25519"
const val KEY_VERSION = "storage_server_version" const val KEY_VERSION = "storage_server_version"
const val EMPTY_VERSION = "0.0.0"
// Error // Error
sealed class Error(val description: String) : Exception(description) { sealed class Error(val description: String) : Exception(description) {
object Generic : Error("An error occurred.") object Generic : Error("An error occurred.")
@ -191,7 +190,7 @@ object SnodeAPI {
val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String
val version = (rawSnodeAsJSON?.get(KEY_VERSION) as? ArrayList<*>) val version = (rawSnodeAsJSON?.get(KEY_VERSION) as? ArrayList<*>)
?.filterIsInstance<Int>() // get the array as Integers ?.filterIsInstance<Int>() // get the array as Integers
?.joinToString(separator = ".") // turn it int a version string ?.let(Snode::Version) // turn it int a version
if (address != null && port != null && ed25519Key != null && x25519Key != null if (address != null && port != null && ed25519Key != null && x25519Key != null
&& address != "0.0.0.0" && version != null) { && address != "0.0.0.0" && version != null) {
@ -696,7 +695,7 @@ object SnodeAPI {
getSingleTargetSnode(publicKey).bind { snode -> getSingleTargetSnode(publicKey).bind { snode ->
retryIfNeeded(maxRetryCount) { retryIfNeeded(maxRetryCount) {
val signature = ByteArray(Sign.BYTES) val signature = ByteArray(Sign.BYTES)
val verificationData = (Snode.Method.DeleteMessage.rawValue + serverHashes.fold("") { a, v -> a + v }).toByteArray() val verificationData = sequenceOf(Snode.Method.DeleteMessage.rawValue).plus(serverHashes).toByteArray()
sodium.cryptoSignDetached(signature, verificationData, verificationData.size.toLong(), userED25519KeyPair.secretKey.asBytes) sodium.cryptoSignDetached(signature, verificationData, verificationData.size.toLong(), userED25519KeyPair.secretKey.asBytes)
val deleteMessageParams = mapOf( val deleteMessageParams = mapOf(
"pubkey" to userPublicKey, "pubkey" to userPublicKey,
@ -719,7 +718,7 @@ object SnodeAPI {
val signature = json["signature"] as String val signature = json["signature"] as String
val snodePublicKey = Key.fromHexString(hexSnodePublicKey) val snodePublicKey = Key.fromHexString(hexSnodePublicKey)
// The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] ) // The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] )
val message = (userPublicKey + serverHashes.fold("") { a, v -> a + v } + hashes.fold("") { a, v -> a + v }).toByteArray() val message = sequenceOf(userPublicKey).plus(serverHashes).plus(hashes).toByteArray()
sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes) sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes)
} }
} }
@ -733,11 +732,10 @@ object SnodeAPI {
} }
// Parsing // Parsing
private fun parseSnodes(rawResponse: Any): List<Snode> { private fun parseSnodes(rawResponse: Any): List<Snode> =
val json = rawResponse as? Map<*, *> (rawResponse as? Map<*, *>)
val rawSnodes = json?.get("snodes") as? List<*> ?.run { get("snodes") as? List<*> }
if (rawSnodes != null) { ?.mapNotNull { rawSnode ->
return rawSnodes.mapNotNull { rawSnode ->
val rawSnodeAsJSON = rawSnode as? Map<*, *> val rawSnodeAsJSON = rawSnode as? Map<*, *>
val address = rawSnodeAsJSON?.get("ip") as? String val address = rawSnodeAsJSON?.get("ip") as? String
val portAsString = rawSnodeAsJSON?.get("port") as? String val portAsString = rawSnodeAsJSON?.get("port") as? String
@ -746,17 +744,12 @@ object SnodeAPI {
val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String val x25519Key = rawSnodeAsJSON?.get(KEY_X25519) as? String
if (address != null && port != null && ed25519Key != null && x25519Key != null && address != "0.0.0.0") { if (address != null && port != null && ed25519Key != null && x25519Key != null && address != "0.0.0.0") {
Snode("https://$address", port, Snode.KeySet(ed25519Key, x25519Key), EMPTY_VERSION) Snode("https://$address", port, Snode.KeySet(ed25519Key, x25519Key), Snode.Version.ZERO)
} else { } else {
Log.d("Loki", "Failed to parse snode from: ${rawSnode?.prettifiedDescription()}.") Log.d("Loki", "Failed to parse snode from: ${rawSnode?.prettifiedDescription()}.")
null null
} }
} } ?: listOf<Snode>().also { Log.d("Loki", "Failed to parse snodes from: ${rawResponse.prettifiedDescription()}.") }
} else {
Log.d("Loki", "Failed to parse snodes from: ${rawResponse.prettifiedDescription()}.")
return listOf()
}
}
fun deleteAllMessages(): Promise<Map<String,Boolean>, Exception> { fun deleteAllMessages(): Promise<Map<String,Boolean>, Exception> {
return retryIfNeeded(maxRetryCount) { return retryIfNeeded(maxRetryCount) {
@ -796,8 +789,7 @@ object SnodeAPI {
getSingleTargetSnode(userPublicKey).bind { snode -> getSingleTargetSnode(userPublicKey).bind { snode ->
retryIfNeeded(maxRetryCount) { retryIfNeeded(maxRetryCount) {
// "expire" || expiry || messages[0] || ... || messages[N] // "expire" || expiry || messages[0] || ... || messages[N]
val verificationData = val verificationData = sequenceOf(Snode.Method.Expire.rawValue, "$updatedExpiryMsWithNetworkOffset").plus(serverHashes).toByteArray()
(Snode.Method.Expire.rawValue + updatedExpiryMsWithNetworkOffset + serverHashes.fold("") { a, v -> a + v }).toByteArray()
val signature = ByteArray(Sign.BYTES) val signature = ByteArray(Sign.BYTES)
sodium.cryptoSignDetached( sodium.cryptoSignDetached(
signature, signature,
@ -828,7 +820,7 @@ object SnodeAPI {
val signature = json["signature"] as String val signature = json["signature"] as String
val snodePublicKey = Key.fromHexString(hexSnodePublicKey) val snodePublicKey = Key.fromHexString(hexSnodePublicKey)
// The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] ) // The signature looks like ( PUBKEY_HEX || RMSG[0] || ... || RMSG[N] || DMSG[0] || ... || DMSG[M] )
val message = (userPublicKey + serverHashes.fold("") { a, v -> a + v } + hashes.fold("") { a, v -> a + v }).toByteArray() val message = sequenceOf(userPublicKey).plus(serverHashes).plus(hashes).toByteArray()
if (sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes)) { if (sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes)) {
hashes to expiryApplied hashes to expiryApplied
} else listOf<String>() to 0L } else listOf<String>() to 0L
@ -922,7 +914,7 @@ object SnodeAPI {
val signature = json["signature"] as String val signature = json["signature"] as String
val snodePublicKey = Key.fromHexString(hexSnodePublicKey) val snodePublicKey = Key.fromHexString(hexSnodePublicKey)
// The signature looks like ( PUBKEY_HEX || TIMESTAMP || DELETEDHASH[0] || ... || DELETEDHASH[N] ) // The signature looks like ( PUBKEY_HEX || TIMESTAMP || DELETEDHASH[0] || ... || DELETEDHASH[N] )
val message = (userPublicKey + timestamp.toString() + hashes.joinToString(separator = "")).toByteArray() val message = sequenceOf(userPublicKey, "$timestamp").plus(hashes).toByteArray()
sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes) sodium.cryptoSignVerifyDetached(Base64.decode(signature), message, message.size, snodePublicKey.asBytes)
} }
} }

@ -366,34 +366,6 @@ object Util {
val digitGroups = (Math.log10(sizeBytes.toDouble()) / Math.log10(1024.0)).toInt() val digitGroups = (Math.log10(sizeBytes.toDouble()) / Math.log10(1024.0)).toInt()
return DecimalFormat("#,##0.#").format(sizeBytes / Math.pow(1024.0, digitGroups.toDouble())) + " " + units[digitGroups] return DecimalFormat("#,##0.#").format(sizeBytes / Math.pow(1024.0, digitGroups.toDouble())) + " " + units[digitGroups]
} }
/**
* Compares two version strings (for example "1.8.0")
*
* @param version1 the first version string to compare.
* @param version2 the second version string to compare.
* @return an integer indicating the result of the comparison:
* - 0 if the versions are equal
* - a positive number if version1 is greater than version2
* - a negative number if version1 is less than version2
*/
@JvmStatic
fun compareVersions(version1: String, version2: String): Int {
val parts1 = version1.split(".").map { it.toIntOrNull() ?: 0 }
val parts2 = version2.split(".").map { it.toIntOrNull() ?: 0 }
val maxLength = maxOf(parts1.size, parts2.size)
val paddedParts1 = parts1 + List(maxLength - parts1.size) { 0 }
val paddedParts2 = parts2 + List(maxLength - parts2.size) { 0 }
for (i in 0 until maxLength) {
val compare = paddedParts1[i].compareTo(paddedParts2[i])
if (compare != 0) {
return compare
}
}
return 0
}
} }
fun <T, R> T.runIf(condition: Boolean, block: T.() -> R): R where T: R = if (condition) block() else this fun <T, R> T.runIf(condition: Boolean, block: T.() -> R): R where T: R = if (condition) block() else this
@ -440,3 +412,8 @@ fun <E, K: Any, V: Any> Iterable<E>.associateByNotNull(
inline fun <E, K> Iterable<E>.groupByNotNull(keySelector: (E) -> K?): Map<K, List<E>> = LinkedHashMap<K, MutableList<E>>().also { inline fun <E, K> Iterable<E>.groupByNotNull(keySelector: (E) -> K?): Map<K, List<E>> = LinkedHashMap<K, MutableList<E>>().also {
forEach { e -> keySelector(e)?.let { k -> it.getOrPut(k) { mutableListOf() } += e } } forEach { e -> keySelector(e)?.let { k -> it.getOrPut(k) { mutableListOf() } += e } }
} }
fun Sequence<String>.toByteArray(): ByteArray = ByteArrayOutputStream().use { output ->
forEach { it.byteInputStream().use { input -> input.copyTo(output) } }
output.toByteArray()
}

@ -1,9 +1,21 @@
package org.session.libsignal.utilities package org.session.libsignal.utilities
class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val version: String) { import android.annotation.SuppressLint
fun Snode(string: String): Snode? {
val components = string.split("-")
val address = components[0]
val port = components.getOrNull(1)?.toIntOrNull() ?: return null
val ed25519Key = components.getOrNull(2) ?: return null
val x25519Key = components.getOrNull(3) ?: return null
val version = components.getOrNull(4)?.let(Snode::Version) ?: Snode.Version.ZERO
return Snode(address, port, Snode.KeySet(ed25519Key, x25519Key), version)
}
class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val version: Version) {
val ip: String get() = address.removePrefix("https://") val ip: String get() = address.removePrefix("https://")
public enum class Method(val rawValue: String) { enum class Method(val rawValue: String) {
GetSwarm("get_snodes_for_pubkey"), GetSwarm("get_snodes_for_pubkey"),
Retrieve("retrieve"), Retrieve("retrieve"),
SendMessage("store"), SendMessage("store"),
@ -32,4 +44,42 @@ class Snode(val address: String, val port: Int, val publicKeySet: KeySet?, val v
} }
override fun toString(): String { return "$address:$port" } override fun toString(): String { return "$address:$port" }
companion object {
private val CACHE = mutableMapOf<String, Version>()
@SuppressLint("NotConstructor")
fun Version(value: String) = CACHE.getOrElse(value) {
Snode.Version(value)
}
}
@JvmInline
value class Version(val value: ULong) {
companion object {
val ZERO = Version(0UL)
private const val MASK_BITS = 16
private const val MASK = 0xFFFFUL
private fun Sequence<ULong>.foldToVersionAsULong() = take(4).foldIndexed(0UL) { i, acc, it ->
it and MASK shl (3 - i) * MASK_BITS or acc
}
}
constructor(parts: List<Int>): this(
parts.asSequence()
.map { it.toByte().toULong() }
.foldToVersionAsULong()
)
constructor(value: Int): this(value.toULong())
internal constructor(value: String): this(
value.splitToSequence(".")
.map { it.toULongOrNull() ?: 0UL }
.foldToVersionAsULong()
)
operator fun compareTo(other: Version): Int = value.compareTo(other.value)
}
} }

Loading…
Cancel
Save