Skip to content

Commit 683f9a9

Browse files
authored
feat(ai_agent): introduce support for RAG embeddings (#228)
* updatd junits with content reteriver * feat(rag-aiAgent) Add support for RAG embeddings to AI Agent
1 parent 1e0249d commit 683f9a9

File tree

6 files changed

+544
-4
lines changed

6 files changed

+544
-4
lines changed

src/main/java/io/kestra/plugin/ai/AIUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ private AIUtils() {
2323
// utility class pattern
2424
}
2525

26-
public static Map<ToolSpecification, ToolExecutor> buildTools(RunContext runContext, Map<String, Object> additionalVariables, List<ToolProvider> toolProviders) throws IllegalVariableEvaluationException {
26+
public static Map<ToolSpecification, ToolExecutor> buildTools(RunContext runContext, Map<String, Object> additionalVariables, List<ToolProvider> toolProviders) throws Exception {
2727
if (toolProviders.isEmpty()) {
2828
return Collections.emptyMap();
2929
}

src/main/java/io/kestra/plugin/ai/domain/ContentRetrieverProvider.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import lombok.NoArgsConstructor;
1212
import lombok.experimental.SuperBuilder;
1313

14+
import java.io.IOException;
15+
1416
@Plugin
1517
@SuperBuilder(toBuilder = true)
1618
@Getter
@@ -19,5 +21,5 @@
1921
// AND concrete subclasses must be annotated by @JsonDeserialize() to avoid StackOverflow.
2022
@JsonDeserialize(using = PluginDeserializer.class)
2123
public abstract class ContentRetrieverProvider extends AdditionalPlugin {
22-
public abstract ContentRetriever contentRetriever(RunContext runContext) throws IllegalVariableEvaluationException;
24+
public abstract ContentRetriever contentRetriever(RunContext runContext) throws IllegalVariableEvaluationException, IOException;
2325
}

src/main/java/io/kestra/plugin/ai/domain/ToolProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
// AND concrete subclasses must be annotated by @JsonDeserialize() to avoid StackOverflow.
3030
@JsonDeserialize(using = PluginDeserializer.class)
3131
public abstract class ToolProvider extends AdditionalPlugin {
32-
public abstract Map<ToolSpecification, ToolExecutor> tool(RunContext runContext, Map<String, Object> additionalVariables) throws IllegalVariableEvaluationException;
32+
public abstract Map<ToolSpecification, ToolExecutor> tool(RunContext runContext, Map<String, Object> additionalVariables) throws Exception;
3333

3434
public void close(RunContext runContext) {
3535
// by default: no-op
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package io.kestra.plugin.ai.retriever;
2+
3+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
4+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
5+
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
6+
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
7+
import io.kestra.core.models.annotations.Example;
8+
import io.kestra.core.models.annotations.Plugin;
9+
import io.kestra.core.models.annotations.PluginProperty;
10+
import io.kestra.core.models.property.Property;
11+
import io.kestra.core.runners.RunContext;
12+
import io.kestra.plugin.ai.domain.ContentRetrieverProvider;
13+
import io.kestra.plugin.ai.domain.EmbeddingStoreProvider;
14+
import io.kestra.plugin.ai.domain.ModelProvider;
15+
import io.swagger.v3.oas.annotations.media.Schema;
16+
import jakarta.validation.constraints.NotNull;
17+
import lombok.AllArgsConstructor;
18+
import lombok.Builder;
19+
import lombok.Getter;
20+
import lombok.NoArgsConstructor;
21+
import lombok.experimental.SuperBuilder;
22+
23+
import java.io.IOException;
24+
25+
@Getter
26+
@SuperBuilder
27+
@NoArgsConstructor
28+
@AllArgsConstructor
29+
@JsonDeserialize
30+
@Schema(
31+
title = "Embedding store content retriever for RAG (Retrieval Augmented Generation)",
32+
description = "Retrieves relevant content from an embedding store based on semantic similarity to the query."
33+
)
34+
@Plugin(
35+
examples = {
36+
@Example(
37+
title = "Use RAG with AIAgent using an embedding store content retriever. This example ingests documents into a KV embedding store and then uses an AI agent with the EmbeddingStoreRetriever to answer questions grounded in the ingested data.",
38+
code = """
39+
id: agent_with_rag
40+
namespace: company.ai
41+
42+
tasks:
43+
- id: ingest
44+
type: io.kestra.plugin.ai.rag.IngestDocument
45+
provider:
46+
type: io.kestra.plugin.ai.provider.GoogleGemini
47+
modelName: gemini-embedding-exp-03-07
48+
googleApiKey: "{{ kv('GEMINI_API_KEY') }}"
49+
embeddings:
50+
type: io.kestra.plugin.ai.embeddings.KestraKVStore
51+
drop: true
52+
fromDocuments:
53+
- content: Paris is the capital of France with a population of over 2.1 million people
54+
- content: The Eiffel Tower is the most famous landmark in Paris at 330 meters tall
55+
56+
- id: agent
57+
type: io.kestra.plugin.ai.agent.AIAgent
58+
provider:
59+
type: io.kestra.plugin.ai.provider.GoogleGemini
60+
modelName: gemini-2.0-flash
61+
googleApiKey: "{{ kv('GEMINI_API_KEY') }}"
62+
contentRetrievers:
63+
- type: io.kestra.plugin.ai.retriever.EmbeddingStoreRetriever
64+
embeddings:
65+
type: io.kestra.plugin.ai.embeddings.KestraKVStore
66+
embeddingProvider:
67+
type: io.kestra.plugin.ai.provider.GoogleGemini
68+
modelName: gemini-embedding-exp-03-07
69+
googleApiKey: "{{ kv('GEMINI_API_KEY') }}"
70+
maxResults: 3
71+
minScore: 0.0
72+
prompt: What is the capital of France and how many people live there?
73+
"""
74+
),
75+
@Example(
76+
title = "Use multiple embedding stores simultaneously. This demonstrates the power of the content retriever approach - you can retrieve from multiple embedding stores and other sources in a single task.",
77+
code = """
78+
id: multi_store_rag
79+
namespace: company.ai
80+
81+
tasks:
82+
- id: agent
83+
type: io.kestra.plugin.ai.agent.AIAgent
84+
provider:
85+
type: io.kestra.plugin.ai.provider.GoogleGemini
86+
modelName: gemini-2.0-flash
87+
googleApiKey: "{{ kv('GEMINI_API_KEY') }}"
88+
contentRetrievers:
89+
- type: io.kestra.plugin.ai.retriever.EmbeddingStoreRetriever
90+
embeddings:
91+
type: io.kestra.plugin.ai.embeddings.Pinecone
92+
pineconeApiKey: "{{ kv('PINECONE_API_KEY') }}"
93+
index: technical-docs
94+
embeddingProvider:
95+
type: io.kestra.plugin.ai.provider.OpenAI
96+
googleApiKey: "{{ kv('OPENAI_API_KEY') }}"
97+
modelName: text-embedding-3-small
98+
- type: io.kestra.plugin.ai.retriever.EmbeddingStoreRetriever
99+
embeddings:
100+
type: io.kestra.plugin.ai.embeddings.Qdrant
101+
host: localhost
102+
port: 6333
103+
collectionName: business-docs
104+
embeddingProvider:
105+
type: io.kestra.plugin.ai.provider.GoogleGemini
106+
modelName: gemini-embedding-exp-03-07
107+
googleApiKey: "{{ kv('GEMINI_API_KEY') }}"
108+
- type: io.kestra.plugin.ai.retriever.TavilyWebSearch
109+
tavilyApiKey: "{{ kv('TAVILY_API_KEY') }}"
110+
prompt: What are the latest trends in data orchestration?
111+
"""
112+
)
113+
}
114+
)
115+
public class EmbeddingStoreRetriever extends ContentRetrieverProvider {
116+
117+
@Schema(
118+
title = "Embedding store",
119+
description = "The embedding store to retrieve relevant content from"
120+
)
121+
@NotNull
122+
@PluginProperty
123+
private EmbeddingStoreProvider embeddings;
124+
125+
@Schema(
126+
title = "Embedding model provider",
127+
description = "Provider used to generate embeddings for the query. Must support embedding generation."
128+
)
129+
@NotNull
130+
@PluginProperty
131+
private ModelProvider embeddingProvider;
132+
133+
@Schema(title = "Maximum number of results to return from the embedding store")
134+
@NotNull
135+
@Builder.Default
136+
private Property<Integer> maxResults = Property.ofValue(3);
137+
138+
@Schema(
139+
title = "Minimum similarity score",
140+
description = "Only results with a similarity score ≥ minScore are returned. Range: 0.0 to 1.0 inclusive."
141+
)
142+
@NotNull
143+
@Builder.Default
144+
private Property<Double> minScore = Property.ofValue(0.0);
145+
146+
@Override
147+
public ContentRetriever contentRetriever(RunContext runContext) throws IllegalVariableEvaluationException, IOException {
148+
var embeddingModel = embeddingProvider.embeddingModel(runContext);
149+
150+
return EmbeddingStoreContentRetriever.builder()
151+
.embeddingModel(embeddingModel)
152+
.embeddingStore(embeddings.embeddingStore(runContext, embeddingModel.dimension(), false))
153+
.maxResults(runContext.render(this.maxResults).as(Integer.class).orElse(3))
154+
.minScore(runContext.render(this.minScore).as(Double.class).orElse(0.0))
155+
.build();
156+
}
157+
}

src/main/java/io/kestra/plugin/ai/tool/AIAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public class AIAgent extends ToolProvider {
131131
private transient List<ToolProvider> toolProviders;
132132

133133
@Override
134-
public Map<ToolSpecification, ToolExecutor> tool(RunContext runContext, Map<String, Object> additionalVariables) throws IllegalVariableEvaluationException {
134+
public Map<ToolSpecification, ToolExecutor> tool(RunContext runContext, Map<String, Object> additionalVariables) throws Exception {
135135
toolProviders = ListUtils.emptyOnNull(tools);
136136

137137
AiServices<AgentTool> agent = AiServices.builder(AgentTool.class)

0 commit comments

Comments
 (0)