Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/api/providers/__tests__/deepinfra.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ vitest.mock("../fetchers/modelCache", () => ({
getModels: vitest.fn().mockResolvedValue({
[deepInfraDefaultModelId]: deepInfraDefaultModelInfo,
}),
getModelsFromCache: vitest.fn().mockReturnValue(undefined),
}))

import OpenAI from "openai"
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/__tests__/glama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ vitest.mock("../fetchers/modelCache", () => ({
},
})
}),
getModelsFromCache: vitest.fn().mockReturnValue(undefined),
}))

// Mock OpenAI client
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/__tests__/lite-llm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ vi.mock("../fetchers/modelCache", () => ({
"gpt-4-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 },
})
}),
getModelsFromCache: vi.fn().mockReturnValue(undefined),
}))

describe("LiteLLMHandler", () => {
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/__tests__/unbound.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ vitest.mock("../fetchers/modelCache", () => ({
},
})
}),
getModelsFromCache: vitest.fn().mockReturnValue(undefined),
}))

// Mock OpenAI client
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/__tests__/vercel-ai-gateway.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ vitest.mock("../fetchers/modelCache", () => ({
},
})
}),
getModelsFromCache: vitest.fn().mockReturnValue(undefined),
}))

vitest.mock("../../transform/caching/vercel-ai-gateway", () => ({
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string
const client = new LMStudioClient({ baseUrl: lmsUrl })
await client.llm.model(modelId)
// Flush and refresh cache to get updated model details
await flushModels("lmstudio", true)
await flushModels({ provider: "lmstudio", baseUrl }, true)

// Mark this model as having full details loaded.
modelsWithLoadedDetails.add(modelId)
Expand Down
12 changes: 6 additions & 6 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -272,20 +272,20 @@ export async function initializeModelCacheRefresh(): Promise<void> {
/**
* Flush models memory cache for a specific router.
*
* @param router - The router to flush models for.
* @param options - The options for fetching models, including provider, apiKey, and baseUrl
* @param refresh - If true, immediately fetch fresh data from API
*/
export const flushModels = async (router: RouterName, refresh: boolean = false): Promise<void> => {
export const flushModels = async (options: GetModelsOptions, refresh: boolean = false): Promise<void> => {
const { provider } = options
if (refresh) {
// Don't delete memory cache - let refreshModels atomically replace it
// This prevents a race condition where getModels() might be called
// before refresh completes, avoiding a gap in cache availability
refreshModels({ provider: router } as GetModelsOptions).catch((error) => {
console.error(`[flushModels] Refresh failed for ${router}:`, error)
})
// Await the refresh to ensure the cache is updated before returning
await refreshModels(options)
} else {
// Only delete memory cache when not refreshing
memoryCache.del(router)
memoryCache.del(provider)
}
}

Expand Down
21 changes: 17 additions & 4 deletions src/api/providers/router-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type { ModelInfo } from "@roo-code/types"
import { ApiHandlerOptions, RouterName, ModelRecord } from "../../shared/api"

import { BaseProvider } from "./base-provider"
import { getModels } from "./fetchers/modelCache"
import { getModels, getModelsFromCache } from "./fetchers/modelCache"

import { DEFAULT_HEADERS } from "./constants"

Expand Down Expand Up @@ -63,9 +63,22 @@ export abstract class RouterProvider extends BaseProvider {
override getModel(): { id: string; info: ModelInfo } {
const id = this.modelId ?? this.defaultModelId

return this.models[id]
? { id, info: this.models[id] }
: { id: this.defaultModelId, info: this.defaultModelInfo }
// First check instance models (populated by fetchModel)
if (this.models[id]) {
return { id, info: this.models[id] }
}

// Fall back to global cache (synchronous disk/memory cache)
// This ensures models are available before fetchModel() is called
const cachedModels = getModelsFromCache(this.name)
if (cachedModels?.[id]) {
// Also populate instance models for future calls
this.models = cachedModels
return { id, info: cachedModels[id] }
}

// Last resort: return default model
return { id: this.defaultModelId, info: this.defaultModelInfo }
}

protected supportsTemperature(modelId: string): boolean {
Expand Down
2 changes: 2 additions & 0 deletions src/core/webview/__tests__/ClineProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ vi.mock("../../../integrations/misc/extract-text", () => ({
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: vi.fn().mockResolvedValue({}),
flushModels: vi.fn(),
getModelsFromCache: vi.fn().mockReturnValue(undefined),
}))

vi.mock("../../../shared/modes", () => ({
Expand Down Expand Up @@ -308,6 +309,7 @@ vi.mock("../../../integrations/misc/extract-text", () => ({
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: vi.fn().mockResolvedValue({}),
flushModels: vi.fn(),
getModelsFromCache: vi.fn().mockReturnValue(undefined),
}))

vi.mock("../diff/strategies/multi-search-replace", () => ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ vi.mock("../../prompts/system", () => ({
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: vi.fn().mockResolvedValue({}),
flushModels: vi.fn(),
getModelsFromCache: vi.fn().mockReturnValue(undefined),
}))

vi.mock("../../../integrations/misc/extract-text", () => ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ vi.mock("../../task-persistence", () => ({
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: vi.fn(),
flushModels: vi.fn(),
getModelsFromCache: vi.fn().mockReturnValue(undefined),
}))

vi.mock("../checkpointRestoreHandler", () => ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,11 @@ describe("webviewMessageHandler - requestRouterModels provider filter", () => {
} as any,
)

// flushModels should have been called for litellm with refresh=true
expect(flushModelsMock).toHaveBeenCalledWith("litellm", true)
// flushModels should have been called for litellm with refresh=true and credentials
expect(flushModelsMock).toHaveBeenCalledWith(
{ provider: "litellm", apiKey: "test-api-key", baseUrl: "http://localhost:4000" },
true,
)

// getModels should have been called with the provided credentials
const litellmCalls = getModelsMock.mock.calls.filter((c: any[]) => c[0]?.provider === "litellm")
Expand Down
43 changes: 24 additions & 19 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,9 @@ export const webviewMessageHandler = async (
break
case "flushRouterModels":
const routerNameFlush: RouterName = toRouterName(message.text)
await flushModels(routerNameFlush, true)
// Note: flushRouterModels is a generic flush without credentials
// For providers that need credentials, use their specific handlers
await flushModels({ provider: routerNameFlush } as GetModelsOptions, true)
break
case "requestRouterModels":
const { apiConfiguration } = await provider.getState()
Expand Down Expand Up @@ -871,7 +873,7 @@ export const webviewMessageHandler = async (
// If explicit credentials are provided in message.values (from Refresh Models button),
// flush the cache first to ensure we fetch fresh data with the new credentials
if (message?.values?.litellmApiKey || message?.values?.litellmBaseUrl) {
await flushModels("litellm", true)
await flushModels({ provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, true)
}

candidates.push({
Expand Down Expand Up @@ -925,14 +927,15 @@ export const webviewMessageHandler = async (
// Specific handler for Ollama models only.
const { apiConfiguration: ollamaApiConfig } = await provider.getState()
try {
// Flush cache and refresh to ensure fresh models.
await flushModels("ollama", true)

const ollamaModels = await getModels({
provider: "ollama",
const ollamaOptions = {
provider: "ollama" as const,
baseUrl: ollamaApiConfig.ollamaBaseUrl,
apiKey: ollamaApiConfig.ollamaApiKey,
})
}
// Flush cache and refresh to ensure fresh models.
await flushModels(ollamaOptions, true)

const ollamaModels = await getModels(ollamaOptions)

if (Object.keys(ollamaModels).length > 0) {
provider.postMessageToWebview({ type: "ollamaModels", ollamaModels: ollamaModels })
Expand All @@ -947,13 +950,14 @@ export const webviewMessageHandler = async (
// Specific handler for LM Studio models only.
const { apiConfiguration: lmStudioApiConfig } = await provider.getState()
try {
const lmStudioOptions = {
provider: "lmstudio" as const,
baseUrl: lmStudioApiConfig.lmStudioBaseUrl,
}
// Flush cache and refresh to ensure fresh models.
await flushModels("lmstudio", true)
await flushModels(lmStudioOptions, true)

const lmStudioModels = await getModels({
provider: "lmstudio",
baseUrl: lmStudioApiConfig.lmStudioBaseUrl,
})
const lmStudioModels = await getModels(lmStudioOptions)

if (Object.keys(lmStudioModels).length > 0) {
provider.postMessageToWebview({
Expand All @@ -970,16 +974,17 @@ export const webviewMessageHandler = async (
case "requestRooModels": {
// Specific handler for Roo models only - flushes cache to ensure fresh auth token is used
try {
// Flush cache and refresh to ensure fresh models with current auth state
await flushModels("roo", true)

const rooModels = await getModels({
provider: "roo",
const rooOptions = {
provider: "roo" as const,
baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy",
apiKey: CloudService.hasInstance()
? CloudService.instance.authService?.getSessionToken()
: undefined,
})
}
// Flush cache and refresh to ensure fresh models with current auth state
await flushModels(rooOptions, true)

const rooModels = await getModels(rooOptions)

// Always send a response, even if no models are returned
provider.postMessageToWebview({
Expand Down
9 changes: 8 additions & 1 deletion src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,14 @@ export async function activate(context: vscode.ExtensionContext) {
const handleRooModelsCache = async () => {
try {
// Flush and refresh cache on auth state changes
await flushModels("roo", true)
const rooOptions = {
provider: "roo" as const,
baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy",
apiKey: CloudService.hasInstance()
? CloudService.instance.authService?.getSessionToken()
: undefined,
}
await flushModels(rooOptions, true)

if (data.state === "active-session") {
cloudLogger(`[authStateChangedHandler] Refreshed Roo models cache for active session`)
Expand Down
Loading