You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/server/router/api/v1/test/ai_service_test.go

281 lines
8.2 KiB
Go

package test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestTranscribe(t *testing.T) {
ctx := context.Background()
t.Run("requires authentication", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.Service.Transcribe(ctx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
t.Run("transcribes audio file with configured provider", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/audio/transcriptions", r.URL.Path)
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
require.NoError(t, r.ParseMultipartForm(10<<20))
require.Equal(t, "gpt-4o-transcribe", r.FormValue("model"))
require.Equal(t, "names: Alice", r.FormValue("prompt"))
file, header, err := r.FormFile("file")
require.NoError(t, err)
defer file.Close()
require.Equal(t, "voice.wav", header.Filename)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
"text": "transcribed text",
}))
}))
defer openAIServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI,
Endpoint: openAIServer.URL,
ApiKey: "sk-test",
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{
Prompt: "names: Alice",
},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.NoError(t, err)
require.Equal(t, "transcribed text", resp.Text)
})
t.Run("returns provider error without rewriting it", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "notfound-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
openAIServer := httptest.NewServer(http.NotFoundHandler())
defer openAIServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI,
Endpoint: openAIServer.URL,
ApiKey: "sk-test",
},
},
},
},
})
require.NoError(t, err)
_, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to transcribe audio")
})
t.Run("transcribes audio file with Gemini provider", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "gemini-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
geminiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1beta/models/gemini-2.5-flash:generateContent", r.URL.Path)
require.Equal(t, "gemini-key", r.Header.Get("x-goog-api-key"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"parts": []map[string]string{{"text": "gemini transcript"}},
},
},
},
}))
}))
defer geminiServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "gemini-main",
Title: "Gemini",
Type: storepb.AIProviderType_GEMINI,
Endpoint: geminiServer.URL + "/v1beta",
ApiKey: "gemini-key",
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "gemini-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("mp3 bytes")},
Filename: "voice.mp3",
ContentType: "audio/mp3",
},
})
require.NoError(t, err)
require.Equal(t, "gemini transcript", resp.Text)
})
t.Run("uses built-in transcription model", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "bob")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseMultipartForm(10<<20))
require.Equal(t, "gpt-4o-transcribe", r.FormValue("model"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
"text": "built-in model",
}))
}))
defer openAIServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI,
Endpoint: openAIServer.URL,
ApiKey: "sk-test",
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.NoError(t, err)
require.Equal(t, "built-in model", resp.Text)
})
t.Run("rejects non-audio content before provider call", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "charlie")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI,
Endpoint: "https://example.com/v1",
ApiKey: "sk-test",
},
},
},
},
})
require.NoError(t, err)
_, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("not audio")},
Filename: "notes.txt",
ContentType: "text/plain",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not supported")
})
}