Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions docs/en/transforms/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,58 @@ vectorization_fields {
}
```

**Multi-field Mixing Multimodal Vectorization:**
> Note: Currently, only the `DOUBAO` provider supports multimodal data processing.
```hocon
vectorization_fields {
# Multi-field text
multi_field_text_vector = [product_name, description]

# Multi-field image
multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

# Multi-field video
multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

# Multi-field mix multimodal
multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}
```

**Field Specification Formats:**

**Supported Modality Types:**
Expand Down
52 changes: 52 additions & 0 deletions docs/zh/transforms/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,58 @@ vectorization_fields {
}
```

**多字段混合多模态向量化:**
> 注意: 目前,仅 `DOUBAO` 提供商支持多模态数据处理
```hocon
vectorization_fields {
# 多字段文本
multi_field_text_vector = [product_name, description]

# 多字段图片
multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

# 多字段视频
multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

# 多字段混合多模态
multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}
```

**字段规范格式:**

**支持的模态类型:**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,48 @@ transform {
}

product_name_vector = product_name

multi_field_text_vector = [product_name, description]

multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}

plugin_output = "multimodal_embedding_output"
Expand Down Expand Up @@ -219,6 +261,42 @@ sink {
}
]
},
{
field_name = multi_field_text_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_image_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_video_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_mix_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = category
field_type = string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Slf4j
public class EmbeddingTransform extends MultipleFieldOutputTransform {

private final ReadonlyConfig config;
private List<Integer> fieldOriginalIndexes;
private transient Model model;
private Integer dimension;
private boolean isMultimodalFields = false;
private Map<Integer, FieldSpec> fieldSpecMap;
private Map<VectorFieldSpec, List<Integer>> fieldSpecMap;
private List<String> fieldNames;

private final Map<String, TreeMap<Long, byte[]>> binaryFileCache = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -197,30 +197,35 @@ public void open() {
}

private void initOutputFields(SeaTunnelRowType inputRowType, ReadonlyConfig config) {
Map<Integer, FieldSpec> fieldSpecMap = new HashMap<>();
List<String> fieldNames = new ArrayList<>();
Map<String, Object> fieldsConfig =
config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS);
if (fieldsConfig == null || fieldsConfig.isEmpty()) {
throw new IllegalArgumentException("vectorization_fields configuration is required");
}

for (Map.Entry<String, Object> field : fieldsConfig.entrySet()) {
FieldSpec fieldSpec = new FieldSpec(field);
log.info("Field spec: {}", fieldSpec.toString());
String srcField = fieldSpec.getFieldName();
int srcFieldIndex;
try {
srcFieldIndex = inputRowType.indexOf(srcField);
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldError(getPluginName(), srcField);
}
if (fieldSpec.isMultimodalField()) {
isMultimodalFields = true;
List<String> fieldNames = new ArrayList<>();
Map<VectorFieldSpec, List<Integer>> fieldSpecMap = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldConfig : fieldsConfig.entrySet()) {
VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(fieldConfig);
log.info("Vector field spec: {}", vectorFieldSpec);
List<String> srcFieldNames =
vectorFieldSpec.getSrcFieldSpecs().stream()
.map(SrcFieldSpec::getFieldName)
.collect(Collectors.toList());
List<Integer> srcFieldIndexes = new ArrayList<>();
for (String srcFieldName : srcFieldNames) {
try {
srcFieldIndexes.add(inputRowType.indexOf(srcFieldName));
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldsError(
getPluginName(), srcFieldNames);
}
}
fieldSpecMap.put(srcFieldIndex, fieldSpec);
fieldNames.add(field.getKey());
fieldSpecMap.put(vectorFieldSpec, srcFieldIndexes);
fieldNames.add(vectorFieldSpec.getFieldName());
Comment on lines +208 to +225
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isMultimodalFields = vectorFieldSpec.isMultimodalField();

There is a logical issue; currently, only the last value can be obtained.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your time, let me fix it

}
this.isMultimodalFields =
fieldSpecMap.keySet().stream().anyMatch(VectorFieldSpec::isMultimodalField);
this.fieldSpecMap = fieldSpecMap;
this.fieldNames = fieldNames;
}
Expand All @@ -232,19 +237,28 @@ protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
if (MetadataUtil.isBinaryFormat(inputRow)) {
return vectorizationBinaryRow(inputRow);
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];
List<ByteBuffer> vectorization;

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
Object value = inputRow.getField(fieldOriginalIndex);
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<Integer> srcFieldIndexes = fieldSpecMap.get(vectorFieldSpec);
List<SrcField> srcFields = new ArrayList<>();
for (int j = 0; j < srcFieldSpecs.size(); j++) {
srcFields.add(
new SrcField(
srcFieldSpecs.get(j),
inputRow.getField(srcFieldIndexes.get(j))));
}
fieldValues[i++] =
isMultimodalFields ? new MultimodalFieldValue(fieldSpec, value) : value;
isMultimodalFields
? new MultimodalFieldValue(srcFields)
: srcFields.get(0).getFieldValue();
}

vectorization = model.vectorization(fieldValues);
List<ByteBuffer> vectorization = model.vectorization(fieldValues);
return vectorization.toArray();
} catch (Exception e) {
throw new RuntimeException("Failed to data vectorization", e);
Expand Down Expand Up @@ -282,32 +296,34 @@ public boolean isMultimodalFields() {

/** Process a row in binary format: [data, relativePath, partIndex] */
private Object[] vectorizationBinaryRow(SeaTunnelRowAccessor inputRow) throws Exception {

byte[] completeData = processBinaryRow(inputRow);
if (completeData == null) {
return null;
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
if (fieldSpec.isBinary()) {
fieldValues[i++] = new MultimodalFieldValue(fieldSpec, completeData);
} else {
log.warn(
"Non-binary field {} configured in binary format data",
fieldSpec.getFieldName());
fieldValues[i++] = null;
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<SrcField> srcFields = new ArrayList<>();
for (SrcFieldSpec srcFieldSpec : srcFieldSpecs) {
if (srcFieldSpec.isBinary()) {
srcFields.add(new SrcField(srcFieldSpec, completeData));
} else {
log.warn(
"Non-binary field {} configured in binary format data",
srcFieldSpec.getFieldName());
}
}
fieldValues[i++] = srcFields.isEmpty() ? null : new MultimodalFieldValue(srcFields);
}

try {
return model.vectorization(fieldValues).toArray();
} catch (Exception e) {
throw new RuntimeException(
"Failed to vectorize binary data for file: " + inputRow.toString(), e);
throw new RuntimeException("Failed to vectorize binary data for file: " + inputRow, e);
}
}

Expand Down
Loading
Loading