Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/en/transform-v2/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The `Embedding` transform plugin leverages embedding models to convert text data
transformation can be applied to various fields. The plugin supports multiple model providers and can be integrated with
different API endpoints.

> **Important Note:** The current embedding precision only supports float32 format.

## Options

| Name | Type | Required | Default Value | Description |
Expand All @@ -27,6 +29,13 @@ different API endpoints.
| custom_request_headers | map | no | | Custom headers for the request to the model. |
| custom_request_body | map | no | | Custom body for the request. Supports placeholders like `${model}`, `${input}`. |

## Precision Support

**Important:** The current version of the Embedding plugin only supports **float32** precision for vector data.

- All generated embedding vectors will be stored in float32 format
- If your model or API returns other precision formats (such as float64), the plugin will automatically convert them to float32

### model_provider

The providers for generating embeddings include common options such as `AMAZON`, `DOUBAO`, `QIANFAN`, and `OPENAI`. Additionally,
Expand Down
9 changes: 9 additions & 0 deletions docs/zh/transform-v2/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

`Embedding` 转换插件利用 embedding 模型将文本数据转换为向量化表示。此转换可以应用于各种字段。该插件支持多种模型提供商,并且可以与不同的API集成。

> **重要提示:** 当前 embedding 精确度仅支持 float32

## 配置选项

| 名称 | 类型 | 是否必填 | 默认值 | 描述 |
Expand All @@ -25,6 +27,13 @@
| custom_request_headers | map | 否 | | 发送到模型的请求的自定义头信息。 |
| custom_request_body | map | 否 | | 请求体的自定义配置。支持占位符如 `${model}`、`${input}`。 |

## 精度支持

**重要:** 当前版本的 Embedding 插件仅支持 **float32** 精度的向量数据。

- 所有生成的 embedding 向量将以 float32 格式存储
- 如果您的模型或API返回其他精度格式(如 float64),插件会自动转换为 float32

### embedding_model_provider

用于生成 embedding 的模型提供商。常见选项包括 `AMAZON`、 `DOUBAO`、`QIANFAN`、`OPENAI` 等,同时可选择 `CUSTOM` 实现自定义 embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,23 @@ protected AbstractModel(Integer singleVectorizedInputNumber) {
public List<ByteBuffer> vectorization(Object[] fields) throws IOException {
List<ByteBuffer> result = new ArrayList<>();

List<List<Double>> vectors = batchProcess(fields, singleVectorizedInputNumber);
for (List<Double> vector : vectors) {
result.add(BufferUtils.toByteBuffer(vector.toArray(new Double[0])));
List<List<Float>> vectors = batchProcess(fields, singleVectorizedInputNumber);
for (List<Float> vector : vectors) {
result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
}
return result;
}

protected abstract List<List<Double>> vector(Object[] fields) throws IOException;
protected abstract List<List<Float>> vector(Object[] fields) throws IOException;

public List<List<Double>> batchProcess(Object[] array, int batchSize) throws IOException {
List<List<Double>> merged = new ArrayList<>();
public List<List<Float>> batchProcess(Object[] array, int batchSize) throws IOException {
List<List<Float>> merged = new ArrayList<>();
if (array == null || array.length == 0) {
return merged;
}
for (int i = 0; i < array.length; i += batchSize) {
Object[] batch = ArrayUtils.subarray(array, i, i + batchSize);
List<List<Double>> vector = vector(batch);
List<List<Float>> vector = vector(batch);
merged.addAll(vector);
}
if (array.length != merged.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ public static BedrockRuntimeClient createBedrockClient(
}

@Override
protected List<List<Double>> vector(Object[] fields) throws IOException {
protected List<List<Float>> vector(Object[] fields) throws IOException {
if (fields == null || fields.length == 0) {
return new ArrayList<>();
}
Expand Down Expand Up @@ -247,26 +247,26 @@ public ObjectNode createRequestForBatchInput(Object[] inputs) {
return requestBody;
}

private List<List<Double>> parseSingleResponse(String responseBody) throws IOException {
private List<List<Float>> parseSingleResponse(String responseBody) throws IOException {
try {
JsonNode responseJson = OBJECT_MAPPER.readTree(responseBody);
List<List<Double>> result = new ArrayList<>();
List<List<Float>> result = new ArrayList<>();

if (modelId.startsWith("amazon.titan")) {
JsonNode embedding = responseJson.get("embedding");
if (embedding != null && embedding.isArray()) {
List<Double> vector = new ArrayList<>();
List<Float> vector = new ArrayList<>();
for (JsonNode value : embedding) {
vector.add(value.asDouble());
vector.add(value.floatValue());
}
result.add(vector);
}
} else if (modelId.startsWith("cohere.")) {
JsonNode embeddings = responseJson.get("embeddings");
if (embeddings != null && embeddings.isArray() && !embeddings.isEmpty()) {
List<Double> vector = new ArrayList<>();
List<Float> vector = new ArrayList<>();
for (JsonNode value : embeddings.get(0)) {
vector.add(value.asDouble());
vector.add(value.floatValue());
}
result.add(vector);
}
Expand All @@ -278,26 +278,26 @@ private List<List<Double>> parseSingleResponse(String responseBody) throws IOExc
}
}

private List<List<Double>> parseBatchResponse(String responseBody) throws IOException {
private List<List<Float>> parseBatchResponse(String responseBody) throws IOException {
try {
JsonNode responseJson = OBJECT_MAPPER.readTree(responseBody);
List<List<Double>> result = new ArrayList<>();
List<List<Float>> result = new ArrayList<>();
JsonNode embeddings = responseJson.get("embeddings");
if (embeddings != null && embeddings.isArray()) {
if (modelId.startsWith("amazon.titan")) {
for (JsonNode embedding : embeddings) {
List<Double> vector = new ArrayList<>();
List<Float> vector = new ArrayList<>();
for (JsonNode value : embedding) {
vector.add(value.asDouble());
vector.add(value.floatValue());
}
result.add(vector);
}

} else if (modelId.startsWith("cohere.")) {
for (JsonNode embedding : embeddings) {
List<Double> vector = new ArrayList<>();
List<Float> vector = new ArrayList<>();
for (JsonNode value : embedding) {
vector.add(value.asDouble());
vector.add(value.floatValue());
}
result.add(vector);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public CustomModel(
}

@Override
protected List<List<Double>> vector(Object[] fields) throws IOException {
protected List<List<Float>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}

Expand All @@ -76,7 +76,7 @@ public Integer dimension() throws IOException {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}

private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
HttpPost post = new HttpPost(apiPath);
// Construct a request with custom parameters
for (Map.Entry<String, String> entry : header.entrySet()) {
Expand All @@ -96,7 +96,7 @@ private List<List<Double>> vectorGeneration(Object[] fields) throws IOException
}

return OBJECT_MAPPER.convertValue(
parseResponse(responseStr), new TypeReference<List<List<Double>>>() {});
parseResponse(responseStr), new TypeReference<List<List<Float>>>() {});
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public DoubaoModel(String apiKey, String model, String apiPath, Integer vectoriz
}

@Override
protected List<List<Double>> vector(Object[] fields) throws IOException {
protected List<List<Float>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}

Expand All @@ -63,7 +63,7 @@ public Integer dimension() throws IOException {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}

private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
Expand All @@ -82,14 +82,14 @@ private List<List<Double>> vectorGeneration(Object[] fields) throws IOException
}

JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
List<List<Double>> embeddings = new ArrayList<>();
List<List<Float>> embeddings = new ArrayList<>();

if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
List<Double> embedding =
List<Float> embedding =
OBJECT_MAPPER.readValue(
embeddingNode.traverse(), new TypeReference<List<Double>>() {});
embeddingNode.traverse(), new TypeReference<List<Float>>() {});
embeddings.add(embedding);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public OpenAIModel(String apiKey, String model, String apiPath, Integer vectoriz
}

@Override
protected List<List<Double>> vector(Object[] fields) throws IOException {
protected List<List<Float>> vector(Object[] fields) throws IOException {
if (fields.length > 1) {
throw new IllegalArgumentException("OpenAI model only supports single input");
}
Expand All @@ -65,7 +65,7 @@ public Integer dimension() throws IOException {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}

private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
Expand All @@ -84,14 +84,14 @@ private List<List<Double>> vectorGeneration(Object[] fields) throws IOException
}

JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
List<List<Double>> embeddings = new ArrayList<>();
List<List<Float>> embeddings = new ArrayList<>();

if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
List<Double> embedding =
List<Float> embedding =
OBJECT_MAPPER.readValue(
embeddingNode.traverse(), new TypeReference<List<Double>>() {});
embeddingNode.traverse(), new TypeReference<List<Float>>() {});
embeddings.add(embedding);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private String getAccessToken() throws IOException {
}

@Override
public List<List<Double>> vector(Object[] fields) throws IOException {
public List<List<Float>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}

Expand All @@ -109,7 +109,7 @@ public Integer dimension() throws IOException {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).get(0).size();
}

private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
String formattedApiPath =
String.format(
(apiPath.endsWith("/") ? apiPath : apiPath + "/") + "%s?access_token=%s",
Expand Down Expand Up @@ -143,14 +143,14 @@ private List<List<Double>> vectorGeneration(Object[] fields) throws IOException
"Failed to get vector from qianfan, response: " + result.get("error_msg"));
}

List<List<Double>> embeddings = new ArrayList<>();
List<List<Float>> embeddings = new ArrayList<>();
JsonNode data = result.get("data");
if (data.isArray()) {
for (JsonNode node : data) {
List<Double> embedding =
List<Float> embedding =
OBJECT_MAPPER.readValue(
node.get("embedding").traverse(),
new TypeReference<List<Double>>() {});
new TypeReference<List<Float>>() {});
embeddings.add(embedding);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public ZhipuModel(
}

@Override
public List<List<Double>> vector(Object[] fields) throws IOException {
public List<List<Float>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}

Expand All @@ -75,7 +75,7 @@ public Integer dimension() throws IOException {
return dimension;
}

private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {

if (fields == null || fields.length > MAX_INPUT_SIZE) {
throw new IOException(
Expand All @@ -98,14 +98,14 @@ private List<List<Double>> vectorGeneration(Object[] fields) throws IOException
throw new IOException("Failed to get vector from zhipu, response: " + responseStr);
}
JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
List<List<Double>> embeddings = new ArrayList<>();
List<List<Float>> embeddings = new ArrayList<>();

if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
List<Double> embedding =
List<Float> embedding =
OBJECT_MAPPER.readValue(
embeddingNode.traverse(), new TypeReference<List<Double>>() {});
embeddingNode.traverse(), new TypeReference<List<Float>>() {});
embeddings.add(embedding);
}
}
Expand Down
Loading