mirror of https://github.com/usememos/memos
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
187 lines
6.0 KiB
Go
187 lines
6.0 KiB
Go
package v1
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"mime"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/usememos/memos/internal/ai"
|
|
"github.com/usememos/memos/internal/ai/gemini"
|
|
"github.com/usememos/memos/internal/ai/openai"
|
|
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
)
|
|
|
|
const (
|
|
maxTranscriptionAudioSizeBytes = 25 * MebiByte
|
|
maxTranscriptionPromptLength = 4096
|
|
maxTranscriptionLanguageLength = 32
|
|
maxTranscriptionFilenameLength = 255
|
|
)
|
|
|
|
var supportedTranscriptionContentTypes = map[string]bool{
|
|
"audio/aac": true,
|
|
"audio/aiff": true,
|
|
"audio/flac": true,
|
|
"audio/mpeg": true,
|
|
"audio/mp3": true,
|
|
"audio/mp4": true,
|
|
"audio/mpga": true,
|
|
"audio/ogg": true,
|
|
"audio/wav": true,
|
|
"audio/x-wav": true,
|
|
"audio/x-flac": true,
|
|
"audio/x-m4a": true,
|
|
"audio/webm": true,
|
|
"video/mp4": true,
|
|
"video/mpeg": true,
|
|
"video/webm": true,
|
|
}
|
|
|
|
// Transcribe transcribes an audio file using an instance AI provider.
|
|
func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeRequest) (*v1pb.TranscribeResponse, error) {
|
|
user, err := s.fetchCurrentUser(ctx)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
|
}
|
|
if user == nil {
|
|
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
|
}
|
|
|
|
if strings.TrimSpace(request.ProviderId) == "" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "provider_id is required")
|
|
}
|
|
if request.Config == nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "config is required")
|
|
}
|
|
prompt := strings.TrimSpace(request.Config.GetPrompt())
|
|
if len(prompt) > maxTranscriptionPromptLength {
|
|
return nil, status.Errorf(codes.InvalidArgument, "prompt is too long; maximum length is %d characters", maxTranscriptionPromptLength)
|
|
}
|
|
language := strings.TrimSpace(request.Config.GetLanguage())
|
|
if len(language) > maxTranscriptionLanguageLength {
|
|
return nil, status.Errorf(codes.InvalidArgument, "language is too long; maximum length is %d characters", maxTranscriptionLanguageLength)
|
|
}
|
|
if request.Audio == nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "audio is required")
|
|
}
|
|
if request.Audio.GetUri() != "" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "audio uri is not supported")
|
|
}
|
|
content := request.Audio.GetContent()
|
|
if len(content) == 0 {
|
|
return nil, status.Errorf(codes.InvalidArgument, "audio content is required")
|
|
}
|
|
if len(content) > maxTranscriptionAudioSizeBytes {
|
|
return nil, status.Errorf(codes.InvalidArgument, "audio file is too large; maximum size is 25 MiB")
|
|
}
|
|
filename := strings.TrimSpace(request.Audio.GetFilename())
|
|
if len(filename) > maxTranscriptionFilenameLength {
|
|
return nil, status.Errorf(codes.InvalidArgument, "filename is too long; maximum length is %d characters", maxTranscriptionFilenameLength)
|
|
}
|
|
contentType := strings.TrimSpace(request.Audio.GetContentType())
|
|
if contentType == "" {
|
|
contentType = http.DetectContentType(content)
|
|
}
|
|
if !isSupportedTranscriptionContentType(contentType) {
|
|
return nil, status.Errorf(codes.InvalidArgument, "audio content type %q is not supported", contentType)
|
|
}
|
|
|
|
provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
transcriber, err := newAITranscriber(provider)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err)
|
|
}
|
|
|
|
transcription, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
|
|
Model: model,
|
|
Filename: filename,
|
|
ContentType: contentType,
|
|
Audio: bytes.NewReader(content),
|
|
Size: int64(len(content)),
|
|
Prompt: prompt,
|
|
Language: language,
|
|
})
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "failed to transcribe audio: %v", err)
|
|
}
|
|
return &v1pb.TranscribeResponse{
|
|
Text: transcription.Text,
|
|
}, nil
|
|
}
|
|
|
|
func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string) (ai.ProviderConfig, string, error) {
|
|
setting, err := s.Store.GetInstanceAISetting(ctx)
|
|
if err != nil {
|
|
return ai.ProviderConfig{}, "", status.Errorf(codes.Internal, "failed to get AI setting: %v", err)
|
|
}
|
|
|
|
providers := make([]ai.ProviderConfig, 0, len(setting.GetProviders()))
|
|
for _, provider := range setting.GetProviders() {
|
|
if provider == nil {
|
|
continue
|
|
}
|
|
providers = append(providers, convertAIProviderConfigFromStore(provider))
|
|
}
|
|
|
|
provider, err := ai.FindProvider(providers, providerID)
|
|
if err != nil {
|
|
return ai.ProviderConfig{}, "", status.Errorf(codes.NotFound, "AI provider not found")
|
|
}
|
|
selectedModel, err := ai.DefaultTranscriptionModel(provider.Type)
|
|
if err != nil {
|
|
return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "%v", err)
|
|
}
|
|
return *provider, selectedModel, nil
|
|
}
|
|
|
|
func convertAIProviderConfigFromStore(provider *storepb.AIProviderConfig) ai.ProviderConfig {
|
|
return ai.ProviderConfig{
|
|
ID: provider.GetId(),
|
|
Title: provider.GetTitle(),
|
|
Type: convertAIProviderTypeFromStore(provider.GetType()),
|
|
Endpoint: provider.GetEndpoint(),
|
|
APIKey: provider.GetApiKey(),
|
|
}
|
|
}
|
|
|
|
func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.ProviderType {
|
|
switch providerType {
|
|
case storepb.AIProviderType_OPENAI:
|
|
return ai.ProviderOpenAI
|
|
case storepb.AIProviderType_GEMINI:
|
|
return ai.ProviderGemini
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) {
|
|
switch provider.Type {
|
|
case ai.ProviderOpenAI:
|
|
return openai.NewTranscriber(provider)
|
|
case ai.ProviderGemini:
|
|
return gemini.NewTranscriber(provider)
|
|
default:
|
|
return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type)
|
|
}
|
|
}
|
|
|
|
func isSupportedTranscriptionContentType(contentType string) bool {
|
|
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
mediaType = strings.ToLower(mediaType)
|
|
return supportedTranscriptionContentTypes[mediaType]
|
|
}
|