You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/server/router/api/v1/ai_service.go

187 lines
6.0 KiB
Go

package v1
import (
"bytes"
"context"
"mime"
"net/http"
"strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/ai"
"github.com/usememos/memos/internal/ai/gemini"
"github.com/usememos/memos/internal/ai/openai"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
const (
maxTranscriptionAudioSizeBytes = 25 * MebiByte
maxTranscriptionPromptLength = 4096
maxTranscriptionLanguageLength = 32
maxTranscriptionFilenameLength = 255
)
var supportedTranscriptionContentTypes = map[string]bool{
"audio/aac": true,
"audio/aiff": true,
"audio/flac": true,
"audio/mpeg": true,
"audio/mp3": true,
"audio/mp4": true,
"audio/mpga": true,
"audio/ogg": true,
"audio/wav": true,
"audio/x-wav": true,
"audio/x-flac": true,
"audio/x-m4a": true,
"audio/webm": true,
"video/mp4": true,
"video/mpeg": true,
"video/webm": true,
}
// Transcribe transcribes an audio file using an instance AI provider.
func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeRequest) (*v1pb.TranscribeResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if strings.TrimSpace(request.ProviderId) == "" {
return nil, status.Errorf(codes.InvalidArgument, "provider_id is required")
}
if request.Config == nil {
return nil, status.Errorf(codes.InvalidArgument, "config is required")
}
prompt := strings.TrimSpace(request.Config.GetPrompt())
if len(prompt) > maxTranscriptionPromptLength {
return nil, status.Errorf(codes.InvalidArgument, "prompt is too long; maximum length is %d characters", maxTranscriptionPromptLength)
}
language := strings.TrimSpace(request.Config.GetLanguage())
if len(language) > maxTranscriptionLanguageLength {
return nil, status.Errorf(codes.InvalidArgument, "language is too long; maximum length is %d characters", maxTranscriptionLanguageLength)
}
if request.Audio == nil {
return nil, status.Errorf(codes.InvalidArgument, "audio is required")
}
if request.Audio.GetUri() != "" {
return nil, status.Errorf(codes.InvalidArgument, "audio uri is not supported")
}
content := request.Audio.GetContent()
if len(content) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "audio content is required")
}
if len(content) > maxTranscriptionAudioSizeBytes {
return nil, status.Errorf(codes.InvalidArgument, "audio file is too large; maximum size is 25 MiB")
}
filename := strings.TrimSpace(request.Audio.GetFilename())
if len(filename) > maxTranscriptionFilenameLength {
return nil, status.Errorf(codes.InvalidArgument, "filename is too long; maximum length is %d characters", maxTranscriptionFilenameLength)
}
contentType := strings.TrimSpace(request.Audio.GetContentType())
if contentType == "" {
contentType = http.DetectContentType(content)
}
if !isSupportedTranscriptionContentType(contentType) {
return nil, status.Errorf(codes.InvalidArgument, "audio content type %q is not supported", contentType)
}
provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId)
if err != nil {
return nil, err
}
transcriber, err := newAITranscriber(provider)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err)
}
transcription, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
Model: model,
Filename: filename,
ContentType: contentType,
Audio: bytes.NewReader(content),
Size: int64(len(content)),
Prompt: prompt,
Language: language,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to transcribe audio: %v", err)
}
return &v1pb.TranscribeResponse{
Text: transcription.Text,
}, nil
}
func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string) (ai.ProviderConfig, string, error) {
setting, err := s.Store.GetInstanceAISetting(ctx)
if err != nil {
return ai.ProviderConfig{}, "", status.Errorf(codes.Internal, "failed to get AI setting: %v", err)
}
providers := make([]ai.ProviderConfig, 0, len(setting.GetProviders()))
for _, provider := range setting.GetProviders() {
if provider == nil {
continue
}
providers = append(providers, convertAIProviderConfigFromStore(provider))
}
provider, err := ai.FindProvider(providers, providerID)
if err != nil {
return ai.ProviderConfig{}, "", status.Errorf(codes.NotFound, "AI provider not found")
}
selectedModel, err := ai.DefaultTranscriptionModel(provider.Type)
if err != nil {
return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "%v", err)
}
return *provider, selectedModel, nil
}
func convertAIProviderConfigFromStore(provider *storepb.AIProviderConfig) ai.ProviderConfig {
return ai.ProviderConfig{
ID: provider.GetId(),
Title: provider.GetTitle(),
Type: convertAIProviderTypeFromStore(provider.GetType()),
Endpoint: provider.GetEndpoint(),
APIKey: provider.GetApiKey(),
}
}
func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.ProviderType {
switch providerType {
case storepb.AIProviderType_OPENAI:
return ai.ProviderOpenAI
case storepb.AIProviderType_GEMINI:
return ai.ProviderGemini
default:
return ""
}
}
func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) {
switch provider.Type {
case ai.ProviderOpenAI:
return openai.NewTranscriber(provider)
case ai.ProviderGemini:
return gemini.NewTranscriber(provider)
default:
return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type)
}
}
func isSupportedTranscriptionContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
if err != nil {
return false
}
mediaType = strings.ToLower(mediaType)
return supportedTranscriptionContentTypes[mediaType]
}