Skip to content

Commit ccfff46

Browse files
committed
Add tests for the file streaming interface
1 parent fa7efb8 commit ccfff46

File tree

2 files changed

+103
-3
lines changed

2 files changed

+103
-3
lines changed

index.js

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,12 @@ class Replicate {
315315
* @yields {ServerSentEvent} Each streamed event from the prediction
316316
*/
317317
async *stream(ref, options) {
318-
const { wait, signal, useFileOutput = this.useFileOutput, ...data } = options;
318+
const {
319+
wait,
320+
signal,
321+
useFileOutput = this.useFileOutput,
322+
...data
323+
} = options;
319324

320325
const identifier = ModelVersionIdentifier.parse(ref);
321326

index.test.ts

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,8 +1906,12 @@ describe("Replicate client", () => {
19061906
// Continue with tests for other methods
19071907

19081908
describe("createReadableStream", () => {
1909-
function createStream(body: string | ReadableStream, status = 200) {
1910-
const streamEndpoint = "https://stream.replicate.com/fake_stream";
1909+
function createStream(
1910+
body: string | ReadableStream,
1911+
status = 200,
1912+
streamEndpoint = "https://stream.replicate.com/fake_stream",
1913+
options: { useFileOutput?: boolean } = {}
1914+
) {
19111915
const fetch = jest.fn((url) => {
19121916
if (url !== streamEndpoint) {
19131917
throw new Error(`Unmocked call to fetch() with url: ${url}`);
@@ -1917,6 +1921,7 @@ describe("Replicate client", () => {
19171921
return createReadableStream({
19181922
url: streamEndpoint,
19191923
fetch: fetch as any,
1924+
options,
19201925
});
19211926
}
19221927

@@ -2193,5 +2198,95 @@ describe("Replicate client", () => {
21932198
);
21942199
expect(await iterator.next()).toEqual({ done: true });
21952200
});
2201+
2202+
describe("file streams", () => {
2203+
test("emits FileOutput objects", async () => {
2204+
const stream = createStream(
2205+
`
2206+
event: output
2207+
id: EVENT_1
2208+
data: 
2209+
2210+
event: output
2211+
id: EVENT_2
2212+
data: https://delivery.replicate.com/my_file.png
2213+
2214+
event: done
2215+
id: EVENT_3
2216+
data: {}
2217+
2218+
`.replace(/^[ ]+/gm, ""),
2219+
200,
2220+
"https://stream.replicate.com/v1/files/abcd"
2221+
);
2222+
2223+
const iterator = stream[Symbol.asyncIterator]();
2224+
const { value: event1 } = await iterator.next();
2225+
expect(event1.data).toBeInstanceOf(ReadableStream);
2226+
expect(event1.data.url().href).toEqual(
2227+
""
2228+
);
2229+
2230+
const { value: event2 } = await iterator.next();
2231+
expect(event2.data).toBeInstanceOf(ReadableStream);
2232+
expect(event2.data.url().href).toEqual(
2233+
"https://delivery.replicate.com/my_file.png"
2234+
);
2235+
2236+
expect(await iterator.next()).toEqual({
2237+
done: false,
2238+
value: { event: "done", id: "EVENT_3", data: "{}" },
2239+
});
2240+
2241+
expect(await iterator.next()).toEqual({ done: true });
2242+
});
2243+
2244+
test("emits strings when useFileOutput is false", async () => {
2245+
const stream = createStream(
2246+
`
2247+
event: output
2248+
id: EVENT_1
2249+
data: 
2250+
2251+
event: output
2252+
id: EVENT_2
2253+
data: https://delivery.replicate.com/my_file.png
2254+
2255+
event: done
2256+
id: EVENT_3
2257+
data: {}
2258+
2259+
`.replace(/^[ ]+/gm, ""),
2260+
200,
2261+
"https://stream.replicate.com/v1/files/abcd",
2262+
{ useFileOutput: false }
2263+
);
2264+
2265+
const iterator = stream[Symbol.asyncIterator]();
2266+
2267+
expect(await iterator.next()).toEqual({
2268+
done: false,
2269+
value: {
2270+
event: "output",
2271+
id: "EVENT_1",
2272+
data: "",
2273+
},
2274+
});
2275+
expect(await iterator.next()).toEqual({
2276+
done: false,
2277+
value: {
2278+
event: "output",
2279+
id: "EVENT_2",
2280+
data: "https://delivery.replicate.com/my_file.png",
2281+
},
2282+
});
2283+
expect(await iterator.next()).toEqual({
2284+
done: false,
2285+
value: { event: "done", id: "EVENT_3", data: "{}" },
2286+
});
2287+
2288+
expect(await iterator.next()).toEqual({ done: true });
2289+
});
2290+
});
21962291
});
21972292
});

0 commit comments

Comments
 (0)