Skip to content

Commit 842f934

Browse files
feat(73): added support for prompts (#87)
* feat(73): added support for prompts * refactors * make enable & disable backward-compatible * register: display any prompts available * prompt should be mindful of sse or streamable http proxy based on server transport * Add telemetry for prompts * fix command orders & tests * enhance comments & doc --------- Co-authored-by: Raghav <[email protected]>
1 parent 0e86097 commit 842f934

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2128
-134
lines changed

client/admin_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestInitServer(t *testing.T) {
1818

1919
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2020
// Verify request method and path
21-
if r.Method != "POST" {
21+
if r.Method != http.MethodPost {
2222
t.Errorf("Expected POST method, got %s", r.Method)
2323
}
2424
if r.URL.Path != "/init" {

client/client_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ func TestNewRequest(t *testing.T) {
8585

8686
t.Run("request with access token", func(t *testing.T) {
8787
body := strings.NewReader("test body")
88-
req, err := client.newRequest("POST", "https://api.example.com/test", body)
88+
req, err := client.newRequest(http.MethodPost, "https://api.example.com/test", body)
8989
if err != nil {
9090
t.Fatalf("Unexpected error: %v", err)
9191
}
9292

93-
if req.Method != "POST" {
93+
if req.Method != http.MethodPost {
9494
t.Errorf("Expected method POST, got %s", req.Method)
9595
}
9696

@@ -106,7 +106,7 @@ func TestNewRequest(t *testing.T) {
106106

107107
t.Run("request without access token", func(t *testing.T) {
108108
clientNoToken := NewClient("https://api.example.com", "", &http.Client{})
109-
req, err := clientNoToken.newRequest("GET", "https://api.example.com/test", nil)
109+
req, err := clientNoToken.newRequest(http.MethodGet, "https://api.example.com/test", nil)
110110
if err != nil {
111111
t.Fatalf("Unexpected error: %v", err)
112112
}
@@ -118,7 +118,7 @@ func TestNewRequest(t *testing.T) {
118118
})
119119

120120
t.Run("request with nil body", func(t *testing.T) {
121-
req, err := client.newRequest("GET", "https://api.example.com/test", nil)
121+
req, err := client.newRequest(http.MethodGet, "https://api.example.com/test", nil)
122122
if err != nil {
123123
t.Fatalf("Unexpected error: %v", err)
124124
}
@@ -135,7 +135,7 @@ func TestNewRequestWithInvalidURL(t *testing.T) {
135135
client := NewClient("https://api.example.com", "token", &http.Client{})
136136

137137
// Test with invalid URL
138-
req, err := client.newRequest("GET", "://invalid-url", nil)
138+
req, err := client.newRequest(http.MethodGet, "://invalid-url", nil)
139139
if err == nil {
140140
t.Error("Expected error for invalid URL, got nil")
141141
}
@@ -150,7 +150,7 @@ func TestClientIntegration(t *testing.T) {
150150
// Test that client can make actual HTTP requests
151151
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152152
// Verify request method and headers
153-
if r.Method != "GET" {
153+
if r.Method != http.MethodGet {
154154
t.Errorf("Expected GET method, got %s", r.Method)
155155
}
156156

@@ -174,7 +174,7 @@ func TestClientIntegration(t *testing.T) {
174174
t.Fatalf("Failed to construct endpoint: %v", err)
175175
}
176176

177-
req, err := client.newRequest("GET", endpoint, nil)
177+
req, err := client.newRequest(http.MethodGet, endpoint, nil)
178178
if err != nil {
179179
t.Fatalf("Failed to create request: %v", err)
180180
}

client/mcp_clients_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestListMcpClients(t *testing.T) {
2929

3030
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3131
// Verify request method and path
32-
if r.Method != "GET" {
32+
if r.Method != http.MethodGet {
3333
t.Errorf("Expected GET method, got %s", r.Method)
3434
}
3535
if !strings.HasSuffix(r.URL.Path, "/clients") {
@@ -204,7 +204,7 @@ func TestCreateMcpClient(t *testing.T) {
204204

205205
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
206206
// Verify request method and path
207-
if r.Method != "POST" {
207+
if r.Method != http.MethodPost {
208208
t.Errorf("Expected POST method, got %s", r.Method)
209209
}
210210
if !strings.HasSuffix(r.URL.Path, "/clients") {

client/mcp_prompts.go

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
package client
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
11+
"github.com/mcpjungle/mcpjungle/internal/model"
12+
"github.com/mcpjungle/mcpjungle/pkg/types"
13+
)
14+
15+
// ListPrompts retrieves all prompts or prompts filtered by server name
16+
func (c *Client) ListPrompts(serverName string) ([]model.Prompt, error) {
17+
u, err := c.constructAPIEndpoint("/prompts")
18+
if err != nil {
19+
return nil, fmt.Errorf("failed to construct API endpoint: %w", err)
20+
}
21+
22+
// Add server filter if specified
23+
if serverName != "" {
24+
parsed, _ := url.Parse(u)
25+
q := parsed.Query()
26+
q.Set("server", serverName)
27+
parsed.RawQuery = q.Encode()
28+
u = parsed.String()
29+
}
30+
31+
req, err := c.newRequest(http.MethodGet, u, nil)
32+
if err != nil {
33+
return nil, fmt.Errorf("failed to create request: %w", err)
34+
}
35+
36+
resp, err := c.httpClient.Do(req)
37+
if err != nil {
38+
return nil, fmt.Errorf("failed to send request: %w", err)
39+
}
40+
defer resp.Body.Close()
41+
42+
if resp.StatusCode != http.StatusOK {
43+
return nil, c.parseErrorResponse(resp)
44+
}
45+
46+
var prompts []model.Prompt
47+
if err := json.NewDecoder(resp.Body).Decode(&prompts); err != nil {
48+
return nil, fmt.Errorf("failed to decode response: %w", err)
49+
}
50+
51+
return prompts, nil
52+
}
53+
54+
// GetPrompt retrieves a specific prompt by name
55+
func (c *Client) GetPrompt(name string) (*model.Prompt, error) {
56+
u, err := c.constructAPIEndpoint("/prompt")
57+
if err != nil {
58+
return nil, fmt.Errorf("failed to construct API endpoint: %w", err)
59+
}
60+
61+
// Add name as query parameter
62+
parsed, _ := url.Parse(u)
63+
q := parsed.Query()
64+
q.Set("name", name)
65+
parsed.RawQuery = q.Encode()
66+
u = parsed.String()
67+
68+
req, err := c.newRequest(http.MethodGet, u, nil)
69+
if err != nil {
70+
return nil, fmt.Errorf("failed to create request: %w", err)
71+
}
72+
73+
resp, err := c.httpClient.Do(req)
74+
if err != nil {
75+
return nil, fmt.Errorf("failed to send request: %w", err)
76+
}
77+
defer resp.Body.Close()
78+
79+
if resp.StatusCode != http.StatusOK {
80+
return nil, c.parseErrorResponse(resp)
81+
}
82+
83+
var prompt model.Prompt
84+
if err := json.NewDecoder(resp.Body).Decode(&prompt); err != nil {
85+
return nil, fmt.Errorf("failed to decode response: %w", err)
86+
}
87+
88+
return &prompt, nil
89+
}
90+
91+
// GetPromptWithArgs retrieves a prompt with arguments and returns the rendered template
92+
func (c *Client) GetPromptWithArgs(name string, arguments map[string]string) (*types.PromptResult, error) {
93+
u, err := c.constructAPIEndpoint("/prompts/render")
94+
if err != nil {
95+
return nil, fmt.Errorf("failed to construct API endpoint: %w", err)
96+
}
97+
98+
request := types.PromptGetRequest{
99+
Name: name,
100+
Arguments: arguments,
101+
}
102+
103+
body, err := json.Marshal(request)
104+
if err != nil {
105+
return nil, fmt.Errorf("failed to marshal request: %w", err)
106+
}
107+
108+
req, err := c.newRequest(http.MethodPost, u, bytes.NewBuffer(body))
109+
if err != nil {
110+
return nil, fmt.Errorf("failed to create request: %w", err)
111+
}
112+
req.Header.Set("Content-Type", "application/json")
113+
114+
resp, err := c.httpClient.Do(req)
115+
if err != nil {
116+
return nil, fmt.Errorf("failed to send request: %w", err)
117+
}
118+
defer resp.Body.Close()
119+
120+
if resp.StatusCode != http.StatusOK {
121+
return nil, c.parseErrorResponse(resp)
122+
}
123+
124+
var result types.PromptResult
125+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
126+
return nil, fmt.Errorf("failed to decode response: %w", err)
127+
}
128+
129+
return &result, nil
130+
}
131+
132+
// EnablePrompts enables one or more prompts
133+
func (c *Client) EnablePrompts(entity string) ([]string, error) {
134+
return c.setPromptsEnabled(entity, true)
135+
}
136+
137+
// DisablePrompts disables one or more prompts
138+
func (c *Client) DisablePrompts(entity string) ([]string, error) {
139+
return c.setPromptsEnabled(entity, false)
140+
}
141+
142+
// setPromptsEnabled is a helper function to enable or disable prompts
143+
func (c *Client) setPromptsEnabled(entity string, enabled bool) ([]string, error) {
144+
var endpoint string
145+
if enabled {
146+
endpoint = "/prompts/enable"
147+
} else {
148+
endpoint = "/prompts/disable"
149+
}
150+
151+
u, err := c.constructAPIEndpoint(endpoint)
152+
if err != nil {
153+
return nil, fmt.Errorf("failed to construct API endpoint: %w", err)
154+
}
155+
156+
// Add entity as query parameter
157+
parsed, _ := url.Parse(u)
158+
q := parsed.Query()
159+
q.Set("entity", entity)
160+
parsed.RawQuery = q.Encode()
161+
u = parsed.String()
162+
163+
req, err := c.newRequest(http.MethodPost, u, nil)
164+
if err != nil {
165+
return nil, fmt.Errorf("failed to create request: %w", err)
166+
}
167+
168+
resp, err := c.httpClient.Do(req)
169+
if err != nil {
170+
return nil, fmt.Errorf("failed to send request: %w", err)
171+
}
172+
defer resp.Body.Close()
173+
174+
if resp.StatusCode != http.StatusOK {
175+
body, _ := io.ReadAll(resp.Body)
176+
action := "enable"
177+
if !enabled {
178+
action = "disable"
179+
}
180+
return nil, fmt.Errorf("failed to %s prompts: status %d, message: %s", action, resp.StatusCode, body)
181+
}
182+
183+
var promptNames []string
184+
if err := json.NewDecoder(resp.Body).Decode(&promptNames); err != nil {
185+
return nil, fmt.Errorf("failed to decode response: %w", err)
186+
}
187+
188+
return promptNames, nil
189+
}

0 commit comments

Comments
 (0)