From 83ed32f1195841d2d6c057c3e00086a3147879a2 Mon Sep 17 00:00:00 2001 From: memoclaw Date: Sun, 12 Apr 2026 19:23:34 +0800 Subject: [PATCH] feat(ai): add instance AI providers and transcription (#5829) Co-authored-by: memoclaw <265580040+memoclaw@users.noreply.github.com> --- internal/ai/ai.go | 26 ++ internal/ai/errors.go | 10 + internal/ai/openai/client.go | 59 +++ internal/ai/openai/transcription.go | 145 +++++++ internal/ai/openai/transcription_test.go | 65 +++ internal/ai/resolver.go | 16 + internal/ai/transcription.go | 29 ++ proto/api/v1/ai_service.proto | 63 +++ proto/api/v1/instance_service.proto | 34 ++ proto/gen/api/v1/ai_service.pb.go | 382 ++++++++++++++++ proto/gen/api/v1/ai_service.pb.gw.go | 157 +++++++ proto/gen/api/v1/ai_service_grpc.pb.go | 123 ++++++ .../api/v1/apiv1connect/ai_service.connect.go | 110 +++++ proto/gen/api/v1/instance_service.pb.go | 387 ++++++++++++++--- proto/gen/openapi.yaml | 127 ++++++ proto/gen/store/instance_setting.pb.go | 329 ++++++++++++-- proto/store/instance_setting.proto | 27 ++ server/router/api/v1/ai_service.go | 198 +++++++++ server/router/api/v1/connect_handler.go | 1 + server/router/api/v1/connect_services.go | 10 + server/router/api/v1/instance_service.go | 171 +++++++- server/router/api/v1/test/ai_service_test.go | 185 ++++++++ .../api/v1/test/instance_service_test.go | 134 ++++++ server/router/api/v1/v1.go | 4 + store/instance_setting.go | 28 ++ store/test/instance_setting_test.go | 49 +++ web/src/components/Settings/AISection.tsx | 408 ++++++++++++++++++ web/src/connect.ts | 2 + web/src/contexts/InstanceContext.tsx | 14 +- web/src/locales/en.json | 26 ++ web/src/pages/Setting.tsx | 9 +- web/src/types/proto/api/v1/ai_service_pb.ts | 166 +++++++ .../types/proto/api/v1/instance_service_pb.ts | 139 +++++- 33 files changed, 3522 insertions(+), 111 deletions(-) create mode 100644 internal/ai/ai.go create mode 100644 internal/ai/errors.go create mode 100644 internal/ai/openai/client.go create mode 100644 internal/ai/openai/transcription.go create mode 100644 internal/ai/openai/transcription_test.go create mode 100644 internal/ai/resolver.go create mode 100644 internal/ai/transcription.go create mode 100644 proto/api/v1/ai_service.proto create mode 100644 proto/gen/api/v1/ai_service.pb.go create mode 100644 proto/gen/api/v1/ai_service.pb.gw.go create mode 100644 proto/gen/api/v1/ai_service_grpc.pb.go create mode 100644 proto/gen/api/v1/apiv1connect/ai_service.connect.go create mode 100644 server/router/api/v1/ai_service.go create mode 100644 server/router/api/v1/test/ai_service_test.go create mode 100644 web/src/components/Settings/AISection.tsx create mode 100644 web/src/types/proto/api/v1/ai_service_pb.ts diff --git a/internal/ai/ai.go b/internal/ai/ai.go new file mode 100644 index 000000000..1a3b1ca2a --- /dev/null +++ b/internal/ai/ai.go @@ -0,0 +1,26 @@ +package ai + +// ProviderType identifies an AI provider implementation. +type ProviderType string + +const ( + // ProviderOpenAI is OpenAI's hosted API. + ProviderOpenAI ProviderType = "OPENAI" + // ProviderOpenAICompatible is an OpenAI-compatible API endpoint. + ProviderOpenAICompatible ProviderType = "OPENAI_COMPATIBLE" + // ProviderAnthropic is Anthropic's API. + ProviderAnthropic ProviderType = "ANTHROPIC" + // ProviderGemini is Google's Gemini API. + ProviderGemini ProviderType = "GEMINI" +) + +// ProviderConfig configures a callable AI provider connection. +type ProviderConfig struct { + ID string + Title string + Type ProviderType + Endpoint string + APIKey string + Models []string + DefaultModel string +} diff --git a/internal/ai/errors.go b/internal/ai/errors.go new file mode 100644 index 000000000..f6ff00ff6 --- /dev/null +++ b/internal/ai/errors.go @@ -0,0 +1,10 @@ +package ai + +import "github.com/pkg/errors" + +var ( + // ErrProviderNotFound indicates that a requested provider ID does not exist. + ErrProviderNotFound = errors.New("AI provider not found") + // ErrCapabilityUnsupported indicates that the provider does not support the requested capability. + ErrCapabilityUnsupported = errors.New("AI provider capability unsupported") +) diff --git a/internal/ai/openai/client.go b/internal/ai/openai/client.go new file mode 100644 index 000000000..b1ad49dea --- /dev/null +++ b/internal/ai/openai/client.go @@ -0,0 +1,59 @@ +package openai + +import ( + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/errors" + + "github.com/usememos/memos/internal/ai" +) + +const defaultEndpoint = "https://api.openai.com/v1" + +// Transcriber transcribes audio with OpenAI-compatible transcription APIs. +type Transcriber struct { + endpoint string + apiKey string + httpClient *http.Client +} + +// NewTranscriber creates a new OpenAI-compatible transcriber. +func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) { + endpoint := strings.TrimSpace(config.Endpoint) + if endpoint == "" { + endpoint = defaultEndpoint + } + if _, err := url.ParseRequestURI(endpoint); err != nil { + return nil, errors.Wrap(err, "invalid OpenAI endpoint") + } + if config.APIKey == "" { + return nil, errors.New("OpenAI API key is required") + } + + transcriber := &Transcriber{ + endpoint: endpoint, + apiKey: config.APIKey, + httpClient: &http.Client{ + Timeout: 2 * time.Minute, + }, + } + for _, option := range options { + option(transcriber) + } + return transcriber, nil +} + +// Option configures a Transcriber. +type Option func(*Transcriber) + +// WithHTTPClient sets the HTTP client used by the transcriber. +func WithHTTPClient(client *http.Client) Option { + return func(t *Transcriber) { + if client != nil { + t.httpClient = client + } + } +} diff --git a/internal/ai/openai/transcription.go b/internal/ai/openai/transcription.go new file mode 100644 index 000000000..79a9adad5 --- /dev/null +++ b/internal/ai/openai/transcription.go @@ -0,0 +1,145 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/internal/ai" +) + +type transcriptionResponse struct { + Text string `json:"text"` + Language string `json:"language"` + Duration float64 `json:"duration"` +} + +type errorResponse struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error"` +} + +// Transcribe transcribes audio with the /audio/transcriptions endpoint. +func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) { + if strings.TrimSpace(request.Model) == "" { + return nil, errors.New("model is required") + } + if request.Audio == nil { + return nil, errors.New("audio is required") + } + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + if err := writeAudioFilePart(writer, request); err != nil { + return nil, err + } + if err := writer.WriteField("model", request.Model); err != nil { + return nil, errors.Wrap(err, "failed to write model field") + } + if err := writer.WriteField("response_format", "json"); err != nil { + return nil, errors.Wrap(err, "failed to write response format field") + } + if request.Prompt != "" { + if err := writer.WriteField("prompt", request.Prompt); err != nil { + return nil, errors.Wrap(err, "failed to write prompt field") + } + } + if request.Language != "" { + if err := writer.WriteField("language", request.Language); err != nil { + return nil, errors.Wrap(err, "failed to write language field") + } + } + if err := writer.Close(); err != nil { + return nil, errors.Wrap(err, "failed to close multipart writer") + } + + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(t.endpoint, "/")+"/audio/transcriptions", body) + if err != nil { + return nil, errors.Wrap(err, "failed to create transcription request") + } + httpRequest.Header.Set("Authorization", "Bearer "+t.apiKey) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + + httpResponse, err := t.httpClient.Do(httpRequest) + if err != nil { + return nil, errors.Wrap(err, "failed to send transcription request") + } + defer httpResponse.Body.Close() + + responseBody, err := io.ReadAll(httpResponse.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read transcription response") + } + if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices { + return nil, errors.Errorf("transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody)) + } + + var response transcriptionResponse + if err := json.Unmarshal(responseBody, &response); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal transcription response") + } + return &ai.TranscribeResponse{ + Text: response.Text, + Language: response.Language, + Duration: response.Duration, + }, nil +} + +func writeAudioFilePart(writer *multipart.Writer, request ai.TranscribeRequest) error { + filename := strings.TrimSpace(request.Filename) + if filename == "" { + filename = "audio" + } + contentType := strings.TrimSpace(request.ContentType) + if contentType == "" { + contentType = "application/octet-stream" + } else { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return errors.Wrap(err, "invalid audio content type") + } + contentType = mediaType + } + + header := make(textproto.MIMEHeader) + header.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{ + "name": "file", + "filename": sanitizeFilename(filename), + })) + header.Set("Content-Type", contentType) + part, err := writer.CreatePart(header) + if err != nil { + return errors.Wrap(err, "failed to create audio file part") + } + if _, err := io.Copy(part, request.Audio); err != nil { + return errors.Wrap(err, "failed to write audio file part") + } + return nil +} + +func extractErrorMessage(responseBody []byte) string { + var response errorResponse + if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" { + return response.Error.Message + } + return string(responseBody) +} + +func sanitizeFilename(filename string) string { + filename = strings.NewReplacer("\r", "_", "\n", "_").Replace(filename) + if strings.TrimSpace(filename) == "" { + return "audio" + } + return filename +} diff --git a/internal/ai/openai/transcription_test.go b/internal/ai/openai/transcription_test.go new file mode 100644 index 000000000..c436b7cd1 --- /dev/null +++ b/internal/ai/openai/transcription_test.go @@ -0,0 +1,65 @@ +package openai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/internal/ai" +) + +func TestTranscribe(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/audio/transcriptions", r.URL.Path) + require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + require.NoError(t, r.ParseMultipartForm(10<<20)) + require.Equal(t, "gpt-4o-transcribe", r.FormValue("model")) + require.Equal(t, "json", r.FormValue("response_format")) + require.Equal(t, "domain words", r.FormValue("prompt")) + require.Equal(t, "en", r.FormValue("language")) + + file, header, err := r.FormFile("file") + require.NoError(t, err) + defer file.Close() + require.Equal(t, "voice.wav", header.Filename) + require.Equal(t, "audio/wav", header.Header.Get("Content-Type")) + + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "text": "hello world", + "language": "en", + "duration": 1.5, + })) + })) + defer server.Close() + + transcriber, err := NewTranscriber(ai.ProviderConfig{ + Endpoint: server.URL, + APIKey: "test-key", + }) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{ + Model: "gpt-4o-transcribe", + Filename: "voice.wav", + ContentType: "audio/wav", + Audio: strings.NewReader("RIFF"), + Prompt: "domain words", + Language: "en", + }) + require.NoError(t, err) + require.Equal(t, "hello world", response.Text) + require.Equal(t, "en", response.Language) + require.Equal(t, 1.5, response.Duration) +} diff --git a/internal/ai/resolver.go b/internal/ai/resolver.go new file mode 100644 index 000000000..902344fd0 --- /dev/null +++ b/internal/ai/resolver.go @@ -0,0 +1,16 @@ +package ai + +import "github.com/pkg/errors" + +// FindProvider returns the provider with the given ID. +func FindProvider(providers []ProviderConfig, providerID string) (*ProviderConfig, error) { + if providerID == "" { + return nil, errors.Wrap(ErrProviderNotFound, "provider ID is required") + } + for _, provider := range providers { + if provider.ID == providerID { + return &provider, nil + } + } + return nil, errors.Wrapf(ErrProviderNotFound, "provider ID %q", providerID) +} diff --git a/internal/ai/transcription.go b/internal/ai/transcription.go new file mode 100644 index 000000000..544504915 --- /dev/null +++ b/internal/ai/transcription.go @@ -0,0 +1,29 @@ +package ai + +import ( + "context" + "io" +) + +// Transcriber transcribes audio into text. +type Transcriber interface { + Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error) +} + +// TranscribeRequest contains an audio transcription request. +type TranscribeRequest struct { + Model string + Filename string + ContentType string + Audio io.Reader + Size int64 + Prompt string + Language string +} + +// TranscribeResponse contains an audio transcription response. +type TranscribeResponse struct { + Text string + Language string + Duration float64 +} diff --git a/proto/api/v1/ai_service.proto b/proto/api/v1/ai_service.proto new file mode 100644 index 000000000..3e908e809 --- /dev/null +++ b/proto/api/v1/ai_service.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package memos.api.v1; + +import "google/api/annotations.proto"; +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; + +option go_package = "gen/api/v1"; + +service AIService { + // Transcribe transcribes an audio file using an instance AI provider. + rpc Transcribe(TranscribeRequest) returns (TranscribeResponse) { + option (google.api.http) = { + post: "/api/v1/ai:transcribe" + body: "*" + }; + option (google.api.method_signature) = "provider_id,config,audio"; + } +} + +message TranscribeRequest { + // Required. The instance AI provider ID to use. + string provider_id = 1 [(google.api.field_behavior) = REQUIRED]; + + // Required. Transcription options. + TranscriptionConfig config = 2 [(google.api.field_behavior) = REQUIRED]; + + // Required. Audio input. + TranscriptionAudio audio = 3 [(google.api.field_behavior) = REQUIRED]; +} + +message TranscriptionConfig { + // Optional. The model to use. If empty, the provider's default model is used. + string model = 1 [(google.api.field_behavior) = OPTIONAL]; + + // Optional. A prompt to improve transcription quality. + string prompt = 2 [(google.api.field_behavior) = OPTIONAL]; + + // Optional. The language of the input audio. + string language = 3 [(google.api.field_behavior) = OPTIONAL]; +} + +message TranscriptionAudio { + oneof source { + // Inline audio bytes. + bytes content = 1 [(google.api.field_behavior) = INPUT_ONLY]; + + // URI for audio content. Reserved for future use. + string uri = 2; + } + + // Optional. The uploaded filename. + string filename = 3 [(google.api.field_behavior) = OPTIONAL]; + + // Optional. The MIME type of the input audio. + string content_type = 4 [(google.api.field_behavior) = OPTIONAL]; +} + +message TranscribeResponse { + // The transcribed text. + string text = 1; +} diff --git a/proto/api/v1/instance_service.proto b/proto/api/v1/instance_service.proto index 1f500b3e3..1be5977a0 100644 --- a/proto/api/v1/instance_service.proto +++ b/proto/api/v1/instance_service.proto @@ -72,6 +72,7 @@ message InstanceSetting { MemoRelatedSetting memo_related_setting = 4; TagsSetting tags_setting = 5; NotificationSetting notification_setting = 6; + AISetting ai_setting = 7; } // Enumeration of instance setting keys. @@ -87,6 +88,8 @@ message InstanceSetting { TAGS = 4; // NOTIFICATION is the key for notification transport settings. NOTIFICATION = 5; + // AI is the key for AI provider settings. + AI = 6; } // General instance settings configuration. @@ -201,6 +204,37 @@ message InstanceSetting { bool use_ssl = 10; } } + + // AI provider configuration settings. + message AISetting { + // providers is the list of AI provider configurations available instance-wide. + repeated AIProviderConfig providers = 1; + } + + // AIProviderConfig represents one callable AI provider connection. + message AIProviderConfig { + string id = 1; + string title = 2; + AIProviderType type = 3; + string endpoint = 4; + // api_key is write-only and is never returned by GetInstanceSetting. + string api_key = 5 [(google.api.field_behavior) = INPUT_ONLY]; + repeated string models = 6; + string default_model = 7; + // api_key_set indicates whether an API key is stored for this provider. + bool api_key_set = 8 [(google.api.field_behavior) = OUTPUT_ONLY]; + // api_key_hint is a masked hint for the stored API key. + string api_key_hint = 9 [(google.api.field_behavior) = OUTPUT_ONLY]; + } + + // AIProviderType is the provider implementation type. + enum AIProviderType { + AI_PROVIDER_TYPE_UNSPECIFIED = 0; + OPENAI = 1; + OPENAI_COMPATIBLE = 2; + ANTHROPIC = 3; + GEMINI = 4; + } } // Request message for GetInstanceSetting method. diff --git a/proto/gen/api/v1/ai_service.pb.go b/proto/gen/api/v1/ai_service.pb.go new file mode 100644 index 000000000..3af97e744 --- /dev/null +++ b/proto/gen/api/v1/ai_service.pb.go @@ -0,0 +1,382 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: api/v1/ai_service.proto + +package apiv1 + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TranscribeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Required. The instance AI provider ID to use. + ProviderId string `protobuf:"bytes,1,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"` + // Required. Transcription options. + Config *TranscriptionConfig `protobuf:"bytes,2,opt,name=config,proto3" json:"config,omitempty"` + // Required. Audio input. + Audio *TranscriptionAudio `protobuf:"bytes,3,opt,name=audio,proto3" json:"audio,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TranscribeRequest) Reset() { + *x = TranscribeRequest{} + mi := &file_api_v1_ai_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TranscribeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscribeRequest) ProtoMessage() {} + +func (x *TranscribeRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_ai_service_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscribeRequest.ProtoReflect.Descriptor instead. +func (*TranscribeRequest) Descriptor() ([]byte, []int) { + return file_api_v1_ai_service_proto_rawDescGZIP(), []int{0} +} + +func (x *TranscribeRequest) GetProviderId() string { + if x != nil { + return x.ProviderId + } + return "" +} + +func (x *TranscribeRequest) GetConfig() *TranscriptionConfig { + if x != nil { + return x.Config + } + return nil +} + +func (x *TranscribeRequest) GetAudio() *TranscriptionAudio { + if x != nil { + return x.Audio + } + return nil +} + +type TranscriptionConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Optional. The model to use. If empty, the provider's default model is used. + Model string `protobuf:"bytes,1,opt,name=model,proto3" json:"model,omitempty"` + // Optional. A prompt to improve transcription quality. + Prompt string `protobuf:"bytes,2,opt,name=prompt,proto3" json:"prompt,omitempty"` + // Optional. The language of the input audio. + Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TranscriptionConfig) Reset() { + *x = TranscriptionConfig{} + mi := &file_api_v1_ai_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TranscriptionConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptionConfig) ProtoMessage() {} + +func (x *TranscriptionConfig) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_ai_service_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptionConfig.ProtoReflect.Descriptor instead. +func (*TranscriptionConfig) Descriptor() ([]byte, []int) { + return file_api_v1_ai_service_proto_rawDescGZIP(), []int{1} +} + +func (x *TranscriptionConfig) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *TranscriptionConfig) GetPrompt() string { + if x != nil { + return x.Prompt + } + return "" +} + +func (x *TranscriptionConfig) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +type TranscriptionAudio struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Source: + // + // *TranscriptionAudio_Content + // *TranscriptionAudio_Uri + Source isTranscriptionAudio_Source `protobuf_oneof:"source"` + // Optional. The uploaded filename. + Filename string `protobuf:"bytes,3,opt,name=filename,proto3" json:"filename,omitempty"` + // Optional. The MIME type of the input audio. + ContentType string `protobuf:"bytes,4,opt,name=content_type,json=contentType,proto3" json:"content_type,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TranscriptionAudio) Reset() { + *x = TranscriptionAudio{} + mi := &file_api_v1_ai_service_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TranscriptionAudio) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptionAudio) ProtoMessage() {} + +func (x *TranscriptionAudio) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_ai_service_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptionAudio.ProtoReflect.Descriptor instead. +func (*TranscriptionAudio) Descriptor() ([]byte, []int) { + return file_api_v1_ai_service_proto_rawDescGZIP(), []int{2} +} + +func (x *TranscriptionAudio) GetSource() isTranscriptionAudio_Source { + if x != nil { + return x.Source + } + return nil +} + +func (x *TranscriptionAudio) GetContent() []byte { + if x != nil { + if x, ok := x.Source.(*TranscriptionAudio_Content); ok { + return x.Content + } + } + return nil +} + +func (x *TranscriptionAudio) GetUri() string { + if x != nil { + if x, ok := x.Source.(*TranscriptionAudio_Uri); ok { + return x.Uri + } + } + return "" +} + +func (x *TranscriptionAudio) GetFilename() string { + if x != nil { + return x.Filename + } + return "" +} + +func (x *TranscriptionAudio) GetContentType() string { + if x != nil { + return x.ContentType + } + return "" +} + +type isTranscriptionAudio_Source interface { + isTranscriptionAudio_Source() +} + +type TranscriptionAudio_Content struct { + // Inline audio bytes. + Content []byte `protobuf:"bytes,1,opt,name=content,proto3,oneof"` +} + +type TranscriptionAudio_Uri struct { + // URI for audio content. Reserved for future use. + Uri string `protobuf:"bytes,2,opt,name=uri,proto3,oneof"` +} + +func (*TranscriptionAudio_Content) isTranscriptionAudio_Source() {} + +func (*TranscriptionAudio_Uri) isTranscriptionAudio_Source() {} + +type TranscribeResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The transcribed text. + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TranscribeResponse) Reset() { + *x = TranscribeResponse{} + mi := &file_api_v1_ai_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TranscribeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscribeResponse) ProtoMessage() {} + +func (x *TranscribeResponse) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_ai_service_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscribeResponse.ProtoReflect.Descriptor instead. +func (*TranscribeResponse) Descriptor() ([]byte, []int) { + return file_api_v1_ai_service_proto_rawDescGZIP(), []int{3} +} + +func (x *TranscribeResponse) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +var File_api_v1_ai_service_proto protoreflect.FileDescriptor + +const file_api_v1_ai_service_proto_rawDesc = "" + + "\n" + + "\x17api/v1/ai_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\"\xb6\x01\n" + + "\x11TranscribeRequest\x12$\n" + + "\vprovider_id\x18\x01 \x01(\tB\x03\xe0A\x02R\n" + + "providerId\x12>\n" + + "\x06config\x18\x02 \x01(\v2!.memos.api.v1.TranscriptionConfigB\x03\xe0A\x02R\x06config\x12;\n" + + "\x05audio\x18\x03 \x01(\v2 .memos.api.v1.TranscriptionAudioB\x03\xe0A\x02R\x05audio\"n\n" + + "\x13TranscriptionConfig\x12\x19\n" + + "\x05model\x18\x01 \x01(\tB\x03\xe0A\x01R\x05model\x12\x1b\n" + + "\x06prompt\x18\x02 \x01(\tB\x03\xe0A\x01R\x06prompt\x12\x1f\n" + + "\blanguage\x18\x03 \x01(\tB\x03\xe0A\x01R\blanguage\"\x9c\x01\n" + + "\x12TranscriptionAudio\x12\x1f\n" + + "\acontent\x18\x01 \x01(\fB\x03\xe0A\x04H\x00R\acontent\x12\x12\n" + + "\x03uri\x18\x02 \x01(\tH\x00R\x03uri\x12\x1f\n" + + "\bfilename\x18\x03 \x01(\tB\x03\xe0A\x01R\bfilename\x12&\n" + + "\fcontent_type\x18\x04 \x01(\tB\x03\xe0A\x01R\vcontentTypeB\b\n" + + "\x06source\"(\n" + + "\x12TranscribeResponse\x12\x12\n" + + "\x04text\x18\x01 \x01(\tR\x04text2\x9a\x01\n" + + "\tAIService\x12\x8c\x01\n" + + "\n" + + "Transcribe\x12\x1f.memos.api.v1.TranscribeRequest\x1a .memos.api.v1.TranscribeResponse\";\xdaA\x18provider_id,config,audio\x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/ai:transcribeB\xa6\x01\n" + + "\x10com.memos.api.v1B\x0eAiServiceProtoP\x01Z0github.com/usememos/memos/proto/gen/api/v1;apiv1\xa2\x02\x03MAX\xaa\x02\fMemos.Api.V1\xca\x02\fMemos\\Api\\V1\xe2\x02\x18Memos\\Api\\V1\\GPBMetadata\xea\x02\x0eMemos::Api::V1b\x06proto3" + +var ( + file_api_v1_ai_service_proto_rawDescOnce sync.Once + file_api_v1_ai_service_proto_rawDescData []byte +) + +func file_api_v1_ai_service_proto_rawDescGZIP() []byte { + file_api_v1_ai_service_proto_rawDescOnce.Do(func() { + file_api_v1_ai_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_api_v1_ai_service_proto_rawDesc), len(file_api_v1_ai_service_proto_rawDesc))) + }) + return file_api_v1_ai_service_proto_rawDescData +} + +var file_api_v1_ai_service_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_api_v1_ai_service_proto_goTypes = []any{ + (*TranscribeRequest)(nil), // 0: memos.api.v1.TranscribeRequest + (*TranscriptionConfig)(nil), // 1: memos.api.v1.TranscriptionConfig + (*TranscriptionAudio)(nil), // 2: memos.api.v1.TranscriptionAudio + (*TranscribeResponse)(nil), // 3: memos.api.v1.TranscribeResponse +} +var file_api_v1_ai_service_proto_depIdxs = []int32{ + 1, // 0: memos.api.v1.TranscribeRequest.config:type_name -> memos.api.v1.TranscriptionConfig + 2, // 1: memos.api.v1.TranscribeRequest.audio:type_name -> memos.api.v1.TranscriptionAudio + 0, // 2: memos.api.v1.AIService.Transcribe:input_type -> memos.api.v1.TranscribeRequest + 3, // 3: memos.api.v1.AIService.Transcribe:output_type -> memos.api.v1.TranscribeResponse + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_api_v1_ai_service_proto_init() } +func file_api_v1_ai_service_proto_init() { + if File_api_v1_ai_service_proto != nil { + return + } + file_api_v1_ai_service_proto_msgTypes[2].OneofWrappers = []any{ + (*TranscriptionAudio_Content)(nil), + (*TranscriptionAudio_Uri)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_v1_ai_service_proto_rawDesc), len(file_api_v1_ai_service_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_api_v1_ai_service_proto_goTypes, + DependencyIndexes: file_api_v1_ai_service_proto_depIdxs, + MessageInfos: file_api_v1_ai_service_proto_msgTypes, + }.Build() + File_api_v1_ai_service_proto = out.File + file_api_v1_ai_service_proto_goTypes = nil + file_api_v1_ai_service_proto_depIdxs = nil +} diff --git a/proto/gen/api/v1/ai_service.pb.gw.go b/proto/gen/api/v1/ai_service.pb.gw.go new file mode 100644 index 000000000..c38e17301 --- /dev/null +++ b/proto/gen/api/v1/ai_service.pb.gw.go @@ -0,0 +1,157 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: api/v1/ai_service.proto + +/* +Package apiv1 is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package apiv1 + +import ( + "context" + "errors" + "io" + "net/http" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// Suppress "imported and not used" errors +var ( + _ codes.Code + _ io.Reader + _ status.Status + _ = errors.New + _ = runtime.String + _ = utilities.NewDoubleArray + _ = metadata.Join +) + +func request_AIService_Transcribe_0(ctx context.Context, marshaler runtime.Marshaler, client AIServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq TranscribeRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.Transcribe(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_AIService_Transcribe_0(ctx context.Context, marshaler runtime.Marshaler, server AIServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq TranscribeRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.Transcribe(ctx, &protoReq) + return msg, metadata, err +} + +// RegisterAIServiceHandlerServer registers the http handlers for service AIService to "mux". +// UnaryRPC :call AIServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterAIServiceHandlerFromEndpoint instead. +// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call. +func RegisterAIServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server AIServiceServer) error { + mux.Handle(http.MethodPost, pattern_AIService_Transcribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.AIService/Transcribe", runtime.WithHTTPPathPattern("/api/v1/ai:transcribe")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_AIService_Transcribe_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_AIService_Transcribe_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + + return nil +} + +// RegisterAIServiceHandlerFromEndpoint is same as RegisterAIServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterAIServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.NewClient(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + return RegisterAIServiceHandler(ctx, mux, conn) +} + +// RegisterAIServiceHandler registers the http handlers for service AIService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterAIServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterAIServiceHandlerClient(ctx, mux, NewAIServiceClient(conn)) +} + +// RegisterAIServiceHandlerClient registers the http handlers for service AIService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "AIServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "AIServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "AIServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares. +func RegisterAIServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client AIServiceClient) error { + mux.Handle(http.MethodPost, pattern_AIService_Transcribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.AIService/Transcribe", runtime.WithHTTPPathPattern("/api/v1/ai:transcribe")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_AIService_Transcribe_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_AIService_Transcribe_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + return nil +} + +var ( + pattern_AIService_Transcribe_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "ai"}, "transcribe")) +) + +var ( + forward_AIService_Transcribe_0 = runtime.ForwardResponseMessage +) diff --git a/proto/gen/api/v1/ai_service_grpc.pb.go b/proto/gen/api/v1/ai_service_grpc.pb.go new file mode 100644 index 000000000..d5f350096 --- /dev/null +++ b/proto/gen/api/v1/ai_service_grpc.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc (unknown) +// source: api/v1/ai_service.proto + +package apiv1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + AIService_Transcribe_FullMethodName = "/memos.api.v1.AIService/Transcribe" +) + +// AIServiceClient is the client API for AIService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AIServiceClient interface { + // Transcribe transcribes an audio file using an instance AI provider. + Transcribe(ctx context.Context, in *TranscribeRequest, opts ...grpc.CallOption) (*TranscribeResponse, error) +} + +type aIServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewAIServiceClient(cc grpc.ClientConnInterface) AIServiceClient { + return &aIServiceClient{cc} +} + +func (c *aIServiceClient) Transcribe(ctx context.Context, in *TranscribeRequest, opts ...grpc.CallOption) (*TranscribeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(TranscribeResponse) + err := c.cc.Invoke(ctx, AIService_Transcribe_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// AIServiceServer is the server API for AIService service. +// All implementations must embed UnimplementedAIServiceServer +// for forward compatibility. +type AIServiceServer interface { + // Transcribe transcribes an audio file using an instance AI provider. + Transcribe(context.Context, *TranscribeRequest) (*TranscribeResponse, error) + mustEmbedUnimplementedAIServiceServer() +} + +// UnimplementedAIServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedAIServiceServer struct{} + +func (UnimplementedAIServiceServer) Transcribe(context.Context, *TranscribeRequest) (*TranscribeResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Transcribe not implemented") +} +func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {} +func (UnimplementedAIServiceServer) testEmbeddedByValue() {} + +// UnsafeAIServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AIServiceServer will +// result in compilation errors. +type UnsafeAIServiceServer interface { + mustEmbedUnimplementedAIServiceServer() +} + +func RegisterAIServiceServer(s grpc.ServiceRegistrar, srv AIServiceServer) { + // If the following call panics, it indicates UnimplementedAIServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&AIService_ServiceDesc, srv) +} + +func _AIService_Transcribe_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TranscribeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AIServiceServer).Transcribe(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AIService_Transcribe_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AIServiceServer).Transcribe(ctx, req.(*TranscribeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var AIService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "memos.api.v1.AIService", + HandlerType: (*AIServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Transcribe", + Handler: _AIService_Transcribe_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "api/v1/ai_service.proto", +} diff --git a/proto/gen/api/v1/apiv1connect/ai_service.connect.go b/proto/gen/api/v1/apiv1connect/ai_service.connect.go new file mode 100644 index 000000000..5d992f9bd --- /dev/null +++ b/proto/gen/api/v1/apiv1connect/ai_service.connect.go @@ -0,0 +1,110 @@ +// Code generated by protoc-gen-connect-go. DO NOT EDIT. +// +// Source: api/v1/ai_service.proto + +package apiv1connect + +import ( + connect "connectrpc.com/connect" + context "context" + errors "errors" + v1 "github.com/usememos/memos/proto/gen/api/v1" + http "net/http" + strings "strings" +) + +// This is a compile-time assertion to ensure that this generated file and the connect package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. +const _ = connect.IsAtLeastVersion1_13_0 + +const ( + // AIServiceName is the fully-qualified name of the AIService service. + AIServiceName = "memos.api.v1.AIService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // AIServiceTranscribeProcedure is the fully-qualified name of the AIService's Transcribe RPC. + AIServiceTranscribeProcedure = "/memos.api.v1.AIService/Transcribe" +) + +// AIServiceClient is a client for the memos.api.v1.AIService service. +type AIServiceClient interface { + // Transcribe transcribes an audio file using an instance AI provider. + Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) +} + +// NewAIServiceClient constructs a client for the memos.api.v1.AIService service. By default, it +// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends +// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or +// connect.WithGRPCWeb() options. +// +// The URL supplied here should be the base URL for the Connect or gRPC server (for example, +// http://api.acme.com or https://acme.com/grpc). +func NewAIServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) AIServiceClient { + baseURL = strings.TrimRight(baseURL, "/") + aIServiceMethods := v1.File_api_v1_ai_service_proto.Services().ByName("AIService").Methods() + return &aIServiceClient{ + transcribe: connect.NewClient[v1.TranscribeRequest, v1.TranscribeResponse]( + httpClient, + baseURL+AIServiceTranscribeProcedure, + connect.WithSchema(aIServiceMethods.ByName("Transcribe")), + connect.WithClientOptions(opts...), + ), + } +} + +// aIServiceClient implements AIServiceClient. +type aIServiceClient struct { + transcribe *connect.Client[v1.TranscribeRequest, v1.TranscribeResponse] +} + +// Transcribe calls memos.api.v1.AIService.Transcribe. +func (c *aIServiceClient) Transcribe(ctx context.Context, req *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) { + return c.transcribe.CallUnary(ctx, req) +} + +// AIServiceHandler is an implementation of the memos.api.v1.AIService service. +type AIServiceHandler interface { + // Transcribe transcribes an audio file using an instance AI provider. + Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) +} + +// NewAIServiceHandler builds an HTTP handler from the service implementation. It returns the path +// on which to mount the handler and the handler itself. +// +// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf +// and JSON codecs. They also support gzip compression. +func NewAIServiceHandler(svc AIServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + aIServiceMethods := v1.File_api_v1_ai_service_proto.Services().ByName("AIService").Methods() + aIServiceTranscribeHandler := connect.NewUnaryHandler( + AIServiceTranscribeProcedure, + svc.Transcribe, + connect.WithSchema(aIServiceMethods.ByName("Transcribe")), + connect.WithHandlerOptions(opts...), + ) + return "/memos.api.v1.AIService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case AIServiceTranscribeProcedure: + aIServiceTranscribeHandler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// UnimplementedAIServiceHandler returns CodeUnimplemented from all methods. +type UnimplementedAIServiceHandler struct{} + +func (UnimplementedAIServiceHandler) Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("memos.api.v1.AIService.Transcribe is not implemented")) +} diff --git a/proto/gen/api/v1/instance_service.pb.go b/proto/gen/api/v1/instance_service.pb.go index be77651a0..e85fc8ec6 100644 --- a/proto/gen/api/v1/instance_service.pb.go +++ b/proto/gen/api/v1/instance_service.pb.go @@ -39,6 +39,8 @@ const ( InstanceSetting_TAGS InstanceSetting_Key = 4 // NOTIFICATION is the key for notification transport settings. InstanceSetting_NOTIFICATION InstanceSetting_Key = 5 + // AI is the key for AI provider settings. + InstanceSetting_AI InstanceSetting_Key = 6 ) // Enum value maps for InstanceSetting_Key. @@ -50,6 +52,7 @@ var ( 3: "MEMO_RELATED", 4: "TAGS", 5: "NOTIFICATION", + 6: "AI", } InstanceSetting_Key_value = map[string]int32{ "KEY_UNSPECIFIED": 0, @@ -58,6 +61,7 @@ var ( "MEMO_RELATED": 3, "TAGS": 4, "NOTIFICATION": 5, + "AI": 6, } ) @@ -88,6 +92,62 @@ func (InstanceSetting_Key) EnumDescriptor() ([]byte, []int) { return file_api_v1_instance_service_proto_rawDescGZIP(), []int{2, 0} } +// AIProviderType is the provider implementation type. +type InstanceSetting_AIProviderType int32 + +const ( + InstanceSetting_AI_PROVIDER_TYPE_UNSPECIFIED InstanceSetting_AIProviderType = 0 + InstanceSetting_OPENAI InstanceSetting_AIProviderType = 1 + InstanceSetting_OPENAI_COMPATIBLE InstanceSetting_AIProviderType = 2 + InstanceSetting_ANTHROPIC InstanceSetting_AIProviderType = 3 + InstanceSetting_GEMINI InstanceSetting_AIProviderType = 4 +) + +// Enum value maps for InstanceSetting_AIProviderType. +var ( + InstanceSetting_AIProviderType_name = map[int32]string{ + 0: "AI_PROVIDER_TYPE_UNSPECIFIED", + 1: "OPENAI", + 2: "OPENAI_COMPATIBLE", + 3: "ANTHROPIC", + 4: "GEMINI", + } + InstanceSetting_AIProviderType_value = map[string]int32{ + "AI_PROVIDER_TYPE_UNSPECIFIED": 0, + "OPENAI": 1, + "OPENAI_COMPATIBLE": 2, + "ANTHROPIC": 3, + "GEMINI": 4, + } +) + +func (x InstanceSetting_AIProviderType) Enum() *InstanceSetting_AIProviderType { + p := new(InstanceSetting_AIProviderType) + *p = x + return p +} + +func (x InstanceSetting_AIProviderType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (InstanceSetting_AIProviderType) Descriptor() protoreflect.EnumDescriptor { + return file_api_v1_instance_service_proto_enumTypes[1].Descriptor() +} + +func (InstanceSetting_AIProviderType) Type() protoreflect.EnumType { + return &file_api_v1_instance_service_proto_enumTypes[1] +} + +func (x InstanceSetting_AIProviderType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use InstanceSetting_AIProviderType.Descriptor instead. +func (InstanceSetting_AIProviderType) EnumDescriptor() ([]byte, []int) { + return file_api_v1_instance_service_proto_rawDescGZIP(), []int{2, 1} +} + // Storage type enumeration for different storage backends. type InstanceSetting_StorageSetting_StorageType int32 @@ -128,11 +188,11 @@ func (x InstanceSetting_StorageSetting_StorageType) String() string { } func (InstanceSetting_StorageSetting_StorageType) Descriptor() protoreflect.EnumDescriptor { - return file_api_v1_instance_service_proto_enumTypes[1].Descriptor() + return file_api_v1_instance_service_proto_enumTypes[2].Descriptor() } func (InstanceSetting_StorageSetting_StorageType) Type() protoreflect.EnumType { - return &file_api_v1_instance_service_proto_enumTypes[1] + return &file_api_v1_instance_service_proto_enumTypes[2] } func (x InstanceSetting_StorageSetting_StorageType) Number() protoreflect.EnumNumber { @@ -268,6 +328,7 @@ type InstanceSetting struct { // *InstanceSetting_MemoRelatedSetting_ // *InstanceSetting_TagsSetting_ // *InstanceSetting_NotificationSetting_ + // *InstanceSetting_AiSetting Value isInstanceSetting_Value `protobuf_oneof:"value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -362,6 +423,15 @@ func (x *InstanceSetting) GetNotificationSetting() *InstanceSetting_Notification return nil } +func (x *InstanceSetting) GetAiSetting() *InstanceSetting_AISetting { + if x != nil { + if x, ok := x.Value.(*InstanceSetting_AiSetting); ok { + return x.AiSetting + } + } + return nil +} + type isInstanceSetting_Value interface { isInstanceSetting_Value() } @@ -386,6 +456,10 @@ type InstanceSetting_NotificationSetting_ struct { NotificationSetting *InstanceSetting_NotificationSetting `protobuf:"bytes,6,opt,name=notification_setting,json=notificationSetting,proto3,oneof"` } +type InstanceSetting_AiSetting struct { + AiSetting *InstanceSetting_AISetting `protobuf:"bytes,7,opt,name=ai_setting,json=aiSetting,proto3,oneof"` +} + func (*InstanceSetting_GeneralSetting_) isInstanceSetting_Value() {} func (*InstanceSetting_StorageSetting_) isInstanceSetting_Value() {} @@ -396,6 +470,8 @@ func (*InstanceSetting_TagsSetting_) isInstanceSetting_Value() {} func (*InstanceSetting_NotificationSetting_) isInstanceSetting_Value() {} +func (*InstanceSetting_AiSetting) isInstanceSetting_Value() {} + // Request message for GetInstanceSetting method. type GetInstanceSettingRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -906,6 +982,164 @@ func (x *InstanceSetting_NotificationSetting) GetEmail() *InstanceSetting_Notifi return nil } +// AI provider configuration settings. +type InstanceSetting_AISetting struct { + state protoimpl.MessageState `protogen:"open.v1"` + // providers is the list of AI provider configurations available instance-wide. + Providers []*InstanceSetting_AIProviderConfig `protobuf:"bytes,1,rep,name=providers,proto3" json:"providers,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstanceSetting_AISetting) Reset() { + *x = InstanceSetting_AISetting{} + mi := &file_api_v1_instance_service_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstanceSetting_AISetting) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstanceSetting_AISetting) ProtoMessage() {} + +func (x *InstanceSetting_AISetting) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_instance_service_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstanceSetting_AISetting.ProtoReflect.Descriptor instead. +func (*InstanceSetting_AISetting) Descriptor() ([]byte, []int) { + return file_api_v1_instance_service_proto_rawDescGZIP(), []int{2, 6} +} + +func (x *InstanceSetting_AISetting) GetProviders() []*InstanceSetting_AIProviderConfig { + if x != nil { + return x.Providers + } + return nil +} + +// AIProviderConfig represents one callable AI provider connection. +type InstanceSetting_AIProviderConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Title string `protobuf:"bytes,2,opt,name=title,proto3" json:"title,omitempty"` + Type InstanceSetting_AIProviderType `protobuf:"varint,3,opt,name=type,proto3,enum=memos.api.v1.InstanceSetting_AIProviderType" json:"type,omitempty"` + Endpoint string `protobuf:"bytes,4,opt,name=endpoint,proto3" json:"endpoint,omitempty"` + // api_key is write-only and is never returned by GetInstanceSetting. + ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` + Models []string `protobuf:"bytes,6,rep,name=models,proto3" json:"models,omitempty"` + DefaultModel string `protobuf:"bytes,7,opt,name=default_model,json=defaultModel,proto3" json:"default_model,omitempty"` + // api_key_set indicates whether an API key is stored for this provider. + ApiKeySet bool `protobuf:"varint,8,opt,name=api_key_set,json=apiKeySet,proto3" json:"api_key_set,omitempty"` + // api_key_hint is a masked hint for the stored API key. + ApiKeyHint string `protobuf:"bytes,9,opt,name=api_key_hint,json=apiKeyHint,proto3" json:"api_key_hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstanceSetting_AIProviderConfig) Reset() { + *x = InstanceSetting_AIProviderConfig{} + mi := &file_api_v1_instance_service_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstanceSetting_AIProviderConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstanceSetting_AIProviderConfig) ProtoMessage() {} + +func (x *InstanceSetting_AIProviderConfig) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_instance_service_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstanceSetting_AIProviderConfig.ProtoReflect.Descriptor instead. +func (*InstanceSetting_AIProviderConfig) Descriptor() ([]byte, []int) { + return file_api_v1_instance_service_proto_rawDescGZIP(), []int{2, 7} +} + +func (x *InstanceSetting_AIProviderConfig) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *InstanceSetting_AIProviderConfig) GetTitle() string { + if x != nil { + return x.Title + } + return "" +} + +func (x *InstanceSetting_AIProviderConfig) GetType() InstanceSetting_AIProviderType { + if x != nil { + return x.Type + } + return InstanceSetting_AI_PROVIDER_TYPE_UNSPECIFIED +} + +func (x *InstanceSetting_AIProviderConfig) GetEndpoint() string { + if x != nil { + return x.Endpoint + } + return "" +} + +func (x *InstanceSetting_AIProviderConfig) GetApiKey() string { + if x != nil { + return x.ApiKey + } + return "" +} + +func (x *InstanceSetting_AIProviderConfig) GetModels() []string { + if x != nil { + return x.Models + } + return nil +} + +func (x *InstanceSetting_AIProviderConfig) GetDefaultModel() string { + if x != nil { + return x.DefaultModel + } + return "" +} + +func (x *InstanceSetting_AIProviderConfig) GetApiKeySet() bool { + if x != nil { + return x.ApiKeySet + } + return false +} + +func (x *InstanceSetting_AIProviderConfig) GetApiKeyHint() string { + if x != nil { + return x.ApiKeyHint + } + return "" +} + // Custom profile configuration for instance branding. type InstanceSetting_GeneralSetting_CustomProfile struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -918,7 +1152,7 @@ type InstanceSetting_GeneralSetting_CustomProfile struct { func (x *InstanceSetting_GeneralSetting_CustomProfile) Reset() { *x = InstanceSetting_GeneralSetting_CustomProfile{} - mi := &file_api_v1_instance_service_proto_msgTypes[11] + mi := &file_api_v1_instance_service_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -930,7 +1164,7 @@ func (x *InstanceSetting_GeneralSetting_CustomProfile) String() string { func (*InstanceSetting_GeneralSetting_CustomProfile) ProtoMessage() {} func (x *InstanceSetting_GeneralSetting_CustomProfile) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_instance_service_proto_msgTypes[11] + mi := &file_api_v1_instance_service_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -983,7 +1217,7 @@ type InstanceSetting_StorageSetting_S3Config struct { func (x *InstanceSetting_StorageSetting_S3Config) Reset() { *x = InstanceSetting_StorageSetting_S3Config{} - mi := &file_api_v1_instance_service_proto_msgTypes[12] + mi := &file_api_v1_instance_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -995,7 +1229,7 @@ func (x *InstanceSetting_StorageSetting_S3Config) String() string { func (*InstanceSetting_StorageSetting_S3Config) ProtoMessage() {} func (x *InstanceSetting_StorageSetting_S3Config) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_instance_service_proto_msgTypes[12] + mi := &file_api_v1_instance_service_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1072,7 +1306,7 @@ type InstanceSetting_NotificationSetting_EmailSetting struct { func (x *InstanceSetting_NotificationSetting_EmailSetting) Reset() { *x = InstanceSetting_NotificationSetting_EmailSetting{} - mi := &file_api_v1_instance_service_proto_msgTypes[14] + mi := &file_api_v1_instance_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1084,7 +1318,7 @@ func (x *InstanceSetting_NotificationSetting_EmailSetting) String() string { func (*InstanceSetting_NotificationSetting_EmailSetting) ProtoMessage() {} func (x *InstanceSetting_NotificationSetting_EmailSetting) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_instance_service_proto_msgTypes[14] + mi := &file_api_v1_instance_service_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1180,14 +1414,16 @@ const file_api_v1_instance_service_proto_rawDesc = "" + "\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" + "\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" + "\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" + - "\x19GetInstanceProfileRequest\"\x83\x16\n" + + "\x19GetInstanceProfileRequest\"\xe2\x1a\n" + "\x0fInstanceSetting\x12\x17\n" + "\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" + "\x0fgeneral_setting\x18\x02 \x01(\v2,.memos.api.v1.InstanceSetting.GeneralSettingH\x00R\x0egeneralSetting\x12W\n" + "\x0fstorage_setting\x18\x03 \x01(\v2,.memos.api.v1.InstanceSetting.StorageSettingH\x00R\x0estorageSetting\x12d\n" + "\x14memo_related_setting\x18\x04 \x01(\v20.memos.api.v1.InstanceSetting.MemoRelatedSettingH\x00R\x12memoRelatedSetting\x12N\n" + "\ftags_setting\x18\x05 \x01(\v2).memos.api.v1.InstanceSetting.TagsSettingH\x00R\vtagsSetting\x12f\n" + - "\x14notification_setting\x18\x06 \x01(\v21.memos.api.v1.InstanceSetting.NotificationSettingH\x00R\x13notificationSetting\x1a\xca\x04\n" + + "\x14notification_setting\x18\x06 \x01(\v21.memos.api.v1.InstanceSetting.NotificationSettingH\x00R\x13notificationSetting\x12H\n" + + "\n" + + "ai_setting\x18\a \x01(\v2'.memos.api.v1.InstanceSetting.AISettingH\x00R\taiSetting\x1a\xca\x04\n" + "\x0eGeneralSetting\x12<\n" + "\x1adisallow_user_registration\x18\x02 \x01(\bR\x18disallowUserRegistration\x124\n" + "\x16disallow_password_auth\x18\x03 \x01(\bR\x14disallowPasswordAuth\x12+\n" + @@ -1245,14 +1481,36 @@ const file_api_v1_instance_service_proto_rawDesc = "" + "\breply_to\x18\b \x01(\tR\areplyTo\x12\x17\n" + "\ause_tls\x18\t \x01(\bR\x06useTls\x12\x17\n" + "\ause_ssl\x18\n" + - " \x01(\bR\x06useSsl\"b\n" + + " \x01(\bR\x06useSsl\x1aY\n" + + "\tAISetting\x12L\n" + + "\tproviders\x18\x01 \x03(\v2..memos.api.v1.InstanceSetting.AIProviderConfigR\tproviders\x1a\xbd\x02\n" + + "\x10AIProviderConfig\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + + "\x05title\x18\x02 \x01(\tR\x05title\x12@\n" + + "\x04type\x18\x03 \x01(\x0e2,.memos.api.v1.InstanceSetting.AIProviderTypeR\x04type\x12\x1a\n" + + "\bendpoint\x18\x04 \x01(\tR\bendpoint\x12\x1c\n" + + "\aapi_key\x18\x05 \x01(\tB\x03\xe0A\x04R\x06apiKey\x12\x16\n" + + "\x06models\x18\x06 \x03(\tR\x06models\x12#\n" + + "\rdefault_model\x18\a \x01(\tR\fdefaultModel\x12#\n" + + "\vapi_key_set\x18\b \x01(\bB\x03\xe0A\x03R\tapiKeySet\x12%\n" + + "\fapi_key_hint\x18\t \x01(\tB\x03\xe0A\x03R\n" + + "apiKeyHint\"j\n" + "\x03Key\x12\x13\n" + "\x0fKEY_UNSPECIFIED\x10\x00\x12\v\n" + "\aGENERAL\x10\x01\x12\v\n" + "\aSTORAGE\x10\x02\x12\x10\n" + "\fMEMO_RELATED\x10\x03\x12\b\n" + "\x04TAGS\x10\x04\x12\x10\n" + - "\fNOTIFICATION\x10\x05:a\xeaA^\n" + + "\fNOTIFICATION\x10\x05\x12\x06\n" + + "\x02AI\x10\x06\"p\n" + + "\x0eAIProviderType\x12 \n" + + "\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" + + "\n" + + "\x06OPENAI\x10\x01\x12\x15\n" + + "\x11OPENAI_COMPATIBLE\x10\x02\x12\r\n" + + "\tANTHROPIC\x10\x03\x12\n" + + "\n" + + "\x06GEMINI\x10\x04:a\xeaA^\n" + "\x1cmemos.api.v1/InstanceSetting\x12\x1binstance/settings/{setting}*\x10instanceSettings2\x0finstanceSettingB\a\n" + "\x05value\"U\n" + "\x19GetInstanceSettingRequest\x128\n" + @@ -1280,57 +1538,63 @@ func file_api_v1_instance_service_proto_rawDescGZIP() []byte { return file_api_v1_instance_service_proto_rawDescData } -var file_api_v1_instance_service_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_api_v1_instance_service_proto_msgTypes = make([]protoimpl.MessageInfo, 15) +var file_api_v1_instance_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) +var file_api_v1_instance_service_proto_msgTypes = make([]protoimpl.MessageInfo, 17) var file_api_v1_instance_service_proto_goTypes = []any{ (InstanceSetting_Key)(0), // 0: memos.api.v1.InstanceSetting.Key - (InstanceSetting_StorageSetting_StorageType)(0), // 1: memos.api.v1.InstanceSetting.StorageSetting.StorageType - (*InstanceProfile)(nil), // 2: memos.api.v1.InstanceProfile - (*GetInstanceProfileRequest)(nil), // 3: memos.api.v1.GetInstanceProfileRequest - (*InstanceSetting)(nil), // 4: memos.api.v1.InstanceSetting - (*GetInstanceSettingRequest)(nil), // 5: memos.api.v1.GetInstanceSettingRequest - (*UpdateInstanceSettingRequest)(nil), // 6: memos.api.v1.UpdateInstanceSettingRequest - (*InstanceSetting_GeneralSetting)(nil), // 7: memos.api.v1.InstanceSetting.GeneralSetting - (*InstanceSetting_StorageSetting)(nil), // 8: memos.api.v1.InstanceSetting.StorageSetting - (*InstanceSetting_MemoRelatedSetting)(nil), // 9: memos.api.v1.InstanceSetting.MemoRelatedSetting - (*InstanceSetting_TagMetadata)(nil), // 10: memos.api.v1.InstanceSetting.TagMetadata - (*InstanceSetting_TagsSetting)(nil), // 11: memos.api.v1.InstanceSetting.TagsSetting - (*InstanceSetting_NotificationSetting)(nil), // 12: memos.api.v1.InstanceSetting.NotificationSetting - (*InstanceSetting_GeneralSetting_CustomProfile)(nil), // 13: memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile - (*InstanceSetting_StorageSetting_S3Config)(nil), // 14: memos.api.v1.InstanceSetting.StorageSetting.S3Config - nil, // 15: memos.api.v1.InstanceSetting.TagsSetting.TagsEntry - (*InstanceSetting_NotificationSetting_EmailSetting)(nil), // 16: memos.api.v1.InstanceSetting.NotificationSetting.EmailSetting - (*User)(nil), // 17: memos.api.v1.User - (*fieldmaskpb.FieldMask)(nil), // 18: google.protobuf.FieldMask - (*color.Color)(nil), // 19: google.type.Color + (InstanceSetting_AIProviderType)(0), // 1: memos.api.v1.InstanceSetting.AIProviderType + (InstanceSetting_StorageSetting_StorageType)(0), // 2: memos.api.v1.InstanceSetting.StorageSetting.StorageType + (*InstanceProfile)(nil), // 3: memos.api.v1.InstanceProfile + (*GetInstanceProfileRequest)(nil), // 4: memos.api.v1.GetInstanceProfileRequest + (*InstanceSetting)(nil), // 5: memos.api.v1.InstanceSetting + (*GetInstanceSettingRequest)(nil), // 6: memos.api.v1.GetInstanceSettingRequest + (*UpdateInstanceSettingRequest)(nil), // 7: memos.api.v1.UpdateInstanceSettingRequest + (*InstanceSetting_GeneralSetting)(nil), // 8: memos.api.v1.InstanceSetting.GeneralSetting + (*InstanceSetting_StorageSetting)(nil), // 9: memos.api.v1.InstanceSetting.StorageSetting + (*InstanceSetting_MemoRelatedSetting)(nil), // 10: memos.api.v1.InstanceSetting.MemoRelatedSetting + (*InstanceSetting_TagMetadata)(nil), // 11: memos.api.v1.InstanceSetting.TagMetadata + (*InstanceSetting_TagsSetting)(nil), // 12: memos.api.v1.InstanceSetting.TagsSetting + (*InstanceSetting_NotificationSetting)(nil), // 13: memos.api.v1.InstanceSetting.NotificationSetting + (*InstanceSetting_AISetting)(nil), // 14: memos.api.v1.InstanceSetting.AISetting + (*InstanceSetting_AIProviderConfig)(nil), // 15: memos.api.v1.InstanceSetting.AIProviderConfig + (*InstanceSetting_GeneralSetting_CustomProfile)(nil), // 16: memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile + (*InstanceSetting_StorageSetting_S3Config)(nil), // 17: memos.api.v1.InstanceSetting.StorageSetting.S3Config + nil, // 18: memos.api.v1.InstanceSetting.TagsSetting.TagsEntry + (*InstanceSetting_NotificationSetting_EmailSetting)(nil), // 19: memos.api.v1.InstanceSetting.NotificationSetting.EmailSetting + (*User)(nil), // 20: memos.api.v1.User + (*fieldmaskpb.FieldMask)(nil), // 21: google.protobuf.FieldMask + (*color.Color)(nil), // 22: google.type.Color } var file_api_v1_instance_service_proto_depIdxs = []int32{ - 17, // 0: memos.api.v1.InstanceProfile.admin:type_name -> memos.api.v1.User - 7, // 1: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting - 8, // 2: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting - 9, // 3: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting - 11, // 4: memos.api.v1.InstanceSetting.tags_setting:type_name -> memos.api.v1.InstanceSetting.TagsSetting - 12, // 5: memos.api.v1.InstanceSetting.notification_setting:type_name -> memos.api.v1.InstanceSetting.NotificationSetting - 4, // 6: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting - 18, // 7: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask - 13, // 8: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile - 1, // 9: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType - 14, // 10: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config - 19, // 11: memos.api.v1.InstanceSetting.TagMetadata.background_color:type_name -> google.type.Color - 15, // 12: memos.api.v1.InstanceSetting.TagsSetting.tags:type_name -> memos.api.v1.InstanceSetting.TagsSetting.TagsEntry - 16, // 13: memos.api.v1.InstanceSetting.NotificationSetting.email:type_name -> memos.api.v1.InstanceSetting.NotificationSetting.EmailSetting - 10, // 14: memos.api.v1.InstanceSetting.TagsSetting.TagsEntry.value:type_name -> memos.api.v1.InstanceSetting.TagMetadata - 3, // 15: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest - 5, // 16: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest - 6, // 17: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest - 2, // 18: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile - 4, // 19: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting - 4, // 20: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting - 18, // [18:21] is the sub-list for method output_type - 15, // [15:18] is the sub-list for method input_type - 15, // [15:15] is the sub-list for extension type_name - 15, // [15:15] is the sub-list for extension extendee - 0, // [0:15] is the sub-list for field type_name + 20, // 0: memos.api.v1.InstanceProfile.admin:type_name -> memos.api.v1.User + 8, // 1: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting + 9, // 2: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting + 10, // 3: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting + 12, // 4: memos.api.v1.InstanceSetting.tags_setting:type_name -> memos.api.v1.InstanceSetting.TagsSetting + 13, // 5: memos.api.v1.InstanceSetting.notification_setting:type_name -> memos.api.v1.InstanceSetting.NotificationSetting + 14, // 6: memos.api.v1.InstanceSetting.ai_setting:type_name -> memos.api.v1.InstanceSetting.AISetting + 5, // 7: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting + 21, // 8: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask + 16, // 9: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile + 2, // 10: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType + 17, // 11: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config + 22, // 12: memos.api.v1.InstanceSetting.TagMetadata.background_color:type_name -> google.type.Color + 18, // 13: memos.api.v1.InstanceSetting.TagsSetting.tags:type_name -> memos.api.v1.InstanceSetting.TagsSetting.TagsEntry + 19, // 14: memos.api.v1.InstanceSetting.NotificationSetting.email:type_name -> memos.api.v1.InstanceSetting.NotificationSetting.EmailSetting + 15, // 15: memos.api.v1.InstanceSetting.AISetting.providers:type_name -> memos.api.v1.InstanceSetting.AIProviderConfig + 1, // 16: memos.api.v1.InstanceSetting.AIProviderConfig.type:type_name -> memos.api.v1.InstanceSetting.AIProviderType + 11, // 17: memos.api.v1.InstanceSetting.TagsSetting.TagsEntry.value:type_name -> memos.api.v1.InstanceSetting.TagMetadata + 4, // 18: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest + 6, // 19: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest + 7, // 20: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest + 3, // 21: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile + 5, // 22: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting + 5, // 23: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting + 21, // [21:24] is the sub-list for method output_type + 18, // [18:21] is the sub-list for method input_type + 18, // [18:18] is the sub-list for extension type_name + 18, // [18:18] is the sub-list for extension extendee + 0, // [0:18] is the sub-list for field type_name } func init() { file_api_v1_instance_service_proto_init() } @@ -1345,14 +1609,15 @@ func file_api_v1_instance_service_proto_init() { (*InstanceSetting_MemoRelatedSetting_)(nil), (*InstanceSetting_TagsSetting_)(nil), (*InstanceSetting_NotificationSetting_)(nil), + (*InstanceSetting_AiSetting)(nil), } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_v1_instance_service_proto_rawDesc), len(file_api_v1_instance_service_proto_rawDesc)), - NumEnums: 2, - NumMessages: 15, + NumEnums: 3, + NumMessages: 17, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index d5373dd7d..3d8d32e83 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -6,6 +6,31 @@ info: title: "" version: 0.0.1 paths: + /api/v1/ai:transcribe: + post: + tags: + - AIService + description: Transcribe transcribes an audio file using an instance AI provider. + operationId: AIService_Transcribe + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TranscribeRequest' + required: true + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/TranscribeResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /api/v1/attachments: get: tags: @@ -2380,7 +2405,55 @@ components: $ref: '#/components/schemas/InstanceSetting_TagsSetting' notificationSetting: $ref: '#/components/schemas/InstanceSetting_NotificationSetting' + aiSetting: + $ref: '#/components/schemas/InstanceSetting_AISetting' description: An instance setting resource. + InstanceSetting_AIProviderConfig: + type: object + properties: + id: + type: string + title: + type: string + type: + enum: + - AI_PROVIDER_TYPE_UNSPECIFIED + - OPENAI + - OPENAI_COMPATIBLE + - ANTHROPIC + - GEMINI + type: string + format: enum + endpoint: + type: string + apiKey: + writeOnly: true + type: string + description: api_key is write-only and is never returned by GetInstanceSetting. + models: + type: array + items: + type: string + defaultModel: + type: string + apiKeySet: + readOnly: true + type: boolean + description: api_key_set indicates whether an API key is stored for this provider. + apiKeyHint: + readOnly: true + type: string + description: api_key_hint is a masked hint for the stored API key. + description: AIProviderConfig represents one callable AI provider connection. + InstanceSetting_AISetting: + type: object + properties: + providers: + type: array + items: + $ref: '#/components/schemas/InstanceSetting_AIProviderConfig' + description: providers is the list of AI provider configurations available instance-wide. + description: AI provider configuration settings. InstanceSetting_GeneralSetting: type: object properties: @@ -3144,6 +3217,59 @@ components: description: |- S3 configuration for cloud storage backend. Reference: https://developers.cloudflare.com/r2/examples/aws/aws-sdk-go/ + TranscribeRequest: + required: + - providerId + - config + - audio + type: object + properties: + providerId: + type: string + description: Required. The instance AI provider ID to use. + config: + allOf: + - $ref: '#/components/schemas/TranscriptionConfig' + description: Required. Transcription options. + audio: + allOf: + - $ref: '#/components/schemas/TranscriptionAudio' + description: Required. Audio input. + TranscribeResponse: + type: object + properties: + text: + type: string + description: The transcribed text. + TranscriptionAudio: + type: object + properties: + content: + writeOnly: true + type: string + description: Inline audio bytes. + format: bytes + uri: + type: string + description: URI for audio content. Reserved for future use. + filename: + type: string + description: Optional. The uploaded filename. + contentType: + type: string + description: Optional. The MIME type of the input audio. + TranscriptionConfig: + type: object + properties: + model: + type: string + description: Optional. The model to use. If empty, the provider's default model is used. + prompt: + type: string + description: Optional. A prompt to improve transcription quality. + language: + type: string + description: Optional. The language of the input audio. UpsertMemoReactionRequest: required: - name @@ -3419,6 +3545,7 @@ components: format: date-time description: UserWebhook represents a webhook owned by a user. tags: + - name: AIService - name: AttachmentService - name: AuthService - name: IdentityProviderService diff --git a/proto/gen/store/instance_setting.pb.go b/proto/gen/store/instance_setting.pb.go index b12889b40..195d4f5f3 100644 --- a/proto/gen/store/instance_setting.pb.go +++ b/proto/gen/store/instance_setting.pb.go @@ -38,6 +38,8 @@ const ( InstanceSettingKey_TAGS InstanceSettingKey = 5 // NOTIFICATION is the key for notification transport settings. InstanceSettingKey_NOTIFICATION InstanceSettingKey = 6 + // AI is the key for AI provider settings. + InstanceSettingKey_AI InstanceSettingKey = 7 ) // Enum value maps for InstanceSettingKey. @@ -50,6 +52,7 @@ var ( 4: "MEMO_RELATED", 5: "TAGS", 6: "NOTIFICATION", + 7: "AI", } InstanceSettingKey_value = map[string]int32{ "INSTANCE_SETTING_KEY_UNSPECIFIED": 0, @@ -59,6 +62,7 @@ var ( "MEMO_RELATED": 4, "TAGS": 5, "NOTIFICATION": 6, + "AI": 7, } ) @@ -89,6 +93,61 @@ func (InstanceSettingKey) EnumDescriptor() ([]byte, []int) { return file_store_instance_setting_proto_rawDescGZIP(), []int{0} } +type AIProviderType int32 + +const ( + AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED AIProviderType = 0 + AIProviderType_OPENAI AIProviderType = 1 + AIProviderType_OPENAI_COMPATIBLE AIProviderType = 2 + AIProviderType_ANTHROPIC AIProviderType = 3 + AIProviderType_GEMINI AIProviderType = 4 +) + +// Enum value maps for AIProviderType. +var ( + AIProviderType_name = map[int32]string{ + 0: "AI_PROVIDER_TYPE_UNSPECIFIED", + 1: "OPENAI", + 2: "OPENAI_COMPATIBLE", + 3: "ANTHROPIC", + 4: "GEMINI", + } + AIProviderType_value = map[string]int32{ + "AI_PROVIDER_TYPE_UNSPECIFIED": 0, + "OPENAI": 1, + "OPENAI_COMPATIBLE": 2, + "ANTHROPIC": 3, + "GEMINI": 4, + } +) + +func (x AIProviderType) Enum() *AIProviderType { + p := new(AIProviderType) + *p = x + return p +} + +func (x AIProviderType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (AIProviderType) Descriptor() protoreflect.EnumDescriptor { + return file_store_instance_setting_proto_enumTypes[1].Descriptor() +} + +func (AIProviderType) Type() protoreflect.EnumType { + return &file_store_instance_setting_proto_enumTypes[1] +} + +func (x AIProviderType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use AIProviderType.Descriptor instead. +func (AIProviderType) EnumDescriptor() ([]byte, []int) { + return file_store_instance_setting_proto_rawDescGZIP(), []int{1} +} + type InstanceStorageSetting_StorageType int32 const ( @@ -128,11 +187,11 @@ func (x InstanceStorageSetting_StorageType) String() string { } func (InstanceStorageSetting_StorageType) Descriptor() protoreflect.EnumDescriptor { - return file_store_instance_setting_proto_enumTypes[1].Descriptor() + return file_store_instance_setting_proto_enumTypes[2].Descriptor() } func (InstanceStorageSetting_StorageType) Type() protoreflect.EnumType { - return &file_store_instance_setting_proto_enumTypes[1] + return &file_store_instance_setting_proto_enumTypes[2] } func (x InstanceStorageSetting_StorageType) Number() protoreflect.EnumNumber { @@ -155,6 +214,7 @@ type InstanceSetting struct { // *InstanceSetting_MemoRelatedSetting // *InstanceSetting_TagsSetting // *InstanceSetting_NotificationSetting + // *InstanceSetting_AiSetting Value isInstanceSetting_Value `protobuf_oneof:"value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -258,6 +318,15 @@ func (x *InstanceSetting) GetNotificationSetting() *InstanceNotificationSetting return nil } +func (x *InstanceSetting) GetAiSetting() *InstanceAISetting { + if x != nil { + if x, ok := x.Value.(*InstanceSetting_AiSetting); ok { + return x.AiSetting + } + } + return nil +} + type isInstanceSetting_Value interface { isInstanceSetting_Value() } @@ -286,6 +355,10 @@ type InstanceSetting_NotificationSetting struct { NotificationSetting *InstanceNotificationSetting `protobuf:"bytes,7,opt,name=notification_setting,json=notificationSetting,proto3,oneof"` } +type InstanceSetting_AiSetting struct { + AiSetting *InstanceAISetting `protobuf:"bytes,8,opt,name=ai_setting,json=aiSetting,proto3,oneof"` +} + func (*InstanceSetting_BasicSetting) isInstanceSetting_Value() {} func (*InstanceSetting_GeneralSetting) isInstanceSetting_Value() {} @@ -298,6 +371,8 @@ func (*InstanceSetting_TagsSetting) isInstanceSetting_Value() {} func (*InstanceSetting_NotificationSetting) isInstanceSetting_Value() {} +func (*InstanceSetting_AiSetting) isInstanceSetting_Value() {} + type InstanceBasicSetting struct { state protoimpl.MessageState `protogen:"open.v1"` // The secret key for instance. Mainly used for session management. @@ -899,6 +974,144 @@ func (x *InstanceNotificationSetting) GetEmail() *InstanceNotificationSetting_Em return nil } +type InstanceAISetting struct { + state protoimpl.MessageState `protogen:"open.v1"` + // providers is the list of AI provider configurations available instance-wide. + Providers []*AIProviderConfig `protobuf:"bytes,1,rep,name=providers,proto3" json:"providers,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstanceAISetting) Reset() { + *x = InstanceAISetting{} + mi := &file_store_instance_setting_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstanceAISetting) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstanceAISetting) ProtoMessage() {} + +func (x *InstanceAISetting) ProtoReflect() protoreflect.Message { + mi := &file_store_instance_setting_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstanceAISetting.ProtoReflect.Descriptor instead. +func (*InstanceAISetting) Descriptor() ([]byte, []int) { + return file_store_instance_setting_proto_rawDescGZIP(), []int{10} +} + +func (x *InstanceAISetting) GetProviders() []*AIProviderConfig { + if x != nil { + return x.Providers + } + return nil +} + +type AIProviderConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Title string `protobuf:"bytes,2,opt,name=title,proto3" json:"title,omitempty"` + Type AIProviderType `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.AIProviderType" json:"type,omitempty"` + Endpoint string `protobuf:"bytes,4,opt,name=endpoint,proto3" json:"endpoint,omitempty"` + // api_key is write-only at the API layer and is required by the server to call providers. + ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` + Models []string `protobuf:"bytes,6,rep,name=models,proto3" json:"models,omitempty"` + DefaultModel string `protobuf:"bytes,7,opt,name=default_model,json=defaultModel,proto3" json:"default_model,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AIProviderConfig) Reset() { + *x = AIProviderConfig{} + mi := &file_store_instance_setting_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AIProviderConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AIProviderConfig) ProtoMessage() {} + +func (x *AIProviderConfig) ProtoReflect() protoreflect.Message { + mi := &file_store_instance_setting_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AIProviderConfig.ProtoReflect.Descriptor instead. +func (*AIProviderConfig) Descriptor() ([]byte, []int) { + return file_store_instance_setting_proto_rawDescGZIP(), []int{11} +} + +func (x *AIProviderConfig) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *AIProviderConfig) GetTitle() string { + if x != nil { + return x.Title + } + return "" +} + +func (x *AIProviderConfig) GetType() AIProviderType { + if x != nil { + return x.Type + } + return AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED +} + +func (x *AIProviderConfig) GetEndpoint() string { + if x != nil { + return x.Endpoint + } + return "" +} + +func (x *AIProviderConfig) GetApiKey() string { + if x != nil { + return x.ApiKey + } + return "" +} + +func (x *AIProviderConfig) GetModels() []string { + if x != nil { + return x.Models + } + return nil +} + +func (x *AIProviderConfig) GetDefaultModel() string { + if x != nil { + return x.DefaultModel + } + return "" +} + type InstanceNotificationSetting_EmailSetting struct { state protoimpl.MessageState `protogen:"open.v1"` Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` @@ -917,7 +1130,7 @@ type InstanceNotificationSetting_EmailSetting struct { func (x *InstanceNotificationSetting_EmailSetting) Reset() { *x = InstanceNotificationSetting_EmailSetting{} - mi := &file_store_instance_setting_proto_msgTypes[11] + mi := &file_store_instance_setting_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -929,7 +1142,7 @@ func (x *InstanceNotificationSetting_EmailSetting) String() string { func (*InstanceNotificationSetting_EmailSetting) ProtoMessage() {} func (x *InstanceNotificationSetting_EmailSetting) ProtoReflect() protoreflect.Message { - mi := &file_store_instance_setting_proto_msgTypes[11] + mi := &file_store_instance_setting_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1019,7 +1232,7 @@ var File_store_instance_setting_proto protoreflect.FileDescriptor const file_store_instance_setting_proto_rawDesc = "" + "\n" + - "\x1cstore/instance_setting.proto\x12\vmemos.store\x1a\x17google/type/color.proto\"\xba\x04\n" + + "\x1cstore/instance_setting.proto\x12\vmemos.store\x1a\x17google/type/color.proto\"\xfb\x04\n" + "\x0fInstanceSetting\x121\n" + "\x03key\x18\x01 \x01(\x0e2\x1f.memos.store.InstanceSettingKeyR\x03key\x12H\n" + "\rbasic_setting\x18\x02 \x01(\v2!.memos.store.InstanceBasicSettingH\x00R\fbasicSetting\x12N\n" + @@ -1027,7 +1240,9 @@ const file_store_instance_setting_proto_rawDesc = "" + "\x0fstorage_setting\x18\x04 \x01(\v2#.memos.store.InstanceStorageSettingH\x00R\x0estorageSetting\x12[\n" + "\x14memo_related_setting\x18\x05 \x01(\v2'.memos.store.InstanceMemoRelatedSettingH\x00R\x12memoRelatedSetting\x12E\n" + "\ftags_setting\x18\x06 \x01(\v2 .memos.store.InstanceTagsSettingH\x00R\vtagsSetting\x12]\n" + - "\x14notification_setting\x18\a \x01(\v2(.memos.store.InstanceNotificationSettingH\x00R\x13notificationSettingB\a\n" + + "\x14notification_setting\x18\a \x01(\v2(.memos.store.InstanceNotificationSettingH\x00R\x13notificationSetting\x12?\n" + + "\n" + + "ai_setting\x18\b \x01(\v2\x1e.memos.store.InstanceAISettingH\x00R\taiSettingB\a\n" + "\x05value\"\\\n" + "\x14InstanceBasicSetting\x12\x1d\n" + "\n" + @@ -1090,7 +1305,17 @@ const file_store_instance_setting_proto_rawDesc = "" + "\breply_to\x18\b \x01(\tR\areplyTo\x12\x17\n" + "\ause_tls\x18\t \x01(\bR\x06useTls\x12\x17\n" + "\ause_ssl\x18\n" + - " \x01(\bR\x06useSsl*\x8d\x01\n" + + " \x01(\bR\x06useSsl\"P\n" + + "\x11InstanceAISetting\x12;\n" + + "\tproviders\x18\x01 \x03(\v2\x1d.memos.store.AIProviderConfigR\tproviders\"\xdb\x01\n" + + "\x10AIProviderConfig\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + + "\x05title\x18\x02 \x01(\tR\x05title\x12/\n" + + "\x04type\x18\x03 \x01(\x0e2\x1b.memos.store.AIProviderTypeR\x04type\x12\x1a\n" + + "\bendpoint\x18\x04 \x01(\tR\bendpoint\x12\x17\n" + + "\aapi_key\x18\x05 \x01(\tR\x06apiKey\x12\x16\n" + + "\x06models\x18\x06 \x03(\tR\x06models\x12#\n" + + "\rdefault_model\x18\a \x01(\tR\fdefaultModel*\x95\x01\n" + "\x12InstanceSettingKey\x12$\n" + " INSTANCE_SETTING_KEY_UNSPECIFIED\x10\x00\x12\t\n" + "\x05BASIC\x10\x01\x12\v\n" + @@ -1098,7 +1323,16 @@ const file_store_instance_setting_proto_rawDesc = "" + "\aSTORAGE\x10\x03\x12\x10\n" + "\fMEMO_RELATED\x10\x04\x12\b\n" + "\x04TAGS\x10\x05\x12\x10\n" + - "\fNOTIFICATION\x10\x06B\x9f\x01\n" + + "\fNOTIFICATION\x10\x06\x12\x06\n" + + "\x02AI\x10\a*p\n" + + "\x0eAIProviderType\x12 \n" + + "\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" + + "\n" + + "\x06OPENAI\x10\x01\x12\x15\n" + + "\x11OPENAI_COMPATIBLE\x10\x02\x12\r\n" + + "\tANTHROPIC\x10\x03\x12\n" + + "\n" + + "\x06GEMINI\x10\x04B\x9f\x01\n" + "\x0fcom.memos.storeB\x14InstanceSettingProtoP\x01Z)github.com/usememos/memos/proto/gen/store\xa2\x02\x03MSX\xaa\x02\vMemos.Store\xca\x02\vMemos\\Store\xe2\x02\x17Memos\\Store\\GPBMetadata\xea\x02\fMemos::Storeb\x06proto3" var ( @@ -1113,45 +1347,51 @@ func file_store_instance_setting_proto_rawDescGZIP() []byte { return file_store_instance_setting_proto_rawDescData } -var file_store_instance_setting_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_store_instance_setting_proto_msgTypes = make([]protoimpl.MessageInfo, 12) +var file_store_instance_setting_proto_enumTypes = make([]protoimpl.EnumInfo, 3) +var file_store_instance_setting_proto_msgTypes = make([]protoimpl.MessageInfo, 14) var file_store_instance_setting_proto_goTypes = []any{ (InstanceSettingKey)(0), // 0: memos.store.InstanceSettingKey - (InstanceStorageSetting_StorageType)(0), // 1: memos.store.InstanceStorageSetting.StorageType - (*InstanceSetting)(nil), // 2: memos.store.InstanceSetting - (*InstanceBasicSetting)(nil), // 3: memos.store.InstanceBasicSetting - (*InstanceGeneralSetting)(nil), // 4: memos.store.InstanceGeneralSetting - (*InstanceCustomProfile)(nil), // 5: memos.store.InstanceCustomProfile - (*InstanceStorageSetting)(nil), // 6: memos.store.InstanceStorageSetting - (*StorageS3Config)(nil), // 7: memos.store.StorageS3Config - (*InstanceMemoRelatedSetting)(nil), // 8: memos.store.InstanceMemoRelatedSetting - (*InstanceTagMetadata)(nil), // 9: memos.store.InstanceTagMetadata - (*InstanceTagsSetting)(nil), // 10: memos.store.InstanceTagsSetting - (*InstanceNotificationSetting)(nil), // 11: memos.store.InstanceNotificationSetting - nil, // 12: memos.store.InstanceTagsSetting.TagsEntry - (*InstanceNotificationSetting_EmailSetting)(nil), // 13: memos.store.InstanceNotificationSetting.EmailSetting - (*color.Color)(nil), // 14: google.type.Color + (AIProviderType)(0), // 1: memos.store.AIProviderType + (InstanceStorageSetting_StorageType)(0), // 2: memos.store.InstanceStorageSetting.StorageType + (*InstanceSetting)(nil), // 3: memos.store.InstanceSetting + (*InstanceBasicSetting)(nil), // 4: memos.store.InstanceBasicSetting + (*InstanceGeneralSetting)(nil), // 5: memos.store.InstanceGeneralSetting + (*InstanceCustomProfile)(nil), // 6: memos.store.InstanceCustomProfile + (*InstanceStorageSetting)(nil), // 7: memos.store.InstanceStorageSetting + (*StorageS3Config)(nil), // 8: memos.store.StorageS3Config + (*InstanceMemoRelatedSetting)(nil), // 9: memos.store.InstanceMemoRelatedSetting + (*InstanceTagMetadata)(nil), // 10: memos.store.InstanceTagMetadata + (*InstanceTagsSetting)(nil), // 11: memos.store.InstanceTagsSetting + (*InstanceNotificationSetting)(nil), // 12: memos.store.InstanceNotificationSetting + (*InstanceAISetting)(nil), // 13: memos.store.InstanceAISetting + (*AIProviderConfig)(nil), // 14: memos.store.AIProviderConfig + nil, // 15: memos.store.InstanceTagsSetting.TagsEntry + (*InstanceNotificationSetting_EmailSetting)(nil), // 16: memos.store.InstanceNotificationSetting.EmailSetting + (*color.Color)(nil), // 17: google.type.Color } var file_store_instance_setting_proto_depIdxs = []int32{ 0, // 0: memos.store.InstanceSetting.key:type_name -> memos.store.InstanceSettingKey - 3, // 1: memos.store.InstanceSetting.basic_setting:type_name -> memos.store.InstanceBasicSetting - 4, // 2: memos.store.InstanceSetting.general_setting:type_name -> memos.store.InstanceGeneralSetting - 6, // 3: memos.store.InstanceSetting.storage_setting:type_name -> memos.store.InstanceStorageSetting - 8, // 4: memos.store.InstanceSetting.memo_related_setting:type_name -> memos.store.InstanceMemoRelatedSetting - 10, // 5: memos.store.InstanceSetting.tags_setting:type_name -> memos.store.InstanceTagsSetting - 11, // 6: memos.store.InstanceSetting.notification_setting:type_name -> memos.store.InstanceNotificationSetting - 5, // 7: memos.store.InstanceGeneralSetting.custom_profile:type_name -> memos.store.InstanceCustomProfile - 1, // 8: memos.store.InstanceStorageSetting.storage_type:type_name -> memos.store.InstanceStorageSetting.StorageType - 7, // 9: memos.store.InstanceStorageSetting.s3_config:type_name -> memos.store.StorageS3Config - 14, // 10: memos.store.InstanceTagMetadata.background_color:type_name -> google.type.Color - 12, // 11: memos.store.InstanceTagsSetting.tags:type_name -> memos.store.InstanceTagsSetting.TagsEntry - 13, // 12: memos.store.InstanceNotificationSetting.email:type_name -> memos.store.InstanceNotificationSetting.EmailSetting - 9, // 13: memos.store.InstanceTagsSetting.TagsEntry.value:type_name -> memos.store.InstanceTagMetadata - 14, // [14:14] is the sub-list for method output_type - 14, // [14:14] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 4, // 1: memos.store.InstanceSetting.basic_setting:type_name -> memos.store.InstanceBasicSetting + 5, // 2: memos.store.InstanceSetting.general_setting:type_name -> memos.store.InstanceGeneralSetting + 7, // 3: memos.store.InstanceSetting.storage_setting:type_name -> memos.store.InstanceStorageSetting + 9, // 4: memos.store.InstanceSetting.memo_related_setting:type_name -> memos.store.InstanceMemoRelatedSetting + 11, // 5: memos.store.InstanceSetting.tags_setting:type_name -> memos.store.InstanceTagsSetting + 12, // 6: memos.store.InstanceSetting.notification_setting:type_name -> memos.store.InstanceNotificationSetting + 13, // 7: memos.store.InstanceSetting.ai_setting:type_name -> memos.store.InstanceAISetting + 6, // 8: memos.store.InstanceGeneralSetting.custom_profile:type_name -> memos.store.InstanceCustomProfile + 2, // 9: memos.store.InstanceStorageSetting.storage_type:type_name -> memos.store.InstanceStorageSetting.StorageType + 8, // 10: memos.store.InstanceStorageSetting.s3_config:type_name -> memos.store.StorageS3Config + 17, // 11: memos.store.InstanceTagMetadata.background_color:type_name -> google.type.Color + 15, // 12: memos.store.InstanceTagsSetting.tags:type_name -> memos.store.InstanceTagsSetting.TagsEntry + 16, // 13: memos.store.InstanceNotificationSetting.email:type_name -> memos.store.InstanceNotificationSetting.EmailSetting + 14, // 14: memos.store.InstanceAISetting.providers:type_name -> memos.store.AIProviderConfig + 1, // 15: memos.store.AIProviderConfig.type:type_name -> memos.store.AIProviderType + 10, // 16: memos.store.InstanceTagsSetting.TagsEntry.value:type_name -> memos.store.InstanceTagMetadata + 17, // [17:17] is the sub-list for method output_type + 17, // [17:17] is the sub-list for method input_type + 17, // [17:17] is the sub-list for extension type_name + 17, // [17:17] is the sub-list for extension extendee + 0, // [0:17] is the sub-list for field type_name } func init() { file_store_instance_setting_proto_init() } @@ -1166,14 +1406,15 @@ func file_store_instance_setting_proto_init() { (*InstanceSetting_MemoRelatedSetting)(nil), (*InstanceSetting_TagsSetting)(nil), (*InstanceSetting_NotificationSetting)(nil), + (*InstanceSetting_AiSetting)(nil), } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_store_instance_setting_proto_rawDesc), len(file_store_instance_setting_proto_rawDesc)), - NumEnums: 2, - NumMessages: 12, + NumEnums: 3, + NumMessages: 14, NumExtensions: 0, NumServices: 0, }, diff --git a/proto/store/instance_setting.proto b/proto/store/instance_setting.proto index 2c7848d3c..521b4fcc1 100644 --- a/proto/store/instance_setting.proto +++ b/proto/store/instance_setting.proto @@ -20,6 +20,8 @@ enum InstanceSettingKey { TAGS = 5; // NOTIFICATION is the key for notification transport settings. NOTIFICATION = 6; + // AI is the key for AI provider settings. + AI = 7; } message InstanceSetting { @@ -31,6 +33,7 @@ message InstanceSetting { InstanceMemoRelatedSetting memo_related_setting = 5; InstanceTagsSetting tags_setting = 6; InstanceNotificationSetting notification_setting = 7; + InstanceAISetting ai_setting = 8; } } @@ -142,3 +145,27 @@ message InstanceNotificationSetting { bool use_ssl = 10; } } + +message InstanceAISetting { + // providers is the list of AI provider configurations available instance-wide. + repeated AIProviderConfig providers = 1; +} + +message AIProviderConfig { + string id = 1; + string title = 2; + AIProviderType type = 3; + string endpoint = 4; + // api_key is write-only at the API layer and is required by the server to call providers. + string api_key = 5; + repeated string models = 6; + string default_model = 7; +} + +enum AIProviderType { + AI_PROVIDER_TYPE_UNSPECIFIED = 0; + OPENAI = 1; + OPENAI_COMPATIBLE = 2; + ANTHROPIC = 3; + GEMINI = 4; +} diff --git a/server/router/api/v1/ai_service.go b/server/router/api/v1/ai_service.go new file mode 100644 index 000000000..eafb99154 --- /dev/null +++ b/server/router/api/v1/ai_service.go @@ -0,0 +1,198 @@ +package v1 + +import ( + "bytes" + "context" + "mime" + "net/http" + "strings" + + "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/usememos/memos/internal/ai" + "github.com/usememos/memos/internal/ai/openai" + v1pb "github.com/usememos/memos/proto/gen/api/v1" + storepb "github.com/usememos/memos/proto/gen/store" +) + +const ( + maxTranscriptionAudioSizeBytes = 25 * MebiByte + maxTranscriptionPromptLength = 4096 + maxTranscriptionLanguageLength = 32 + maxTranscriptionFilenameLength = 255 +) + +var supportedTranscriptionContentTypes = map[string]bool{ + "audio/mpeg": true, + "audio/mp4": true, + "audio/mpga": true, + "audio/wav": true, + "audio/x-wav": true, + "audio/webm": true, + "audio/x-m4a": true, + "video/mp4": true, + "video/mpeg": true, + "video/webm": true, +} + +// Transcribe transcribes an audio file using an instance AI provider. +func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeRequest) (*v1pb.TranscribeResponse, error) { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + + if strings.TrimSpace(request.ProviderId) == "" { + return nil, status.Errorf(codes.InvalidArgument, "provider_id is required") + } + if request.Config == nil { + return nil, status.Errorf(codes.InvalidArgument, "config is required") + } + prompt := strings.TrimSpace(request.Config.GetPrompt()) + if len(prompt) > maxTranscriptionPromptLength { + return nil, status.Errorf(codes.InvalidArgument, "prompt is too long; maximum length is %d characters", maxTranscriptionPromptLength) + } + language := strings.TrimSpace(request.Config.GetLanguage()) + if len(language) > maxTranscriptionLanguageLength { + return nil, status.Errorf(codes.InvalidArgument, "language is too long; maximum length is %d characters", maxTranscriptionLanguageLength) + } + if request.Audio == nil { + return nil, status.Errorf(codes.InvalidArgument, "audio is required") + } + if request.Audio.GetUri() != "" { + return nil, status.Errorf(codes.InvalidArgument, "audio uri is not supported") + } + content := request.Audio.GetContent() + if len(content) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "audio content is required") + } + if len(content) > maxTranscriptionAudioSizeBytes { + return nil, status.Errorf(codes.InvalidArgument, "audio file is too large; maximum size is 25 MiB") + } + filename := strings.TrimSpace(request.Audio.GetFilename()) + if len(filename) > maxTranscriptionFilenameLength { + return nil, status.Errorf(codes.InvalidArgument, "filename is too long; maximum length is %d characters", maxTranscriptionFilenameLength) + } + contentType := strings.TrimSpace(request.Audio.GetContentType()) + if contentType == "" { + contentType = http.DetectContentType(content) + } + if !isSupportedTranscriptionContentType(contentType) { + return nil, status.Errorf(codes.InvalidArgument, "audio content type %q is not supported", contentType) + } + + provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId, request.Config.GetModel()) + if err != nil { + return nil, err + } + transcriber, err := newAITranscriber(provider) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err) + } + + transcription, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{ + Model: model, + Filename: filename, + ContentType: contentType, + Audio: bytes.NewReader(content), + Size: int64(len(content)), + Prompt: prompt, + Language: language, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to transcribe audio: %v", err) + } + return &v1pb.TranscribeResponse{ + Text: transcription.Text, + }, nil +} + +func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string, model string) (ai.ProviderConfig, string, error) { + setting, err := s.Store.GetInstanceAISetting(ctx) + if err != nil { + return ai.ProviderConfig{}, "", status.Errorf(codes.Internal, "failed to get AI setting: %v", err) + } + + providers := make([]ai.ProviderConfig, 0, len(setting.GetProviders())) + for _, provider := range setting.GetProviders() { + if provider == nil { + continue + } + providers = append(providers, convertAIProviderConfigFromStore(provider)) + } + + provider, err := ai.FindProvider(providers, providerID) + if err != nil { + return ai.ProviderConfig{}, "", status.Errorf(codes.NotFound, "AI provider not found") + } + selectedModel := strings.TrimSpace(model) + if selectedModel == "" { + selectedModel = provider.DefaultModel + } + if selectedModel == "" { + return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model is required") + } + if !containsString(provider.Models, selectedModel) { + return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model %q is not configured for provider %q", selectedModel, provider.ID) + } + return *provider, selectedModel, nil +} + +func convertAIProviderConfigFromStore(provider *storepb.AIProviderConfig) ai.ProviderConfig { + return ai.ProviderConfig{ + ID: provider.GetId(), + Title: provider.GetTitle(), + Type: convertAIProviderTypeFromStore(provider.GetType()), + Endpoint: provider.GetEndpoint(), + APIKey: provider.GetApiKey(), + Models: provider.GetModels(), + DefaultModel: provider.GetDefaultModel(), + } +} + +func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.ProviderType { + switch providerType { + case storepb.AIProviderType_OPENAI: + return ai.ProviderOpenAI + case storepb.AIProviderType_OPENAI_COMPATIBLE: + return ai.ProviderOpenAICompatible + case storepb.AIProviderType_ANTHROPIC: + return ai.ProviderAnthropic + case storepb.AIProviderType_GEMINI: + return ai.ProviderGemini + default: + return "" + } +} + +func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) { + switch provider.Type { + case ai.ProviderOpenAI, ai.ProviderOpenAICompatible: + return openai.NewTranscriber(provider) + default: + return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type) + } +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func isSupportedTranscriptionContentType(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType)) + if err != nil { + return false + } + mediaType = strings.ToLower(mediaType) + return supportedTranscriptionContentTypes[mediaType] +} diff --git a/server/router/api/v1/connect_handler.go b/server/router/api/v1/connect_handler.go index 349b8fdc9..95d90d7bc 100644 --- a/server/router/api/v1/connect_handler.go +++ b/server/router/api/v1/connect_handler.go @@ -39,6 +39,7 @@ func (s *ConnectServiceHandler) RegisterConnectHandlers(mux *http.ServeMux, opts wrap(apiv1connect.NewUserServiceHandler(s, opts...)), wrap(apiv1connect.NewMemoServiceHandler(s, opts...)), wrap(apiv1connect.NewAttachmentServiceHandler(s, opts...)), + wrap(apiv1connect.NewAIServiceHandler(s, opts...)), wrap(apiv1connect.NewShortcutServiceHandler(s, opts...)), wrap(apiv1connect.NewIdentityProviderServiceHandler(s, opts...)), } diff --git a/server/router/api/v1/connect_services.go b/server/router/api/v1/connect_services.go index f7ee403f8..9a840bb81 100644 --- a/server/router/api/v1/connect_services.go +++ b/server/router/api/v1/connect_services.go @@ -435,6 +435,16 @@ func (s *ConnectServiceHandler) BatchDeleteAttachments(ctx context.Context, req return connect.NewResponse(resp), nil } +// AIService + +func (s *ConnectServiceHandler) Transcribe(ctx context.Context, req *connect.Request[v1pb.TranscribeRequest]) (*connect.Response[v1pb.TranscribeResponse], error) { + resp, err := s.APIV1Service.Transcribe(ctx, req.Msg) + if err != nil { + return nil, convertGRPCError(err) + } + return connect.NewResponse(resp), nil +} + // ShortcutService func (s *ConnectServiceHandler) ListShortcuts(ctx context.Context, req *connect.Request[v1pb.ListShortcutsRequest]) (*connect.Response[v1pb.ListShortcutsResponse], error) { diff --git a/server/router/api/v1/instance_service.go b/server/router/api/v1/instance_service.go index f6e6d519c..d21ccaa12 100644 --- a/server/router/api/v1/instance_service.go +++ b/server/router/api/v1/instance_service.go @@ -5,8 +5,10 @@ import ( "fmt" "math" "regexp" + "slices" "strings" + "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" colorpb "google.golang.org/genproto/googleapis/type/color" "google.golang.org/grpc/codes" @@ -54,6 +56,8 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get _, err = s.Store.GetInstanceTagsSetting(ctx) case storepb.InstanceSettingKey_NOTIFICATION: _, err = s.Store.GetInstanceNotificationSetting(ctx) + case storepb.InstanceSettingKey_AI: + _, err = s.Store.GetInstanceAISetting(ctx) default: return nil, status.Errorf(codes.InvalidArgument, "unsupported instance setting key: %v", instanceSettingKey) } @@ -71,9 +75,10 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get return nil, status.Errorf(codes.NotFound, "instance setting not found") } - // Storage and notification settings contain credentials; restrict to admins only. + // Storage, notification, and AI settings contain credentials; restrict to admins only. if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE || - instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION { + instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION || + instanceSetting.Key == storepb.InstanceSettingKey_AI { user, err := s.fetchCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) @@ -127,6 +132,10 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb. storage.S3Config.AccessKeySecret = existing.S3Config.AccessKeySecret } } + case storepb.InstanceSettingKey_AI: + if err := s.prepareInstanceAISettingForUpdate(ctx, updateSetting.GetAiSetting()); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid AI setting: %v", err) + } default: // No credential preservation needed for other setting types. } @@ -164,6 +173,10 @@ func convertInstanceSettingFromStore(setting *storepb.InstanceSetting) *v1pb.Ins instanceSetting.Value = &v1pb.InstanceSetting_NotificationSetting_{ NotificationSetting: convertInstanceNotificationSettingFromStore(setting.GetNotificationSetting()), } + case *storepb.InstanceSetting_AiSetting: + instanceSetting.Value = &v1pb.InstanceSetting_AiSetting{ + AiSetting: convertInstanceAISettingFromStore(setting.GetAiSetting()), + } default: // Leave Value unset for unsupported setting variants. } @@ -199,6 +212,10 @@ func convertInstanceSettingToStore(setting *v1pb.InstanceSetting) *storepb.Insta instanceSetting.Value = &storepb.InstanceSetting_NotificationSetting{ NotificationSetting: convertInstanceNotificationSettingToStore(setting.GetNotificationSetting()), } + case storepb.InstanceSettingKey_AI: + instanceSetting.Value = &storepb.InstanceSetting_AiSetting{ + AiSetting: convertInstanceAISettingToStore(setting.GetAiSetting()), + } default: // Keep the default GeneralSetting value } @@ -398,6 +415,58 @@ func convertInstanceNotificationSettingToStore(setting *v1pb.InstanceSetting_Not return notificationSetting } +func convertInstanceAISettingFromStore(setting *storepb.InstanceAISetting) *v1pb.InstanceSetting_AISetting { + if setting == nil { + return nil + } + + aiSetting := &v1pb.InstanceSetting_AISetting{ + Providers: make([]*v1pb.InstanceSetting_AIProviderConfig, 0, len(setting.Providers)), + } + for _, provider := range setting.Providers { + if provider == nil { + continue + } + apiKey := provider.GetApiKey() + aiSetting.Providers = append(aiSetting.Providers, &v1pb.InstanceSetting_AIProviderConfig{ + Id: provider.GetId(), + Title: provider.GetTitle(), + Type: v1pb.InstanceSetting_AIProviderType(provider.GetType()), + Endpoint: provider.GetEndpoint(), + Models: provider.GetModels(), + DefaultModel: provider.GetDefaultModel(), + ApiKeySet: apiKey != "", + ApiKeyHint: maskAPIKey(apiKey), + }) + } + return aiSetting +} + +func convertInstanceAISettingToStore(setting *v1pb.InstanceSetting_AISetting) *storepb.InstanceAISetting { + if setting == nil { + return nil + } + + aiSetting := &storepb.InstanceAISetting{ + Providers: make([]*storepb.AIProviderConfig, 0, len(setting.Providers)), + } + for _, provider := range setting.Providers { + if provider == nil { + continue + } + aiSetting.Providers = append(aiSetting.Providers, &storepb.AIProviderConfig{ + Id: provider.GetId(), + Title: provider.GetTitle(), + Type: storepb.AIProviderType(provider.GetType()), + Endpoint: provider.GetEndpoint(), + ApiKey: provider.GetApiKey(), + Models: provider.GetModels(), + DefaultModel: provider.GetDefaultModel(), + }) + } + return aiSetting +} + func validateInstanceSetting(setting *v1pb.InstanceSetting) error { key, err := ExtractInstanceSettingKeyFromName(setting.Name) if err != nil { @@ -409,6 +478,104 @@ func validateInstanceSetting(setting *v1pb.InstanceSetting) error { return validateInstanceTagsSetting(setting.GetTagsSetting()) } +func (s *APIV1Service) prepareInstanceAISettingForUpdate(ctx context.Context, setting *storepb.InstanceAISetting) error { + if setting == nil { + return errors.New("AI setting is required") + } + + existing, err := s.Store.GetInstanceAISetting(ctx) + if err != nil { + return errors.Wrap(err, "failed to get existing AI setting") + } + existingProviders := map[string]*storepb.AIProviderConfig{} + if existing != nil { + for _, provider := range existing.Providers { + if provider != nil && provider.Id != "" { + existingProviders[provider.Id] = provider + } + } + } + + seenIDs := map[string]bool{} + for _, provider := range setting.Providers { + if provider == nil { + return errors.New("provider cannot be nil") + } + + provider.Id = strings.TrimSpace(provider.Id) + if provider.Id == "" { + provider.Id = shortuuid.New() + } + if seenIDs[provider.Id] { + return errors.Errorf("duplicate provider ID %q", provider.Id) + } + seenIDs[provider.Id] = true + + provider.Title = strings.TrimSpace(provider.Title) + if provider.Title == "" { + return errors.New("provider title is required") + } + if provider.Type == storepb.AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED { + return errors.Errorf("provider %q type is required", provider.Id) + } + + provider.Endpoint = strings.TrimSpace(provider.Endpoint) + if provider.Type == storepb.AIProviderType_OPENAI && provider.Endpoint == "" { + provider.Endpoint = "https://api.openai.com/v1" + } + if provider.Type == storepb.AIProviderType_OPENAI_COMPATIBLE && provider.Endpoint == "" { + return errors.Errorf("provider %q endpoint is required", provider.Id) + } + + provider.Models = normalizeAIModels(provider.Models) + if len(provider.Models) == 0 { + return errors.Errorf("provider %q must define at least one model", provider.Id) + } + provider.DefaultModel = strings.TrimSpace(provider.DefaultModel) + if provider.DefaultModel == "" { + provider.DefaultModel = provider.Models[0] + } + if !slices.Contains(provider.Models, provider.DefaultModel) { + return errors.Errorf("provider %q default model %q must be included in models", provider.Id, provider.DefaultModel) + } + + if provider.ApiKey == "" { + if existingProvider, ok := existingProviders[provider.Id]; ok { + provider.ApiKey = existingProvider.ApiKey + } + } + if provider.ApiKey == "" { + return errors.Errorf("provider %q API key is required", provider.Id) + } + } + return nil +} + +func normalizeAIModels(models []string) []string { + normalized := []string{} + seen := map[string]bool{} + for _, model := range models { + model = strings.TrimSpace(model) + if model == "" || seen[model] { + continue + } + seen[model] = true + normalized = append(normalized, model) + } + return normalized +} + +func maskAPIKey(apiKey string) string { + if apiKey == "" { + return "" + } + if len(apiKey) <= 8 { + return "..." + } + prefixLength := min(4, len(apiKey)) + return apiKey[:prefixLength] + "..." + apiKey[len(apiKey)-4:] +} + func validateInstanceTagsSetting(setting *v1pb.InstanceSetting_TagsSetting) error { if setting == nil { return errors.New("tags setting is required") diff --git a/server/router/api/v1/test/ai_service_test.go b/server/router/api/v1/test/ai_service_test.go new file mode 100644 index 000000000..79adf11a3 --- /dev/null +++ b/server/router/api/v1/test/ai_service_test.go @@ -0,0 +1,185 @@ +package test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" + storepb "github.com/usememos/memos/proto/gen/store" +) + +func TestTranscribe(t *testing.T) { + ctx := context.Background() + + t.Run("requires authentication", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.Transcribe(ctx, &v1pb.TranscribeRequest{ + ProviderId: "openai-main", + Config: &v1pb.TranscriptionConfig{ + Model: "gpt-4o-transcribe", + }, + Audio: &v1pb.TranscriptionAudio{ + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, + Filename: "voice.wav", + ContentType: "audio/wav", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "user not authenticated") + }) + + t.Run("transcribes audio file with configured provider", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "alice") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/audio/transcriptions", r.URL.Path) + require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) + require.NoError(t, r.ParseMultipartForm(10<<20)) + require.Equal(t, "gpt-4o-transcribe", r.FormValue("model")) + require.Equal(t, "names: Alice", r.FormValue("prompt")) + + file, header, err := r.FormFile("file") + require.NoError(t, err) + defer file.Close() + require.Equal(t, "voice.wav", header.Filename) + + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]string{ + "text": "transcribed text", + })) + })) + defer openAIServer.Close() + + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_AI, + Value: &storepb.InstanceSetting_AiSetting{ + AiSetting: &storepb.InstanceAISetting{ + Providers: []*storepb.AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI_COMPATIBLE, + Endpoint: openAIServer.URL, + ApiKey: "sk-test", + Models: []string{"gpt-4o-transcribe"}, + DefaultModel: "gpt-4o-transcribe", + }, + }, + }, + }, + }) + require.NoError(t, err) + + resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + ProviderId: "openai-main", + Config: &v1pb.TranscriptionConfig{ + Prompt: "names: Alice", + }, + Audio: &v1pb.TranscriptionAudio{ + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, + Filename: "voice.wav", + ContentType: "audio/wav", + }, + }) + require.NoError(t, err) + require.Equal(t, "transcribed text", resp.Text) + }) + + t.Run("rejects unconfigured model", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "bob") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_AI, + Value: &storepb.InstanceSetting_AiSetting{ + AiSetting: &storepb.InstanceAISetting{ + Providers: []*storepb.AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI_COMPATIBLE, + Endpoint: "https://example.com/v1", + ApiKey: "sk-test", + Models: []string{"gpt-4o-transcribe"}, + DefaultModel: "gpt-4o-transcribe", + }, + }, + }, + }, + }) + require.NoError(t, err) + + _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + ProviderId: "openai-main", + Config: &v1pb.TranscriptionConfig{ + Model: "other-model", + }, + Audio: &v1pb.TranscriptionAudio{ + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, + Filename: "voice.wav", + ContentType: "audio/wav", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not configured") + }) + + t.Run("rejects non-audio content before provider call", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "charlie") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_AI, + Value: &storepb.InstanceSetting_AiSetting{ + AiSetting: &storepb.InstanceAISetting{ + Providers: []*storepb.AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI_COMPATIBLE, + Endpoint: "https://example.com/v1", + ApiKey: "sk-test", + Models: []string{"gpt-4o-transcribe"}, + DefaultModel: "gpt-4o-transcribe", + }, + }, + }, + }, + }) + require.NoError(t, err) + + _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + ProviderId: "openai-main", + Config: &v1pb.TranscriptionConfig{ + Model: "gpt-4o-transcribe", + }, + Audio: &v1pb.TranscriptionAudio{ + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("not audio")}, + Filename: "notes.txt", + ContentType: "text/plain", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not supported") + }) +} diff --git a/server/router/api/v1/test/instance_service_test.go b/server/router/api/v1/test/instance_service_test.go index eac97b58e..fb795b012 100644 --- a/server/router/api/v1/test/instance_service_test.go +++ b/server/router/api/v1/test/instance_service_test.go @@ -238,6 +238,34 @@ func TestGetInstanceSetting(t *testing.T) { "SmtpPassword must never be returned in responses") }) + t.Run("GetInstanceSetting - AI setting requires admin", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + admin, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + adminCtx := ts.CreateUserContext(ctx, admin.ID) + + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, regularUser.ID) + + req := &v1pb.GetInstanceSettingRequest{Name: "instance/settings/AI"} + + _, err = ts.Service.GetInstanceSetting(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + + _, err = ts.Service.GetInstanceSetting(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + + resp, err := ts.Service.GetInstanceSetting(adminCtx, req) + require.NoError(t, err) + require.NotNil(t, resp.GetAiSetting()) + require.Empty(t, resp.GetAiSetting().GetProviders()) + }) + t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) { // Create test service for this specific test ts := NewTestService(t) @@ -258,6 +286,41 @@ func TestGetInstanceSetting(t *testing.T) { func TestUpdateInstanceSetting(t *testing.T) { ctx := context.Background() + t.Run("UpdateInstanceSetting - AI setting requires admin", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, regularUser.ID) + + setting := &v1pb.InstanceSetting{ + Name: "instance/settings/AI", + Value: &v1pb.InstanceSetting_AiSetting{ + AiSetting: &v1pb.InstanceSetting_AISetting{ + Providers: []*v1pb.InstanceSetting_AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "sk-test", + Models: []string{"gpt-4o-transcribe"}, + DefaultModel: "gpt-4o-transcribe", + }, + }, + }, + }, + } + + _, err = ts.Service.UpdateInstanceSetting(ctx, &v1pb.UpdateInstanceSettingRequest{Setting: setting}) + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + + _, err = ts.Service.UpdateInstanceSetting(userCtx, &v1pb.UpdateInstanceSettingRequest{Setting: setting}) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + t.Run("UpdateInstanceSetting - tags setting", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -490,4 +553,75 @@ func TestUpdateInstanceSetting(t *testing.T) { "existing AccessKeySecret must be preserved when an empty value is sent") require.Equal(t, "s3-v2.example.com", stored.GetS3Config().GetEndpoint()) }) + + t.Run("UpdateInstanceSetting - AI provider keys are write-only and preserved on empty", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + adminCtx := ts.CreateUserContext(ctx, hostUser.ID) + + _, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{ + Setting: &v1pb.InstanceSetting{ + Name: "instance/settings/AI", + Value: &v1pb.InstanceSetting_AiSetting{ + AiSetting: &v1pb.InstanceSetting_AISetting{ + Providers: []*v1pb.InstanceSetting_AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "sk-original", + Models: []string{"gpt-5.4", "gpt-5.4-mini"}, + DefaultModel: "gpt-5.4", + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + + resp, err := ts.Service.GetInstanceSetting(adminCtx, &v1pb.GetInstanceSettingRequest{ + Name: "instance/settings/AI", + }) + require.NoError(t, err) + require.Len(t, resp.GetAiSetting().GetProviders(), 1) + provider := resp.GetAiSetting().GetProviders()[0] + require.Empty(t, provider.GetApiKey(), "AI provider API key must never be returned in responses") + require.True(t, provider.GetApiKeySet()) + require.Equal(t, "sk-o...inal", provider.GetApiKeyHint()) + require.Equal(t, "https://api.openai.com/v1", provider.GetEndpoint()) + + _, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{ + Setting: &v1pb.InstanceSetting{ + Name: "instance/settings/AI", + Value: &v1pb.InstanceSetting_AiSetting{ + AiSetting: &v1pb.InstanceSetting_AISetting{ + Providers: []*v1pb.InstanceSetting_AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI primary", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "", + Models: []string{"gpt-5.4-mini", "gpt-5.4-mini", "gpt-5.4"}, + DefaultModel: "", + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + + stored, err := ts.Store.GetInstanceAISetting(ctx) + require.NoError(t, err) + require.Len(t, stored.GetProviders(), 1) + require.Equal(t, "sk-original", stored.GetProviders()[0].GetApiKey(), + "existing AI provider API key must be preserved when an empty value is sent") + require.Equal(t, "OpenAI primary", stored.GetProviders()[0].GetTitle()) + require.Equal(t, []string{"gpt-5.4-mini", "gpt-5.4"}, stored.GetProviders()[0].GetModels()) + require.Equal(t, "gpt-5.4-mini", stored.GetProviders()[0].GetDefaultModel()) + }) } diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 69acb2fbc..53839548c 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -23,6 +23,7 @@ type APIV1Service struct { v1pb.UnimplementedUserServiceServer v1pb.UnimplementedMemoServiceServer v1pb.UnimplementedAttachmentServiceServer + v1pb.UnimplementedAIServiceServer v1pb.UnimplementedShortcutServiceServer v1pb.UnimplementedIdentityProviderServiceServer @@ -104,6 +105,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil { return err } + if err := v1pb.RegisterAIServiceHandlerServer(ctx, gwMux, s); err != nil { + return err + } if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil { return err } diff --git a/store/instance_setting.go b/store/instance_setting.go index d66b31f9b..445de4b74 100644 --- a/store/instance_setting.go +++ b/store/instance_setting.go @@ -41,6 +41,8 @@ func (s *Store) UpsertInstanceSetting(ctx context.Context, upsert *storepb.Insta valueBytes, err = protojson.Marshal(upsert.GetTagsSetting()) } else if upsert.Key == storepb.InstanceSettingKey_NOTIFICATION { valueBytes, err = protojson.Marshal(upsert.GetNotificationSetting()) + } else if upsert.Key == storepb.InstanceSettingKey_AI { + valueBytes, err = protojson.Marshal(upsert.GetAiSetting()) } else { return nil, errors.Errorf("unsupported instance setting key: %v", upsert.Key) } @@ -216,6 +218,26 @@ func (s *Store) GetInstanceNotificationSetting(ctx context.Context) (*storepb.In return instanceNotificationSetting, nil } +// GetInstanceAISetting gets the AI provider settings for the instance. +func (s *Store) GetInstanceAISetting(ctx context.Context) (*storepb.InstanceAISetting, error) { + instanceSetting, err := s.GetInstanceSetting(ctx, &FindInstanceSetting{ + Name: storepb.InstanceSettingKey_AI.String(), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to get instance AI setting") + } + + instanceAISetting := &storepb.InstanceAISetting{} + if instanceSetting != nil { + instanceAISetting = instanceSetting.GetAiSetting() + } + s.instanceSettingCache.Set(ctx, storepb.InstanceSettingKey_AI.String(), &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_AI, + Value: &storepb.InstanceSetting_AiSetting{AiSetting: instanceAISetting}, + }) + return instanceAISetting, nil +} + const ( defaultInstanceStorageType = storepb.InstanceStorageSetting_LOCAL defaultInstanceUploadSizeLimitMb = 30 @@ -291,6 +313,12 @@ func convertInstanceSettingFromRaw(instanceSettingRaw *InstanceSetting) (*storep return nil, err } instanceSetting.Value = &storepb.InstanceSetting_NotificationSetting{NotificationSetting: notificationSetting} + case storepb.InstanceSettingKey_AI.String(): + aiSetting := &storepb.InstanceAISetting{} + if err := protojsonUnmarshaler.Unmarshal([]byte(instanceSettingRaw.Value), aiSetting); err != nil { + return nil, err + } + instanceSetting.Value = &storepb.InstanceSetting_AiSetting{AiSetting: aiSetting} default: // Skip unsupported instance setting key. return nil, nil diff --git a/store/test/instance_setting_test.go b/store/test/instance_setting_test.go index bf63b1fe3..0fd56bff7 100644 --- a/store/test/instance_setting_test.go +++ b/store/test/instance_setting_test.go @@ -326,6 +326,55 @@ func TestInstanceSettingNotificationSetting(t *testing.T) { ts.Close() } +func TestInstanceSettingAISetting(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + aiSetting, err := ts.GetInstanceAISetting(ctx) + require.NoError(t, err) + require.NotNil(t, aiSetting) + require.Empty(t, aiSetting.Providers) + + _, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_AI, + Value: &storepb.InstanceSetting_AiSetting{ + AiSetting: &storepb.InstanceAISetting{ + Providers: []*storepb.AIProviderConfig{ + { + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: "https://api.openai.com/v1", + ApiKey: "sk-test", + Models: []string{"gpt-5.4", "gpt-5.4-mini"}, + DefaultModel: "gpt-5.4", + }, + { + Id: "company-gateway", + Title: "Company Gateway", + Type: storepb.AIProviderType_OPENAI_COMPATIBLE, + Endpoint: "https://llm.example.com/v1", + ApiKey: "gw-test", + Models: []string{"qwen-plus"}, + DefaultModel: "qwen-plus", + }, + }, + }, + }, + }) + require.NoError(t, err) + + aiSetting, err = ts.GetInstanceAISetting(ctx) + require.NoError(t, err) + require.Len(t, aiSetting.Providers, 2) + require.Equal(t, "openai-main", aiSetting.Providers[0].Id) + require.Equal(t, "sk-test", aiSetting.Providers[0].ApiKey) + require.Equal(t, "company-gateway", aiSetting.Providers[1].Id) + + ts.Close() +} + func TestInstanceSettingListAll(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/web/src/components/Settings/AISection.tsx b/web/src/components/Settings/AISection.tsx new file mode 100644 index 000000000..f6acf85d8 --- /dev/null +++ b/web/src/components/Settings/AISection.tsx @@ -0,0 +1,408 @@ +import { create } from "@bufbuild/protobuf"; +import { isEqual } from "lodash-es"; +import { MoreVerticalIcon, PlusIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { toast } from "react-hot-toast"; +import ConfirmDialog from "@/components/ConfirmDialog"; +import { Button } from "@/components/ui/button"; +import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from "@/components/ui/dialog"; +import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { Textarea } from "@/components/ui/textarea"; +import { useInstance } from "@/contexts/InstanceContext"; +import { handleError } from "@/lib/error"; +import { + InstanceSetting_AIProviderConfig, + InstanceSetting_AIProviderConfigSchema, + InstanceSetting_AIProviderType, + InstanceSetting_AISettingSchema, + InstanceSetting_Key, + InstanceSettingSchema, +} from "@/types/proto/api/v1/instance_service_pb"; +import { useTranslate } from "@/utils/i18n"; +import SettingGroup from "./SettingGroup"; +import SettingSection from "./SettingSection"; +import SettingTable from "./SettingTable"; + +type LocalAIProvider = { + id: string; + title: string; + type: InstanceSetting_AIProviderType; + endpoint: string; + apiKey: string; + apiKeySet: boolean; + apiKeyHint: string; + models: string[]; + defaultModel: string; +}; + +const providerTypeOptions = [ + InstanceSetting_AIProviderType.OPENAI, + InstanceSetting_AIProviderType.OPENAI_COMPATIBLE, + InstanceSetting_AIProviderType.ANTHROPIC, + InstanceSetting_AIProviderType.GEMINI, +]; + +const createProviderID = () => { + if (typeof crypto !== "undefined" && "randomUUID" in crypto) { + return crypto.randomUUID(); + } + return `ai-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; +}; + +const getProviderTypeLabel = (type: InstanceSetting_AIProviderType) => { + return InstanceSetting_AIProviderType[type] ?? "UNKNOWN"; +}; + +const toLocalProvider = (provider: InstanceSetting_AIProviderConfig): LocalAIProvider => ({ + id: provider.id, + title: provider.title, + type: provider.type, + endpoint: provider.endpoint, + apiKey: "", + apiKeySet: provider.apiKeySet, + apiKeyHint: provider.apiKeyHint, + models: [...provider.models], + defaultModel: provider.defaultModel, +}); + +const normalizeModels = (value: string) => { + const models = value + .split(/\r?\n/) + .map((model) => model.trim()) + .filter(Boolean); + return Array.from(new Set(models)); +}; + +const newProvider = (): LocalAIProvider => ({ + id: createProviderID(), + title: "", + type: InstanceSetting_AIProviderType.OPENAI, + endpoint: "", + apiKey: "", + apiKeySet: false, + apiKeyHint: "", + models: [], + defaultModel: "", +}); + +const toProviderConfig = (provider: LocalAIProvider) => + create(InstanceSetting_AIProviderConfigSchema, { + id: provider.id, + title: provider.title.trim(), + type: provider.type, + endpoint: provider.endpoint.trim(), + apiKey: provider.apiKey, + models: provider.models, + defaultModel: provider.defaultModel.trim(), + }); + +const AISection = () => { + const t = useTranslate(); + const { aiSetting: originalSetting, updateSetting, fetchSetting } = useInstance(); + const [providers, setProviders] = useState(() => originalSetting.providers.map(toLocalProvider)); + const [editingProvider, setEditingProvider] = useState(); + const [deleteTarget, setDeleteTarget] = useState(); + + useEffect(() => { + setProviders(originalSetting.providers.map(toLocalProvider)); + }, [originalSetting.providers]); + + const originalProviders = useMemo(() => originalSetting.providers.map(toLocalProvider), [originalSetting.providers]); + const hasChanges = !isEqual(providers, originalProviders); + + const handleCreateProvider = () => { + setEditingProvider(newProvider()); + }; + + const handleEditProvider = (provider: LocalAIProvider) => { + setEditingProvider({ ...provider, apiKey: "" }); + }; + + const handleSaveProvider = (provider: LocalAIProvider) => { + const title = provider.title.trim(); + const endpoint = provider.endpoint.trim(); + const models = provider.models.map((model) => model.trim()).filter(Boolean); + const defaultModel = provider.defaultModel.trim() || models[0] || ""; + + if (!title) { + toast.error(t("setting.ai.provider-title-required")); + return; + } + if (provider.type === InstanceSetting_AIProviderType.OPENAI_COMPATIBLE && !endpoint) { + toast.error(t("setting.ai.endpoint-required")); + return; + } + if (!provider.apiKeySet && !provider.apiKey.trim()) { + toast.error(t("setting.ai.api-key-required")); + return; + } + if (models.length === 0) { + toast.error(t("setting.ai.models-required")); + return; + } + if (defaultModel && !models.includes(defaultModel)) { + toast.error(t("setting.ai.default-model-required")); + return; + } + + const normalizedProvider = { + ...provider, + title, + endpoint, + models, + defaultModel, + }; + setProviders((prev) => { + const exists = prev.some((item) => item.id === normalizedProvider.id); + if (!exists) { + return [...prev, normalizedProvider]; + } + return prev.map((item) => (item.id === normalizedProvider.id ? normalizedProvider : item)); + }); + setEditingProvider(undefined); + }; + + const handleDeleteProvider = () => { + if (!deleteTarget) return; + setProviders((prev) => prev.filter((provider) => provider.id !== deleteTarget.id)); + setDeleteTarget(undefined); + }; + + const handleSaveSetting = async () => { + try { + await updateSetting( + create(InstanceSettingSchema, { + name: `instance/settings/${InstanceSetting_Key[InstanceSetting_Key.AI]}`, + value: { + case: "aiSetting", + value: create(InstanceSetting_AISettingSchema, { + providers: providers.map(toProviderConfig), + }), + }, + }), + ); + await fetchSetting(InstanceSetting_Key.AI); + toast.success(t("message.update-succeed")); + } catch (error: unknown) { + handleError(error, toast.error, { + context: "Update AI providers", + }); + } + }; + + return ( + + + {t("setting.ai.add-provider")} + + } + > + + ( +
+ {provider.title} + {provider.id} +
+ ), + }, + { + key: "type", + header: t("setting.ai.provider-type"), + render: (_, provider: LocalAIProvider) => {getProviderTypeLabel(provider.type)}, + }, + { + key: "models", + header: t("setting.ai.models"), + render: (_, provider: LocalAIProvider) => ( +
+ {provider.defaultModel || provider.models[0] || "-"} + {t("setting.ai.model-count", { count: provider.models.length })} +
+ ), + }, + { + key: "apiKeySet", + header: t("setting.ai.api-key"), + render: (_, provider: LocalAIProvider) => ( + {provider.apiKeySet ? provider.apiKeyHint || t("setting.ai.configured") : "-"} + ), + }, + { + key: "actions", + header: "", + className: "text-right", + render: (_, provider: LocalAIProvider) => ( + + + + + + handleEditProvider(provider)}>{t("common.edit")} + setDeleteTarget(provider)} className="text-destructive focus:text-destructive"> + {t("common.delete")} + + + + ), + }, + ]} + data={providers} + emptyMessage={t("setting.ai.no-providers")} + getRowKey={(provider) => provider.id} + /> +
+ +
+ +
+ + !open && setEditingProvider(undefined)} + onSave={handleSaveProvider} + /> + + !open && setDeleteTarget(undefined)} + title={deleteTarget ? t("setting.ai.delete-provider", { title: deleteTarget.title }) : ""} + confirmLabel={t("common.delete")} + cancelLabel={t("common.cancel")} + onConfirm={handleDeleteProvider} + confirmVariant="destructive" + /> +
+ ); +}; + +interface AIProviderDialogProps { + provider?: LocalAIProvider; + onOpenChange: (open: boolean) => void; + onSave: (provider: LocalAIProvider) => void; +} + +const AIProviderDialog = ({ provider, onOpenChange, onSave }: AIProviderDialogProps) => { + const t = useTranslate(); + const [draft, setDraft] = useState(() => provider ?? newProvider()); + const [modelsText, setModelsText] = useState(""); + + useEffect(() => { + const next = provider ?? newProvider(); + setDraft(next); + setModelsText(next.models.join("\n")); + }, [provider]); + + const updateDraft = (partial: Partial) => { + setDraft((prev) => ({ ...prev, ...partial })); + }; + + const handleSave = () => { + onSave({ + ...draft, + models: normalizeModels(modelsText), + }); + }; + + return ( + + + + {provider?.apiKeySet ? t("setting.ai.edit-provider") : t("setting.ai.add-provider")} + {t("setting.ai.dialog-description")} + + +
+
+ + updateDraft({ title: e.target.value })} placeholder="OpenAI" /> +
+ +
+ + +
+ +
+ + updateDraft({ endpoint: e.target.value })} + placeholder={draft.type === InstanceSetting_AIProviderType.OPENAI ? "https://api.openai.com/v1" : "https://example.com/v1"} + /> +
+ +
+ + updateDraft({ apiKey: e.target.value })} + placeholder={draft.apiKeySet ? t("setting.ai.keep-api-key") : ""} + /> + {draft.apiKeySet && ( +

{t("setting.ai.current-key", { key: draft.apiKeyHint || "-" })}

+ )} +
+ +
+ +