Skip to content

Commit f784059

Browse files
authored
[Fix] [Transform-V2] Fix embedding output columns vector dimension (#9646)
1 parent b7f61bb commit f784059

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.seatunnel.transform.nlpmodel.embedding;
1919

20+
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
21+
2022
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
2123
import org.apache.seatunnel.api.table.catalog.CatalogTable;
2224
import org.apache.seatunnel.api.table.catalog.Column;
@@ -52,7 +54,7 @@ public class EmbeddingTransform extends MultipleFieldOutputTransform {
5254
private final ReadonlyConfig config;
5355
private List<String> fieldNames;
5456
private List<Integer> fieldOriginalIndexes;
55-
private Model model;
57+
private transient Model model;
5658
private Integer dimension;
5759

5860
public EmbeddingTransform(
@@ -212,7 +214,9 @@ protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
212214
}
213215

214216
@Override
215-
protected Column[] getOutputColumns() {
217+
@VisibleForTesting
218+
public Column[] getOutputColumns() {
219+
tryOpen();
216220
Column[] columns = new Column[fieldNames.size()];
217221
for (int i = 0; i < fieldNames.size(); i++) {
218222
columns[i] =
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.seatunnel.transform.embedding;
19+
20+
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.JsonProcessingException;
21+
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
22+
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
23+
24+
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
25+
import org.apache.seatunnel.api.table.catalog.CatalogTable;
26+
import org.apache.seatunnel.api.table.catalog.CatalogTableUtil;
27+
import org.apache.seatunnel.api.table.catalog.Column;
28+
import org.apache.seatunnel.transform.nlpmodel.embedding.EmbeddingTransform;
29+
30+
import org.junit.jupiter.api.Assertions;
31+
import org.junit.jupiter.api.Test;
32+
33+
import java.util.Map;
34+
35+
public class EmbeddingTransformTest {
36+
37+
@Test
38+
void testOutputColumns() throws JsonProcessingException {
39+
ObjectMapper objectMapper = new ObjectMapper();
40+
41+
String sourceConfig =
42+
"{\"path\":\"/seatunnel/test_csv_data.csv\",\"bucket\":\"s3a://ltchen\",\"fs.s3a.endpoint\":\"tos-s3-cn-beijing.volces.com\",\"fs.s3a.aws.credentials.provider\":\"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider\",\"file_format_type\":\"csv\",\"access_key\":\"xxx\",\"secret_key\":\"xxx\",\"csv_use_header_line\":true,\"field_delimiter\":\",\",\"schema\":{\"fields\":{\"id\":\"int\",\"code\":\"int\",\"data\":\"string\",\"success\":\"boolean\"},\"primaryKey\":{\"name\":\"id\",\"columnNames\":[\"id\"]}},\"plugin_name\":\"S3File\"}";
43+
Map<String, Object> sourceConfigMap =
44+
objectMapper.readValue(sourceConfig, new TypeReference<Map<String, Object>>() {});
45+
ReadonlyConfig readonlyConfig = ReadonlyConfig.fromMap(sourceConfigMap);
46+
CatalogTable inputCatalogTable = CatalogTableUtil.buildWithConfig("S3File", readonlyConfig);
47+
48+
int dimension = 1024;
49+
String embeddingConfig =
50+
"{\"model_provider\":\"AMAZON\",\"model\":\"amazon.titan-embed-text-v2:0\",\"aws_region\": \"us-east-1\", \"api_key\":\"xxx\",\"secret_key\":\"xxx\",\"api_path\": \"https://aws.amazon.com/bedrock/amazon-models\", \"dimension\": "
51+
+ dimension
52+
+ ",\"vectorization_fields\":{\"data_vector\":\"data\"},\"plugin_name\":\"Embedding\"}";
53+
Map<String, Object> embeddingConfigMap =
54+
objectMapper.readValue(
55+
embeddingConfig, new TypeReference<Map<String, Object>>() {});
56+
ReadonlyConfig config = ReadonlyConfig.fromMap(embeddingConfigMap);
57+
EmbeddingTransform embeddingTransform = new EmbeddingTransform(config, inputCatalogTable);
58+
59+
Column[] columns = embeddingTransform.getOutputColumns();
60+
for (Column column : columns) {
61+
Assertions.assertEquals(dimension, column.getScale());
62+
}
63+
}
64+
}

0 commit comments

Comments
 (0)