Skip to content

Commit 93e7a7b

Browse files
authored
feat(rag-aiAgent) Add support for RAG embeddings to AI Agent (#227)
1 parent ea44492 commit 93e7a7b

File tree

2 files changed

+199
-9
lines changed

2 files changed

+199
-9
lines changed

src/main/java/io/kestra/plugin/ai/agent/AIAgent.java

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import dev.langchain4j.exception.ToolArgumentsException;
55
import dev.langchain4j.exception.ToolExecutionException;
66
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
7+
import dev.langchain4j.rag.RetrievalAugmentor;
78
import dev.langchain4j.rag.content.retriever.ContentRetriever;
9+
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
810
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
911
import dev.langchain4j.rag.query.router.QueryRouter;
1012
import dev.langchain4j.service.AiServices;
@@ -25,6 +27,7 @@
2527
import io.kestra.plugin.ai.AIUtils;
2628
import io.kestra.plugin.ai.domain.*;
2729
import io.kestra.plugin.ai.provider.TimingChatModelListener;
30+
import io.kestra.plugin.ai.rag.ChatCompletion;
2831
import io.swagger.v3.oas.annotations.media.Schema;
2932
import jakarta.validation.constraints.NotNull;
3033
import lombok.*;
@@ -35,6 +38,8 @@
3538
import java.util.HashMap;
3639
import java.util.List;
3740
import java.util.Map;
41+
import java.util.Optional;
42+
import java.util.stream.Collectors;
3843

3944
import static io.kestra.core.utils.Rethrow.throwFunction;
4045

@@ -646,6 +651,26 @@ public class AIAgent extends Task implements RunnableTask<AIOutput>, OutputFiles
646651
@Builder.Default
647652
private ChatConfiguration configuration = ChatConfiguration.empty();
648653

654+
@Schema(
655+
title = "Embedding store",
656+
description = "Optional when at least one entry is provided in `contentRetrievers`."
657+
)
658+
@PluginProperty
659+
private EmbeddingStoreProvider embeddings;
660+
661+
@Schema(
662+
title = "Embedding model provider",
663+
description = "Optional. If not set, the embedding model is created from `chatProvider`. Ensure the chosen chat provider supports embeddings."
664+
)
665+
@PluginProperty
666+
private ModelProvider embeddingProvider;
667+
668+
@Schema(title = "Content retriever configuration")
669+
@NotNull
670+
@PluginProperty
671+
@Builder.Default
672+
private ChatCompletion.ContentRetrieverConfiguration contentRetrieverConfiguration = ChatCompletion.ContentRetrieverConfiguration.builder().build();
673+
649674
@Schema(title = "Tools that the LLM may use to augment its response")
650675
private List<ToolProvider> tools;
651676

@@ -689,20 +714,20 @@ public AIOutput run(RunContext runContext) throws Exception {
689714
throw new ToolExecutionException(error);
690715
});
691716

717+
// Attach retrieval augmentor (either explicit, embedding-based, or tools-based)
718+
attachRetrievalAugmentor(agent, runContext);
719+
692720
if (memory != null) {
693721
agent.chatMemory(memory.chatMemory(runContext));
694722
}
695723

696-
List<ContentRetriever> toolContentRetrievers = runContext.render(contentRetrievers).asList(ContentRetrieverProvider.class).stream()
697-
.map(throwFunction(provider -> provider.contentRetriever(runContext)))
698-
.toList();
699-
if (!toolContentRetrievers.isEmpty()) {
700-
QueryRouter queryRouter = new DefaultQueryRouter(toolContentRetrievers.toArray(new ContentRetriever[0]));
701-
702-
// Create a query router that will route each query to the content retrievers
724+
List<ContentRetriever> retrievers = buildToolRetrievers(runContext);
725+
if (!retrievers.isEmpty()) {
726+
QueryRouter router = new DefaultQueryRouter(retrievers.toArray(new ContentRetriever[0]));
703727
agent.retrievalAugmentor(DefaultRetrievalAugmentor.builder()
704-
.queryRouter(queryRouter)
705-
.build());
728+
.queryRouter(router)
729+
.build()
730+
);
706731
}
707732

708733
Result<AiMessage> completion = agent.build().invoke(rPrompt);
@@ -735,6 +760,55 @@ private Map<String, URI> gatherOutputFiles(RunContext runContext) throws Excepti
735760
return outputFiles;
736761
}
737762

763+
private void attachRetrievalAugmentor(AiServices<Agent> agent, RunContext runContext) throws Exception {
764+
765+
Optional<ContentRetriever> embeddingRetriever = buildEmbeddingRetriever(runContext);
766+
List<ContentRetriever> toolRetrievers = buildToolRetrievers(runContext);
767+
// No retrievers at all nothing to attach
768+
if (toolRetrievers.isEmpty() && embeddingRetriever.isEmpty()) {
769+
return;
770+
}
771+
// Case 1: Only embedding retriever
772+
if (toolRetrievers.isEmpty()) {
773+
agent.retrievalAugmentor(DefaultRetrievalAugmentor.builder()
774+
.contentRetriever(embeddingRetriever.orElseThrow())
775+
.build()
776+
);
777+
return;
778+
}
779+
// Case 2: Tools exist add embedding retriever first if present
780+
embeddingRetriever.ifPresent(toolRetrievers::addFirst);
781+
QueryRouter router = new DefaultQueryRouter(toolRetrievers.toArray(new ContentRetriever[0]));
782+
agent.retrievalAugmentor(
783+
DefaultRetrievalAugmentor.builder()
784+
.queryRouter(router)
785+
.build()
786+
);
787+
}
788+
789+
private List<ContentRetriever> buildToolRetrievers(RunContext runContext) throws Exception {
790+
return runContext.render(contentRetrievers)
791+
.asList(ContentRetrieverProvider.class)
792+
.stream()
793+
.map(throwFunction(provider -> provider.contentRetriever(runContext)))
794+
.toList();
795+
}
796+
797+
private Optional<ContentRetriever> buildEmbeddingRetriever(final RunContext runContext) throws Exception {
798+
if (embeddings == null) return Optional.empty();
799+
var model = Optional.ofNullable(embeddingProvider).orElse(provider).embeddingModel(runContext);
800+
return Optional.of(
801+
EmbeddingStoreContentRetriever.builder()
802+
.embeddingModel(model)
803+
.embeddingStore(
804+
embeddings.embeddingStore(runContext, model.dimension(), false)
805+
)
806+
.maxResults(contentRetrieverConfiguration.getMaxResults())
807+
.minScore(contentRetrieverConfiguration.getMinScore())
808+
.build()
809+
);
810+
}
811+
738812
interface Agent {
739813
Result<AiMessage> invoke(String userMessage);
740814
}

src/test/java/io/kestra/plugin/ai/agent/AIAgentTest.java

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
import io.kestra.core.utils.IdUtils;
1010
import io.kestra.plugin.ai.domain.ChatConfiguration;
1111
import io.kestra.plugin.ai.memory.KestraKVStore;
12+
import io.kestra.plugin.ai.provider.GoogleGemini;
1213
import io.kestra.plugin.ai.provider.OpenAI;
14+
import io.kestra.plugin.ai.rag.IngestDocument;
1315
import io.kestra.plugin.ai.tool.DockerMcpClient;
1416
import io.kestra.plugin.ai.tool.StdioMcpClient;
1517
import jakarta.inject.Inject;
1618
import org.junit.jupiter.api.Test;
19+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
1720

1821
import java.util.List;
1922
import java.util.Map;
@@ -22,6 +25,7 @@
2225

2326
@KestraTest
2427
class AIAgentTest {
28+
private final String GOOGLE_API_KEY = System.getenv("GOOGLE_API_KEY");
2529
@Inject
2630
private TestRunContextFactory runContextFactory;
2731

@@ -170,4 +174,116 @@ void withOutputFiles() throws Exception {
170174
assertThat(content).isEqualTo("Hello World");
171175
}
172176
}
177+
178+
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
179+
@Test
180+
void withEmbeddingRetriever() throws Exception {
181+
RunContext runContext = runContextFactory.of("namespace", Map.of(
182+
"modelName", "gemini-2.0-flash",
183+
"apiKey", GOOGLE_API_KEY
184+
));
185+
// First ingest some documents
186+
var ingest = IngestDocument.builder()
187+
.provider(
188+
GoogleGemini.builder()
189+
.type(GoogleGemini.class.getName())
190+
.modelName(Property.ofValue("gemini-embedding-exp-03-07"))
191+
.apiKey(Property.ofExpression("{{ apiKey }}"))
192+
.build()
193+
)
194+
.embeddings(io.kestra.plugin.ai.embeddings.KestraKVStore.builder().build())
195+
.fromDocuments(List.of(
196+
IngestDocument.InlineDocument.builder()
197+
.content(Property.ofValue("Paris is the capital of France with a population of over 2.1 million people"))
198+
.build(),
199+
IngestDocument.InlineDocument.builder()
200+
.content(Property.ofValue("The Eiffel Tower is the most famous landmark in Paris at 330 meters tall"))
201+
.build()
202+
))
203+
.build();
204+
205+
IngestDocument.Output ingestOutput = ingest.run(runContext);
206+
assertThat(ingestOutput.getIngestedDocuments()).isEqualTo(2);
207+
208+
var agent = AIAgent.builder()
209+
.provider(
210+
GoogleGemini.builder()
211+
.type(GoogleGemini.class.getName())
212+
.modelName(Property.ofExpression("{{ modelName }}"))
213+
.apiKey(Property.ofExpression("{{ apiKey }}"))
214+
.build()
215+
)
216+
.embeddingProvider(
217+
GoogleGemini.builder()
218+
.type(GoogleGemini.class.getName())
219+
.modelName(Property.ofValue("gemini-embedding-exp-03-07"))
220+
.apiKey(Property.ofExpression("{{ apiKey }}"))
221+
.build()
222+
).embeddings(io.kestra.plugin.ai.embeddings.KestraKVStore.builder().build())
223+
.prompt(Property.ofValue("What is the capital of France and how many people live there?"))
224+
.configuration(ChatConfiguration.builder().temperature(Property.ofValue(0.1)).seed(Property.ofValue(123456789)).build())
225+
.build();
226+
227+
var output = agent.run(runContext);
228+
assertThat(output.getTextOutput()).isNotNull();
229+
assertThat(output.getTextOutput()).contains("Paris");
230+
}
231+
232+
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
233+
@Test
234+
void withEmbeddingRetriever_andWithTool() throws Exception {
235+
RunContext runContext = runContextFactory.of("namespace", Map.of(
236+
"modelName", "gemini-2.0-flash",
237+
"apiKey", GOOGLE_API_KEY
238+
));
239+
// First ingest some documents
240+
var ingest = IngestDocument.builder()
241+
.provider(
242+
GoogleGemini.builder()
243+
.type(GoogleGemini.class.getName())
244+
.modelName(Property.ofValue("gemini-embedding-exp-03-07"))
245+
.apiKey(Property.ofExpression("{{ apiKey }}"))
246+
.build()
247+
)
248+
.embeddings(io.kestra.plugin.ai.embeddings.KestraKVStore.builder().build())
249+
.fromDocuments(List.of(
250+
IngestDocument.InlineDocument.builder()
251+
.content(Property.ofValue("Paris is the capital of France with a population of over 2.1 million people"))
252+
.build(),
253+
IngestDocument.InlineDocument.builder()
254+
.content(Property.ofValue("The Eiffel Tower is the most famous landmark in Paris at 330 meters tall"))
255+
.build()
256+
))
257+
.build();
258+
259+
IngestDocument.Output ingestOutput = ingest.run(runContext);
260+
assertThat(ingestOutput.getIngestedDocuments()).isEqualTo(2);
261+
262+
var agent = AIAgent.builder()
263+
.provider(
264+
GoogleGemini.builder()
265+
.type(GoogleGemini.class.getName())
266+
.modelName(Property.ofExpression("{{ modelName }}"))
267+
.apiKey(Property.ofExpression("{{ apiKey }}"))
268+
.build()
269+
)
270+
.embeddingProvider(
271+
GoogleGemini.builder()
272+
.type(GoogleGemini.class.getName())
273+
.modelName(Property.ofValue("gemini-embedding-exp-03-07"))
274+
.apiKey(Property.ofExpression("{{ apiKey }}"))
275+
.build()
276+
).embeddings(io.kestra.plugin.ai.embeddings.KestraKVStore.builder().build())
277+
.tools(
278+
List.of(StdioMcpClient.builder().command(Property.ofValue(List.of("docker", "run", "--rm", "-i", "mcp/everything"))).build())
279+
)
280+
.prompt(Property.ofValue("What is the capital of France and how many people live there?"))
281+
.configuration(ChatConfiguration.builder().temperature(Property.ofValue(0.1)).seed(Property.ofValue(123456789)).build())
282+
.build();
283+
284+
var output = agent.run(runContext);
285+
assertThat(output.getTextOutput()).isNotNull();
286+
assertThat(output.getTextOutput()).contains("Paris");
287+
}
288+
173289
}

0 commit comments

Comments
 (0)