Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.assertj.core.api.Condition;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -39,7 +38,7 @@ public Builder hasHeader(String headerKey) {
validators.add(new StoreResponseValidator() {
@Override
public void validate(StoreResponse resp) {
assertThat(Arrays.asList(resp.getResponseHeaderNames()).contains(headerKey)).isTrue();
assertThat(resp.getHeaderValue(headerKey)).isNotNull();
}
});
return this;
Expand All @@ -49,9 +48,7 @@ public Builder withHeader(String headerKey, String headerValue) {
validators.add(new StoreResponseValidator() {
@Override
public void validate(StoreResponse resp) {
assertThat(Arrays.asList(resp.getResponseHeaderNames())).asList().contains(headerKey);
int index = Arrays.asList(resp.getResponseHeaderNames()).indexOf(headerKey);
assertThat(resp.getResponseHeaderValues()[index]).isEqualTo(headerValue);
assertThat(resp.getHeaderValue(headerKey)).isEqualTo(headerValue);
}
});
return this;
Expand All @@ -62,9 +59,8 @@ public Builder withHeaderValueCondition(String headerKey, Condition<String> cond
validators.add(new StoreResponseValidator() {
@Override
public void validate(StoreResponse resp) {
assertThat(Arrays.asList(resp.getResponseHeaderNames())).asList().contains(headerKey);
int index = Arrays.asList(resp.getResponseHeaderNames()).indexOf(headerKey);
String value = resp.getResponseHeaderValues()[index];
String value = resp.getHeaderValue(headerKey);
assertThat(value).isNotNull();
condition.matches(value);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,12 @@ public class RxDocumentServiceResponse {
private CosmosDiagnostics cosmosDiagnostics;

public RxDocumentServiceResponse(DiagnosticsClientContext diagnosticsClientContext, StoreResponse response) {
String[] headerNames = response.getResponseHeaderNames();
String[] headerValues = response.getResponseHeaderValues();

this.headersMap = new HashMap<>(headerNames.length);
this.headersMap = new HashMap<>(response.getResponseHeaders());

// Gets status code.
this.statusCode = response.getStatus();

// Extracts headers.
for (int i = 0; i < headerNames.length; i++) {
this.headersMap.put(headerNames[i], headerValues[i]);
}

this.storeResponse = response;
this.diagnosticsClientContext = diagnosticsClientContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -69,13 +71,13 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize
private static final boolean leakDetectionDebuggingEnabled = ResourceLeakDetector.getLevel().ordinal() >=
ResourceLeakDetector.Level.ADVANCED.ordinal();
private static final boolean HTTP_CONNECTION_WITHOUT_TLS_ALLOWED = Configs.isHttpConnectionWithoutTLSAllowed();
private static final List<String> headersNeedToBeEscaped = Arrays.asList(
private static final Set<String> headersNeedToBeEscaped = new HashSet<>(Arrays.asList(
HttpConstants.HttpHeaders.PARTITION_KEY,
HttpConstants.HttpHeaders.POST_TRIGGER_EXCLUDE,
HttpConstants.HttpHeaders.POST_TRIGGER_INCLUDE,
HttpConstants.HttpHeaders.PRE_TRIGGER_EXCLUDE,
HttpConstants.HttpHeaders.PRE_TRIGGER_INCLUDE
);
));

private final DiagnosticsClientContext clientContext;
private final Logger logger = LoggerFactory.getLogger(RxGatewayStoreModel.class);
Expand Down Expand Up @@ -211,7 +213,7 @@ public StoreResponse unwrapToStoreResponse(
String endpoint,
RxDocumentServiceRequest request,
int statusCode,
HttpHeaders headers,
Map<String, String> headers,
ByteBuf retainedContent) {

checkNotNull(headers, "Argument 'headers' must not be null.");
Expand All @@ -238,7 +240,7 @@ public StoreResponse unwrapToStoreResponse(
return new StoreResponse(
endpoint,
statusCode,
HttpUtils.unescape(headers.toLowerCaseMap()),
headers,
new ByteBufInputStream(retainedContent, true),
size);
} else {
Expand All @@ -248,7 +250,7 @@ public StoreResponse unwrapToStoreResponse(
return new StoreResponse(
endpoint,
statusCode,
HttpUtils.unescape(headers.toLowerCaseMap()),
headers,
null,
0);
}
Expand Down Expand Up @@ -437,8 +439,9 @@ private Mono<RxDocumentServiceResponse> toDocumentServiceResponse(Mono<HttpRespo
.publishOn(CosmosSchedulers.TRANSPORT_RESPONSE_BOUNDED_ELASTIC)
.flatMap(httpResponse -> {

// header key/value pairs
HttpHeaders httpResponseHeaders = httpResponse.headers();
// Build lowercase header map directly from transport headers.
// For HTTP/2, keys are already lowercase (no toLowerCase overhead).
Map<String, String> responseHeaders = HttpUtils.unescape(httpResponse.headersAsLowerCaseMap());
int httpResponseStatus = httpResponse.statusCode();

// Track the retained ByteBuf so we can release it as a safety net in doFinally
Expand Down Expand Up @@ -503,7 +506,7 @@ private Mono<RxDocumentServiceResponse> toDocumentServiceResponse(Mono<HttpRespo
}
StoreResponse rsp = request
.getEffectiveHttpTransportSerializer(this)
.unwrapToStoreResponse(httpRequest.uri().toString(), request, httpResponseStatus, httpResponseHeaders, content);
.unwrapToStoreResponse(httpRequest.uri().toString(), request, httpResponseStatus, responseHeaders, content);

// Only clear retainedBufRef AFTER StoreResponse successfully takes ownership.
// If unwrapToStoreResponse throws, retainedBufRef remains set so doFinally
Expand Down Expand Up @@ -707,7 +710,7 @@ private Mono<RxDocumentServiceResponse> toDocumentServiceResponse(Mono<HttpRespo

private void validateOrThrow(RxDocumentServiceRequest request,
HttpResponseStatus status,
HttpHeaders headers,
Map<String, String> headers,
ByteBuf retainedBodyAsByteBuf) {

int statusCode = status.code();
Expand All @@ -729,7 +732,7 @@ private void validateOrThrow(RxDocumentServiceRequest request,
String.format("%s, StatusCode: %s", cosmosError.getMessage(), statusCodeString),
cosmosError.getPartitionedQueryExecutionInfo());

CosmosException dce = BridgeInternal.createCosmosException(request.requestContext.resourcePhysicalAddress, statusCode, cosmosError, headers.toLowerCaseMap());
CosmosException dce = BridgeInternal.createCosmosException(request.requestContext.resourcePhysicalAddress, statusCode, cosmosError, headers);
BridgeInternal.setRequestHeaders(dce, request.getHeaders());
throw dce;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package com.azure.cosmos.implementation;

import com.azure.cosmos.ConsistencyLevel;
import com.azure.cosmos.implementation.directconnectivity.HttpUtils;
import com.azure.cosmos.implementation.directconnectivity.StoreResponse;
import com.azure.cosmos.implementation.directconnectivity.rntbd.RntbdConstants;
import com.azure.cosmos.implementation.directconnectivity.rntbd.RntbdFramer;
Expand Down Expand Up @@ -99,7 +100,7 @@ public StoreResponse unwrapToStoreResponse(
String endpoint,
RxDocumentServiceRequest request,
int statusCode,
HttpHeaders headers,
Map<String, String> headers,
ByteBuf content) {

if (content == null) {
Expand Down Expand Up @@ -141,7 +142,7 @@ public StoreResponse unwrapToStoreResponse(
endpoint,
request,
response.getStatus().code(),
new HttpHeaders(response.getHeaders().asMap(request.getActivityId())),
HttpUtils.unescape(response.getHeaders().asMap(request.getActivityId())),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Blocking · Correctness: HttpUtils.unescape() always throws UnsupportedOperationException on the ImmutableMap returned by asMap()

response.getHeaders().asMap(activityId) returns a com.azure.cosmos.implementation.guava25.collect.ImmutableMap. This PR now passes that directly to HttpUtils.unescape():

// ThinClientStoreModel.java line 145
HttpUtils.unescape(response.getHeaders().asMap(request.getActivityId()))

HttpUtils.unescape(Map) calls headers.computeIfPresent(...), and the vendored ImmutableMap overrides computeIfPresent to always throw unconditionally:

// guava25/collect/ImmutableMap.java:583-586
`@Deprecated`
`@Override`
public final V computeIfPresent(K key, BiFunction<...> remappingFunction) {
    throw new UnsupportedOperationException();   // always, no key check
}

So every call through ThinClientStoreModel.unwrapToStoreResponse that successfully decodes an RntbdResponse will throw UnsupportedOperationException before producing a StoreResponse. ThinClient mode is broken for all responses.

Before this PR, the old code wrapped in HttpHeaders first and then called unescape on the mutable map from toLowerCaseMap():

// Old (safe)
new HttpHeaders(response.getHeaders().asMap(request.getActivityId()))
// ... parent called: HttpUtils.unescape(headers.toLowerCaseMap())  ← new mutable HashMap

Fix: Wrap in a mutable HashMap before calling unescape:

HttpUtils.unescape(new HashMap<>(response.getHeaders().asMap(request.getActivityId()))),

This preserves the original semantics and avoids the UnsupportedOperationException.


⚠️ AI-generated review — may be incorrect. Agree? → resolve the conversation. Disagree? → reply with your reasoning.

payloadBuf
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.azure.cosmos.implementation.ISessionToken;
import com.azure.cosmos.implementation.InternalServerErrorException;
import com.azure.cosmos.implementation.OperationType;
import com.azure.cosmos.implementation.RMResources;
import com.azure.cosmos.implementation.ResourceType;
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
import com.azure.cosmos.implementation.RxDocumentServiceResponse;
Expand Down Expand Up @@ -174,19 +173,7 @@ private RxDocumentServiceResponse completeResponse(
StoreResponse storeResponse,
RxDocumentServiceRequest request) throws InternalServerErrorException {

if (storeResponse.getResponseHeaderNames().length != storeResponse.getResponseHeaderValues().length) {
throw new InternalServerErrorException(
Exceptions.getInternalServerErrorMessage(RMResources.InvalidBackendResponse),
HttpConstants.SubStatusCodes.INVALID_BACKEND_RESPONSE);
}

Map<String, String> headers = new HashMap<>(storeResponse.getResponseHeaderNames().length);
for (int idx = 0; idx < storeResponse.getResponseHeaderNames().length; idx++) {
String name = storeResponse.getResponseHeaderNames()[idx];
String value = storeResponse.getResponseHeaderValues()[idx];

headers.put(name, value);
}
Map<String, String> headers = new HashMap<>(storeResponse.getResponseHeaders());

this.updateResponseHeader(request, headers);
this.captureSessionToken(request, headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

Expand All @@ -30,8 +30,7 @@
public class StoreResponse {
private static final Logger logger = LoggerFactory.getLogger(StoreResponse.class.getSimpleName());
final private int status;
final private String[] responseHeaderNames;
final private String[] responseHeaderValues;
final private Map<String, String> responseHeaders;
private int requestPayloadLength;
private RequestTimeline requestTimeline;
private RntbdChannelAcquisitionTimeline channelAcquisitionTimeline;
Expand All @@ -56,22 +55,14 @@ public StoreResponse(
checkArgument((contentStream == null) == (responsePayloadLength == 0),
"Parameter 'contentStream' must be consistent with 'responsePayloadLength'.");
requestTimeline = RequestTimeline.empty();
responseHeaderNames = new String[headerMap.size()];
responseHeaderValues = new String[headerMap.size()];
this.responseHeaders = toLowerCasedMap(headerMap);
this.endpoint = endpoint != null ? endpoint : "";

int i = 0;
for (Map.Entry<String, String> headerEntry : headerMap.entrySet()) {
responseHeaderNames[i] = headerEntry.getKey();
responseHeaderValues[i] = headerEntry.getValue();
i++;
}

this.status = status;
replicaStatusList = new HashMap<>();
if (contentStream != null) {
try {
this.responsePayload = new JsonNodeStorePayload(contentStream, responsePayloadLength, headerMap);
this.responsePayload = new JsonNodeStorePayload(contentStream, responsePayloadLength, this.responseHeaders);
} finally {
try {
contentStream.close();
Expand All @@ -91,37 +82,34 @@ private StoreResponse(
String endpoint,
int status,
Map<String, String> headerMap,
JsonNodeStorePayload responsePayload) {
JsonNodeStorePayload responsePayload,
boolean keysAlreadyLowerCased) {

checkNotNull(endpoint, "Parameter 'endpoint' must not be null.");

requestTimeline = RequestTimeline.empty();
responseHeaderNames = new String[headerMap.size()];
responseHeaderValues = new String[headerMap.size()];
this.responseHeaders = keysAlreadyLowerCased ? headerMap : toLowerCasedMap(headerMap);
this.endpoint = endpoint;

int i = 0;
for (Map.Entry<String, String> headerEntry : headerMap.entrySet()) {
responseHeaderNames[i] = headerEntry.getKey();
responseHeaderValues[i] = headerEntry.getValue();
i++;
}

this.status = status;
replicaStatusList = new HashMap<>();
this.responsePayload = responsePayload;
}

public int getStatus() {
return status;
private static Map<String, String> toLowerCasedMap(Map<String, String> map) {
Map<String, String> result = new HashMap<>(map.size());
for (Map.Entry<String, String> entry : map.entrySet()) {
result.put(entry.getKey().toLowerCase(Locale.ROOT), entry.getValue());
}
return result;
}

public String[] getResponseHeaderNames() {
return responseHeaderNames;
public int getStatus() {
return status;
}

public String[] getResponseHeaderValues() {
return responseHeaderValues;
public Map<String, String> getResponseHeaders() {
return responseHeaders;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Recommendation · Code Quality: getResponseHeaders() exposes the live mutable internal map

responseHeaders is a final HashMap, but final only prevents reassignment — the map's contents remain mutable. Returning the live reference means any caller can silently modify StoreResponse internal state:

storeResponse.getResponseHeaders().put("x-ms-foo", "injected"); // modifies StoreResponse internals

All current production callers are well-behaved (they either take a defensive copy or read-only), but this is a latent footgun. Hardening with Collections.unmodifiableMap is zero-cost at runtime for read paths and doesn't break anything — setHeaderValue() mutates this.responseHeaders directly, not through the returned reference:

public Map(String, String) getResponseHeaders() {
    return Collections.unmodifiableMap(responseHeaders);
}

⚠️ AI-generated review — may be incorrect. Agree? → resolve the conversation. Disagree? → reply with your reasoning.

}

public void setRntbdRequestLength(int rntbdRequestLength) {
Expand Down Expand Up @@ -191,31 +179,19 @@ public String getCorrelatedActivityId() {
}

public String getHeaderValue(String attribute) {
if (this.responseHeaderValues == null || this.responseHeaderNames.length != this.responseHeaderValues.length) {
if (this.responseHeaders == null) {
return null;
}

for (int i = 0; i < responseHeaderNames.length; i++) {
if (responseHeaderNames[i].equalsIgnoreCase(attribute)) {
return responseHeaderValues[i];
}
}

return null;
return responseHeaders.get(attribute.toLowerCase(Locale.ROOT));
}

//NOTE: only used for testing purpose to change the response header value
void setHeaderValue(String headerName, String value) {
if (this.responseHeaderValues == null || this.responseHeaderNames.length != this.responseHeaderValues.length) {
if (this.responseHeaders == null) {
return;
}

for (int i = 0; i < responseHeaderNames.length; i++) {
if (responseHeaderNames[i].equalsIgnoreCase(headerName)) {
responseHeaderValues[i] = value;
break;
}
}
this.responseHeaders.put(headerName.toLowerCase(Locale.ROOT), value);
}

public double getRequestCharge() {
Expand Down Expand Up @@ -310,23 +286,20 @@ public void setFaultInjectionRuleEvaluationResults(List<String> results) {

public StoreResponse withRemappedStatusCode(int newStatusCode, double additionalRequestCharge) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Recommendation · Test Coverage: withRemappedStatusCode has no unit test

This method was substantially refactored (parallel arrays → Map, new keysAlreadyLowerCased=true constructor path, containsKey guard for REQUEST_CHARGE), but StoreResponseTest covers none of it. Consider adding tests for:

  1. Status code is updated — returned response has newStatusCode, not the original
  2. REQUEST_CHARGE is accumulated — if header is present, additionalRequestCharge is added
  3. REQUEST_CHARGE absent — if header was never set, no charge header is introduced (guards against regression of the prior unconditional-insert bug)
  4. Headers are copied, not aliased — mutating the original StoreResponse headers after calling withRemappedStatusCode does not affect the returned response

Example skeleton:

`@Test`(groups = { "unit" })
public void withRemappedStatusCode_updatesChargeAndStatus() {
    HashMap(String, String) headers = new HashMap<>();
    headers.put(HttpConstants.HttpHeaders.REQUEST_CHARGE.toLowerCase(Locale.ROOT), "5.0");
    StoreResponse original = new StoreResponse(null, 200, headers, null, 0);

    StoreResponse remapped = original.withRemappedStatusCode(201, 2.0);

    assertThat(remapped.getStatus()).isEqualTo(201);
    assertThat(Double.parseDouble(remapped.getHeaderValue(HttpConstants.HttpHeaders.REQUEST_CHARGE))).isEqualTo(7.0);
    // original is unmodified
    assertThat(original.getStatus()).isEqualTo(200);
}

⚠️ AI-generated review — may be incorrect. Agree? → resolve the conversation. Disagree? → reply with your reasoning.


Map<String, String> headers = new HashMap<>();
for (int i = 0; i < this.responseHeaderNames.length; i++) {
String headerName = this.responseHeaderNames[i];
if (headerName.equalsIgnoreCase(HttpConstants.HttpHeaders.REQUEST_CHARGE)) {
double currentRequestCharge = this.getRequestCharge();
double newRequestCharge = currentRequestCharge + additionalRequestCharge;
headers.put(headerName, String.valueOf(newRequestCharge));
} else {
headers.put(headerName, this.responseHeaderValues[i]);
}
Map<String, String> headers = new HashMap<>(this.responseHeaders);
String requestChargeKey = HttpConstants.HttpHeaders.REQUEST_CHARGE.toLowerCase(Locale.ROOT);
if (headers.containsKey(requestChargeKey)) {
double currentRequestCharge = this.getRequestCharge();
double newRequestCharge = currentRequestCharge + additionalRequestCharge;
headers.put(requestChargeKey, String.valueOf(newRequestCharge));
}

return new StoreResponse(
this.endpoint,
newStatusCode,
headers,
this.responsePayload);
this.responsePayload,
true);
}

public String getEndpoint() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import com.azure.cosmos.implementation.InternalServerErrorException;
import com.azure.cosmos.implementation.RMResources;
import com.azure.cosmos.implementation.RequestChargeTracker;
import com.azure.cosmos.implementation.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -141,13 +140,8 @@ private static void setRequestCharge(StoreResponse response, CosmosException cos
Double.toString(totalRequestCharge));
}
// Set total charge as final charge for the response.
else if (response.getResponseHeaderNames() != null) {
for (int i = 0; i < response.getResponseHeaderNames().length; ++i) {
if (Strings.areEqualIgnoreCase(response.getResponseHeaderNames()[i], HttpConstants.HttpHeaders.REQUEST_CHARGE)) {
response.getResponseHeaderValues()[i] = Double.toString(totalRequestCharge);
break;
}
}
else if (response.getResponseHeaders() != null) {
response.setHeaderValue(HttpConstants.HttpHeaders.REQUEST_CHARGE, Double.toString(totalRequestCharge));
}
}
}
Loading
Loading