From 101704c8eac17c7f34508d8db6c53bc972061cdb Mon Sep 17 00:00:00 2001 From: boojack Date: Mon, 13 Apr 2026 22:09:24 +0800 Subject: [PATCH] feat(ai): add BYOK audio transcription (#5832) --- internal/ai/ai.go | 16 +- internal/ai/models.go | 22 +++ proto/api/v1/ai_service.proto | 7 +- proto/api/v1/instance_service.proto | 6 +- proto/gen/api/v1/ai_service.pb.go | 22 +-- proto/gen/api/v1/instance_service.pb.go | 46 ++--- proto/gen/openapi.yaml | 11 -- proto/gen/store/instance_setting.pb.go | 44 +---- proto/store/instance_setting.proto | 6 +- server/router/api/v1/ai_service.go | 43 ++--- server/router/api/v1/instance_service.go | 78 +++------ server/router/api/v1/test/ai_service_test.go | 159 +++++++++--------- .../api/v1/test/instance_service_test.go | 78 ++------- store/test/instance_setting_test.go | 26 ++- .../components/AudioRecorderPanel.tsx | 59 +++++-- .../MemoEditor/hooks/useAudioRecorder.ts | 32 ++-- web/src/components/MemoEditor/index.tsx | 123 ++++++++++++-- .../MemoEditor/services/errorService.ts | 4 + .../components/MemoEditor/services/index.ts | 1 + .../services/transcriptionService.ts | 26 +++ .../components/MemoEditor/types/components.ts | 3 + web/src/components/Settings/AISection.tsx | 112 +++++------- web/src/contexts/InstanceContext.tsx | 26 ++- web/src/locales/en.json | 25 ++- web/src/locales/zh-Hans.json | 33 ++++ web/src/locales/zh-Hant.json | 33 ++++ web/src/pages/Setting.tsx | 4 +- web/src/types/proto/api/v1/ai_service_pb.ts | 13 +- .../types/proto/api/v1/instance_service_pb.ts | 26 +-- 29 files changed, 564 insertions(+), 520 deletions(-) create mode 100644 internal/ai/models.go create mode 100644 web/src/components/MemoEditor/services/transcriptionService.ts diff --git a/internal/ai/ai.go b/internal/ai/ai.go index 948cb487f..7e16f5b43 100644 --- a/internal/ai/ai.go +++ b/internal/ai/ai.go @@ -6,21 +6,15 @@ type ProviderType string const ( // ProviderOpenAI is OpenAI's hosted API. ProviderOpenAI ProviderType = "OPENAI" - // ProviderOpenAICompatible is an OpenAI-compatible API endpoint. - ProviderOpenAICompatible ProviderType = "OPENAI_COMPATIBLE" // ProviderGemini is Google's Gemini API. ProviderGemini ProviderType = "GEMINI" - // ProviderAnthropic is Anthropic's API. - ProviderAnthropic ProviderType = "ANTHROPIC" ) // ProviderConfig configures a callable AI provider connection. type ProviderConfig struct { - ID string - Title string - Type ProviderType - Endpoint string - APIKey string - Models []string - DefaultModel string + ID string + Title string + Type ProviderType + Endpoint string + APIKey string } diff --git a/internal/ai/models.go b/internal/ai/models.go new file mode 100644 index 000000000..b0855c321 --- /dev/null +++ b/internal/ai/models.go @@ -0,0 +1,22 @@ +package ai + +import "github.com/pkg/errors" + +const ( + // DefaultOpenAITranscriptionModel is the built-in OpenAI transcription model. + DefaultOpenAITranscriptionModel = "gpt-4o-transcribe" + // DefaultGeminiTranscriptionModel is the built-in Gemini transcription model. + DefaultGeminiTranscriptionModel = "gemini-2.5-flash" +) + +// DefaultTranscriptionModel returns the built-in transcription model for a provider. +func DefaultTranscriptionModel(providerType ProviderType) (string, error) { + switch providerType { + case ProviderOpenAI: + return DefaultOpenAITranscriptionModel, nil + case ProviderGemini: + return DefaultGeminiTranscriptionModel, nil + default: + return "", errors.Wrapf(ErrCapabilityUnsupported, "provider type %q", providerType) + } +} diff --git a/proto/api/v1/ai_service.proto b/proto/api/v1/ai_service.proto index 3e908e809..82c9386eb 100644 --- a/proto/api/v1/ai_service.proto +++ b/proto/api/v1/ai_service.proto @@ -31,14 +31,11 @@ message TranscribeRequest { } message TranscriptionConfig { - // Optional. The model to use. If empty, the provider's default model is used. - string model = 1 [(google.api.field_behavior) = OPTIONAL]; - // Optional. A prompt to improve transcription quality. - string prompt = 2 [(google.api.field_behavior) = OPTIONAL]; + string prompt = 1 [(google.api.field_behavior) = OPTIONAL]; // Optional. The language of the input audio. - string language = 3 [(google.api.field_behavior) = OPTIONAL]; + string language = 2 [(google.api.field_behavior) = OPTIONAL]; } message TranscriptionAudio { diff --git a/proto/api/v1/instance_service.proto b/proto/api/v1/instance_service.proto index 9d73a8968..91bbcf632 100644 --- a/proto/api/v1/instance_service.proto +++ b/proto/api/v1/instance_service.proto @@ -219,8 +219,6 @@ message InstanceSetting { string endpoint = 4; // api_key is write-only and is never returned by GetInstanceSetting. string api_key = 5 [(google.api.field_behavior) = INPUT_ONLY]; - repeated string models = 6; - string default_model = 7; // api_key_set indicates whether an API key is stored for this provider. bool api_key_set = 8 [(google.api.field_behavior) = OUTPUT_ONLY]; // api_key_hint is a masked hint for the stored API key. @@ -231,9 +229,7 @@ message InstanceSetting { enum AIProviderType { AI_PROVIDER_TYPE_UNSPECIFIED = 0; OPENAI = 1; - OPENAI_COMPATIBLE = 2; - GEMINI = 3; - ANTHROPIC = 4; + GEMINI = 2; } } diff --git a/proto/gen/api/v1/ai_service.pb.go b/proto/gen/api/v1/ai_service.pb.go index 3af97e744..6656c0953 100644 --- a/proto/gen/api/v1/ai_service.pb.go +++ b/proto/gen/api/v1/ai_service.pb.go @@ -87,12 +87,10 @@ func (x *TranscribeRequest) GetAudio() *TranscriptionAudio { type TranscriptionConfig struct { state protoimpl.MessageState `protogen:"open.v1"` - // Optional. The model to use. If empty, the provider's default model is used. - Model string `protobuf:"bytes,1,opt,name=model,proto3" json:"model,omitempty"` // Optional. A prompt to improve transcription quality. - Prompt string `protobuf:"bytes,2,opt,name=prompt,proto3" json:"prompt,omitempty"` + Prompt string `protobuf:"bytes,1,opt,name=prompt,proto3" json:"prompt,omitempty"` // Optional. The language of the input audio. - Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + Language string `protobuf:"bytes,2,opt,name=language,proto3" json:"language,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -127,13 +125,6 @@ func (*TranscriptionConfig) Descriptor() ([]byte, []int) { return file_api_v1_ai_service_proto_rawDescGZIP(), []int{1} } -func (x *TranscriptionConfig) GetModel() string { - if x != nil { - return x.Model - } - return "" -} - func (x *TranscriptionConfig) GetPrompt() string { if x != nil { return x.Prompt @@ -304,11 +295,10 @@ const file_api_v1_ai_service_proto_rawDesc = "" + "\vprovider_id\x18\x01 \x01(\tB\x03\xe0A\x02R\n" + "providerId\x12>\n" + "\x06config\x18\x02 \x01(\v2!.memos.api.v1.TranscriptionConfigB\x03\xe0A\x02R\x06config\x12;\n" + - "\x05audio\x18\x03 \x01(\v2 .memos.api.v1.TranscriptionAudioB\x03\xe0A\x02R\x05audio\"n\n" + - "\x13TranscriptionConfig\x12\x19\n" + - "\x05model\x18\x01 \x01(\tB\x03\xe0A\x01R\x05model\x12\x1b\n" + - "\x06prompt\x18\x02 \x01(\tB\x03\xe0A\x01R\x06prompt\x12\x1f\n" + - "\blanguage\x18\x03 \x01(\tB\x03\xe0A\x01R\blanguage\"\x9c\x01\n" + + "\x05audio\x18\x03 \x01(\v2 .memos.api.v1.TranscriptionAudioB\x03\xe0A\x02R\x05audio\"S\n" + + "\x13TranscriptionConfig\x12\x1b\n" + + "\x06prompt\x18\x01 \x01(\tB\x03\xe0A\x01R\x06prompt\x12\x1f\n" + + "\blanguage\x18\x02 \x01(\tB\x03\xe0A\x01R\blanguage\"\x9c\x01\n" + "\x12TranscriptionAudio\x12\x1f\n" + "\acontent\x18\x01 \x01(\fB\x03\xe0A\x04H\x00R\acontent\x12\x12\n" + "\x03uri\x18\x02 \x01(\tH\x00R\x03uri\x12\x1f\n" + diff --git a/proto/gen/api/v1/instance_service.pb.go b/proto/gen/api/v1/instance_service.pb.go index 3da7dc902..d308d1ec2 100644 --- a/proto/gen/api/v1/instance_service.pb.go +++ b/proto/gen/api/v1/instance_service.pb.go @@ -98,9 +98,7 @@ type InstanceSetting_AIProviderType int32 const ( InstanceSetting_AI_PROVIDER_TYPE_UNSPECIFIED InstanceSetting_AIProviderType = 0 InstanceSetting_OPENAI InstanceSetting_AIProviderType = 1 - InstanceSetting_OPENAI_COMPATIBLE InstanceSetting_AIProviderType = 2 - InstanceSetting_GEMINI InstanceSetting_AIProviderType = 3 - InstanceSetting_ANTHROPIC InstanceSetting_AIProviderType = 4 + InstanceSetting_GEMINI InstanceSetting_AIProviderType = 2 ) // Enum value maps for InstanceSetting_AIProviderType. @@ -108,16 +106,12 @@ var ( InstanceSetting_AIProviderType_name = map[int32]string{ 0: "AI_PROVIDER_TYPE_UNSPECIFIED", 1: "OPENAI", - 2: "OPENAI_COMPATIBLE", - 3: "GEMINI", - 4: "ANTHROPIC", + 2: "GEMINI", } InstanceSetting_AIProviderType_value = map[string]int32{ "AI_PROVIDER_TYPE_UNSPECIFIED": 0, "OPENAI": 1, - "OPENAI_COMPATIBLE": 2, - "GEMINI": 3, - "ANTHROPIC": 4, + "GEMINI": 2, } ) @@ -1036,9 +1030,7 @@ type InstanceSetting_AIProviderConfig struct { Type InstanceSetting_AIProviderType `protobuf:"varint,3,opt,name=type,proto3,enum=memos.api.v1.InstanceSetting_AIProviderType" json:"type,omitempty"` Endpoint string `protobuf:"bytes,4,opt,name=endpoint,proto3" json:"endpoint,omitempty"` // api_key is write-only and is never returned by GetInstanceSetting. - ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` - Models []string `protobuf:"bytes,6,rep,name=models,proto3" json:"models,omitempty"` - DefaultModel string `protobuf:"bytes,7,opt,name=default_model,json=defaultModel,proto3" json:"default_model,omitempty"` + ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` // api_key_set indicates whether an API key is stored for this provider. ApiKeySet bool `protobuf:"varint,8,opt,name=api_key_set,json=apiKeySet,proto3" json:"api_key_set,omitempty"` // api_key_hint is a masked hint for the stored API key. @@ -1112,20 +1104,6 @@ func (x *InstanceSetting_AIProviderConfig) GetApiKey() string { return "" } -func (x *InstanceSetting_AIProviderConfig) GetModels() []string { - if x != nil { - return x.Models - } - return nil -} - -func (x *InstanceSetting_AIProviderConfig) GetDefaultModel() string { - if x != nil { - return x.DefaultModel - } - return "" -} - func (x *InstanceSetting_AIProviderConfig) GetApiKeySet() bool { if x != nil { return x.ApiKeySet @@ -1414,7 +1392,7 @@ const file_api_v1_instance_service_proto_rawDesc = "" + "\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" + "\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" + "\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" + - "\x19GetInstanceProfileRequest\"\xe2\x1a\n" + + "\x19GetInstanceProfileRequest\"\xff\x19\n" + "\x0fInstanceSetting\x12\x17\n" + "\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" + "\x0fgeneral_setting\x18\x02 \x01(\v2,.memos.api.v1.InstanceSetting.GeneralSettingH\x00R\x0egeneralSetting\x12W\n" + @@ -1483,15 +1461,13 @@ const file_api_v1_instance_service_proto_rawDesc = "" + "\ause_ssl\x18\n" + " \x01(\bR\x06useSsl\x1aY\n" + "\tAISetting\x12L\n" + - "\tproviders\x18\x01 \x03(\v2..memos.api.v1.InstanceSetting.AIProviderConfigR\tproviders\x1a\xbd\x02\n" + + "\tproviders\x18\x01 \x03(\v2..memos.api.v1.InstanceSetting.AIProviderConfigR\tproviders\x1a\x80\x02\n" + "\x10AIProviderConfig\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + "\x05title\x18\x02 \x01(\tR\x05title\x12@\n" + "\x04type\x18\x03 \x01(\x0e2,.memos.api.v1.InstanceSetting.AIProviderTypeR\x04type\x12\x1a\n" + "\bendpoint\x18\x04 \x01(\tR\bendpoint\x12\x1c\n" + - "\aapi_key\x18\x05 \x01(\tB\x03\xe0A\x04R\x06apiKey\x12\x16\n" + - "\x06models\x18\x06 \x03(\tR\x06models\x12#\n" + - "\rdefault_model\x18\a \x01(\tR\fdefaultModel\x12#\n" + + "\aapi_key\x18\x05 \x01(\tB\x03\xe0A\x04R\x06apiKey\x12#\n" + "\vapi_key_set\x18\b \x01(\bB\x03\xe0A\x03R\tapiKeySet\x12%\n" + "\fapi_key_hint\x18\t \x01(\tB\x03\xe0A\x03R\n" + "apiKeyHint\"j\n" + @@ -1502,15 +1478,13 @@ const file_api_v1_instance_service_proto_rawDesc = "" + "\fMEMO_RELATED\x10\x03\x12\b\n" + "\x04TAGS\x10\x04\x12\x10\n" + "\fNOTIFICATION\x10\x05\x12\x06\n" + - "\x02AI\x10\x06\"p\n" + + "\x02AI\x10\x06\"J\n" + "\x0eAIProviderType\x12 \n" + "\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" + "\n" + - "\x06OPENAI\x10\x01\x12\x15\n" + - "\x11OPENAI_COMPATIBLE\x10\x02\x12\n" + + "\x06OPENAI\x10\x01\x12\n" + "\n" + - "\x06GEMINI\x10\x03\x12\r\n" + - "\tANTHROPIC\x10\x04:a\xeaA^\n" + + "\x06GEMINI\x10\x02:a\xeaA^\n" + "\x1cmemos.api.v1/InstanceSetting\x12\x1binstance/settings/{setting}*\x10instanceSettings2\x0finstanceSettingB\a\n" + "\x05value\"U\n" + "\x19GetInstanceSettingRequest\x128\n" + diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index b4148efa7..6d3b444c0 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -2419,9 +2419,7 @@ components: enum: - AI_PROVIDER_TYPE_UNSPECIFIED - OPENAI - - OPENAI_COMPATIBLE - GEMINI - - ANTHROPIC type: string format: enum endpoint: @@ -2430,12 +2428,6 @@ components: writeOnly: true type: string description: api_key is write-only and is never returned by GetInstanceSetting. - models: - type: array - items: - type: string - defaultModel: - type: string apiKeySet: readOnly: true type: boolean @@ -3261,9 +3253,6 @@ components: TranscriptionConfig: type: object properties: - model: - type: string - description: Optional. The model to use. If empty, the provider's default model is used. prompt: type: string description: Optional. A prompt to improve transcription quality. diff --git a/proto/gen/store/instance_setting.pb.go b/proto/gen/store/instance_setting.pb.go index b826b364a..8cf5ff605 100644 --- a/proto/gen/store/instance_setting.pb.go +++ b/proto/gen/store/instance_setting.pb.go @@ -98,9 +98,7 @@ type AIProviderType int32 const ( AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED AIProviderType = 0 AIProviderType_OPENAI AIProviderType = 1 - AIProviderType_OPENAI_COMPATIBLE AIProviderType = 2 - AIProviderType_GEMINI AIProviderType = 3 - AIProviderType_ANTHROPIC AIProviderType = 4 + AIProviderType_GEMINI AIProviderType = 2 ) // Enum value maps for AIProviderType. @@ -108,16 +106,12 @@ var ( AIProviderType_name = map[int32]string{ 0: "AI_PROVIDER_TYPE_UNSPECIFIED", 1: "OPENAI", - 2: "OPENAI_COMPATIBLE", - 3: "GEMINI", - 4: "ANTHROPIC", + 2: "GEMINI", } AIProviderType_value = map[string]int32{ "AI_PROVIDER_TYPE_UNSPECIFIED": 0, "OPENAI": 1, - "OPENAI_COMPATIBLE": 2, - "GEMINI": 3, - "ANTHROPIC": 4, + "GEMINI": 2, } ) @@ -1026,9 +1020,7 @@ type AIProviderConfig struct { Type AIProviderType `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.AIProviderType" json:"type,omitempty"` Endpoint string `protobuf:"bytes,4,opt,name=endpoint,proto3" json:"endpoint,omitempty"` // api_key is write-only at the API layer and is required by the server to call providers. - ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` - Models []string `protobuf:"bytes,6,rep,name=models,proto3" json:"models,omitempty"` - DefaultModel string `protobuf:"bytes,7,opt,name=default_model,json=defaultModel,proto3" json:"default_model,omitempty"` + ApiKey string `protobuf:"bytes,5,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1098,20 +1090,6 @@ func (x *AIProviderConfig) GetApiKey() string { return "" } -func (x *AIProviderConfig) GetModels() []string { - if x != nil { - return x.Models - } - return nil -} - -func (x *AIProviderConfig) GetDefaultModel() string { - if x != nil { - return x.DefaultModel - } - return "" -} - type InstanceNotificationSetting_EmailSetting struct { state protoimpl.MessageState `protogen:"open.v1"` Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` @@ -1307,15 +1285,13 @@ const file_store_instance_setting_proto_rawDesc = "" + "\ause_ssl\x18\n" + " \x01(\bR\x06useSsl\"P\n" + "\x11InstanceAISetting\x12;\n" + - "\tproviders\x18\x01 \x03(\v2\x1d.memos.store.AIProviderConfigR\tproviders\"\xdb\x01\n" + + "\tproviders\x18\x01 \x03(\v2\x1d.memos.store.AIProviderConfigR\tproviders\"\x9e\x01\n" + "\x10AIProviderConfig\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + "\x05title\x18\x02 \x01(\tR\x05title\x12/\n" + "\x04type\x18\x03 \x01(\x0e2\x1b.memos.store.AIProviderTypeR\x04type\x12\x1a\n" + "\bendpoint\x18\x04 \x01(\tR\bendpoint\x12\x17\n" + - "\aapi_key\x18\x05 \x01(\tR\x06apiKey\x12\x16\n" + - "\x06models\x18\x06 \x03(\tR\x06models\x12#\n" + - "\rdefault_model\x18\a \x01(\tR\fdefaultModel*\x95\x01\n" + + "\aapi_key\x18\x05 \x01(\tR\x06apiKey*\x95\x01\n" + "\x12InstanceSettingKey\x12$\n" + " INSTANCE_SETTING_KEY_UNSPECIFIED\x10\x00\x12\t\n" + "\x05BASIC\x10\x01\x12\v\n" + @@ -1324,15 +1300,13 @@ const file_store_instance_setting_proto_rawDesc = "" + "\fMEMO_RELATED\x10\x04\x12\b\n" + "\x04TAGS\x10\x05\x12\x10\n" + "\fNOTIFICATION\x10\x06\x12\x06\n" + - "\x02AI\x10\a*p\n" + + "\x02AI\x10\a*J\n" + "\x0eAIProviderType\x12 \n" + "\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" + "\n" + - "\x06OPENAI\x10\x01\x12\x15\n" + - "\x11OPENAI_COMPATIBLE\x10\x02\x12\n" + + "\x06OPENAI\x10\x01\x12\n" + "\n" + - "\x06GEMINI\x10\x03\x12\r\n" + - "\tANTHROPIC\x10\x04B\x9f\x01\n" + + "\x06GEMINI\x10\x02B\x9f\x01\n" + "\x0fcom.memos.storeB\x14InstanceSettingProtoP\x01Z)github.com/usememos/memos/proto/gen/store\xa2\x02\x03MSX\xaa\x02\vMemos.Store\xca\x02\vMemos\\Store\xe2\x02\x17Memos\\Store\\GPBMetadata\xea\x02\fMemos::Storeb\x06proto3" var ( diff --git a/proto/store/instance_setting.proto b/proto/store/instance_setting.proto index 11a622d48..f1010b7ab 100644 --- a/proto/store/instance_setting.proto +++ b/proto/store/instance_setting.proto @@ -158,14 +158,10 @@ message AIProviderConfig { string endpoint = 4; // api_key is write-only at the API layer and is required by the server to call providers. string api_key = 5; - repeated string models = 6; - string default_model = 7; } enum AIProviderType { AI_PROVIDER_TYPE_UNSPECIFIED = 0; OPENAI = 1; - OPENAI_COMPATIBLE = 2; - GEMINI = 3; - ANTHROPIC = 4; + GEMINI = 2; } diff --git a/server/router/api/v1/ai_service.go b/server/router/api/v1/ai_service.go index 5a1e4ff05..14d80b88d 100644 --- a/server/router/api/v1/ai_service.go +++ b/server/router/api/v1/ai_service.go @@ -93,7 +93,7 @@ func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeR return nil, status.Errorf(codes.InvalidArgument, "audio content type %q is not supported", contentType) } - provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId, request.Config.GetModel()) + provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId) if err != nil { return nil, err } @@ -119,7 +119,7 @@ func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeR }, nil } -func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string, model string) (ai.ProviderConfig, string, error) { +func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string) (ai.ProviderConfig, string, error) { setting, err := s.Store.GetInstanceAISetting(ctx) if err != nil { return ai.ProviderConfig{}, "", status.Errorf(codes.Internal, "failed to get AI setting: %v", err) @@ -137,28 +137,20 @@ func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, pr if err != nil { return ai.ProviderConfig{}, "", status.Errorf(codes.NotFound, "AI provider not found") } - selectedModel := strings.TrimSpace(model) - if selectedModel == "" { - selectedModel = provider.DefaultModel - } - if selectedModel == "" { - return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model is required") - } - if !containsString(provider.Models, selectedModel) { - return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model %q is not configured for provider %q", selectedModel, provider.ID) + selectedModel, err := ai.DefaultTranscriptionModel(provider.Type) + if err != nil { + return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "%v", err) } return *provider, selectedModel, nil } func convertAIProviderConfigFromStore(provider *storepb.AIProviderConfig) ai.ProviderConfig { return ai.ProviderConfig{ - ID: provider.GetId(), - Title: provider.GetTitle(), - Type: convertAIProviderTypeFromStore(provider.GetType()), - Endpoint: provider.GetEndpoint(), - APIKey: provider.GetApiKey(), - Models: provider.GetModels(), - DefaultModel: provider.GetDefaultModel(), + ID: provider.GetId(), + Title: provider.GetTitle(), + Type: convertAIProviderTypeFromStore(provider.GetType()), + Endpoint: provider.GetEndpoint(), + APIKey: provider.GetApiKey(), } } @@ -166,12 +158,8 @@ func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.Prov switch providerType { case storepb.AIProviderType_OPENAI: return ai.ProviderOpenAI - case storepb.AIProviderType_OPENAI_COMPATIBLE: - return ai.ProviderOpenAICompatible case storepb.AIProviderType_GEMINI: return ai.ProviderGemini - case storepb.AIProviderType_ANTHROPIC: - return ai.ProviderAnthropic default: return "" } @@ -179,7 +167,7 @@ func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.Prov func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) { switch provider.Type { - case ai.ProviderOpenAI, ai.ProviderOpenAICompatible: + case ai.ProviderOpenAI: return openai.NewTranscriber(provider) case ai.ProviderGemini: return gemini.NewTranscriber(provider) @@ -188,15 +176,6 @@ func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) { } } -func containsString(values []string, target string) bool { - for _, value := range values { - if value == target { - return true - } - } - return false -} - func isSupportedTranscriptionContentType(contentType string) bool { mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType)) if err != nil { diff --git a/server/router/api/v1/instance_service.go b/server/router/api/v1/instance_service.go index cc1b9af22..f4abf54f7 100644 --- a/server/router/api/v1/instance_service.go +++ b/server/router/api/v1/instance_service.go @@ -5,7 +5,6 @@ import ( "fmt" "math" "regexp" - "slices" "strings" "github.com/lithammer/shortuuid/v4" @@ -75,10 +74,9 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get return nil, status.Errorf(codes.NotFound, "instance setting not found") } - // Storage, notification, and AI settings contain credentials; restrict to admins only. + // Storage and notification settings contain credentials; restrict to admins only. if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE || - instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION || - instanceSetting.Key == storepb.InstanceSettingKey_AI { + instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION { user, err := s.fetchCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) @@ -90,6 +88,15 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get return nil, status.Errorf(codes.PermissionDenied, "permission denied") } } + if instanceSetting.Key == storepb.InstanceSettingKey_AI { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + } return convertInstanceSettingFromStore(instanceSetting), nil } @@ -429,14 +436,12 @@ func convertInstanceAISettingFromStore(setting *storepb.InstanceAISetting) *v1pb } apiKey := provider.GetApiKey() aiSetting.Providers = append(aiSetting.Providers, &v1pb.InstanceSetting_AIProviderConfig{ - Id: provider.GetId(), - Title: provider.GetTitle(), - Type: v1pb.InstanceSetting_AIProviderType(provider.GetType()), - Endpoint: provider.GetEndpoint(), - Models: provider.GetModels(), - DefaultModel: provider.GetDefaultModel(), - ApiKeySet: apiKey != "", - ApiKeyHint: maskAPIKey(apiKey), + Id: provider.GetId(), + Title: provider.GetTitle(), + Type: v1pb.InstanceSetting_AIProviderType(provider.GetType()), + Endpoint: provider.GetEndpoint(), + ApiKeySet: apiKey != "", + ApiKeyHint: maskAPIKey(apiKey), }) } return aiSetting @@ -455,13 +460,11 @@ func convertInstanceAISettingToStore(setting *v1pb.InstanceSetting_AISetting) *s continue } aiSetting.Providers = append(aiSetting.Providers, &storepb.AIProviderConfig{ - Id: provider.GetId(), - Title: provider.GetTitle(), - Type: storepb.AIProviderType(provider.GetType()), - Endpoint: provider.GetEndpoint(), - ApiKey: provider.GetApiKey(), - Models: provider.GetModels(), - DefaultModel: provider.GetDefaultModel(), + Id: provider.GetId(), + Title: provider.GetTitle(), + Type: storepb.AIProviderType(provider.GetType()), + Endpoint: provider.GetEndpoint(), + ApiKey: provider.GetApiKey(), }) } return aiSetting @@ -515,31 +518,16 @@ func (s *APIV1Service) prepareInstanceAISettingForUpdate(ctx context.Context, se if provider.Title == "" { return errors.New("provider title is required") } - if provider.Type == storepb.AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED { - return errors.Errorf("provider %q type is required", provider.Id) + if provider.Type != storepb.AIProviderType_OPENAI && provider.Type != storepb.AIProviderType_GEMINI { + return errors.Errorf("provider %q has unsupported type", provider.Id) } provider.Endpoint = strings.TrimSpace(provider.Endpoint) if provider.Type == storepb.AIProviderType_OPENAI && provider.Endpoint == "" { provider.Endpoint = "https://api.openai.com/v1" } - if provider.Type == storepb.AIProviderType_ANTHROPIC && provider.Endpoint == "" { - provider.Endpoint = "https://api.anthropic.com/v1" - } - if provider.Type == storepb.AIProviderType_OPENAI_COMPATIBLE && provider.Endpoint == "" { - return errors.Errorf("provider %q endpoint is required", provider.Id) - } - - provider.Models = normalizeAIModels(provider.Models) - if len(provider.Models) == 0 { - return errors.Errorf("provider %q must define at least one model", provider.Id) - } - provider.DefaultModel = strings.TrimSpace(provider.DefaultModel) - if provider.DefaultModel == "" { - provider.DefaultModel = provider.Models[0] - } - if !slices.Contains(provider.Models, provider.DefaultModel) { - return errors.Errorf("provider %q default model %q must be included in models", provider.Id, provider.DefaultModel) + if provider.Type == storepb.AIProviderType_GEMINI && provider.Endpoint == "" { + provider.Endpoint = "https://generativelanguage.googleapis.com/v1beta" } if provider.ApiKey == "" { @@ -554,20 +542,6 @@ func (s *APIV1Service) prepareInstanceAISettingForUpdate(ctx context.Context, se return nil } -func normalizeAIModels(models []string) []string { - normalized := []string{} - seen := map[string]bool{} - for _, model := range models { - model = strings.TrimSpace(model) - if model == "" || seen[model] { - continue - } - seen[model] = true - normalized = append(normalized, model) - } - return normalized -} - func maskAPIKey(apiKey string) string { if apiKey == "" { return "" diff --git a/server/router/api/v1/test/ai_service_test.go b/server/router/api/v1/test/ai_service_test.go index 53a96803e..aac7d0644 100644 --- a/server/router/api/v1/test/ai_service_test.go +++ b/server/router/api/v1/test/ai_service_test.go @@ -22,9 +22,7 @@ func TestTranscribe(t *testing.T) { _, err := ts.Service.Transcribe(ctx, &v1pb.TranscribeRequest{ ProviderId: "openai-main", - Config: &v1pb.TranscriptionConfig{ - Model: "gpt-4o-transcribe", - }, + Config: &v1pb.TranscriptionConfig{}, Audio: &v1pb.TranscriptionAudio{ Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, Filename: "voice.wav", @@ -68,13 +66,11 @@ func TestTranscribe(t *testing.T) { AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: storepb.AIProviderType_OPENAI_COMPATIBLE, - Endpoint: openAIServer.URL, - ApiKey: "sk-test", - Models: []string{"gpt-4o-transcribe"}, - DefaultModel: "gpt-4o-transcribe", + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: openAIServer.URL, + ApiKey: "sk-test", }, }, }, @@ -97,29 +93,16 @@ func TestTranscribe(t *testing.T) { require.Equal(t, "transcribed text", resp.Text) }) - t.Run("transcribes audio file with Gemini provider", func(t *testing.T) { + t.Run("returns provider error without rewriting it", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() - user, err := ts.CreateRegularUser(ctx, "gemini-user") + user, err := ts.CreateRegularUser(ctx, "notfound-user") require.NoError(t, err) userCtx := ts.CreateUserContext(ctx, user.ID) - geminiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/v1beta/models/gemini-2.5-flash:generateContent", r.URL.Path) - require.Equal(t, "gemini-key", r.Header.Get("x-goog-api-key")) - w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ - "candidates": []map[string]any{ - { - "content": map[string]any{ - "parts": []map[string]string{{"text": "gemini transcript"}}, - }, - }, - }, - })) - })) - defer geminiServer.Close() + openAIServer := httptest.NewServer(http.NotFoundHandler()) + defer openAIServer.Close() _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ Key: storepb.InstanceSettingKey_AI, @@ -127,13 +110,11 @@ func TestTranscribe(t *testing.T) { AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "gemini-main", - Title: "Gemini", - Type: storepb.AIProviderType_GEMINI, - Endpoint: geminiServer.URL + "/v1beta", - ApiKey: "gemini-key", - Models: []string{"gemini-2.5-flash"}, - DefaultModel: "gemini-2.5-flash", + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: openAIServer.URL, + ApiKey: "sk-test", }, }, }, @@ -141,40 +122,54 @@ func TestTranscribe(t *testing.T) { }) require.NoError(t, err) - resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ - ProviderId: "gemini-main", + _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + ProviderId: "openai-main", Config: &v1pb.TranscriptionConfig{}, Audio: &v1pb.TranscriptionAudio{ - Source: &v1pb.TranscriptionAudio_Content{Content: []byte("mp3 bytes")}, - Filename: "voice.mp3", - ContentType: "audio/mp3", + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, + Filename: "voice.wav", + ContentType: "audio/wav", }, }) - require.NoError(t, err) - require.Equal(t, "gemini transcript", resp.Text) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to transcribe audio") }) - t.Run("rejects Anthropic transcription as unsupported", func(t *testing.T) { + t.Run("transcribes audio file with Gemini provider", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() - user, err := ts.CreateRegularUser(ctx, "anthropic-user") + user, err := ts.CreateRegularUser(ctx, "gemini-user") require.NoError(t, err) userCtx := ts.CreateUserContext(ctx, user.ID) + geminiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1beta/models/gemini-2.5-flash:generateContent", r.URL.Path) + require.Equal(t, "gemini-key", r.Header.Get("x-goog-api-key")) + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "candidates": []map[string]any{ + { + "content": map[string]any{ + "parts": []map[string]string{{"text": "gemini transcript"}}, + }, + }, + }, + })) + })) + defer geminiServer.Close() + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ Key: storepb.InstanceSettingKey_AI, Value: &storepb.InstanceSetting_AiSetting{ AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "anthropic-main", - Title: "Anthropic", - Type: storepb.AIProviderType_ANTHROPIC, - Endpoint: "https://api.anthropic.com/v1", - ApiKey: "sk-ant-test", - Models: []string{"claude-sonnet-4-5"}, - DefaultModel: "claude-sonnet-4-5", + Id: "gemini-main", + Title: "Gemini", + Type: storepb.AIProviderType_GEMINI, + Endpoint: geminiServer.URL + "/v1beta", + ApiKey: "gemini-key", }, }, }, @@ -182,20 +177,20 @@ func TestTranscribe(t *testing.T) { }) require.NoError(t, err) - _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ - ProviderId: "anthropic-main", + resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + ProviderId: "gemini-main", Config: &v1pb.TranscriptionConfig{}, Audio: &v1pb.TranscriptionAudio{ - Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, - Filename: "voice.wav", - ContentType: "audio/wav", + Source: &v1pb.TranscriptionAudio_Content{Content: []byte("mp3 bytes")}, + Filename: "voice.mp3", + ContentType: "audio/mp3", }, }) - require.Error(t, err) - require.Contains(t, err.Error(), "capability unsupported") + require.NoError(t, err) + require.Equal(t, "gemini transcript", resp.Text) }) - t.Run("rejects unconfigured model", func(t *testing.T) { + t.Run("uses built-in transcription model", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -203,19 +198,27 @@ func TestTranscribe(t *testing.T) { require.NoError(t, err) userCtx := ts.CreateUserContext(ctx, user.ID) + openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, r.ParseMultipartForm(10<<20)) + require.Equal(t, "gpt-4o-transcribe", r.FormValue("model")) + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]string{ + "text": "built-in model", + })) + })) + defer openAIServer.Close() + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ Key: storepb.InstanceSettingKey_AI, Value: &storepb.InstanceSetting_AiSetting{ AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: storepb.AIProviderType_OPENAI_COMPATIBLE, - Endpoint: "https://example.com/v1", - ApiKey: "sk-test", - Models: []string{"gpt-4o-transcribe"}, - DefaultModel: "gpt-4o-transcribe", + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: openAIServer.URL, + ApiKey: "sk-test", }, }, }, @@ -223,19 +226,17 @@ func TestTranscribe(t *testing.T) { }) require.NoError(t, err) - _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ + resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ ProviderId: "openai-main", - Config: &v1pb.TranscriptionConfig{ - Model: "other-model", - }, + Config: &v1pb.TranscriptionConfig{}, Audio: &v1pb.TranscriptionAudio{ Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")}, Filename: "voice.wav", ContentType: "audio/wav", }, }) - require.Error(t, err) - require.Contains(t, err.Error(), "not configured") + require.NoError(t, err) + require.Equal(t, "built-in model", resp.Text) }) t.Run("rejects non-audio content before provider call", func(t *testing.T) { @@ -252,13 +253,11 @@ func TestTranscribe(t *testing.T) { AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: storepb.AIProviderType_OPENAI_COMPATIBLE, - Endpoint: "https://example.com/v1", - ApiKey: "sk-test", - Models: []string{"gpt-4o-transcribe"}, - DefaultModel: "gpt-4o-transcribe", + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: "https://example.com/v1", + ApiKey: "sk-test", }, }, }, @@ -268,9 +267,7 @@ func TestTranscribe(t *testing.T) { _, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{ ProviderId: "openai-main", - Config: &v1pb.TranscriptionConfig{ - Model: "gpt-4o-transcribe", - }, + Config: &v1pb.TranscriptionConfig{}, Audio: &v1pb.TranscriptionAudio{ Source: &v1pb.TranscriptionAudio_Content{Content: []byte("not audio")}, Filename: "notes.txt", diff --git a/server/router/api/v1/test/instance_service_test.go b/server/router/api/v1/test/instance_service_test.go index 58693544b..7160dff2f 100644 --- a/server/router/api/v1/test/instance_service_test.go +++ b/server/router/api/v1/test/instance_service_test.go @@ -238,7 +238,7 @@ func TestGetInstanceSetting(t *testing.T) { "SmtpPassword must never be returned in responses") }) - t.Run("GetInstanceSetting - AI setting requires admin", func(t *testing.T) { + t.Run("GetInstanceSetting - AI setting requires authenticated user", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -256,11 +256,12 @@ func TestGetInstanceSetting(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "not authenticated") - _, err = ts.Service.GetInstanceSetting(userCtx, req) - require.Error(t, err) - require.Contains(t, err.Error(), "permission denied") + resp, err := ts.Service.GetInstanceSetting(userCtx, req) + require.NoError(t, err) + require.NotNil(t, resp.GetAiSetting()) + require.Empty(t, resp.GetAiSetting().GetProviders()) - resp, err := ts.Service.GetInstanceSetting(adminCtx, req) + resp, err = ts.Service.GetInstanceSetting(adminCtx, req) require.NoError(t, err) require.NotNil(t, resp.GetAiSetting()) require.Empty(t, resp.GetAiSetting().GetProviders()) @@ -300,12 +301,10 @@ func TestUpdateInstanceSetting(t *testing.T) { AiSetting: &v1pb.InstanceSetting_AISetting{ Providers: []*v1pb.InstanceSetting_AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: v1pb.InstanceSetting_OPENAI, - ApiKey: "sk-test", - Models: []string{"gpt-4o-transcribe"}, - DefaultModel: "gpt-4o-transcribe", + Id: "openai-main", + Title: "OpenAI", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "sk-test", }, }, }, @@ -569,12 +568,10 @@ func TestUpdateInstanceSetting(t *testing.T) { AiSetting: &v1pb.InstanceSetting_AISetting{ Providers: []*v1pb.InstanceSetting_AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: v1pb.InstanceSetting_OPENAI, - ApiKey: "sk-original", - Models: []string{"gpt-5.4", "gpt-5.4-mini"}, - DefaultModel: "gpt-5.4", + Id: "openai-main", + Title: "OpenAI", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "sk-original", }, }, }, @@ -601,12 +598,10 @@ func TestUpdateInstanceSetting(t *testing.T) { AiSetting: &v1pb.InstanceSetting_AISetting{ Providers: []*v1pb.InstanceSetting_AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI primary", - Type: v1pb.InstanceSetting_OPENAI, - ApiKey: "", - Models: []string{"gpt-5.4-mini", "gpt-5.4-mini", "gpt-5.4"}, - DefaultModel: "", + Id: "openai-main", + Title: "OpenAI primary", + Type: v1pb.InstanceSetting_OPENAI, + ApiKey: "", }, }, }, @@ -621,42 +616,5 @@ func TestUpdateInstanceSetting(t *testing.T) { require.Equal(t, "sk-original", stored.GetProviders()[0].GetApiKey(), "existing AI provider API key must be preserved when an empty value is sent") require.Equal(t, "OpenAI primary", stored.GetProviders()[0].GetTitle()) - require.Equal(t, []string{"gpt-5.4-mini", "gpt-5.4"}, stored.GetProviders()[0].GetModels()) - require.Equal(t, "gpt-5.4-mini", stored.GetProviders()[0].GetDefaultModel()) - }) - - t.Run("UpdateInstanceSetting - Anthropic provider gets default endpoint", func(t *testing.T) { - ts := NewTestService(t) - defer ts.Cleanup() - - hostUser, err := ts.CreateHostUser(ctx, "admin") - require.NoError(t, err) - adminCtx := ts.CreateUserContext(ctx, hostUser.ID) - - _, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{ - Setting: &v1pb.InstanceSetting{ - Name: "instance/settings/AI", - Value: &v1pb.InstanceSetting_AiSetting{ - AiSetting: &v1pb.InstanceSetting_AISetting{ - Providers: []*v1pb.InstanceSetting_AIProviderConfig{ - { - Id: "anthropic-main", - Title: "Anthropic", - Type: v1pb.InstanceSetting_ANTHROPIC, - ApiKey: "sk-ant-test", - Models: []string{"claude-sonnet-4-5"}, - DefaultModel: "claude-sonnet-4-5", - }, - }, - }, - }, - }, - }) - require.NoError(t, err) - - stored, err := ts.Store.GetInstanceAISetting(ctx) - require.NoError(t, err) - require.Len(t, stored.GetProviders(), 1) - require.Equal(t, "https://api.anthropic.com/v1", stored.GetProviders()[0].GetEndpoint()) }) } diff --git a/store/test/instance_setting_test.go b/store/test/instance_setting_test.go index 0fd56bff7..dd450cd3d 100644 --- a/store/test/instance_setting_test.go +++ b/store/test/instance_setting_test.go @@ -342,22 +342,18 @@ func TestInstanceSettingAISetting(t *testing.T) { AiSetting: &storepb.InstanceAISetting{ Providers: []*storepb.AIProviderConfig{ { - Id: "openai-main", - Title: "OpenAI", - Type: storepb.AIProviderType_OPENAI, - Endpoint: "https://api.openai.com/v1", - ApiKey: "sk-test", - Models: []string{"gpt-5.4", "gpt-5.4-mini"}, - DefaultModel: "gpt-5.4", + Id: "openai-main", + Title: "OpenAI", + Type: storepb.AIProviderType_OPENAI, + Endpoint: "https://api.openai.com/v1", + ApiKey: "sk-test", }, { - Id: "company-gateway", - Title: "Company Gateway", - Type: storepb.AIProviderType_OPENAI_COMPATIBLE, - Endpoint: "https://llm.example.com/v1", - ApiKey: "gw-test", - Models: []string{"qwen-plus"}, - DefaultModel: "qwen-plus", + Id: "gemini-main", + Title: "Gemini", + Type: storepb.AIProviderType_GEMINI, + Endpoint: "https://generativelanguage.googleapis.com/v1beta", + ApiKey: "gemini-test", }, }, }, @@ -370,7 +366,7 @@ func TestInstanceSettingAISetting(t *testing.T) { require.Len(t, aiSetting.Providers, 2) require.Equal(t, "openai-main", aiSetting.Providers[0].Id) require.Equal(t, "sk-test", aiSetting.Providers[0].ApiKey) - require.Equal(t, "company-gateway", aiSetting.Providers[1].Id) + require.Equal(t, "gemini-main", aiSetting.Providers[1].Id) ts.Close() } diff --git a/web/src/components/MemoEditor/components/AudioRecorderPanel.tsx b/web/src/components/MemoEditor/components/AudioRecorderPanel.tsx index d7943f399..87016fcfd 100644 --- a/web/src/components/MemoEditor/components/AudioRecorderPanel.tsx +++ b/web/src/components/MemoEditor/components/AudioRecorderPanel.tsx @@ -1,21 +1,35 @@ -import { LoaderCircleIcon, XIcon } from "lucide-react"; +import { AudioWaveformIcon, LoaderCircleIcon, SquareIcon, XIcon } from "lucide-react"; import type { FC } from "react"; import { formatAudioTime } from "@/components/MemoMetadata/Attachment/attachmentHelpers"; import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { cn } from "@/lib/utils"; import { useTranslate } from "@/utils/i18n"; import { useAudioWaveform } from "../hooks/useAudioWaveform"; import type { AudioRecorderPanelProps } from "../types/components"; import { VoiceWaveform } from "./VoiceWaveform"; -export const AudioRecorderPanel: FC = ({ audioRecorder, mediaStream, onStop, onCancel }) => { +export const AudioRecorderPanel: FC = ({ + audioRecorder, + mediaStream, + onStop, + onCancel, + onTranscribe, + canTranscribe = false, + isTranscribing = false, +}) => { const t = useTranslate(); const { status, elapsedSeconds } = audioRecorder; const isRequestingPermission = status === "requesting_permission"; const isRecording = status === "recording"; + const isTranscribeDisabled = !canTranscribe || isRequestingPermission || isTranscribing; const waveformLevels = useAudioWaveform(mediaStream, isRecording && mediaStream !== null); - const srStatusText = isRequestingPermission ? t("editor.audio-recorder.requesting-permission") : t("editor.audio-recorder.recording"); + const srStatusText = isTranscribing + ? t("editor.audio-recorder.transcribing") + : isRequestingPermission + ? t("editor.audio-recorder.requesting-permission") + : t("editor.audio-recorder.recording"); return (
= ({ audioRecorder, )} >
- {isRequestingPermission ? : null} + {isRequestingPermission || isTranscribing ? ( + + ) : null} {srStatusText} - {formatAudioTime(elapsedSeconds)} + + {isTranscribing ? t("editor.audio-recorder.transcribing") : formatAudioTime(elapsedSeconds)} +
@@ -36,22 +54,43 @@ export const AudioRecorderPanel: FC = ({ audioRecorder, type="button" variant="ghost" size="icon" - className="size-7 shrink-0 rounded-full text-muted-foreground hover:bg-accent hover:text-foreground" + className="rounded-full" onClick={onCancel} + disabled={isTranscribing} aria-label={t("common.cancel")} > - + + + + + + + + +

{canTranscribe ? t("editor.audio-recorder.transcribe") : t("editor.audio-recorder.configure-ai-provider")}

+
+
diff --git a/web/src/components/MemoEditor/hooks/useAudioRecorder.ts b/web/src/components/MemoEditor/hooks/useAudioRecorder.ts index 7df3b79c5..63722e59e 100644 --- a/web/src/components/MemoEditor/hooks/useAudioRecorder.ts +++ b/web/src/components/MemoEditor/hooks/useAudioRecorder.ts @@ -3,6 +3,7 @@ import type { LocalFile } from "../types/attachment"; import { useBlobUrls } from "./useBlobUrls"; const FALLBACK_AUDIO_MIME_TYPE = "audio/webm"; +export type AudioRecordingCompleteMode = "attach" | "transcribe"; interface AudioRecorderActions { setAudioRecorderSupport: (value: boolean) => void; @@ -10,7 +11,8 @@ interface AudioRecorderActions { setAudioRecorderStatus: (value: "idle" | "requesting_permission" | "recording" | "error" | "unsupported") => void; setAudioRecorderElapsed: (value: number) => void; setAudioRecorderError: (value?: string) => void; - onRecordingComplete: (localFile: LocalFile) => void; + onRecordingComplete: (localFile: LocalFile, mode: AudioRecordingCompleteMode) => void; + onRecordingEmpty?: (mode: AudioRecordingCompleteMode) => void; } const AUDIO_MIME_TYPE_CANDIDATES = ["audio/webm;codecs=opus", "audio/webm", "audio/mp4", "audio/ogg;codecs=opus"] as const; @@ -55,6 +57,7 @@ export const useAudioRecorder = (actions: AudioRecorderActions) => { const startedAtRef = useRef(null); const elapsedTimerRef = useRef(null); const recorderMimeTypeRef = useRef(FALLBACK_AUDIO_MIME_TYPE); + const completionModeRef = useRef("attach"); const startRequestIdRef = useRef(0); const { createBlobUrl } = useBlobUrls(); @@ -153,10 +156,13 @@ export const useAudioRecorder = (actions: AudioRecorderActions) => { const durationSeconds = startedAtRef.current ? Math.max(0, Math.round((Date.now() - startedAtRef.current) / 1000)) : 0; const blob = new Blob(chunksRef.current, { type: recorderMimeTypeRef.current }); + const completionMode = completionModeRef.current; + completionModeRef.current = "attach"; if (blob.size === 0) { actions.setAudioRecorderElapsed(0); actions.setAudioRecorderError(undefined); actions.setAudioRecorderStatus("idle"); + actions.onRecordingEmpty?.(completionMode); resetRecorderRefs(); return; } @@ -164,14 +170,17 @@ export const useAudioRecorder = (actions: AudioRecorderActions) => { const file = createRecordedFile(blob, recorderMimeTypeRef.current); const previewUrl = createBlobUrl(file); - actions.onRecordingComplete({ - file, - previewUrl, - origin: "audio_recording", - audioMeta: { - durationSeconds, + actions.onRecordingComplete( + { + file, + previewUrl, + origin: "audio_recording", + audioMeta: { + durationSeconds, + }, }, - }); + completionMode, + ); actions.setAudioRecorderElapsed(0); actions.setAudioRecorderError(undefined); actions.setAudioRecorderStatus("idle"); @@ -203,17 +212,20 @@ export const useAudioRecorder = (actions: AudioRecorderActions) => { } }; - const stopRecording = () => { + const stopRecording = (mode: AudioRecordingCompleteMode = "attach") => { if (!mediaRecorderRef.current || mediaRecorderRef.current.state === "inactive") { - return; + return false; } + completionModeRef.current = mode; cleanupTimer(); mediaRecorderRef.current.stop(); + return true; }; const resetRecording = () => { startRequestIdRef.current += 1; + completionModeRef.current = "attach"; resetRecorderRefs(); actions.setAudioRecorderElapsed(0); actions.setAudioRecorderError(undefined); diff --git a/web/src/components/MemoEditor/index.tsx b/web/src/components/MemoEditor/index.tsx index 37f7e7554..746b7df23 100644 --- a/web/src/components/MemoEditor/index.tsx +++ b/web/src/components/MemoEditor/index.tsx @@ -1,12 +1,14 @@ import { useQueryClient } from "@tanstack/react-query"; -import { useEffect, useMemo, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "react-hot-toast"; import { useAuth } from "@/contexts/AuthContext"; +import { useInstance } from "@/contexts/InstanceContext"; import useCurrentUser from "@/hooks/useCurrentUser"; import { memoKeys } from "@/hooks/useMemoQueries"; import { userKeys } from "@/hooks/useUserQueries"; import { handleError } from "@/lib/error"; import { cn } from "@/lib/utils"; +import { InstanceSetting_AIProviderType, InstanceSetting_Key } from "@/types/proto/api/v1/instance_service_pb"; import { useTranslate } from "@/utils/i18n"; import { convertVisibilityFromString } from "@/utils/memo"; import { @@ -21,9 +23,15 @@ import { import { FOCUS_MODE_STYLES } from "./constants"; import type { EditorRefActions } from "./Editor"; import { useAudioRecorder, useAutoSave, useFocusMode, useKeyboard, useMemoInit } from "./hooks"; -import { cacheService, errorService, memoService, validationService } from "./services"; +import { cacheService, errorService, memoService, transcriptionService, validationService } from "./services"; import { EditorProvider, useEditorContext } from "./state"; import type { MemoEditorProps } from "./types"; +import type { LocalFile } from "./types/attachment"; + +const TRANSCRIPTION_PROVIDER_TYPES: InstanceSetting_AIProviderType[] = [ + InstanceSetting_AIProviderType.OPENAI, + InstanceSetting_AIProviderType.GEMINI, +]; const MemoEditor = (props: MemoEditorProps) => ( @@ -47,9 +55,15 @@ const MemoEditorImpl: React.FC = ({ const editorRef = useRef(null); const { state, actions, dispatch } = useEditorContext(); const { userGeneralSetting } = useAuth(); + const { aiSetting, fetchSetting } = useInstance(); const [isAudioRecorderOpen, setIsAudioRecorderOpen] = useState(false); + const [isTranscribingAudio, setIsTranscribingAudio] = useState(false); const memoName = memo?.name; + const transcriptionProvider = useMemo( + () => aiSetting.providers.find((provider) => provider.apiKeySet && TRANSCRIPTION_PROVIDER_TYPES.includes(provider.type)), + [aiSetting.providers], + ); // Get default visibility from user settings const defaultVisibility = userGeneralSetting?.memoVisibility ? convertVisibilityFromString(userGeneralSetting.memoVisibility) : undefined; @@ -62,6 +76,62 @@ const MemoEditorImpl: React.FC = ({ // Focus mode management with body scroll lock useFocusMode(state.ui.isFocusMode); + useEffect(() => { + if (!currentUser) { + return; + } + + void fetchSetting(InstanceSetting_Key.AI).catch(() => undefined); + }, [currentUser, fetchSetting]); + + const insertTranscribedText = useCallback((text: string) => { + const editor = editorRef.current; + if (!editor) { + return; + } + + const content = editor.getContent(); + const cursor = editor.getCursorPosition(); + const beforeCursor = content.slice(0, cursor); + const afterCursor = content.slice(cursor); + const prefix = beforeCursor.length === 0 || beforeCursor.endsWith("\n\n") ? "" : beforeCursor.endsWith("\n") ? "\n" : "\n\n"; + const suffix = afterCursor.length === 0 || afterCursor.startsWith("\n\n") ? "" : afterCursor.startsWith("\n") ? "\n" : "\n\n"; + + editor.insertText(text, prefix, suffix); + editor.scrollToCursor(); + }, []); + + const handleTranscribeRecordedAudio = useCallback( + async (localFile: LocalFile) => { + if (!transcriptionProvider) { + dispatch(actions.addLocalFile(localFile)); + setIsTranscribingAudio(false); + setIsAudioRecorderOpen(false); + return; + } + + try { + const text = (await transcriptionService.transcribeFile(localFile.file, transcriptionProvider)).trim(); + if (!text) { + dispatch(actions.addLocalFile(localFile)); + toast.error(t("editor.audio-recorder.transcribe-empty")); + return; + } + + insertTranscribedText(text); + toast.success(t("editor.audio-recorder.transcribe-success")); + } catch (error) { + console.error(error); + toast.error(errorService.getErrorMessage(error) || t("editor.audio-recorder.transcribe-error")); + dispatch(actions.addLocalFile(localFile)); + } finally { + setIsTranscribingAudio(false); + setIsAudioRecorderOpen(false); + } + }, + [actions, dispatch, insertTranscribedText, t, transcriptionProvider], + ); + const audioRecorderActions = useMemo( () => ({ setAudioRecorderSupport: (value: boolean) => dispatch(actions.setAudioRecorderSupport(value)), @@ -70,12 +140,24 @@ const MemoEditorImpl: React.FC = ({ dispatch(actions.setAudioRecorderStatus(value)), setAudioRecorderElapsed: (value: number) => dispatch(actions.setAudioRecorderElapsed(value)), setAudioRecorderError: (value?: string) => dispatch(actions.setAudioRecorderError(value)), - onRecordingComplete: (localFile: (typeof state.localFiles)[number]) => { + onRecordingComplete: (localFile: LocalFile, mode: "attach" | "transcribe") => { + if (mode === "transcribe") { + void handleTranscribeRecordedAudio(localFile); + return; + } + dispatch(actions.addLocalFile(localFile)); setIsAudioRecorderOpen(false); }, + onRecordingEmpty: (mode: "attach" | "transcribe") => { + if (mode === "transcribe") { + setIsTranscribingAudio(false); + toast.error(t("editor.audio-recorder.transcribe-empty")); + } + setIsAudioRecorderOpen(false); + }, }), - [actions, dispatch, state.localFiles], + [actions, dispatch, handleTranscribeRecordedAudio, t], ); const audioRecorder = useAudioRecorder(audioRecorderActions); @@ -109,10 +191,23 @@ const MemoEditorImpl: React.FC = ({ }; const handleCancelAudioRecording = () => { + setIsTranscribingAudio(false); audioRecorder.resetRecording(); setIsAudioRecorderOpen(false); }; + const handleTranscribeAudioRecording = () => { + if (!transcriptionProvider || isTranscribingAudio) { + return; + } + + setIsTranscribingAudio(true); + const didStop = audioRecorder.stopRecording("transcribe"); + if (!didStop) { + setIsTranscribingAudio(false); + } + }; + useKeyboard(editorRef, handleSave); async function handleSave() { @@ -203,14 +298,18 @@ const MemoEditorImpl: React.FC = ({ {/* Editor content grows to fill available space in focus mode */} - {isAudioRecorderOpen && (state.audioRecorder.status === "recording" || state.audioRecorder.status === "requesting_permission") && ( - - )} + {isAudioRecorderOpen && + (state.audioRecorder.status === "recording" || state.audioRecorder.status === "requesting_permission" || isTranscribingAudio) && ( + + )} {/* Metadata and toolbar grouped together at bottom */}
diff --git a/web/src/components/MemoEditor/services/errorService.ts b/web/src/components/MemoEditor/services/errorService.ts index 76ce3bbe5..7913fd881 100644 --- a/web/src/components/MemoEditor/services/errorService.ts +++ b/web/src/components/MemoEditor/services/errorService.ts @@ -1,5 +1,9 @@ export const errorService = { getErrorMessage(error: unknown): string { + if (error && typeof error === "object" && "rawMessage" in error) { + return (error as { rawMessage?: string }).rawMessage || "An error occurred"; + } + // Handle ConnectError or errors with details property if (error && typeof error === "object" && "details" in error) { return (error as { details?: string }).details || "An error occurred"; diff --git a/web/src/components/MemoEditor/services/index.ts b/web/src/components/MemoEditor/services/index.ts index 7b9fb3f4c..5bd92a587 100644 --- a/web/src/components/MemoEditor/services/index.ts +++ b/web/src/components/MemoEditor/services/index.ts @@ -1,5 +1,6 @@ export * from "./cacheService"; export * from "./errorService"; export * from "./memoService"; +export * from "./transcriptionService"; export * from "./uploadService"; export * from "./validationService"; diff --git a/web/src/components/MemoEditor/services/transcriptionService.ts b/web/src/components/MemoEditor/services/transcriptionService.ts new file mode 100644 index 000000000..ea4fc9f82 --- /dev/null +++ b/web/src/components/MemoEditor/services/transcriptionService.ts @@ -0,0 +1,26 @@ +import { create } from "@bufbuild/protobuf"; +import { aiServiceClient } from "@/connect"; +import { TranscribeRequestSchema, TranscriptionAudioSchema, TranscriptionConfigSchema } from "@/types/proto/api/v1/ai_service_pb"; +import type { InstanceSetting_AIProviderConfig } from "@/types/proto/api/v1/instance_service_pb"; + +export const transcriptionService = { + async transcribeFile(file: File, provider: InstanceSetting_AIProviderConfig): Promise { + const content = new Uint8Array(await file.arrayBuffer()); + const response = await aiServiceClient.transcribe( + create(TranscribeRequestSchema, { + providerId: provider.id, + config: create(TranscriptionConfigSchema, {}), + audio: create(TranscriptionAudioSchema, { + source: { + case: "content", + value: content, + }, + filename: file.name, + contentType: file.type, + }), + }), + ); + + return response.text; + }, +}; diff --git a/web/src/components/MemoEditor/types/components.ts b/web/src/components/MemoEditor/types/components.ts index 3b33d577a..367a082d5 100644 --- a/web/src/components/MemoEditor/types/components.ts +++ b/web/src/components/MemoEditor/types/components.ts @@ -36,6 +36,9 @@ export interface AudioRecorderPanelProps { mediaStream: MediaStream | null; onStop: () => void; onCancel: () => void; + onTranscribe?: () => void; + canTranscribe?: boolean; + isTranscribing?: boolean; } export interface FocusModeOverlayProps { diff --git a/web/src/components/Settings/AISection.tsx b/web/src/components/Settings/AISection.tsx index f5396dac7..7e0fda616 100644 --- a/web/src/components/Settings/AISection.tsx +++ b/web/src/components/Settings/AISection.tsx @@ -10,7 +10,6 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigge import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; -import { Textarea } from "@/components/ui/textarea"; import { useInstance } from "@/contexts/InstanceContext"; import { handleError } from "@/lib/error"; import { @@ -34,16 +33,11 @@ type LocalAIProvider = { apiKey: string; apiKeySet: boolean; apiKeyHint: string; - models: string[]; - defaultModel: string; }; -const providerTypeOptions = [ - InstanceSetting_AIProviderType.OPENAI, - InstanceSetting_AIProviderType.OPENAI_COMPATIBLE, - InstanceSetting_AIProviderType.GEMINI, - InstanceSetting_AIProviderType.ANTHROPIC, -]; +const providerTypeOptions = [InstanceSetting_AIProviderType.OPENAI, InstanceSetting_AIProviderType.GEMINI]; + +const byokNotes = ["setting.ai.byok-key-note", "setting.ai.byok-storage-note", "setting.ai.byok-model-note"] as const; const createProviderID = () => { if (typeof crypto !== "undefined" && "randomUUID" in crypto) { @@ -64,18 +58,8 @@ const toLocalProvider = (provider: InstanceSetting_AIProviderConfig): LocalAIPro apiKey: "", apiKeySet: provider.apiKeySet, apiKeyHint: provider.apiKeyHint, - models: [...provider.models], - defaultModel: provider.defaultModel, }); -const normalizeModels = (value: string) => { - const models = value - .split(/\r?\n/) - .map((model) => model.trim()) - .filter(Boolean); - return Array.from(new Set(models)); -}; - const newProvider = (): LocalAIProvider => ({ id: createProviderID(), title: "", @@ -84,8 +68,6 @@ const newProvider = (): LocalAIProvider => ({ apiKey: "", apiKeySet: false, apiKeyHint: "", - models: [], - defaultModel: "", }); const toProviderConfig = (provider: LocalAIProvider) => @@ -95,8 +77,6 @@ const toProviderConfig = (provider: LocalAIProvider) => type: provider.type, endpoint: provider.endpoint.trim(), apiKey: provider.apiKey, - models: provider.models, - defaultModel: provider.defaultModel.trim(), }); const AISection = () => { @@ -124,36 +104,20 @@ const AISection = () => { const handleSaveProvider = (provider: LocalAIProvider) => { const title = provider.title.trim(); const endpoint = provider.endpoint.trim(); - const models = provider.models.map((model) => model.trim()).filter(Boolean); - const defaultModel = provider.defaultModel.trim() || models[0] || ""; if (!title) { toast.error(t("setting.ai.provider-title-required")); return; } - if (provider.type === InstanceSetting_AIProviderType.OPENAI_COMPATIBLE && !endpoint) { - toast.error(t("setting.ai.endpoint-required")); - return; - } if (!provider.apiKeySet && !provider.apiKey.trim()) { toast.error(t("setting.ai.api-key-required")); return; } - if (models.length === 0) { - toast.error(t("setting.ai.models-required")); - return; - } - if (defaultModel && !models.includes(defaultModel)) { - toast.error(t("setting.ai.default-model-required")); - return; - } const normalizedProvider = { ...provider, title, endpoint, - models, - defaultModel, }; setProviders((prev) => { const exists = prev.some((item) => item.id === normalizedProvider.id); @@ -203,6 +167,26 @@ const AISection = () => { } > +
+
+
+ + {t("setting.ai.byok-label")} + +

{t("setting.ai.byok-title")}

+
+

{t("setting.ai.byok-description")}

+
    + {byokNotes.map((note) => ( +
  • + + {t(note)} +
  • + ))} +
+
+
+ { render: (_, provider: LocalAIProvider) => {getProviderTypeLabel(provider.type)}, }, { - key: "models", - header: t("setting.ai.models"), + key: "endpoint", + header: t("setting.ai.endpoint"), render: (_, provider: LocalAIProvider) => ( -
- {provider.defaultModel || provider.models[0] || "-"} - {t("setting.ai.model-count", { count: provider.models.length })} -
+ {provider.endpoint || t("setting.ai.default-endpoint")} ), }, { @@ -299,12 +280,10 @@ interface AIProviderDialogProps { const AIProviderDialog = ({ provider, onOpenChange, onSave }: AIProviderDialogProps) => { const t = useTranslate(); const [draft, setDraft] = useState(() => provider ?? newProvider()); - const [modelsText, setModelsText] = useState(""); useEffect(() => { const next = provider ?? newProvider(); setDraft(next); - setModelsText(next.models.join("\n")); }, [provider]); const updateDraft = (partial: Partial) => { @@ -312,10 +291,7 @@ const AIProviderDialog = ({ provider, onOpenChange, onSave }: AIProviderDialogPr }; const handleSave = () => { - onSave({ - ...draft, - models: normalizeModels(modelsText), - }); + onSave(draft); }; return ( @@ -356,8 +332,9 @@ const AIProviderDialog = ({ provider, onOpenChange, onSave }: AIProviderDialogPr updateDraft({ endpoint: e.target.value })} - placeholder={draft.type === InstanceSetting_AIProviderType.OPENAI ? "https://api.openai.com/v1" : "https://example.com/v1"} + placeholder={getDefaultEndpointPlaceholder(draft.type)} /> +

{t("setting.ai.endpoint-hint")}

@@ -372,26 +349,6 @@ const AIProviderDialog = ({ provider, onOpenChange, onSave }: AIProviderDialogPr

{t("setting.ai.current-key", { key: draft.apiKeyHint || "-" })}

)}
- -
- -