Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.AGENT_ID_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.BINARY_DATA_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CHECKPOINT_ID_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CREATED_TIME_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INFER_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LAST_UPDATED_TIME_FIELD;
Expand Down Expand Up @@ -68,6 +69,9 @@ public class MLWorkingMemory implements ToXContentObject, Writeable {
private Instant lastUpdateTime;
private String ownerId;

// Checkpoint field
private String checkpointId;

@Builder
public MLWorkingMemory(
String memoryContainerId,
Expand All @@ -82,7 +86,8 @@ public MLWorkingMemory(
Map<String, String> tags,
Instant createdTime,
Instant lastUpdateTime,
String ownerId
String ownerId,
String checkpointId
) {
// MAX_MESSAGES_PER_REQUEST limit removed for performance testing

Expand All @@ -100,6 +105,7 @@ public MLWorkingMemory(
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.ownerId = ownerId;
this.checkpointId = checkpointId;
}

public MLWorkingMemory(StreamInput in) throws IOException {
Expand Down Expand Up @@ -131,6 +137,7 @@ public MLWorkingMemory(StreamInput in) throws IOException {
this.createdTime = in.readOptionalInstant();
this.lastUpdateTime = in.readOptionalInstant();
this.ownerId = in.readOptionalString();
this.checkpointId = in.readOptionalString();
}

@Override
Expand Down Expand Up @@ -177,6 +184,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(ownerId);
out.writeOptionalString(checkpointId);
}

@Override
Expand Down Expand Up @@ -225,6 +233,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (ownerId != null) {
builder.field(OWNER_ID_FIELD, ownerId);
}
if (checkpointId != null) {
builder.field(CHECKPOINT_ID_FIELD, checkpointId);
}
builder.endObject();
return builder;
}
Expand All @@ -243,6 +254,7 @@ public static MLWorkingMemory parse(XContentParser parser) throws IOException {
Instant createdTime = null;
Instant lastUpdateTime = null;
String ownerId = null;
String checkpointId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -293,6 +305,9 @@ public static MLWorkingMemory parse(XContentParser parser) throws IOException {
case OWNER_ID_FIELD:
ownerId = parser.text();
break;
case CHECKPOINT_ID_FIELD:
checkpointId = parser.text();
break;
default:
parser.skipChildren();
break;
Expand All @@ -314,6 +329,7 @@ public static MLWorkingMemory parse(XContentParser parser) throws IOException {
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.ownerId(ownerId)
.checkpointId(checkpointId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public class MemoryContainerConstants {
public static final String TEXT_FIELD = "text";
public static final String UPDATE_CONTENT_FIELD = "update_content";

// Checkpoint field
public static final String CHECKPOINT_ID_FIELD = "checkpoint_id";

// KNN index settings
public static final String KNN_ENGINE = "lucene";
public static final String KNN_SPACE_TYPE = "cosinesimil";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.AGENT_ID_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.BINARY_DATA_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CHECKPOINT_ID_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CREATED_TIME_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INFER_FIELD;
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LAST_UPDATED_TIME_FIELD;
Expand Down Expand Up @@ -69,6 +70,9 @@ public class MLAddMemoriesInput implements ToXContentObject, Writeable {
private Map<String, Object> parameters;
private String ownerId;

// Checkpoint field
private String checkpointId;

public MLAddMemoriesInput(
String memoryContainerId,
PayloadType payloadType,
Expand All @@ -81,7 +85,8 @@ public MLAddMemoriesInput(
Map<String, String> metadata,
Map<String, String> tags,
Map<String, Object> parameters,
String ownerId
String ownerId,
String checkpointId
) {
// MAX_MESSAGES_PER_REQUEST limit removed for performance testing

Expand All @@ -100,6 +105,7 @@ public MLAddMemoriesInput(
this.parameters.putAll(parameters);
}
this.ownerId = ownerId;
this.checkpointId = checkpointId;
validate();
}

Expand Down Expand Up @@ -144,6 +150,7 @@ public MLAddMemoriesInput(StreamInput in) throws IOException {
this.parameters = in.readMap();
}
this.ownerId = in.readOptionalString();
this.checkpointId = in.readOptionalString();
}

@Override
Expand Down Expand Up @@ -193,6 +200,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(ownerId);
out.writeOptionalString(checkpointId);
}

@Override
Expand Down Expand Up @@ -239,6 +247,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (ownerId != null) {
builder.field(OWNER_ID_FIELD, ownerId);
}
if (checkpointId != null) {
builder.field(CHECKPOINT_ID_FIELD, checkpointId);
}
if (withTimeStamp) {
Instant now = Instant.now();
builder.field(CREATED_TIME_FIELD, now.toEpochMilli());
Expand All @@ -260,6 +271,7 @@ public static MLAddMemoriesInput parse(XContentParser parser, String memoryConta
Map<String, String> tags = null;
Map<String, Object> parameters = null;
String ownerId = null;
String checkpointId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -307,6 +319,9 @@ public static MLAddMemoriesInput parse(XContentParser parser, String memoryConta
case OWNER_ID_FIELD:
ownerId = parser.text();
break;
case CHECKPOINT_ID_FIELD:
checkpointId = parser.text();
break;
default:
parser.skipChildren();
break;
Expand All @@ -327,6 +342,7 @@ public static MLAddMemoriesInput parse(XContentParser parser, String memoryConta
.tags(tags)
.parameters(parameters)
.ownerId(ownerId)
.checkpointId(checkpointId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
"message_id": {
"type": "integer"
},
"checkpoint_id": {
"type": "keyword"
},
"binary_data": {
"type": "binary"
},
Expand Down
Loading
Loading