diff --git a/.changeset/dull-beds-wash.md b/.changeset/dull-beds-wash.md new file mode 100644 index 000000000..ecfcf3fd6 --- /dev/null +++ b/.changeset/dull-beds-wash.md @@ -0,0 +1,5 @@ +--- +"claude-dev": patch +--- + +getOllamaModels protobus migration diff --git a/proto/build-proto.js b/proto/build-proto.js index 7bfc74dd5..975d5a745 100755 --- a/proto/build-proto.js +++ b/proto/build-proto.js @@ -84,6 +84,7 @@ async function generateMethodRegistrations() { path.join(ROOT_DIR, "src", "core", "controller", "checkpoints"), path.join(ROOT_DIR, "src", "core", "controller", "file"), path.join(ROOT_DIR, "src", "core", "controller", "mcp"), + path.join(ROOT_DIR, "src", "core", "controller", "models"), path.join(ROOT_DIR, "src", "core", "controller", "task"), path.join(ROOT_DIR, "src", "core", "controller", "web-content"), // Add more service directories here as needed diff --git a/proto/common.proto b/proto/common.proto index 6272cd093..313aada70 100644 --- a/proto/common.proto +++ b/proto/common.proto @@ -54,3 +54,7 @@ message BooleanRequest { message Boolean { bool value = 1; } + +message StringArray { + repeated string values = 1; +} diff --git a/proto/models.proto b/proto/models.proto new file mode 100644 index 000000000..4052eb712 --- /dev/null +++ b/proto/models.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package cline; +option java_package = "bot.cline.proto"; +option java_multiple_files = true; + +import "common.proto"; + +// Service for model-related operations +service ModelsService { + // Fetches available models from Ollama + rpc getOllamaModels(StringRequest) returns (StringArray); +} diff --git a/src/core/controller/grpc-handler.ts b/src/core/controller/grpc-handler.ts index 7fc7da2b0..9911a0806 100644 --- a/src/core/controller/grpc-handler.ts +++ b/src/core/controller/grpc-handler.ts @@ -6,6 +6,7 @@ import { handleTaskServiceRequest } from "./task" import { handleCheckpointsServiceRequest } from "./checkpoints" import { handleMcpServiceRequest } from "./mcp" import { handleWebContentServiceRequest } from "./web-content" +import { handleModelsServiceRequest } from "./models" /** * Handles gRPC requests from the webview @@ -68,6 +69,11 @@ export class GrpcHandler { message: await handleWebContentServiceRequest(this.controller, method, message), request_id: requestId, } + case "cline.ModelsService": + return { + message: await handleModelsServiceRequest(this.controller, method, message), + request_id: requestId, + } default: throw new Error(`Unknown service: ${service}`) } diff --git a/src/core/controller/index.ts b/src/core/controller/index.ts index 3762e4b78..11ae63183 100644 --- a/src/core/controller/index.ts +++ b/src/core/controller/index.ts @@ -342,13 +342,6 @@ export class Controller { case "resetState": await this.resetState() break - case "requestOllamaModels": - const ollamaModels = await this.getOllamaModels(message.text) - this.postMessageToWebview({ - type: "ollamaModels", - ollamaModels, - }) - break case "requestLmStudioModels": const lmStudioModels = await this.getLmStudioModels(message.text) this.postMessageToWebview({ @@ -960,25 +953,6 @@ export class Controller { } } - // Ollama - - async getOllamaModels(baseUrl?: string) { - try { - if (!baseUrl) { - baseUrl = "http://localhost:11434" - } - if (!URL.canParse(baseUrl)) { - return [] - } - const response = await axios.get(`${baseUrl}/api/tags`) - const modelsArray = response.data?.models?.map((model: any) => model.name) || [] - const models = [...new Set(modelsArray)] - return models - } catch (error) { - return [] - } - } - // LM Studio async getLmStudioModels(baseUrl?: string) { diff --git a/src/core/controller/models/getOllamaModels.ts b/src/core/controller/models/getOllamaModels.ts new file mode 100644 index 000000000..79b08d147 --- /dev/null +++ b/src/core/controller/models/getOllamaModels.ts @@ -0,0 +1,27 @@ +import { Controller } from ".." +import { StringArray, StringRequest } from "../../../shared/proto/common" +import axios from "axios" + +/** + * Fetches available models from Ollama + * @param controller The controller instance + * @param request The request containing the base URL (optional) + * @returns Array of model names + */ +export async function getOllamaModels(controller: Controller, request: StringRequest): Promise { + try { + let baseUrl = request.value || "http://localhost:11434" + + if (!URL.canParse(baseUrl)) { + return StringArray.create({ values: [] }) + } + + const response = await axios.get(`${baseUrl}/api/tags`) + const modelsArray = response.data?.models?.map((model: any) => model.name) || [] + const models = [...new Set(modelsArray)] + + return StringArray.create({ values: models }) + } catch (error) { + return StringArray.create({ values: [] }) + } +} diff --git a/src/core/controller/models/index.ts b/src/core/controller/models/index.ts new file mode 100644 index 000000000..d58a396cf --- /dev/null +++ b/src/core/controller/models/index.ts @@ -0,0 +1,15 @@ +import { createServiceRegistry, ServiceMethodHandler } from "../grpc-service" +import { registerAllMethods } from "./methods" + +// Create models service registry +const modelsService = createServiceRegistry("models") + +// Export the method handler type and registration function +export type ModelsMethodHandler = ServiceMethodHandler +export const registerMethod = modelsService.registerMethod + +// Export the request handler +export const handleModelsServiceRequest = modelsService.handleRequest + +// Register all models methods +registerAllMethods() diff --git a/src/core/controller/models/methods.ts b/src/core/controller/models/methods.ts new file mode 100644 index 000000000..8251f5b65 --- /dev/null +++ b/src/core/controller/models/methods.ts @@ -0,0 +1,12 @@ +// AUTO-GENERATED FILE - DO NOT MODIFY DIRECTLY +// Generated by proto/build-proto.js + +// Import all method implementations +import { registerMethod } from "./index" +import { getOllamaModels } from "./getOllamaModels" + +// Register all models service methods +export function registerAllMethods(): void { + // Register each method with the registry + registerMethod("getOllamaModels", getOllamaModels) +} diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 2ec6c1589..de649e9a3 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -20,7 +20,6 @@ export interface WebviewMessage { | "showTaskWithId" | "exportTaskWithId" | "resetState" - | "requestOllamaModels" | "requestLmStudioModels" | "openInBrowser" | "openMention" diff --git a/src/shared/proto/common.ts b/src/shared/proto/common.ts index ffc280253..2bdac92f0 100644 --- a/src/shared/proto/common.ts +++ b/src/shared/proto/common.ts @@ -58,6 +58,10 @@ export interface Boolean { value: boolean } +export interface StringArray { + values: string[] +} + function createBaseMetadata(): Metadata { return {} } @@ -820,6 +824,66 @@ export const Boolean: MessageFns = { }, } +function createBaseStringArray(): StringArray { + return { values: [] } +} + +export const StringArray: MessageFns = { + encode(message: StringArray, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + for (const v of message.values) { + writer.uint32(10).string(v!) + } + return writer + }, + + decode(input: BinaryReader | Uint8Array, length?: number): StringArray { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input) + let end = length === undefined ? reader.len : reader.pos + length + const message = createBaseStringArray() + while (reader.pos < end) { + const tag = reader.uint32() + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break + } + + message.values.push(reader.string()) + continue + } + } + if ((tag & 7) === 4 || tag === 0) { + break + } + reader.skip(tag & 7) + } + return message + }, + + fromJSON(object: any): StringArray { + return { + values: globalThis.Array.isArray(object?.values) ? object.values.map((e: any) => globalThis.String(e)) : [], + } + }, + + toJSON(message: StringArray): unknown { + const obj: any = {} + if (message.values?.length) { + obj.values = message.values + } + return obj + }, + + create, I>>(base?: I): StringArray { + return StringArray.fromPartial(base ?? ({} as any)) + }, + fromPartial, I>>(object: I): StringArray { + const message = createBaseStringArray() + message.values = object.values?.map((e) => e) || [] + return message + }, +} + function bytesFromBase64(b64: string): Uint8Array { return Uint8Array.from(globalThis.Buffer.from(b64, "base64")) } diff --git a/src/shared/proto/models.ts b/src/shared/proto/models.ts new file mode 100644 index 000000000..71daac69a --- /dev/null +++ b/src/shared/proto/models.ts @@ -0,0 +1,28 @@ +// Code generated by protoc-gen-ts_proto. DO NOT EDIT. +// versions: +// protoc-gen-ts_proto v2.7.0 +// protoc v3.19.1 +// source: models.proto + +/* eslint-disable */ +import { StringArray, StringRequest } from "./common" + +export const protobufPackage = "cline" + +/** Service for model-related operations */ +export type ModelsServiceDefinition = typeof ModelsServiceDefinition +export const ModelsServiceDefinition = { + name: "ModelsService", + fullName: "cline.ModelsService", + methods: { + /** Fetches available models from Ollama */ + getOllamaModels: { + name: "getOllamaModels", + requestType: StringRequest, + requestStream: false, + responseType: StringArray, + responseStream: false, + options: {}, + }, + }, +} as const diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 494f3e361..ba753e875 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -55,6 +55,7 @@ import { import { ExtensionMessage } from "@shared/ExtensionMessage" import { useExtensionState } from "@/context/ExtensionStateContext" import { vscode } from "@/utils/vscode" +import { ModelsServiceClient } from "@/services/grpc-client" import { getAsVar, VSC_DESCRIPTION_FOREGROUND } from "@/utils/vscStyles" import VSCodeButtonLink from "@/components/common/VSCodeButtonLink" import OpenRouterModelPicker, { ModelDescriptionMarkdown, OPENROUTER_MODEL_PICKER_Z_INDEX } from "./OpenRouterModelPicker" @@ -182,12 +183,19 @@ const ApiOptions = ({ }, [apiConfiguration]) // Poll ollama/lmstudio models - const requestLocalModels = useCallback(() => { + const requestLocalModels = useCallback(async () => { if (selectedProvider === "ollama") { - vscode.postMessage({ - type: "requestOllamaModels", - text: apiConfiguration?.ollamaBaseUrl, - }) + try { + const response = await ModelsServiceClient.getOllamaModels({ + value: apiConfiguration?.ollamaBaseUrl || "", + }) + if (response && response.values) { + setOllamaModels(response.values) + } + } catch (error) { + console.error("Failed to fetch Ollama models:", error) + setOllamaModels([]) + } } else if (selectedProvider === "lmstudio") { vscode.postMessage({ type: "requestLmStudioModels", @@ -209,9 +217,7 @@ const ApiOptions = ({ const handleMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data - if (message.type === "ollamaModels" && message.ollamaModels) { - setOllamaModels(message.ollamaModels) - } else if (message.type === "lmStudioModels" && message.lmStudioModels) { + if (message.type === "lmStudioModels" && message.lmStudioModels) { setLmStudioModels(message.lmStudioModels) } else if (message.type === "vsCodeLmModels" && message.vsCodeLmModels) { setVsCodeLmModels(message.vsCodeLmModels) diff --git a/webview-ui/src/services/grpc-client.ts b/webview-ui/src/services/grpc-client.ts index 44169a348..92ccdddb9 100644 --- a/webview-ui/src/services/grpc-client.ts +++ b/webview-ui/src/services/grpc-client.ts @@ -6,6 +6,7 @@ import { CheckpointsServiceDefinition } from "@shared/proto/checkpoints" import { EmptyRequest } from "@shared/proto/common" import { FileServiceDefinition } from "@shared/proto/file" import { McpServiceDefinition } from "@shared/proto/mcp" +import { ModelsServiceDefinition } from "@shared/proto/models" import { TaskServiceDefinition } from "@shared/proto/task" import { WebContentServiceDefinition } from "@shared/proto/web_content" // Generic type for any protobuf service definition @@ -102,6 +103,7 @@ const BrowserServiceClient = createGrpcClient(BrowserServiceDefinition) const CheckpointsServiceClient = createGrpcClient(CheckpointsServiceDefinition) const FileServiceClient = createGrpcClient(FileServiceDefinition) const McpServiceClient = createGrpcClient(McpServiceDefinition) +const ModelsServiceClient = createGrpcClient(ModelsServiceDefinition) const TaskServiceClient = createGrpcClient(TaskServiceDefinition) const WebContentServiceClient = createGrpcClient(WebContentServiceDefinition) @@ -112,5 +114,6 @@ export { FileServiceClient, TaskServiceClient, McpServiceClient, + ModelsServiceClient, WebContentServiceClient, }