From 9dce376780b8a42c00e5725b6bdc2d2334a50bf0 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Wed, 3 Sep 2014 02:51:20 -0700 Subject: [PATCH] Correctly handle formatting when "one time PreKey" is absent. --- .../test/SessionBuilderTest.java | 55 +++++++++++++++++++ .../libaxolotl/SessionBuilder.java | 53 ++++++++++-------- .../libaxolotl/SessionCipher.java | 11 ++-- .../protocol/PreKeyWhisperMessage.java | 47 +++++++++------- .../libaxolotl/state/SessionState.java | 42 +++++++++----- 5 files changed, 145 insertions(+), 63 deletions(-) diff --git a/libaxolotl/src/androidTest/java/org/whispersystems/test/SessionBuilderTest.java b/libaxolotl/src/androidTest/java/org/whispersystems/test/SessionBuilderTest.java index 477f51e061..0feb223e2a 100644 --- a/libaxolotl/src/androidTest/java/org/whispersystems/test/SessionBuilderTest.java +++ b/libaxolotl/src/androidTest/java/org/whispersystems/test/SessionBuilderTest.java @@ -3,6 +3,7 @@ package org.whispersystems.test; import android.test.AndroidTestCase; import org.whispersystems.libaxolotl.DuplicateMessageException; +import org.whispersystems.libaxolotl.IdentityKey; import org.whispersystems.libaxolotl.InvalidKeyException; import org.whispersystems.libaxolotl.InvalidKeyIdException; import org.whispersystems.libaxolotl.InvalidMessageException; @@ -568,6 +569,60 @@ public class SessionBuilderTest extends AndroidTestCase { bobSessionStore, bobPreKeyStore, bobSignedPreKeyStore, bobIdentityKeyStore); } + public void testOptionalOneTimePreKey() throws Exception { + SessionStore aliceSessionStore = new InMemorySessionStore(); + SignedPreKeyStore aliceSignedPreKeyStore = new InMemorySignedPreKeyStore(); + PreKeyStore alicePreKeyStore = new InMemoryPreKeyStore(); + IdentityKeyStore aliceIdentityKeyStore = new InMemoryIdentityKeyStore(); + SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceSessionStore, alicePreKeyStore, + aliceSignedPreKeyStore, + aliceIdentityKeyStore, + BOB_RECIPIENT_ID, 1); + + SessionStore bobSessionStore = new InMemorySessionStore(); + PreKeyStore bobPreKeyStore = new InMemoryPreKeyStore(); + SignedPreKeyStore bobSignedPreKeyStore = new InMemorySignedPreKeyStore(); + IdentityKeyStore bobIdentityKeyStore = new InMemoryIdentityKeyStore(); + + ECKeyPair bobPreKeyPair = Curve.generateKeyPair(); + ECKeyPair bobSignedPreKeyPair = Curve.generateKeyPair(); + byte[] bobSignedPreKeySignature = Curve.calculateSignature(bobIdentityKeyStore.getIdentityKeyPair().getPrivateKey(), + bobSignedPreKeyPair.getPublicKey().serialize()); + + PreKeyBundle bobPreKey = new PreKeyBundle(bobIdentityKeyStore.getLocalRegistrationId(), 1, + 0, null, + 22, bobSignedPreKeyPair.getPublicKey(), + bobSignedPreKeySignature, + bobIdentityKeyStore.getIdentityKeyPair().getPublicKey()); + + aliceSessionBuilder.process(bobPreKey); + + assertTrue(aliceSessionStore.containsSession(BOB_RECIPIENT_ID, 1)); + assertTrue(!aliceSessionStore.loadSession(BOB_RECIPIENT_ID, 1).getSessionState().getNeedsRefresh()); + assertTrue(aliceSessionStore.loadSession(BOB_RECIPIENT_ID, 1).getSessionState().getSessionVersion() == 3); + + String originalMessage = "L'homme est condamné à être libre"; + SessionCipher aliceSessionCipher = new SessionCipher(aliceSessionStore, alicePreKeyStore, aliceSignedPreKeyStore, aliceIdentityKeyStore, BOB_RECIPIENT_ID, 1); + CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); + + assertTrue(outgoingMessage.getType() == CiphertextMessage.PREKEY_TYPE); + + PreKeyWhisperMessage incomingMessage = new PreKeyWhisperMessage(outgoingMessage.serialize()); + assertTrue(!incomingMessage.getPreKeyId().isPresent()); + + bobPreKeyStore.storePreKey(31337, new PreKeyRecord(bobPreKey.getPreKeyId(), bobPreKeyPair)); + bobSignedPreKeyStore.storeSignedPreKey(22, new SignedPreKeyRecord(22, System.currentTimeMillis(), bobSignedPreKeyPair, bobSignedPreKeySignature)); + + SessionCipher bobSessionCipher = new SessionCipher(bobSessionStore, bobPreKeyStore, bobSignedPreKeyStore, bobIdentityKeyStore, ALICE_RECIPIENT_ID, 1); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); + + assertTrue(bobSessionStore.containsSession(ALICE_RECIPIENT_ID, 1)); + assertTrue(bobSessionStore.loadSession(ALICE_RECIPIENT_ID, 1).getSessionState().getSessionVersion() == 3); + assertTrue(bobSessionStore.loadSession(ALICE_RECIPIENT_ID, 1).getSessionState().getAliceBaseKey() != null); + assertTrue(originalMessage.equals(new String(plaintext))); + } + + private void runInteraction(SessionStore aliceSessionStore, PreKeyStore alicePreKeyStore, SignedPreKeyStore aliceSignedPreKeyStore, diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java index bcf0ef962e..c5890760c3 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java @@ -88,12 +88,13 @@ public class SessionBuilder { * @throws org.whispersystems.libaxolotl.InvalidKeyException when the message is formatted incorrectly. * @throws org.whispersystems.libaxolotl.UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted. */ - /*package*/ int process(SessionRecord sessionRecord, PreKeyWhisperMessage message) + /*package*/ Optional process(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException { int messageVersion = message.getMessageVersion(); IdentityKey theirIdentityKey = message.getIdentityKey(); - int unsignedPreKeyId; + + Optional unsignedPreKeyId; if (!identityKeyStore.isTrustedIdentity(recipientId, theirIdentityKey)) { throw new UntrustedIdentityException(); @@ -109,13 +110,13 @@ public class SessionBuilder { return unsignedPreKeyId; } - private int processV3(SessionRecord sessionRecord, PreKeyWhisperMessage message) + private Optional processV3(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException { if (sessionRecord.hasSessionState(message.getMessageVersion(), message.getBaseKey().serialize())) { Log.w(TAG, "We've already setup a session for this V3 message, letting bundled message fall through..."); - return -1; + return Optional.absent(); } boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); @@ -129,8 +130,8 @@ public class SessionBuilder { .setOurSignedPreKey(ourSignedPreKey) .setOurRatchetKey(ourSignedPreKey); - if (message.getPreKeyId() >= 0) { - parameters.setOurOneTimePreKey(Optional.of(preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair())); + if (message.getPreKeyId().isPresent()) { + parameters.setOurOneTimePreKey(Optional.of(preKeyStore.loadPreKey(message.getPreKeyId().get()).getKeyPair())); } else { parameters.setOurOneTimePreKey(Optional.absent()); } @@ -146,25 +147,28 @@ public class SessionBuilder { if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true); - if (message.getPreKeyId() >= 0 && message.getPreKeyId() != Medium.MAX_VALUE) { + if (message.getPreKeyId().isPresent() && message.getPreKeyId().get() != Medium.MAX_VALUE) { return message.getPreKeyId(); } else { - return -1; + return Optional.absent(); } } - private int processV2(SessionRecord sessionRecord, PreKeyWhisperMessage message) + private Optional processV2(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException { + if (!message.getPreKeyId().isPresent()) { + throw new InvalidKeyIdException("V2 message requires one time prekey id!"); + } - if (!preKeyStore.containsPreKey(message.getPreKeyId()) && + if (!preKeyStore.containsPreKey(message.getPreKeyId().get()) && sessionStore.containsSession(recipientId, deviceId)) { Log.w(TAG, "We've already processed the prekey part of this V2 session, letting bundled message fall through..."); - return -1; + return Optional.absent(); } - ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair(); + ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId().get()).getKeyPair(); boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); BobAxolotlParameters.Builder parameters = BobAxolotlParameters.newBuilder(); @@ -186,10 +190,10 @@ public class SessionBuilder { if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true); - if (message.getPreKeyId() != Medium.MAX_VALUE) { + if (message.getPreKeyId().get() != Medium.MAX_VALUE) { return message.getPreKeyId(); } else { - return -1; + return Optional.absent(); } } @@ -222,11 +226,14 @@ public class SessionBuilder { throw new InvalidKeyException("Both signed and unsigned prekeys are absent!"); } - boolean isExistingSession = sessionStore.containsSession(recipientId, deviceId); - SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); - ECKeyPair ourBaseKey = Curve.generateKeyPair(); - ECPublicKey theirSignedPreKey = preKey.getSignedPreKey() != null ? preKey.getSignedPreKey() : - preKey.getPreKey(); + boolean supportsV3 = preKey.getSignedPreKey() != null; + boolean isExistingSession = sessionStore.containsSession(recipientId, deviceId); + SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); + ECKeyPair ourBaseKey = Curve.generateKeyPair(); + ECPublicKey theirSignedPreKey = supportsV3 ? preKey.getSignedPreKey() : preKey.getPreKey(); + Optional theirOneTimePreKey = Optional.fromNullable(preKey.getPreKey()); + Optional theirOneTimePreKeyId = theirOneTimePreKey.isPresent() ? Optional.of(preKey.getPreKeyId()) : + Optional.absent(); AliceAxolotlParameters.Builder parameters = AliceAxolotlParameters.newBuilder(); @@ -235,18 +242,16 @@ public class SessionBuilder { .setTheirIdentityKey(preKey.getIdentityKey()) .setTheirSignedPreKey(theirSignedPreKey) .setTheirRatchetKey(theirSignedPreKey) - .setTheirOneTimePreKey(preKey.getSignedPreKey() != null ? - Optional.fromNullable(preKey.getPreKey()) : - Optional.absent()); + .setTheirOneTimePreKey(supportsV3 ? theirOneTimePreKey : Optional.absent()); if (isExistingSession) sessionRecord.archiveCurrentState(); else sessionRecord.reset(); RatchetingSession.initializeSession(sessionRecord.getSessionState(), - preKey.getSignedPreKey() == null ? 2 : 3, + supportsV3 ? 3 : 2, parameters.create()); - sessionRecord.getSessionState().setUnacknowledgedPreKeyMessage(preKey.getPreKeyId(), preKey.getSignedPreKeyId(), ourBaseKey.getPublicKey()); + sessionRecord.getSessionState().setUnacknowledgedPreKeyMessage(theirOneTimePreKeyId, preKey.getSignedPreKeyId(), ourBaseKey.getPublicKey()); sessionRecord.getSessionState().setLocalRegistrationId(identityKeyStore.getLocalRegistrationId()); sessionRecord.getSessionState().setRemoteRegistrationId(preKey.getRegistrationId()); diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java index ceabbd18fd..e4fac57627 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java @@ -33,6 +33,7 @@ import org.whispersystems.libaxolotl.state.SessionStore; import org.whispersystems.libaxolotl.state.SignedPreKeyStore; import org.whispersystems.libaxolotl.util.ByteUtil; import org.whispersystems.libaxolotl.util.Pair; +import org.whispersystems.libaxolotl.util.guava.Optional; import java.security.InvalidAlgorithmParameterException; import java.security.NoSuchAlgorithmException; @@ -147,14 +148,14 @@ public class SessionCipher { InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException { synchronized (SESSION_LOCK) { - SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); - int unsignedPreKeyId = sessionBuilder.process(sessionRecord, ciphertext); - byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage()); + SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); + Optional unsignedPreKeyId = sessionBuilder.process(sessionRecord, ciphertext); + byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage()); sessionStore.storeSession(recipientId, deviceId, sessionRecord); - if (unsignedPreKeyId >=0) { - preKeyStore.removePreKey(unsignedPreKeyId); + if (unsignedPreKeyId.isPresent()) { + preKeyStore.removePreKey(unsignedPreKeyId.get()); } return plaintext; diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/protocol/PreKeyWhisperMessage.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/protocol/PreKeyWhisperMessage.java index f379c92d21..fff6d02a8c 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/protocol/PreKeyWhisperMessage.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/protocol/PreKeyWhisperMessage.java @@ -27,18 +27,19 @@ import org.whispersystems.libaxolotl.LegacyMessageException; import org.whispersystems.libaxolotl.ecc.Curve; import org.whispersystems.libaxolotl.ecc.ECPublicKey; import org.whispersystems.libaxolotl.util.ByteUtil; +import org.whispersystems.libaxolotl.util.guava.Optional; public class PreKeyWhisperMessage implements CiphertextMessage { - private final int version; - private final int registrationId; - private final int preKeyId; - private final int signedPreKeyId; - private final ECPublicKey baseKey; - private final IdentityKey identityKey; - private final WhisperMessage message; - private final byte[] serialized; + private final int version; + private final int registrationId; + private final Optional preKeyId; + private final int signedPreKeyId; + private final ECPublicKey baseKey; + private final IdentityKey identityKey; + private final WhisperMessage message; + private final byte[] serialized; public PreKeyWhisperMessage(byte[] serialized) throws InvalidMessageException, InvalidVersionException @@ -65,7 +66,7 @@ public class PreKeyWhisperMessage implements CiphertextMessage { this.serialized = serialized; this.registrationId = preKeyWhisperMessage.getRegistrationId(); - this.preKeyId = preKeyWhisperMessage.hasPreKeyId() ? preKeyWhisperMessage.getPreKeyId() : -1; + this.preKeyId = preKeyWhisperMessage.hasPreKeyId() ? Optional.of(preKeyWhisperMessage.getPreKeyId()) : Optional.absent(); this.signedPreKeyId = preKeyWhisperMessage.hasSignedPreKeyId() ? preKeyWhisperMessage.getSignedPreKeyId() : -1; this.baseKey = Curve.decodePoint(preKeyWhisperMessage.getBaseKey().toByteArray(), 0); this.identityKey = new IdentityKey(Curve.decodePoint(preKeyWhisperMessage.getIdentityKey().toByteArray(), 0)); @@ -75,8 +76,9 @@ public class PreKeyWhisperMessage implements CiphertextMessage { } } - public PreKeyWhisperMessage(int messageVersion, int registrationId, int preKeyId, int signedPreKeyId, - ECPublicKey baseKey, IdentityKey identityKey, WhisperMessage message) + public PreKeyWhisperMessage(int messageVersion, int registrationId, Optional preKeyId, + int signedPreKeyId, ECPublicKey baseKey, IdentityKey identityKey, + WhisperMessage message) { this.version = messageVersion; this.registrationId = registrationId; @@ -86,15 +88,20 @@ public class PreKeyWhisperMessage implements CiphertextMessage { this.identityKey = identityKey; this.message = message; + WhisperProtos.PreKeyWhisperMessage.Builder builder = + WhisperProtos.PreKeyWhisperMessage.newBuilder() + .setSignedPreKeyId(signedPreKeyId) + .setBaseKey(ByteString.copyFrom(baseKey.serialize())) + .setIdentityKey(ByteString.copyFrom(identityKey.serialize())) + .setMessage(ByteString.copyFrom(message.serialize())) + .setRegistrationId(registrationId); + + if (preKeyId.isPresent()) { + builder.setPreKeyId(preKeyId.get()); + } + byte[] versionBytes = {ByteUtil.intsToByteHighAndLow(this.version, CURRENT_VERSION)}; - byte[] messageBytes = WhisperProtos.PreKeyWhisperMessage.newBuilder() - .setPreKeyId(preKeyId) - .setSignedPreKeyId(signedPreKeyId) - .setBaseKey(ByteString.copyFrom(baseKey.serialize())) - .setIdentityKey(ByteString.copyFrom(identityKey.serialize())) - .setMessage(ByteString.copyFrom(message.serialize())) - .setRegistrationId(registrationId) - .build().toByteArray(); + byte[] messageBytes = builder.build().toByteArray(); this.serialized = ByteUtil.combine(versionBytes, messageBytes); } @@ -111,7 +118,7 @@ public class PreKeyWhisperMessage implements CiphertextMessage { return registrationId; } - public int getPreKeyId() { + public Optional getPreKeyId() { return preKeyId; } diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java index f6bc90bc26..a700e89aa2 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java @@ -36,6 +36,7 @@ import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.Chain; import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.PendingKeyExchange; import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.PendingPreKey; import org.whispersystems.libaxolotl.util.Pair; +import org.whispersystems.libaxolotl.util.guava.Optional; import java.util.Iterator; import java.util.LinkedList; @@ -415,15 +416,17 @@ public class SessionState { return sessionStructure.hasPendingKeyExchange(); } - public void setUnacknowledgedPreKeyMessage(int preKeyId, int signedPreKeyId, ECPublicKey baseKey) { - PendingPreKey pending = PendingPreKey.newBuilder() - .setPreKeyId(preKeyId) - .setSignedPreKeyId(signedPreKeyId) - .setBaseKey(ByteString.copyFrom(baseKey.serialize())) - .build(); + public void setUnacknowledgedPreKeyMessage(Optional preKeyId, int signedPreKeyId, ECPublicKey baseKey) { + PendingPreKey.Builder pending = PendingPreKey.newBuilder() + .setSignedPreKeyId(signedPreKeyId) + .setBaseKey(ByteString.copyFrom(baseKey.serialize())); + + if (preKeyId.isPresent()) { + pending.setPreKeyId(preKeyId.get()); + } this.sessionStructure = this.sessionStructure.toBuilder() - .setPendingPreKey(pending) + .setPendingPreKey(pending.build()) .build(); } @@ -433,8 +436,16 @@ public class SessionState { public UnacknowledgedPreKeyMessageItems getUnacknowledgedPreKeyMessageItems() { try { + Optional preKeyId; + + if (sessionStructure.getPendingPreKey().hasPreKeyId()) { + preKeyId = Optional.of(sessionStructure.getPendingPreKey().getPreKeyId()); + } else { + preKeyId = Optional.absent(); + } + return - new UnacknowledgedPreKeyMessageItems(sessionStructure.getPendingPreKey().getPreKeyId(), + new UnacknowledgedPreKeyMessageItems(preKeyId, sessionStructure.getPendingPreKey().getSignedPreKeyId(), Curve.decodePoint(sessionStructure.getPendingPreKey() .getBaseKey() @@ -475,18 +486,21 @@ public class SessionState { } public static class UnacknowledgedPreKeyMessageItems { - private final int preKeyId; - private final int signedPreKeyId; - private final ECPublicKey baseKey; - - public UnacknowledgedPreKeyMessageItems(int preKeyId, int signedPreKeyId, ECPublicKey baseKey) { + private final Optional preKeyId; + private final int signedPreKeyId; + private final ECPublicKey baseKey; + + public UnacknowledgedPreKeyMessageItems(Optional preKeyId, + int signedPreKeyId, + ECPublicKey baseKey) + { this.preKeyId = preKeyId; this.signedPreKeyId = signedPreKeyId; this.baseKey = baseKey; } - public int getPreKeyId() { + public Optional getPreKeyId() { return preKeyId; }