Skip to content

Commit b7d455e

Browse files
authored
Validate pre-key key-id ranges
1 parent ac23b8e commit b7d455e

16 files changed

Lines changed: 180 additions & 37 deletions

service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
99
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
1010
import io.swagger.v3.oas.annotations.media.Schema;
11+
import jakarta.validation.constraints.Max;
12+
import jakarta.validation.constraints.Min;
1113
import org.signal.libsignal.protocol.ecc.ECPublicKey;
14+
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
1215
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
1316

1417
public record ECPreKey(
1518
@Schema(description="""
1619
An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up.
1720
Should not be zero. Should be less than 2^24.
1821
""")
22+
@Max(KeyIdUtil.MAX_KEY_ID)
23+
@Min(KeyIdUtil.MIN_KEY_ID)
1924
long keyId,
2025

2126
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)

service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
99
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
1010
import io.swagger.v3.oas.annotations.media.Schema;
11+
import jakarta.validation.constraints.Max;
12+
import jakarta.validation.constraints.Min;
1113
import org.signal.libsignal.protocol.ecc.ECPublicKey;
14+
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
1215
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
1316
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
1417
import java.util.Arrays;
@@ -19,6 +22,8 @@ public record ECSignedPreKey(
1922
An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up.
2023
Should not be zero. Should be less than 2^24.
2124
""")
25+
@Max(KeyIdUtil.MAX_KEY_ID)
26+
@Min(KeyIdUtil.MIN_KEY_ID)
2227
long keyId,
2328

2429
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)

service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
99
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
1010
import io.swagger.v3.oas.annotations.media.Schema;
11+
import jakarta.validation.constraints.Max;
12+
import jakarta.validation.constraints.Min;
1113
import org.signal.libsignal.protocol.kem.KEMPublicKey;
14+
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
1215
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
1316
import org.whispersystems.textsecuregcm.util.KEMPublicKeyAdapter;
1417
import java.util.Arrays;
@@ -20,6 +23,8 @@ public record KEMSignedPreKey(
2023
Should not be zero. Should be less than 2^24. The owner of this key must be able to determine from the key ID whether this represents
2124
a single-use or last-resort key, but another party should *not* be able to tell.
2225
""")
26+
@Max(KeyIdUtil.MAX_KEY_ID)
27+
@Min(KeyIdUtil.MIN_KEY_ID)
2328
long keyId,
2429

2530
@JsonSerialize(using = KEMPublicKeyAdapter.Serializer.class)

service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
2020
import org.whispersystems.textsecuregcm.storage.Account;
2121
import org.whispersystems.textsecuregcm.storage.Device;
22+
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
2223
import org.whispersystems.textsecuregcm.storage.KeysManager;
2324

2425
class KeysGrpcHelper {
@@ -67,19 +68,20 @@ static Optional<AccountPreKeyBundles> getPreKeys(final Account targetAccount,
6768
preKeysByDeviceId.forEach((deviceId, devicePreKeys) -> {
6869
final Device device = targetAccount.getDevice(deviceId).orElseThrow();
6970

71+
7072
final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder()
7173
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
72-
.setKeyId(devicePreKeys.ecSignedPreKey().keyId())
74+
.setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.ecSignedPreKey().keyId()))
7375
.setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey()))
7476
.setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature())))
7577
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
76-
.setKeyId(devicePreKeys.kemSignedPreKey().keyId())
78+
.setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.kemSignedPreKey().keyId()))
7779
.setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey()))
7880
.setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature())))
7981
.setRegistrationId(device.getRegistrationId(targetServiceIdentifier.identityType()));
8082

8183
devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder()
82-
.setKeyId(ecPreKey.keyId())
84+
.setKeyId(KeyIdUtil.toUnsignedInt(ecPreKey.keyId()))
8385
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))));
8486

8587
preKeyBundlesBuilder.putDevicePreKeys(deviceId, builder.build());
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2026 Signal Messenger, LLC
3+
* SPDX-License-Identifier: AGPL-3.0-only
4+
*/
5+
package org.whispersystems.textsecuregcm.storage;
6+
7+
public class KeyIdUtil {
8+
public static final long MAX_KEY_ID = (1L << 32) - 1;
9+
public static final long MIN_KEY_ID = 0;
10+
private KeyIdUtil(){}
11+
12+
public static boolean keyIdValid(final long keyId) {
13+
return keyId <= MAX_KEY_ID && keyId >= MIN_KEY_ID;
14+
}
15+
16+
/// Convert a long keyId (a 32-bit unsigned int) into an int representation.
17+
///
18+
/// The inverse of [Integer#toUnsignedLong].
19+
///
20+
/// @param keyId A key ID which must be in the range [0, 2^32)
21+
/// @throws IllegalArgumentException If `keyId` is not within the range
22+
/// @return A 32-bit unsigned integer where the top bit is stored in the sign bit
23+
public static int toUnsignedInt(final long keyId) {
24+
if (!keyIdValid(keyId)) {
25+
throw new IllegalArgumentException("Invalid keyId " + keyId);
26+
}
27+
return (int) keyId;
28+
}
29+
}

service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
99

10+
import io.micrometer.core.instrument.Counter;
1011
import io.micrometer.core.instrument.DistributionSummary;
1112
import io.micrometer.core.instrument.Metrics;
1213
import io.micrometer.core.instrument.Timer;
@@ -71,6 +72,7 @@ public class PagedSingleUseKEMPreKeyStore {
7172
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
7273
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
7374

75+
private final Counter outOfRangeKeysDiscarded = Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded"));
7476
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
7577
.builder(name(getClass(), "availableKeyCount"))
7678
.register(Metrics.globalRegistry);
@@ -159,7 +161,14 @@ public CompletableFuture<Void> store(
159161
*/
160162
public CompletableFuture<Optional<KEMSignedPreKey>> take(final UUID identifier, final byte deviceId) {
161163
final Timer.Sample sample = Timer.start();
164+
return takeHelper(identifier, deviceId)
165+
.whenComplete((maybeKey, throwable) ->
166+
sample.stop(Metrics.timer(
167+
takeKeyTimerName,
168+
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
169+
}
162170

171+
private CompletableFuture<Optional<KEMSignedPreKey>> takeHelper(final UUID identifier, final byte deviceId) {
163172
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
164173
.tableName(tableName)
165174
.key(Map.of(
@@ -196,10 +205,15 @@ public CompletableFuture<Optional<KEMSignedPreKey>> take(final UUID identifier,
196205
.exceptionally(ExceptionUtils.exceptionallyHandler(
197206
ConditionalCheckFailedException.class,
198207
e -> Optional.empty()))
199-
.whenComplete((maybeKey, throwable) ->
200-
sample.stop(Metrics.timer(
201-
takeKeyTimerName,
202-
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
208+
.thenCompose(maybeKey -> {
209+
if (!maybeKey.map(KEMSignedPreKey::keyId).map(KeyIdUtil::keyIdValid).orElse(true)) {
210+
// At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require.
211+
// This keyId was invalid, so just recursively fetch the next key
212+
outOfRangeKeysDiscarded.increment();
213+
return takeHelper(identifier, deviceId);
214+
}
215+
return CompletableFuture.completedFuture(maybeKey);
216+
});
203217
}
204218

205219
/**

service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,14 @@ public CompletableFuture<Optional<K>> find(final UUID identifier, final byte dev
107107
.build())
108108
.thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty());
109109

110-
findFuture.whenComplete((maybeSignedPreKey, throwable) ->
111-
sample.stop(Metrics.timer(findKeyTimerName,
112-
"keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent()))));
113-
114-
return findFuture;
110+
return findFuture.whenComplete((maybeSignedPreKey, throwable) -> {
111+
if (throwable == null && maybeSignedPreKey.map(k -> !KeyIdUtil.keyIdValid(k.keyId())).orElse(false)) {
112+
throw new IllegalStateException("Encountered an impossible invalid repeated use pre-key id of " + maybeSignedPreKey.get().keyId());
113+
}
114+
115+
sample.stop(Metrics.timer(findKeyTimerName,
116+
"keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent())));
117+
});
115118
}
116119

117120
protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final byte deviceId) {

service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
4848
*/
4949
public class SingleUseECPreKeyStore {
50-
5150
private final DynamoDbAsyncClient dynamoDbAsyncClient;
5251
private final String tableName;
5352

@@ -58,12 +57,14 @@ public class SingleUseECPreKeyStore {
5857
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
5958

6059
private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable"));
61-
60+
private final Counter outOfRangeKeysDiscarded =
61+
Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded"));
6262
final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary
6363
.builder(name(getClass(), "keysConsideredForTake"))
6464
.distributionStatisticExpiry(Duration.ofMinutes(10))
6565
.register(Metrics.globalRegistry);
6666

67+
6768
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
6869
.builder(name(getClass(), "availableKeyCount"))
6970
.distributionStatisticExpiry(Duration.ofMinutes(10))
@@ -135,7 +136,7 @@ private CompletableFuture<Void> store(final UUID identifier, final byte deviceId
135136
public CompletableFuture<Optional<ECPreKey>> take(final UUID identifier, final byte deviceId) {
136137
final Timer.Sample sample = Timer.start();
137138
final AttributeValue partitionKey = getPartitionKey(identifier);
138-
final AtomicInteger keysConsidered = new AtomicInteger(0);
139+
final AtomicInteger deletionAttempts = new AtomicInteger(0);
139140

140141
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
141142
.tableName(tableName)
@@ -156,16 +157,26 @@ public CompletableFuture<Optional<ECPreKey>> take(final UUID identifier, final b
156157
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
157158
.returnValues(ReturnValue.ALL_OLD)
158159
.build())
159-
.flatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1)
160-
.doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet())
160+
.concatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
161+
.doOnNext(_ -> deletionAttempts.incrementAndGet())
161162
.filter(DeleteItemResponse::hasAttributes)
163+
.filter(item -> {
164+
final long keyId = getKeyIdFromItem(item.attributes());
165+
final boolean keyIdValid = KeyIdUtil.keyIdValid(keyId);
166+
if (!keyIdValid) {
167+
outOfRangeKeysDiscarded.increment();
168+
}
169+
// At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require.
170+
// If this keyId is invalid, we'll skip it and fetch the next key
171+
return keyIdValid;
172+
})
162173
.next()
163174
.map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes()))
164175
.toFuture()
165176
.thenApply(Optional::ofNullable)
166177
.whenComplete((maybeKey, throwable) -> {
167178
sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())));
168-
keysConsideredForTakeDistributionSummary.record(keysConsidered.get());
179+
keysConsideredForTakeDistributionSummary.record(deletionAttempts.get());
169180
});
170181
}
171182

@@ -310,7 +321,7 @@ KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()),
310321
}
311322

312323
private ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
313-
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
324+
final long keyId = getKeyIdFromItem(item);
314325
final byte[] publicKey = AttributeValues.extractByteArray(item.get(ATTR_PUBLIC_KEY), PARSE_BYTE_ARRAY_COUNTER_NAME);
315326

316327
try {
@@ -320,4 +331,8 @@ private ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
320331
throw new IllegalArgumentException(e);
321332
}
322333
}
334+
335+
private static long getKeyIdFromItem(final Map<String, AttributeValue> item) {
336+
return item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
337+
}
323338
}

service/src/main/proto/org/signal/chat/common.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ message EcPreKey {
4141
// A locally-unique identifier for this key, which will be provided by
4242
// peers using this key to encrypt messages so the private key can be looked
4343
// up.
44-
uint64 key_id = 1;
44+
uint32 key_id = 1;
4545

4646
// The public key, serialized in libsignal's elliptic-curve public key format.
4747
bytes public_key = 2 [(require.nonEmpty) = true];
@@ -51,7 +51,7 @@ message EcSignedPreKey {
5151
// A locally-unique identifier for this key, which will be provided by
5252
// peers using this key to encrypt messages so the private key can be looked
5353
// up.
54-
uint64 key_id = 1;
54+
uint32 key_id = 1;
5555

5656
// The public key, serialized in libsignal's elliptic-curve public key format.
5757
bytes public_key = 2 [(require.nonEmpty) = true];
@@ -64,7 +64,7 @@ message EcSignedPreKey {
6464
message KemSignedPreKey {
6565
// An locally-unique identifier for this key, which will be provided by peers
6666
// using this key to encrypt messages so the private key can be looked up.
67-
uint64 key_id = 1;
67+
uint32 key_id = 1;
6868

6969
// The public key, serialized in libsignal's Kyber1024 public key format.
7070
bytes public_key = 2 [(require.nonEmpty) = true];

service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.whispersystems.textsecuregcm.storage.Account;
6767
import org.whispersystems.textsecuregcm.storage.AccountsManager;
6868
import org.whispersystems.textsecuregcm.storage.Device;
69+
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
6970
import org.whispersystems.textsecuregcm.storage.KeysManager;
7071
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
7172
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@@ -508,22 +509,22 @@ private void assertGetKeysFailure(Status code, GetPreKeysAnonymousRequest reques
508509

509510
private static EcPreKey toGrpcEcPreKey(final ECPreKey preKey) {
510511
return EcPreKey.newBuilder()
511-
.setKeyId(preKey.keyId())
512+
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
512513
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
513514
.build();
514515
}
515516

516517
private static EcSignedPreKey toGrpcEcSignedPreKey(final ECSignedPreKey preKey) {
517518
return EcSignedPreKey.newBuilder()
518-
.setKeyId(preKey.keyId())
519+
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
519520
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
520521
.setSignature(ByteString.copyFrom(preKey.signature()))
521522
.build();
522523
}
523524

524525
private static KemSignedPreKey toGrpcKemSignedPreKey(final KEMSignedPreKey preKey) {
525526
return KemSignedPreKey.newBuilder()
526-
.setKeyId(preKey.keyId())
527+
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
527528
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
528529
.setSignature(ByteString.copyFrom(preKey.signature()))
529530
.build();

0 commit comments

Comments
 (0)