diff --git a/backend/Dockerfile.golang b/backend/Dockerfile.golang index 4d0980a81e37..af704d375478 100644 --- a/backend/Dockerfile.golang +++ b/backend/Dockerfile.golang @@ -21,20 +21,28 @@ ENV AMDGPU_TARGETS=${AMDGPU_TARGETS} ARG APT_MIRROR ARG APT_PORTS_MIRROR +# gcc-14 is the default on noble (ubuntu:24.04) but absent from jammy +# (the L4T jetpack r36.4.0 base). LocalVQE specifically needs it; the +# other Go backends compile fine with the default gcc shipped via +# build-essential. So: try gcc-14 from the configured repos, fall back +# gracefully when it's not available so jammy-based builds don't fail +# at the apt step. RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \ APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \ apt-get update && \ apt-get install -y --no-install-recommends \ build-essential \ - gcc-14 g++-14 \ git ccache \ ca-certificates \ make cmake wget libopenblas-dev \ curl unzip \ libssl-dev && \ - update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 \ - --slave /usr/bin/g++ g++ /usr/bin/g++-14 \ - --slave /usr/bin/gcov gcov /usr/bin/gcov-14 && \ + if apt-cache show gcc-14 >/dev/null 2>&1 && apt-cache show g++-14 >/dev/null 2>&1; then \ + apt-get install -y --no-install-recommends gcc-14 g++-14 && \ + update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 \ + --slave /usr/bin/g++ g++ /usr/bin/g++-14 \ + --slave /usr/bin/gcov gcov /usr/bin/gcov-14; \ + fi && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/backend/backend.proto b/backend/backend.proto index dbfaff0114c2..a5a5926e4a4a 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -41,6 +41,8 @@ service Backend { rpc VAD(VADRequest) returns (VADResponse) {} + rpc Diarize(DiarizeRequest) returns (DiarizeResponse) {} + rpc AudioEncode(AudioEncodeRequest) returns (AudioEncodeResult) {} rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {} @@ -416,6 +418,43 @@ message VADResponse { repeated VADSegment segments = 1; } +// --- Speaker diarization messages --- +// +// Pure speaker diarization: "who spoke when". Returns time-stamped segments +// labelled with cluster IDs (the same string for the same speaker across +// segments). Some backends (e.g. vibevoice.cpp) produce diarization as a +// by-product of ASR and may also fill in `text` per segment; backends with a +// dedicated diarization pipeline (e.g. sherpa-onnx pyannote) leave `text` +// empty and emit only the segmentation. + +message DiarizeRequest { + string dst = 1; // path to audio file (HTTP layer materialises uploads to a temp file) + uint32 threads = 2; + string language = 3; // optional; only meaningful for transcription-bundling backends + int32 num_speakers = 4; // exact speaker count if known (>0 forces); 0 = auto + int32 min_speakers = 5; // hint when auto-detecting; 0 = unset + int32 max_speakers = 6; // hint when auto-detecting; 0 = unset + float clustering_threshold = 7; // distance threshold when num_speakers unknown; 0 = backend default + float min_duration_on = 8; // discard segments shorter than this (seconds); 0 = backend default + float min_duration_off = 9; // merge gaps shorter than this (seconds); 0 = backend default + bool include_text = 10; // when the backend can emit per-segment transcript for free, ask it to populate `text` +} + +message DiarizeSegment { + int32 id = 1; + float start = 2; // seconds + float end = 3; // seconds + string speaker = 4; // backend-emitted speaker label (e.g. "0", "SPEAKER_00") + string text = 5; // optional per-segment transcript (empty unless include_text and supported) +} + +message DiarizeResponse { + repeated DiarizeSegment segments = 1; + int32 num_speakers = 2; // count of distinct speaker labels in `segments` + float duration = 3; // total audio duration in seconds (0 if unknown) + string language = 4; // optional, when the backend bundles transcription +} + message SoundGenerationRequest { string text = 1; string model = 2; diff --git a/backend/go/sherpa-onnx/backend.go b/backend/go/sherpa-onnx/backend.go index 5d858357f73b..d73474af6bb8 100644 --- a/backend/go/sherpa-onnx/backend.go +++ b/backend/go/sherpa-onnx/backend.go @@ -29,6 +29,12 @@ type SherpaBackend struct { vadWindowSize int ttsSpeed float32 onlineChunkSamples int + + // Speaker diarization (offline pyannote + embedding extractor + clustering). + // diarSampleRate is reported by sherpa at create time; we cache it so + // runDiarization can resample only when the input doesn't already match. + diarizer uintptr + diarSampleRate int } var onnxProvider = "cpu" @@ -128,6 +134,25 @@ var ( // TTS streaming callback trampoline shimTtsGenerateWithCallback func(tts uintptr, text string, sid int32, speed float32, cb uintptr, ud uintptr) uintptr + + // Diarization config + result accessors (see csrc/shim.h). + shimDiarizeConfigNew func() uintptr + shimDiarizeConfigFree func(uintptr) + shimDiarizeConfigSetSegmentationModel func(uintptr, string) + shimDiarizeConfigSetSegmentationNumThreads func(uintptr, int32) + shimDiarizeConfigSetSegmentationProvider func(uintptr, string) + shimDiarizeConfigSetSegmentationDebug func(uintptr, int32) + shimDiarizeConfigSetEmbeddingModel func(uintptr, string) + shimDiarizeConfigSetEmbeddingNumThreads func(uintptr, int32) + shimDiarizeConfigSetEmbeddingProvider func(uintptr, string) + shimDiarizeConfigSetEmbeddingDebug func(uintptr, int32) + shimDiarizeConfigSetClusteringNumClusters func(uintptr, int32) + shimDiarizeConfigSetClusteringThreshold func(uintptr, float32) + shimDiarizeConfigSetMinDurationOn func(uintptr, float32) + shimDiarizeConfigSetMinDurationOff func(uintptr, float32) + shimCreateOfflineSpeakerDiarization func(uintptr) uintptr + shimDiarizeSetClustering func(uintptr, int32, float32) + shimDiarizeSegmentAt func(segs uintptr, i int32, outStart unsafe.Pointer, outEnd unsafe.Pointer, outSpeaker unsafe.Pointer) ) // libsherpa-onnx-c-api pass-throughs — called directly from Go via purego. @@ -172,6 +197,18 @@ var ( sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr sherpaDestroyOfflineTtsGeneratedAudio func(audio uintptr) sherpaOfflineTtsSampleRate func(tts uintptr) int32 + + // Offline speaker diarization. Result handle owns the segment-array + // pointer returned by ResultSortByStartTime; destroy the segment + // array first, then the result, then (at backend Free()) the diarizer. + sherpaDestroyOfflineSpeakerDiarization func(sd uintptr) + sherpaOfflineSpeakerDiarizationGetSampleRate func(sd uintptr) int32 + sherpaOfflineSpeakerDiarizationProcess func(sd uintptr, samples unsafe.Pointer, n int32) uintptr + sherpaOfflineSpeakerDiarizationResultGetNumSegments func(result uintptr) int32 + sherpaOfflineSpeakerDiarizationResultGetNumSpeakers func(result uintptr) int32 + sherpaOfflineSpeakerDiarizationResultSortByStartTime func(result uintptr) uintptr + sherpaOfflineSpeakerDiarizationDestroySegment func(segs uintptr) + sherpaDestroyOfflineSpeakerDiarizationResult func(result uintptr) ) var ( @@ -292,6 +329,24 @@ func loadSherpaLibsOnce() error { {&shimSpeechSegmentStart, "sherpa_shim_speech_segment_start"}, {&shimSpeechSegmentN, "sherpa_shim_speech_segment_n"}, {&shimTtsGenerateWithCallback, "sherpa_shim_tts_generate_with_callback"}, + + {&shimDiarizeConfigNew, "sherpa_shim_diarize_config_new"}, + {&shimDiarizeConfigFree, "sherpa_shim_diarize_config_free"}, + {&shimDiarizeConfigSetSegmentationModel, "sherpa_shim_diarize_config_set_segmentation_model"}, + {&shimDiarizeConfigSetSegmentationNumThreads, "sherpa_shim_diarize_config_set_segmentation_num_threads"}, + {&shimDiarizeConfigSetSegmentationProvider, "sherpa_shim_diarize_config_set_segmentation_provider"}, + {&shimDiarizeConfigSetSegmentationDebug, "sherpa_shim_diarize_config_set_segmentation_debug"}, + {&shimDiarizeConfigSetEmbeddingModel, "sherpa_shim_diarize_config_set_embedding_model"}, + {&shimDiarizeConfigSetEmbeddingNumThreads, "sherpa_shim_diarize_config_set_embedding_num_threads"}, + {&shimDiarizeConfigSetEmbeddingProvider, "sherpa_shim_diarize_config_set_embedding_provider"}, + {&shimDiarizeConfigSetEmbeddingDebug, "sherpa_shim_diarize_config_set_embedding_debug"}, + {&shimDiarizeConfigSetClusteringNumClusters, "sherpa_shim_diarize_config_set_clustering_num_clusters"}, + {&shimDiarizeConfigSetClusteringThreshold, "sherpa_shim_diarize_config_set_clustering_threshold"}, + {&shimDiarizeConfigSetMinDurationOn, "sherpa_shim_diarize_config_set_min_duration_on"}, + {&shimDiarizeConfigSetMinDurationOff, "sherpa_shim_diarize_config_set_min_duration_off"}, + {&shimCreateOfflineSpeakerDiarization, "sherpa_shim_create_offline_speaker_diarization"}, + {&shimDiarizeSetClustering, "sherpa_shim_diarize_set_clustering"}, + {&shimDiarizeSegmentAt, "sherpa_shim_diarize_segment_at"}, } { purego.RegisterLibFunc(r.ptr, shim, r.name) } @@ -334,6 +389,15 @@ func loadSherpaLibsOnce() error { {&sherpaOfflineTtsGenerate, "SherpaOnnxOfflineTtsGenerate"}, {&sherpaDestroyOfflineTtsGeneratedAudio, "SherpaOnnxDestroyOfflineTtsGeneratedAudio"}, {&sherpaOfflineTtsSampleRate, "SherpaOnnxOfflineTtsSampleRate"}, + + {&sherpaDestroyOfflineSpeakerDiarization, "SherpaOnnxDestroyOfflineSpeakerDiarization"}, + {&sherpaOfflineSpeakerDiarizationGetSampleRate, "SherpaOnnxOfflineSpeakerDiarizationGetSampleRate"}, + {&sherpaOfflineSpeakerDiarizationProcess, "SherpaOnnxOfflineSpeakerDiarizationProcess"}, + {&sherpaOfflineSpeakerDiarizationResultGetNumSegments, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments"}, + {&sherpaOfflineSpeakerDiarizationResultGetNumSpeakers, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers"}, + {&sherpaOfflineSpeakerDiarizationResultSortByStartTime, "SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime"}, + {&sherpaOfflineSpeakerDiarizationDestroySegment, "SherpaOnnxOfflineSpeakerDiarizationDestroySegment"}, + {&sherpaDestroyOfflineSpeakerDiarizationResult, "SherpaOnnxOfflineSpeakerDiarizationDestroyResult"}, } { purego.RegisterLibFunc(r.ptr, capi, r.name) } @@ -383,6 +447,11 @@ func isVADType(t string) bool { return t == "vad" } +func isDiarizationType(t string) bool { + t = strings.ToLower(t) + return t == "diarization" || t == "diarize" || t == "speaker-diarization" +} + // Model-options prefixes recognised by this backend. Kept as typed // constants so the asrFamily / loadWhisperASR / loadGenericASR paths // can all speak the same vocabulary. @@ -423,6 +492,19 @@ const ( optionOnlineRule2 = "online.rule2_min_trailing_silence=" optionOnlineRule3 = "online.rule3_min_utterance_length=" optionOnlineChunkSamples = "online.chunk_samples=" + + // Speaker diarization (offline pyannote + speaker-embedding extractor). + // `diarize.segmentation_model` overrides the auto-detected pyannote + // segmentation .onnx in modelDir; `diarize.embedding_model` does the + // same for the speaker-embedding extractor. `diarize.num_clusters` + // pins a known speaker count at load time; per-call DiarizeRequest + // fields take precedence at process time. + optionDiarizeSegmentationModel = "diarize.segmentation_model=" + optionDiarizeEmbeddingModel = "diarize.embedding_model=" + optionDiarizeNumClusters = "diarize.num_clusters=" + optionDiarizeThreshold = "diarize.threshold=" + optionDiarizeMinDurationOn = "diarize.min_duration_on=" + optionDiarizeMinDurationOff = "diarize.min_duration_off=" ) func hasOption(opts *pb.ModelOptions, prefix string) bool { @@ -493,6 +575,9 @@ func (s *SherpaBackend) Load(opts *pb.ModelOptions) error { if isVADType(opts.Type) { return s.loadVAD(opts) } + if isDiarizationType(opts.Type) { + return s.loadDiarization(opts) + } // An explicit `subtype=...` option routes to ASR even when Type is // unset — handy for the e2e-backends harness, which doesn't know // about ModelOptions.Type. @@ -1247,3 +1332,176 @@ func (s *SherpaBackend) TTSStream(req *pb.TTSRequest, results chan []byte) error } return nil } + +// ============================================================= +// Speaker diarization (offline) +// ============================================================= +// +// Conventions: +// - opts.ModelFile is the pyannote segmentation .onnx (e.g. model.onnx +// under sherpa-onnx-pyannote-segmentation-3-0/). Override with +// `diarize.segmentation_model=` if the gallery layout differs. +// - The speaker-embedding extractor must be provided via +// `diarize.embedding_model=`. There's no reliable filename heuristic +// we can rely on (3dspeaker, NeMo, WeSpeaker all ship with +// model-specific names), so we require it to be explicit. +// - Both paths are resolved relative to opts.ModelPath if not absolute. + +func (s *SherpaBackend) loadDiarization(opts *pb.ModelOptions) error { + if s.diarizer != 0 { + return nil + } + + modelDir := filepath.Dir(opts.ModelFile) + segModel := findOptionValue(opts, optionDiarizeSegmentationModel, opts.ModelFile) + if segModel != "" && !filepath.IsAbs(segModel) && opts.ModelPath != "" { + segModel = filepath.Join(opts.ModelPath, segModel) + } + if !fileExists(segModel) { + return fmt.Errorf("sherpa-onnx diarization: pyannote segmentation model not found at %q (set diarize.segmentation_model=...)", segModel) + } + + embModel := findOptionValue(opts, optionDiarizeEmbeddingModel, "") + if embModel == "" { + return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model is required — pass options: [diarize.embedding_model=] (e.g. 3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx)") + } + if !filepath.IsAbs(embModel) { + base := opts.ModelPath + if base == "" { + base = modelDir + } + embModel = filepath.Join(base, embModel) + } + if !fileExists(embModel) { + return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model not found at %q", embModel) + } + + threads := int32(1) + if opts.Threads != 0 { + threads = opts.Threads + } + + cfg := shimDiarizeConfigNew() + defer shimDiarizeConfigFree(cfg) + + shimDiarizeConfigSetSegmentationModel(cfg, segModel) + shimDiarizeConfigSetSegmentationNumThreads(cfg, threads) + shimDiarizeConfigSetSegmentationProvider(cfg, onnxProvider) + shimDiarizeConfigSetSegmentationDebug(cfg, 0) + + shimDiarizeConfigSetEmbeddingModel(cfg, embModel) + shimDiarizeConfigSetEmbeddingNumThreads(cfg, threads) + shimDiarizeConfigSetEmbeddingProvider(cfg, onnxProvider) + shimDiarizeConfigSetEmbeddingDebug(cfg, 0) + + shimDiarizeConfigSetClusteringNumClusters(cfg, findOptionInt(opts, optionDiarizeNumClusters, -1)) + shimDiarizeConfigSetClusteringThreshold(cfg, findOptionFloat(opts, optionDiarizeThreshold, 0.5)) + shimDiarizeConfigSetMinDurationOn(cfg, findOptionFloat(opts, optionDiarizeMinDurationOn, 0.3)) + shimDiarizeConfigSetMinDurationOff(cfg, findOptionFloat(opts, optionDiarizeMinDurationOff, 0.5)) + + sd := shimCreateOfflineSpeakerDiarization(cfg) + if sd == 0 { + return fmt.Errorf("sherpa-onnx diarization: failed to create diarizer (segmentation=%s embedding=%s)", segModel, embModel) + } + s.diarizer = sd + s.diarSampleRate = int(sherpaOfflineSpeakerDiarizationGetSampleRate(sd)) + return nil +} + +// applyDiarizeOverrides re-applies clustering knobs onto an existing +// diarizer when per-call DiarizeRequest fields are set. Both -1/0 sentinels +// follow sherpa's convention: num_clusters<=0 → use threshold-based +// clustering, threshold<=0 → keep load-time default. +func (s *SherpaBackend) applyDiarizeOverrides(req *pb.DiarizeRequest) { + num := int32(-1) + if req.NumSpeakers > 0 { + num = req.NumSpeakers + } + threshold := float32(0) + if req.ClusteringThreshold > 0 { + threshold = req.ClusteringThreshold + } + if num > 0 || threshold > 0 { + shimDiarizeSetClustering(s.diarizer, num, threshold) + } +} + +func (s *SherpaBackend) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) { + if s.diarizer == 0 { + return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization not loaded (model must be loaded with type=diarization)") + } + if req.Dst == "" { + return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: DiarizeRequest.dst (audio path) is required") + } + + dir, err := os.MkdirTemp("", "sherpa-diarize") + if err != nil { + return pb.DiarizeResponse{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer func() { _ = os.RemoveAll(dir) }() + + wavPath := filepath.Join(dir, "input.wav") + if err := utils.AudioToWav(req.Dst, wavPath); err != nil { + return pb.DiarizeResponse{}, fmt.Errorf("failed to convert audio to wav: %w", err) + } + + wave := sherpaReadWave(wavPath) + if wave == 0 { + return pb.DiarizeResponse{}, fmt.Errorf("failed to read wav %s", wavPath) + } + defer sherpaFreeWave(wave) + + sr := int(shimWaveSampleRate(wave)) + nSamples := shimWaveNumSamples(wave) + samples := shimWaveSamples(wave) + duration := float32(nSamples) / float32(sr) + if sr != s.diarSampleRate { + // AudioToWav already targets 16 kHz; pyannote-3.0 also wants 16 kHz, so + // this branch should be unreachable. Fail loudly instead of silently + // passing mismatched audio to the model. + return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: input sample rate %d Hz does not match model %d Hz", sr, s.diarSampleRate) + } + + s.applyDiarizeOverrides(req) + + result := sherpaOfflineSpeakerDiarizationProcess(s.diarizer, samples, nSamples) + if result == 0 { + return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: process failed") + } + defer sherpaDestroyOfflineSpeakerDiarizationResult(result) + + numSegments := sherpaOfflineSpeakerDiarizationResultGetNumSegments(result) + numSpeakers := sherpaOfflineSpeakerDiarizationResultGetNumSpeakers(result) + if numSegments <= 0 { + return pb.DiarizeResponse{ + Segments: []*pb.DiarizeSegment{}, + NumSpeakers: numSpeakers, + Duration: duration, + }, nil + } + + segs := sherpaOfflineSpeakerDiarizationResultSortByStartTime(result) + if segs == 0 { + return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: failed to retrieve segments") + } + defer sherpaOfflineSpeakerDiarizationDestroySegment(segs) + + out := make([]*pb.DiarizeSegment, 0, numSegments) + for i := range int(numSegments) { + var start, end float32 + var spk int32 + shimDiarizeSegmentAt(segs, int32(i), + unsafe.Pointer(&start), unsafe.Pointer(&end), unsafe.Pointer(&spk)) + out = append(out, &pb.DiarizeSegment{ + Id: int32(i), + Start: start, + End: end, + Speaker: strconv.FormatInt(int64(spk), 10), + }) + } + return pb.DiarizeResponse{ + Segments: out, + NumSpeakers: numSpeakers, + Duration: duration, + }, nil +} diff --git a/backend/go/sherpa-onnx/csrc/shim.c b/backend/go/sherpa-onnx/csrc/shim.c index c09a449033e5..f6cae4453764 100644 --- a/backend/go/sherpa-onnx/csrc/shim.c +++ b/backend/go/sherpa-onnx/csrc/shim.c @@ -310,6 +310,87 @@ int32_t sherpa_shim_speech_segment_n(const void *h) { return ((const SherpaOnnxSpeechSegment *)h)->n; } +// ================================================================== +// Offline speaker diarization config +// ================================================================== + +void *sherpa_shim_diarize_config_new(void) { + return calloc(1, sizeof(SherpaOnnxOfflineSpeakerDiarizationConfig)); +} + +void sherpa_shim_diarize_config_free(void *h) { + if (!h) return; + SherpaOnnxOfflineSpeakerDiarizationConfig *c = + (SherpaOnnxOfflineSpeakerDiarizationConfig *)h; + free((char *)c->segmentation.pyannote.model); + free((char *)c->segmentation.provider); + free((char *)c->embedding.model); + free((char *)c->embedding.provider); + free(c); +} + +void sherpa_shim_diarize_config_set_segmentation_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.pyannote.model, v); +} +void sherpa_shim_diarize_config_set_segmentation_num_threads(void *h, int32_t v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.num_threads = v; +} +void sherpa_shim_diarize_config_set_segmentation_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.provider, v); +} +void sherpa_shim_diarize_config_set_segmentation_debug(void *h, int32_t v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.debug = v; +} +void sherpa_shim_diarize_config_set_embedding_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.model, v); +} +void sherpa_shim_diarize_config_set_embedding_num_threads(void *h, int32_t v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.num_threads = v; +} +void sherpa_shim_diarize_config_set_embedding_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.provider, v); +} +void sherpa_shim_diarize_config_set_embedding_debug(void *h, int32_t v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.debug = v; +} +void sherpa_shim_diarize_config_set_clustering_num_clusters(void *h, int32_t v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.num_clusters = v; +} +void sherpa_shim_diarize_config_set_clustering_threshold(void *h, float v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.threshold = v; +} +void sherpa_shim_diarize_config_set_min_duration_on(void *h, float v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_on = v; +} +void sherpa_shim_diarize_config_set_min_duration_off(void *h, float v) { + ((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_off = v; +} + +void *sherpa_shim_create_offline_speaker_diarization(void *h) { + return (void *)SherpaOnnxCreateOfflineSpeakerDiarization( + (const SherpaOnnxOfflineSpeakerDiarizationConfig *)h); +} + +void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold) { + if (!sd) return; + SherpaOnnxOfflineSpeakerDiarizationConfig cfg; + memset(&cfg, 0, sizeof(cfg)); + cfg.clustering.num_clusters = num_clusters; + cfg.clustering.threshold = threshold; + SherpaOnnxOfflineSpeakerDiarizationSetConfig( + (const SherpaOnnxOfflineSpeakerDiarization *)sd, &cfg); +} + +void sherpa_shim_diarize_segment_at(const void *segs, int32_t i, + float *out_start, float *out_end, + int32_t *out_speaker) { + const SherpaOnnxOfflineSpeakerDiarizationSegment *arr = + (const SherpaOnnxOfflineSpeakerDiarizationSegment *)segs; + if (out_start) *out_start = arr[i].start; + if (out_end) *out_end = arr[i].end; + if (out_speaker) *out_speaker = arr[i].speaker; +} + // ================================================================== // TTS streaming callback trampoline // ================================================================== diff --git a/backend/go/sherpa-onnx/csrc/shim.h b/backend/go/sherpa-onnx/csrc/shim.h index d479a33a308b..7b9b249cc03f 100644 --- a/backend/go/sherpa-onnx/csrc/shim.h +++ b/backend/go/sherpa-onnx/csrc/shim.h @@ -109,6 +109,41 @@ const float *sherpa_shim_generated_audio_samples(const void *audio); int32_t sherpa_shim_speech_segment_start(const void *seg); int32_t sherpa_shim_speech_segment_n(const void *seg); +// --- Offline speaker diarization config ----------------------------- +// Pyannote segmentation + speaker-embedding extractor + fast clustering. +// The upstream config is a struct of nested structs; purego can't read or +// build those across dlopen, so we expose a calloc'd opaque holder plus +// flat setters, then hand it to sherpa via the create wrapper. +void *sherpa_shim_diarize_config_new(void); +void sherpa_shim_diarize_config_free(void *cfg); +void sherpa_shim_diarize_config_set_segmentation_model(void *cfg, const char *path); +void sherpa_shim_diarize_config_set_segmentation_num_threads(void *cfg, int32_t v); +void sherpa_shim_diarize_config_set_segmentation_provider(void *cfg, const char *v); +void sherpa_shim_diarize_config_set_segmentation_debug(void *cfg, int32_t v); +void sherpa_shim_diarize_config_set_embedding_model(void *cfg, const char *path); +void sherpa_shim_diarize_config_set_embedding_num_threads(void *cfg, int32_t v); +void sherpa_shim_diarize_config_set_embedding_provider(void *cfg, const char *v); +void sherpa_shim_diarize_config_set_embedding_debug(void *cfg, int32_t v); +void sherpa_shim_diarize_config_set_clustering_num_clusters(void *cfg, int32_t v); +void sherpa_shim_diarize_config_set_clustering_threshold(void *cfg, float v); +void sherpa_shim_diarize_config_set_min_duration_on(void *cfg, float v); +void sherpa_shim_diarize_config_set_min_duration_off(void *cfg, float v); +void *sherpa_shim_create_offline_speaker_diarization(void *cfg); + +// Apply just the clustering knobs onto a loaded diarizer (sherpa +// supports re-clustering after Create), so per-call overrides like +// num_speakers don't require re-loading the heavy ONNX models. +void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold); + +// Sherpa's ResultSortByStartTime returns a sherpa-allocated array of +// SherpaOnnxOfflineSpeakerDiarizationSegment structs (free with +// SherpaOnnxOfflineSpeakerDiarizationDestroySegment). Purego can't read +// fields out of an array of C structs, so this getter copies one +// segment's fields into the caller-supplied float/int32 cells. +void sherpa_shim_diarize_segment_at(const void *segs, int32_t i, + float *out_start, float *out_end, + int32_t *out_speaker); + // --- TTS streaming callback trampoline ----------------------------- // Replaces the //export sherpaTtsGoCallback + callbacks.c bridge pattern. // `callback_ptr` is the C-callable function pointer returned by diff --git a/backend/go/vibevoice-cpp/govibevoicecpp.go b/backend/go/vibevoice-cpp/govibevoicecpp.go index 516ffed518e0..7067c162d97b 100644 --- a/backend/go/vibevoice-cpp/govibevoicecpp.go +++ b/backend/go/vibevoice-cpp/govibevoicecpp.go @@ -3,7 +3,9 @@ package main import ( "encoding/json" "fmt" + "io" "os" + "os/exec" "path/filepath" "strings" @@ -12,6 +14,84 @@ import ( pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) +// vv_capi_asr loads audio with load_wav_24k_mono — a 24 kHz mono s16le +// WAV is the format the model was trained on. Inputs already in that +// format pass through; everything else is converted via ffmpeg, which +// is therefore a runtime requirement only when callers upload non-WAV +// (or non-24 kHz mono s16le WAV) audio. Skipping ffmpeg on the happy +// path matters for the e2e-backends test container, which does not +// ship ffmpeg but feeds the backend pre-cooked 24 kHz mono WAVs. +const vibevoiceASRSampleRate = 24000 + +// prepareWavInput resolves `src` to a 24 kHz mono s16le WAV path that +// vv_capi_asr's load_wav_24k_mono accepts. Returns the resolved path +// plus a cleanup func; both must be honoured by the caller. +// +// Pass-through happens when `src` already has the right WAV format — +// no ffmpeg required. Otherwise we shell out to ffmpeg into a temp +// dir; if ffmpeg isn't on PATH we surface a clear error mentioning the +// underlying format mismatch. +func prepareWavInput(src string) (string, func(), error) { + if src == "" { + return "", func() {}, fmt.Errorf("empty audio path") + } + if isVibevoiceCompatibleWav(src) { + return src, func() {}, nil + } + + dir, err := os.MkdirTemp("", "vibevoice-asr") + if err != nil { + return "", func() {}, fmt.Errorf("mkdtemp: %w", err) + } + cleanup := func() { _ = os.RemoveAll(dir) } + wavPath := filepath.Join(dir, "input.wav") + + // -y: overwrite, -ar 24000: target sample rate, -ac 1: mono, + // -acodec pcm_s16le: signed 16-bit little-endian PCM (load_wav_24k_mono + // only accepts s16le). + cmd := exec.Command("ffmpeg", + "-y", "-i", src, + "-ar", fmt.Sprintf("%d", vibevoiceASRSampleRate), + "-ac", "1", + "-acodec", "pcm_s16le", + wavPath, + ) + cmd.Env = []string{} + if out, err := cmd.CombinedOutput(); err != nil { + cleanup() + return "", func() {}, fmt.Errorf("ffmpeg convert to 24k mono wav: %w (output: %s)", err, string(out)) + } + return wavPath, cleanup, nil +} + +// isVibevoiceCompatibleWav returns true when `src` carries the RIFF/WAVE +// magic bytes. vibevoice's load_wav_24k_mono uses drwav under the hood, +// which accepts any PCM/IEEE-float WAV at any sample rate and downmixes +// multi-channel input to mono on its own — so any valid WAV passes +// through to the C side without conversion. Anything else (MP3, OGG, +// FLAC, ...) needs ffmpeg. +func isVibevoiceCompatibleWav(src string) bool { + f, err := os.Open(src) + if err != nil { + return false + } + defer func() { _ = f.Close() }() + + // 0..3 = "RIFF", 8..11 = "WAVE". + var hdr [12]byte + if _, err := io.ReadFull(f, hdr[:]); err != nil { + return false + } + return string(hdr[0:4]) == "RIFF" && string(hdr[8:12]) == "WAVE" +} + +// asrMaxNewTokens caps the ASR generation budget. The C ABI defaults to +// 256 when 0 is passed — far too small for anything past ~10s of speech. +// Vibevoice generates ~30 tokens per second of audio, so 16 384 covers +// roughly 9 minutes of dialogue, well past any normal /v1/audio/diarization +// upload. Going higher costs little since generation stops at EOS. +const asrMaxNewTokens = 16384 + // vibevoice.cpp synthesizes 24 kHz mono 16-bit PCM. Hardcoded - the // model itself is fixed-rate; if the upstream ever changes this we'll // pick it up via vv_capi_version(). @@ -302,7 +382,13 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: TranscriptRequest.dst (audio path) is required") } - out, err := v.callASR(req.Dst, 0) + wavPath, cleanup, err := prepareWavInput(req.Dst) + if err != nil { + return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: %w", err) + } + defer cleanup() + + out, err := v.callASR(wavPath, asrMaxNewTokens) if err != nil { return pb.TranscriptResult{}, err } @@ -346,6 +432,83 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr }, nil } +// Diarize runs vibevoice's ASR and projects the speaker-labelled segment +// list it returns natively. vibevoice.cpp's ASR prompt asks the model to +// emit `[{"Start":..,"End":..,"Speaker":..,"Content":..}]`, so diarization +// is a by-product of the same pass — we reuse callASR and re-shape. +// +// Speaker hints (num_speakers/min/max/threshold) and min_duration_on/off are +// not actionable here: vibevoice's model picks the speaker count itself and +// has no clustering knob. The HTTP layer documents this; we accept the +// fields for API symmetry and ignore them. +func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) { + if v.asrModel == "" { + return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: Diarize requires an ASR model (load options: type=asr)") + } + if req.Dst == "" { + return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: DiarizeRequest.dst (audio path) is required") + } + + wavPath, cleanup, err := prepareWavInput(req.Dst) + if err != nil { + return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: %w", err) + } + defer cleanup() + + out, err := v.callASR(wavPath, asrMaxNewTokens) + if err != nil { + return pb.DiarizeResponse{}, err + } + if out == "" { + return pb.DiarizeResponse{}, nil + } + + var segs []asrSegment + if err := json.Unmarshal([]byte(out), &segs); err != nil { + // Mirror AudioTranscription's fallback: vibevoice's ASR sometimes + // emits free-form text instead of JSON for short or unusual audio. + // Surface a single unknown-speaker segment carrying the full text + // (when include_text is set) so the caller still gets coverage of + // the whole clip rather than a hard failure. + fmt.Fprintf(os.Stderr, + "[vibevoice-cpp] WARNING: vv_capi_asr returned non-JSON for diarization, falling back to single segment: %v\n", err) + text := strings.TrimSpace(out) + seg := &pb.DiarizeSegment{Id: 0, Speaker: "0"} + if req.IncludeText { + seg.Text = text + } + return pb.DiarizeResponse{ + Segments: []*pb.DiarizeSegment{seg}, + NumSpeakers: 1, + }, nil + } + + speakers := make(map[int]struct{}) + segments := make([]*pb.DiarizeSegment, 0, len(segs)) + var duration float32 + for i, s := range segs { + ds := &pb.DiarizeSegment{ + Id: int32(i), + Start: float32(s.Start), + End: float32(s.End), + Speaker: fmt.Sprintf("%d", s.Speaker), + } + if req.IncludeText { + ds.Text = strings.TrimSpace(s.Content) + } + segments = append(segments, ds) + speakers[s.Speaker] = struct{}{} + if float32(s.End) > duration { + duration = float32(s.End) + } + } + return pb.DiarizeResponse{ + Segments: segments, + NumSpeakers: int32(len(speakers)), + Duration: duration, + }, nil +} + // AudioTranscriptionStream wraps AudioTranscription so the streaming // gRPC endpoint (server.go:AudioTranscriptionStream) sees its channel // close and the client doesn't sit waiting until deadline. vibevoice's diff --git a/core/backend/diarization.go b/core/backend/diarization.go new file mode 100644 index 000000000000..d311d4c4581e --- /dev/null +++ b/core/backend/diarization.go @@ -0,0 +1,158 @@ +package backend + +import ( + "context" + "fmt" + "sort" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + + grpcPkg "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" +) + +// DiarizationRequest carries the diarization-specific knobs the HTTP +// layer collects. Speaker hints (NumSpeakers / MinSpeakers / MaxSpeakers) +// and clustering knobs are optional — backends ignore the ones they +// don't act on. IncludeText only matters for backends that emit +// per-segment transcripts as a by-product (e.g. vibevoice.cpp). +type DiarizationRequest struct { + Audio string + Language string + NumSpeakers int32 + MinSpeakers int32 + MaxSpeakers int32 + ClusteringThreshold float32 + MinDurationOn float32 + MinDurationOff float32 + IncludeText bool +} + +func (r *DiarizationRequest) toProto(threads uint32) *proto.DiarizeRequest { + return &proto.DiarizeRequest{ + Dst: r.Audio, + Threads: threads, + Language: r.Language, + NumSpeakers: r.NumSpeakers, + MinSpeakers: r.MinSpeakers, + MaxSpeakers: r.MaxSpeakers, + ClusteringThreshold: r.ClusteringThreshold, + MinDurationOn: r.MinDurationOn, + MinDurationOff: r.MinDurationOff, + IncludeText: r.IncludeText, + } +} + +func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) { + if modelConfig.Backend == "" { + return nil, fmt.Errorf("diarization: model %q has no backend set; supported backends include vibevoice-cpp and sherpa-onnx", modelConfig.Name) + } + opts := ModelOptions(modelConfig, appConfig) + m, err := ml.Load(opts...) + if err != nil { + recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) + return nil, err + } + if m == nil { + return nil, fmt.Errorf("could not load diarization model") + } + return m, nil +} + +// ModelDiarization runs the Diarize RPC against the configured backend +// and returns a normalized schema.DiarizationResult. +func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) { + m, err := loadDiarizationModel(ml, modelConfig, appConfig) + if err != nil { + return nil, err + } + + threads := uint32(0) + if modelConfig.Threads != nil { + threads = uint32(*modelConfig.Threads) + } + + r, err := m.Diarize(context.Background(), req.toProto(threads)) + if err != nil { + return nil, err + } + return diarizationResultFromProto(r), nil +} + +// diarizationResultFromProto normalizes backend speaker labels to +// "SPEAKER_NN" — the convention pyannote/RTTM tooling expects — while +// keeping the original label available via the Speaker field. Each +// distinct backend label gets its own normalized id, in first-seen order. +func diarizationResultFromProto(r *proto.DiarizeResponse) *schema.DiarizationResult { + if r == nil { + return &schema.DiarizationResult{Segments: []schema.DiarizationSegment{}} + } + + out := &schema.DiarizationResult{ + Task: "diarize", + Duration: float64(r.Duration), + Language: r.Language, + Segments: make([]schema.DiarizationSegment, 0, len(r.Segments)), + } + + type speakerStats struct { + idx int + duration float64 + segments int + } + stats := map[string]*speakerStats{} + order := []string{} + + for i, s := range r.Segments { + if s == nil { + continue + } + raw := s.Speaker + if raw == "" { + raw = "0" + } + st, ok := stats[raw] + if !ok { + st = &speakerStats{idx: len(order)} + stats[raw] = st + order = append(order, raw) + } + dur := float64(s.End) - float64(s.Start) + if dur > 0 { + st.duration += dur + } + st.segments++ + + out.Segments = append(out.Segments, schema.DiarizationSegment{ + Id: i, + Speaker: fmt.Sprintf("SPEAKER_%02d", st.idx), + Label: raw, + Start: float64(s.Start), + End: float64(s.End), + Text: s.Text, + }) + } + + out.NumSpeakers = len(order) + if out.NumSpeakers == 0 && r.NumSpeakers > 0 { + out.NumSpeakers = int(r.NumSpeakers) + } + + out.Speakers = make([]schema.DiarizationSpeaker, 0, len(order)) + for _, raw := range order { + st := stats[raw] + out.Speakers = append(out.Speakers, schema.DiarizationSpeaker{ + Id: fmt.Sprintf("SPEAKER_%02d", st.idx), + Label: raw, + TotalSpeechDuration: st.duration, + SegmentCount: st.segments, + }) + } + sort.SliceStable(out.Speakers, func(i, j int) bool { + return out.Speakers[i].Id < out.Speakers[j].Id + }) + + return out +} diff --git a/core/backend/diarization_test.go b/core/backend/diarization_test.go new file mode 100644 index 000000000000..3d86a1b0068e --- /dev/null +++ b/core/backend/diarization_test.go @@ -0,0 +1,76 @@ +package backend + +import ( + "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("diarizationResultFromProto", func() { + It("normalises raw backend speaker labels to SPEAKER_NN in first-seen order", func() { + in := &proto.DiarizeResponse{ + Duration: 10.5, + Language: "en", + Segments: []*proto.DiarizeSegment{ + {Start: 0.0, End: 1.0, Speaker: "5", Text: "hi"}, + {Start: 1.0, End: 2.0, Speaker: "2"}, + {Start: 2.0, End: 3.5, Speaker: "5"}, + {Start: 3.5, End: 4.0, Speaker: ""}, // empty → coerced to "0" + }, + } + + got := diarizationResultFromProto(in) + + Expect(got.Task).To(Equal("diarize")) + Expect(got.NumSpeakers).To(Equal(3), "expected 3 distinct speakers (5, 2, 0)") + Expect(got.Duration).To(BeEquivalentTo(10.5)) + Expect(got.Language).To(Equal("en")) + Expect(got.Segments).To(HaveLen(4)) + + // First-seen-order normalisation: "5"→SPEAKER_00, "2"→SPEAKER_01, ""→SPEAKER_02 + want := []struct { + speaker string + label string + }{ + {"SPEAKER_00", "5"}, + {"SPEAKER_01", "2"}, + {"SPEAKER_00", "5"}, + {"SPEAKER_02", "0"}, + } + for i, w := range want { + Expect(got.Segments[i].Speaker).To(Equal(w.speaker), "seg[%d].speaker", i) + Expect(got.Segments[i].Label).To(Equal(w.label), "seg[%d].label", i) + } + + // Per-speaker totals reflect cumulative speech duration and segment count. + Expect(got.Speakers).To(HaveLen(3)) + byID := map[string]float64{} + countByID := map[string]int{} + for _, sp := range got.Speakers { + byID[sp.Id] = sp.TotalSpeechDuration + countByID[sp.Id] = sp.SegmentCount + } + Expect(byID["SPEAKER_00"]).To(BeNumerically("~", 2.5, 0.001), "1.0 + 1.5") + Expect(byID["SPEAKER_01"]).To(BeNumerically("~", 1.0, 0.001)) + Expect(countByID["SPEAKER_00"]).To(Equal(2)) + Expect(countByID["SPEAKER_01"]).To(Equal(1)) + Expect(countByID["SPEAKER_02"]).To(Equal(1)) + }) + + It("returns a non-nil result with a non-nil segments slice for nil input", func() { + got := diarizationResultFromProto(nil) + Expect(got).ToNot(BeNil()) + Expect(got.Segments).ToNot(BeNil()) + Expect(got.Segments).To(BeEmpty()) + }) + + It("keeps the backend speaker count when no segments are returned", func() { + // Backend reports a non-zero NumSpeakers but no segments (early stop, + // silence-only audio after VAD trim). Surface the backend's count. + in := &proto.DiarizeResponse{NumSpeakers: 2, Duration: 5} + got := diarizationResultFromProto(in) + Expect(got.NumSpeakers).To(Equal(2)) + Expect(got.Segments).To(BeEmpty()) + }) +}) diff --git a/core/config/model_config.go b/core/config/model_config.go index 5f051251a4c0..5a0dfb7ed803 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -607,6 +607,7 @@ const ( FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000 FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000 FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b1000000000000000 + FLAG_DIARIZATION ModelConfigUsecase = 0b10000000000000000 // Common Subsets FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT @@ -633,6 +634,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { "FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION, "FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION, "FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM, + "FLAG_DIARIZATION": FLAG_DIARIZATION, } } @@ -797,6 +799,16 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { } } + if (u & FLAG_DIARIZATION) == FLAG_DIARIZATION { + // vibevoice-cpp emits speaker-labelled segments natively from its + // ASR pass; sherpa-onnx pipes pyannote segmentation + speaker + // embeddings + clustering. Both surface as a Diarize gRPC. + diarizationBackends := []string{"vibevoice-cpp", "sherpa-onnx"} + if !slices.Contains(diarizationBackends, c.Backend) { + return false + } + } + return true } diff --git a/core/http/auth/features.go b/core/http/auth/features.go index 7b3ae6a9ea61..77199580a7a5 100644 --- a/core/http/auth/features.go +++ b/core/http/auth/features.go @@ -44,6 +44,10 @@ var RouteFeatureRegistry = []RouteFeature{ {"POST", "/v1/audio/transcriptions", FeatureAudioTranscription}, {"POST", "/audio/transcriptions", FeatureAudioTranscription}, + // Audio diarization (speaker turns) + {"POST", "/v1/audio/diarization", FeatureAudioDiarization}, + {"POST", "/audio/diarization", FeatureAudioDiarization}, + // Audio speech / TTS {"POST", "/v1/audio/speech", FeatureAudioSpeech}, {"POST", "/audio/speech", FeatureAudioSpeech}, @@ -163,6 +167,7 @@ func APIFeatureMetas() []FeatureMeta { {FeatureImages, "Image Generation", true}, {FeatureAudioSpeech, "Audio Speech / TTS", true}, {FeatureAudioTranscription, "Audio Transcription", true}, + {FeatureAudioDiarization, "Audio Diarization", true}, {FeatureVAD, "Voice Activity Detection", true}, {FeatureDetection, "Detection", true}, {FeatureVideo, "Video Generation", true}, diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go index bccceb56c2a0..fb8246f7c5f0 100644 --- a/core/http/auth/permissions.go +++ b/core/http/auth/permissions.go @@ -42,6 +42,7 @@ const ( FeatureImages = "images" FeatureAudioSpeech = "audio_speech" FeatureAudioTranscription = "audio_transcription" + FeatureAudioDiarization = "audio_diarization" FeatureVAD = "vad" FeatureDetection = "detection" FeatureVideo = "video" @@ -66,6 +67,7 @@ var GeneralFeatures = []string{FeatureFineTuning, FeatureQuantization} // APIFeatures lists API endpoint features (default ON). var APIFeatures = []string{ FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription, + FeatureAudioDiarization, FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound, FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores, FeatureFaceRecognition, FeatureVoiceRecognition, FeatureAudioTransform, diff --git a/core/http/endpoints/localai/api_instructions.go b/core/http/endpoints/localai/api_instructions.go index c4bc41f2f3c8..103c87443209 100644 --- a/core/http/endpoints/localai/api_instructions.go +++ b/core/http/endpoints/localai/api_instructions.go @@ -32,8 +32,9 @@ var instructionDefs = []instructionDef{ }, { Name: "audio", - Description: "Text-to-speech, voice activity detection, transcription, and sound generation", + Description: "Text-to-speech, voice activity detection, transcription, speaker diarization, and sound generation", Tags: []string{"audio"}, + Intro: "Diarization (/v1/audio/diarization) returns speaker-labelled time segments. Backends with native ASR-diarization (vibevoice-cpp) can also emit per-segment text via include_text=true; backends with a dedicated pipeline (sherpa-onnx + pyannote) emit segmentation only. Response formats: json (default), verbose_json (adds speakers summary + text), rttm (NIST format).", }, { Name: "images", diff --git a/core/http/endpoints/openai/diarization.go b/core/http/endpoints/openai/diarization.go new file mode 100644 index 000000000000..2f927ddae1aa --- /dev/null +++ b/core/http/endpoints/openai/diarization.go @@ -0,0 +1,181 @@ +package openai + +import ( + "errors" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + model "github.com/mudler/LocalAI/pkg/model" + + "github.com/mudler/xlog" +) + +// DiarizationEndpoint runs offline speaker diarization on an uploaded +// audio file and returns "who spoke when". Backends with a pure +// diarization pipeline (sherpa-onnx + pyannote) emit only segmentation; +// backends that produce diarization as a by-product of ASR (vibevoice.cpp) +// can additionally fill in the per-segment transcript when the caller +// passes `include_text=true`. +// +// Response formats follow transcription's: `json` (default, segments only), +// `verbose_json` (adds speaker summary and per-segment text), and `rttm` +// (NIST RTTM, the standard interchange format used by pyannote/dscore). +// +// @Summary Identify speakers in audio (who spoke when). +// @Tags audio +// @accept multipart/form-data +// @Param model formData string true "model" +// @Param file formData file true "audio file" +// @Param num_speakers formData int false "exact speaker count (>0 forces; 0 = auto)" +// @Param min_speakers formData int false "lower bound when auto-detecting" +// @Param max_speakers formData int false "upper bound when auto-detecting" +// @Param clustering_threshold formData number false "clustering distance threshold when num_speakers is unknown" +// @Param min_duration_on formData number false "discard segments shorter than this (seconds)" +// @Param min_duration_off formData number false "merge gaps shorter than this (seconds)" +// @Param language formData string false "audio language hint (only meaningful for backends that bundle ASR)" +// @Param include_text formData boolean false "include per-segment transcript when the backend supports it" +// @Param response_format formData string false "json (default), verbose_json, or rttm" +// @Success 200 {object} schema.DiarizationResult +// @Router /v1/audio/diarization [post] +func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + modelConfig, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || modelConfig == nil { + return echo.ErrBadRequest + } + + req := backend.DiarizationRequest{ + Language: input.Language, + IncludeText: parseFormBool(c, "include_text", false), + } + req.NumSpeakers = int32(parseFormInt(c, "num_speakers", 0)) + req.MinSpeakers = int32(parseFormInt(c, "min_speakers", 0)) + req.MaxSpeakers = int32(parseFormInt(c, "max_speakers", 0)) + req.ClusteringThreshold = float32(parseFormFloat(c, "clustering_threshold", 0)) + req.MinDurationOn = float32(parseFormFloat(c, "min_duration_on", 0)) + req.MinDurationOff = float32(parseFormFloat(c, "min_duration_off", 0)) + + responseFormat := schema.DiarizationResponseFormatType(strings.ToLower(c.FormValue("response_format"))) + if responseFormat == "" { + responseFormat = schema.DiarizationResponseFormatJson + } + + file, err := c.FormFile("file") + if err != nil { + return err + } + f, err := file.Open() + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + dir, err := os.MkdirTemp("", "diarize") + if err != nil { + return err + } + defer func() { _ = os.RemoveAll(dir) }() + + dst := filepath.Join(dir, path.Base(file.Filename)) + dstFile, err := os.Create(dst) + if err != nil { + return err + } + if _, err := io.Copy(dstFile, f); err != nil { + xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err) + _ = dstFile.Close() + return err + } + _ = dstFile.Close() + req.Audio = dst + + result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig) + if err != nil { + return err + } + + switch responseFormat { + case schema.DiarizationResponseFormatRTTM: + c.Response().Header().Set(echo.HeaderContentType, "text/plain; charset=utf-8") + return c.String(http.StatusOK, renderRTTM(result, file.Filename)) + case schema.DiarizationResponseFormatJson: + // Default JSON: drop the heavy per-speaker summary and any + // optional per-segment text so simple consumers see a tight + // payload. verbose_json keeps everything. + result.Speakers = nil + for i := range result.Segments { + result.Segments[i].Text = "" + } + return c.JSON(http.StatusOK, result) + case schema.DiarizationResponseFormatJsonVerbose: + return c.JSON(http.StatusOK, result) + default: + return errors.New("invalid response_format (expected: json, verbose_json, rttm)") + } + } +} + +// renderRTTM emits NIST RTTM rows. Each row: +// SPEAKER 1 +// Field separators are spaces; one row per segment. +func renderRTTM(r *schema.DiarizationResult, sourceFile string) string { + id := strings.TrimSuffix(filepath.Base(sourceFile), filepath.Ext(sourceFile)) + // filepath.Base("") returns "." — treat both as a missing source name and + // fall back to a stable placeholder so the RTTM row stays parseable. + if id == "" || id == "." { + id = "audio" + } + var sb strings.Builder + for _, seg := range r.Segments { + dur := seg.End - seg.Start + if dur < 0 { + dur = 0 + } + fmt.Fprintf(&sb, "SPEAKER %s 1 %.3f %.3f %s \n", + id, seg.Start, dur, seg.Speaker) + } + return sb.String() +} + +func parseFormInt(c echo.Context, key string, def int) int { + if v := c.FormValue(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return def +} + +func parseFormFloat(c echo.Context, key string, def float64) float64 { + if v := c.FormValue(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + return def +} + +func parseFormBool(c echo.Context, key string, def bool) bool { + if v := c.FormValue(key); v != "" { + if b, err := strconv.ParseBool(v); err == nil { + return b + } + } + return def +} diff --git a/core/http/endpoints/openai/diarization_test.go b/core/http/endpoints/openai/diarization_test.go new file mode 100644 index 000000000000..9cba206a304b --- /dev/null +++ b/core/http/endpoints/openai/diarization_test.go @@ -0,0 +1,51 @@ +package openai + +import ( + "strings" + + "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("renderRTTM", func() { + It("formats segments as NIST RTTM rows", func() { + r := &schema.DiarizationResult{ + Segments: []schema.DiarizationSegment{ + {Id: 0, Speaker: "SPEAKER_00", Start: 0, End: 2.34}, + {Id: 1, Speaker: "SPEAKER_01", Start: 2.34, End: 4.10}, + }, + } + out := renderRTTM(r, "/tmp/uploads/meeting.wav") + + lines := strings.Split(strings.TrimSpace(out), "\n") + Expect(lines).To(HaveLen(2)) + + // File ID should be the basename without extension; durations are + // (end - start) with millisecond precision. + Expect(lines[0]).To(HavePrefix("SPEAKER meeting 1 ")) + Expect(lines[0]).To(ContainSubstring(" 0.000 2.340 SPEAKER_00 ")) + Expect(lines[1]).To(ContainSubstring(" 2.340 1.760 SPEAKER_01 ")) + }) + + It("clamps negative duration to zero", func() { + // Backends shouldn't emit end0 forces; 0 = auto) | +| `min_speakers` | int | hint when auto-detecting | +| `max_speakers` | int | hint when auto-detecting | +| `clustering_threshold` | float | cosine distance threshold used when `num_speakers` is unknown | +| `min_duration_on` | float | discard segments shorter than this many seconds | +| `min_duration_off` | float | merge gaps shorter than this many seconds | +| `language` | string | only meaningful for backends that bundle ASR (e.g. vibevoice) | +| `include_text` | bool | when the backend can emit per-segment transcript for free, populate it | +| `response_format` | string | `json` (default), `verbose_json`, or `rttm` | + +### Response — `json` (default) + +Compact payload, no transcription, no per-speaker summary: + +```json +{ + "task": "diarize", + "duration": 12.34, + "num_speakers": 2, + "segments": [ + {"id": 0, "speaker": "SPEAKER_00", "label": "0", "start": 0.00, "end": 2.34}, + {"id": 1, "speaker": "SPEAKER_01", "label": "1", "start": 2.34, "end": 4.10} + ] +} +``` + +`speaker` is the normalized, zero-padded label clients should display. `label` preserves the raw backend-emitted ID for clients that maintain their own speaker dictionary. + +### Response — `verbose_json` + +Adds per-speaker totals and (when the backend supports it and `include_text=true`) the per-segment transcript: + +```json +{ + "task": "diarize", + "duration": 12.34, + "language": "en", + "num_speakers": 2, + "segments": [ + {"id": 0, "speaker": "SPEAKER_00", "label": "0", "start": 0.00, "end": 2.34, "text": "Hello, world."}, + {"id": 1, "speaker": "SPEAKER_01", "label": "1", "start": 2.34, "end": 4.10, "text": "How are you?"} + ], + "speakers": [ + {"id": "SPEAKER_00", "label": "0", "total_speech_duration": 5.6, "segment_count": 3}, + {"id": "SPEAKER_01", "label": "1", "total_speech_duration": 1.76, "segment_count": 1} + ] +} +``` + +### Response — `rttm` + +NIST RTTM, the standard interchange format used by `pyannote.metrics` / `dscore`: + +``` +SPEAKER audio 1 0.000 2.340 SPEAKER_00 +SPEAKER audio 1 2.340 1.760 SPEAKER_01 +``` + +Returned as `Content-Type: text/plain; charset=utf-8`. + +## Quick start + +```bash +curl http://localhost:8080/v1/audio/diarization \ + -H "Content-Type: multipart/form-data" \ + -F file="@meeting.wav" \ + -F model="pyannote-diarization" \ + -F num_speakers=3 +``` + +## Backend setup — sherpa-onnx (pure diarization) + +Sherpa-onnx needs two ONNX models: pyannote segmentation and a speaker-embedding extractor. Place them under your LocalAI models directory and reference them from the YAML: + +```yaml +name: pyannote-diarization +backend: sherpa-onnx +type: diarization +parameters: + model: sherpa-onnx-pyannote-segmentation-3-0/model.onnx +options: + - diarize.embedding_model=3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx + # Optional clustering knobs (per-call DiarizeRequest fields override these): + - diarize.threshold=0.5 + - diarize.min_duration_on=0.3 + - diarize.min_duration_off=0.5 +known_usecases: + - FLAG_DIARIZATION +``` + +Both `model:` and `diarize.embedding_model=` are resolved relative to the LocalAI models directory. + +## Backend setup — vibevoice.cpp (diarization + ASR) + +vibevoice.cpp's ASR mode emits `[{Start, End, Speaker, Content}]` natively, so a single pass gives both diarization and transcription: + +```yaml +name: vibevoice-diarize +backend: vibevoice-cpp +parameters: + model: vibevoice-asr.gguf +options: + - type=asr + - tokenizer=vibevoice-tokenizer.gguf +known_usecases: + - FLAG_DIARIZATION + - FLAG_TRANSCRIPT +``` + +Pass `include_text=true` on the request to populate the `text` field on each diarization segment. + +```bash +curl http://localhost:8080/v1/audio/diarization \ + -H "Content-Type: multipart/form-data" \ + -F file="@interview.wav" \ + -F model="vibevoice-diarize" \ + -F include_text=true \ + -F response_format=verbose_json +``` + +## Notes + +- **Speaker identity across files**: speaker IDs (`SPEAKER_00`, `SPEAKER_01`, …) are local to each request. To track the same person across multiple recordings, combine `/v1/audio/diarization` with `/v1/voice/embed` (speaker embedding) and maintain your own embedding store. +- **Hints vs. forces**: `num_speakers` overrides clustering when set; `min_speakers` / `max_speakers` are advisory and only honored by backends that expose a range hint. vibevoice.cpp ignores them — its model picks the count itself. +- **Sample rate**: input is automatically converted to 16 kHz mono via ffmpeg before the backend sees it; sherpa-onnx pyannote-3.0 requires 16 kHz. diff --git a/docs/content/features/audio-to-text.md b/docs/content/features/audio-to-text.md index 36686e11d60b..d4b6d37bbbfd 100644 --- a/docs/content/features/audio-to-text.md +++ b/docs/content/features/audio-to-text.md @@ -16,6 +16,8 @@ The transcription endpoint allows to convert audio files to text. The endpoint s The endpoint input supports all the audio formats supported by `ffmpeg`. +> Looking for **"who spoke when"** instead of a flat transcript? See [Speaker Diarization](/features/audio-diarization/) — `/v1/audio/diarization` returns time-stamped speaker segments and supports the `rttm` format used by `pyannote.metrics`. + ## Usage Once LocalAI is started and whisper models are installed, you can use the `/v1/audio/transcriptions` API endpoint. diff --git a/docs/content/whats-new.md b/docs/content/whats-new.md index 62ccf99c3383..8a393b4b4e0d 100644 --- a/docs/content/whats-new.md +++ b/docs/content/whats-new.md @@ -14,6 +14,7 @@ You can see the release notes [here](https://github.com/mudler/LocalAI/releases) - **April 2026**: [Audio Transform](/features/audio-transform/) — generic audio-in / audio-out endpoint with optional reference signal. First implementation: [LocalVQE](https://github.com/localai-org/LocalVQE) C++ backend (joint AEC + noise suppression + dereverberation, DeepVQE-style). Both batch (`POST /audio/transformations`) and bidirectional WebSocket streaming (`/audio/transformations/stream`). Studio "Transform" tab with synchronized waveform players for input / reference / output. - **April 2026**: [Face recognition backend](/features/face-recognition/) — `insightface`-powered 1:1 verification, 1:N identification, face embedding, face detection, and demographic analysis. Ships both a non-commercial `buffalo_l` model and an Apache 2.0 OpenCV Zoo alternative. +- **May 2026**: [Speaker diarization](/features/audio-diarization/) — new `/v1/audio/diarization` endpoint returning "who spoke when" segments. Backed by `sherpa-onnx` (pyannote-3.0 + speaker embeddings + clustering) for pure diarization, and `vibevoice-cpp` for diarization bundled with long-form ASR. Supports `json` / `verbose_json` / `rttm` response formats. ## 2024 Highlights diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 996e189b11e4..2f3b2a192087 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -75,6 +75,8 @@ type Backend interface { VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) + Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) + AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 455ec609aba1..1fdf258cddf7 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -101,6 +101,10 @@ func (llm *Base) VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error return pb.VoiceEmbedResponse{}, fmt.Errorf("unimplemented") } +func (llm *Base) Diarize(*pb.DiarizeRequest) (pb.DiarizeResponse, error) { + return pb.DiarizeResponse{}, fmt.Errorf("unimplemented") +} + func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { return pb.TokenizationResponse{}, fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 40be6d6e25e0..d7277ee6b05f 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -562,6 +562,24 @@ func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOp return client.VAD(ctx, in, opts...) } +func (c *Client) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := c.dial() + if err != nil { + return nil, err + } + defer func() { _ = conn.Close() }() + client := pb.NewBackendClient(conn) + return client.Diarize(ctx, in, opts...) +} + func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 5b28a00f9c84..c9fd307bf291 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -136,6 +136,10 @@ func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc. return e.s.VAD(ctx, in) } +func (e *embedBackend) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) { + return e.s.Diarize(ctx, in) +} + func (e *embedBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) { return e.s.AudioEncode(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 1d79c3ea55bb..8fb800aec62e 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -36,6 +36,7 @@ type AIModel interface { StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) VAD(*pb.VADRequest) (pb.VADResponse, error) + Diarize(*pb.DiarizeRequest) (pb.DiarizeResponse, error) AudioEncode(*pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 6a30611b1462..a931e65564a1 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -377,6 +377,18 @@ func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, e return &res, nil } +func (s *server) Diarize(ctx context.Context, in *pb.DiarizeRequest) (*pb.DiarizeResponse, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + res, err := s.llm.Diarize(in) + if err != nil { + return nil, err + } + return &res, nil +} + func (s *server) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) { if s.llm.Locking() { s.llm.Lock() diff --git a/swagger/docs.go b/swagger/docs.go index 014ed6f23a4f..bc4604983b1b 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -1671,6 +1671,95 @@ const docTemplate = `{ } } }, + "/v1/audio/diarization": { + "post": { + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "audio" + ], + "summary": "Identify speakers in audio (who spoke when).", + "parameters": [ + { + "type": "string", + "description": "model", + "name": "model", + "in": "formData", + "required": true + }, + { + "type": "file", + "description": "audio file", + "name": "file", + "in": "formData", + "required": true + }, + { + "type": "integer", + "description": "exact speaker count (\u003e0 forces; 0 = auto)", + "name": "num_speakers", + "in": "formData" + }, + { + "type": "integer", + "description": "lower bound when auto-detecting", + "name": "min_speakers", + "in": "formData" + }, + { + "type": "integer", + "description": "upper bound when auto-detecting", + "name": "max_speakers", + "in": "formData" + }, + { + "type": "number", + "description": "clustering distance threshold when num_speakers is unknown", + "name": "clustering_threshold", + "in": "formData" + }, + { + "type": "number", + "description": "discard segments shorter than this (seconds)", + "name": "min_duration_on", + "in": "formData" + }, + { + "type": "number", + "description": "merge gaps shorter than this (seconds)", + "name": "min_duration_off", + "in": "formData" + }, + { + "type": "string", + "description": "audio language hint (only meaningful for backends that bundle ASR)", + "name": "language", + "in": "formData" + }, + { + "type": "boolean", + "description": "include per-segment transcript when the backend supports it", + "name": "include_text", + "in": "formData" + }, + { + "type": "string", + "description": "json (default), verbose_json, or rttm", + "name": "response_format", + "in": "formData" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.DiarizationResult" + } + } + } + } + }, "/v1/audio/speech": { "post": { "consumes": [ @@ -3588,6 +3677,75 @@ const docTemplate = `{ } } }, + "schema.DiarizationResult": { + "type": "object", + "properties": { + "duration": { + "type": "number" + }, + "language": { + "type": "string" + }, + "num_speakers": { + "type": "integer" + }, + "segments": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.DiarizationSegment" + } + }, + "speakers": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.DiarizationSpeaker" + } + }, + "task": { + "type": "string" + } + } + }, + "schema.DiarizationSegment": { + "type": "object", + "properties": { + "end": { + "type": "number" + }, + "id": { + "type": "integer" + }, + "label": { + "type": "string" + }, + "speaker": { + "type": "string" + }, + "start": { + "type": "number" + }, + "text": { + "type": "string" + } + } + }, + "schema.DiarizationSpeaker": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "label": { + "type": "string" + }, + "segment_count": { + "type": "integer" + }, + "total_speech_duration": { + "type": "number" + } + } + }, "schema.ElevenLabsSoundGenerationRequest": { "type": "object", "properties": { diff --git a/swagger/swagger.json b/swagger/swagger.json index 2f20b061a8da..666970636281 100644 --- a/swagger/swagger.json +++ b/swagger/swagger.json @@ -1668,6 +1668,95 @@ } } }, + "/v1/audio/diarization": { + "post": { + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "audio" + ], + "summary": "Identify speakers in audio (who spoke when).", + "parameters": [ + { + "type": "string", + "description": "model", + "name": "model", + "in": "formData", + "required": true + }, + { + "type": "file", + "description": "audio file", + "name": "file", + "in": "formData", + "required": true + }, + { + "type": "integer", + "description": "exact speaker count (\u003e0 forces; 0 = auto)", + "name": "num_speakers", + "in": "formData" + }, + { + "type": "integer", + "description": "lower bound when auto-detecting", + "name": "min_speakers", + "in": "formData" + }, + { + "type": "integer", + "description": "upper bound when auto-detecting", + "name": "max_speakers", + "in": "formData" + }, + { + "type": "number", + "description": "clustering distance threshold when num_speakers is unknown", + "name": "clustering_threshold", + "in": "formData" + }, + { + "type": "number", + "description": "discard segments shorter than this (seconds)", + "name": "min_duration_on", + "in": "formData" + }, + { + "type": "number", + "description": "merge gaps shorter than this (seconds)", + "name": "min_duration_off", + "in": "formData" + }, + { + "type": "string", + "description": "audio language hint (only meaningful for backends that bundle ASR)", + "name": "language", + "in": "formData" + }, + { + "type": "boolean", + "description": "include per-segment transcript when the backend supports it", + "name": "include_text", + "in": "formData" + }, + { + "type": "string", + "description": "json (default), verbose_json, or rttm", + "name": "response_format", + "in": "formData" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.DiarizationResult" + } + } + } + } + }, "/v1/audio/speech": { "post": { "consumes": [ @@ -3585,6 +3674,75 @@ } } }, + "schema.DiarizationResult": { + "type": "object", + "properties": { + "duration": { + "type": "number" + }, + "language": { + "type": "string" + }, + "num_speakers": { + "type": "integer" + }, + "segments": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.DiarizationSegment" + } + }, + "speakers": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.DiarizationSpeaker" + } + }, + "task": { + "type": "string" + } + } + }, + "schema.DiarizationSegment": { + "type": "object", + "properties": { + "end": { + "type": "number" + }, + "id": { + "type": "integer" + }, + "label": { + "type": "string" + }, + "speaker": { + "type": "string" + }, + "start": { + "type": "number" + }, + "text": { + "type": "string" + } + } + }, + "schema.DiarizationSpeaker": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "label": { + "type": "string" + }, + "segment_count": { + "type": "integer" + }, + "total_speech_duration": { + "type": "number" + } + } + }, "schema.ElevenLabsSoundGenerationRequest": { "type": "object", "properties": { diff --git a/swagger/swagger.yaml b/swagger/swagger.yaml index 94a0652b974d..285b2e78b07b 100644 --- a/swagger/swagger.yaml +++ b/swagger/swagger.yaml @@ -595,6 +595,51 @@ definitions: $ref: '#/definitions/schema.Detection' type: array type: object + schema.DiarizationResult: + properties: + duration: + type: number + language: + type: string + num_speakers: + type: integer + segments: + items: + $ref: '#/definitions/schema.DiarizationSegment' + type: array + speakers: + items: + $ref: '#/definitions/schema.DiarizationSpeaker' + type: array + task: + type: string + type: object + schema.DiarizationSegment: + properties: + end: + type: number + id: + type: integer + label: + type: string + speaker: + type: string + start: + type: number + text: + type: string + type: object + schema.DiarizationSpeaker: + properties: + id: + type: string + label: + type: string + segment_count: + type: integer + total_speech_duration: + type: number + type: object schema.ElevenLabsSoundGenerationRequest: properties: bpm: @@ -3224,6 +3269,66 @@ paths: summary: Generates audio from the input text. tags: - audio + /v1/audio/diarization: + post: + consumes: + - multipart/form-data + parameters: + - description: model + in: formData + name: model + required: true + type: string + - description: audio file + in: formData + name: file + required: true + type: file + - description: exact speaker count (>0 forces; 0 = auto) + in: formData + name: num_speakers + type: integer + - description: lower bound when auto-detecting + in: formData + name: min_speakers + type: integer + - description: upper bound when auto-detecting + in: formData + name: max_speakers + type: integer + - description: clustering distance threshold when num_speakers is unknown + in: formData + name: clustering_threshold + type: number + - description: discard segments shorter than this (seconds) + in: formData + name: min_duration_on + type: number + - description: merge gaps shorter than this (seconds) + in: formData + name: min_duration_off + type: number + - description: audio language hint (only meaningful for backends that bundle + ASR) + in: formData + name: language + type: string + - description: include per-segment transcript when the backend supports it + in: formData + name: include_text + type: boolean + - description: json (default), verbose_json, or rttm + in: formData + name: response_format + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/schema.DiarizationResult' + summary: Identify speakers in audio (who spoke when). + tags: + - audio /v1/audio/speech: post: consumes: diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index f6cef3fdfb36..e140b7f1bbc5 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -169,6 +169,21 @@ var _ = BeforeSuite(func() { Expect(os.WriteFile(filepath.Join(modelsPath, name+".yaml"), data, 0644)).To(Succeed()) } + // Diarization model — known_usecases bypasses the FLAG_DIARIZATION + // backend-name guard so the /v1/audio/diarization route can dispatch + // to the mock backend. + diarizeCfg := map[string]any{ + "name": "mock-diarize", + "backend": "mock-backend", + "known_usecases": []string{"FLAG_DIARIZATION"}, + "parameters": map[string]any{ + "model": "mock-diarize.bin", + }, + } + diarizeData, err := yaml.Marshal(diarizeCfg) + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(modelsPath, "mock-diarize.yaml"), diarizeData, 0644)).To(Succeed()) + // Pipeline model that wires the component models together. pipelineCfg := map[string]any{ "name": "realtime-pipeline", diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go index c9e096f42a33..f3523d628a8e 100644 --- a/tests/e2e/mock-backend/main.go +++ b/tests/e2e/mock-backend/main.go @@ -631,6 +631,36 @@ func (m *MockBackend) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADRespon }, nil } +// Diarize returns a deterministic two-speaker layout that exercises the +// HTTP layer's normalisation: raw labels "5" and "2" should become +// SPEAKER_00 and SPEAKER_01 in first-seen order, the SPEAKER_00 totals +// should reflect two segments (1.0s + 1.5s = 2.5s), and IncludeText must +// gate the per-segment Text field. +func (m *MockBackend) Diarize(ctx context.Context, in *pb.DiarizeRequest) (*pb.DiarizeResponse, error) { + xlog.Debug("Diarize called", + "dst", in.Dst, + "num_speakers", in.NumSpeakers, + "include_text", in.IncludeText) + + seg := func(start, end float32, speaker, text string) *pb.DiarizeSegment { + out := &pb.DiarizeSegment{Start: start, End: end, Speaker: speaker} + if in.IncludeText { + out.Text = text + } + return out + } + return &pb.DiarizeResponse{ + Segments: []*pb.DiarizeSegment{ + seg(0.0, 1.0, "5", "hello there"), + seg(1.0, 2.0, "2", "general kenobi"), + seg(2.0, 3.5, "5", "you are a bold one"), + }, + NumSpeakers: 2, + Duration: 3.5, + Language: in.Language, + }, nil +} + func (m *MockBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) { xlog.Debug("AudioEncode called", "pcm_len", len(in.PcmData), "sample_rate", in.SampleRate) // Return a single mock Opus frame per 960-sample chunk (20ms at 48kHz). diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go index 9e7cecf88c36..24f4cbd94cf5 100644 --- a/tests/e2e/mock_backend_test.go +++ b/tests/e2e/mock_backend_test.go @@ -1,9 +1,11 @@ package e2e_test import ( + "bytes" "context" "encoding/json" "io" + "mime/multipart" "net/http" "strings" "time" @@ -225,6 +227,124 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { }) }) + Describe("Audio Diarization API", func() { + // Helper: build a multipart/form-data request to /v1/audio/diarization + // with a tiny stub WAV. The backend ignores the audio payload + // (it returns a deterministic three-segment layout), so a 4-byte + // stub is enough to exercise the HTTP layer. + postDiarize := func(extraFields map[string]string) (*http.Response, []byte) { + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + + Expect(mw.WriteField("model", "mock-diarize")).To(Succeed()) + for k, v := range extraFields { + Expect(mw.WriteField(k, v)).To(Succeed()) + } + + part, err := mw.CreateFormFile("file", "stub.wav") + Expect(err).ToNot(HaveOccurred()) + _, err = part.Write([]byte{0, 0, 0, 0}) + Expect(err).ToNot(HaveOccurred()) + Expect(mw.Close()).To(Succeed()) + + req, err := http.NewRequest("POST", apiURL+"/audio/diarization", body) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", mw.FormDataContentType()) + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer func() { _ = resp.Body.Close() }() + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + return resp, data + } + + It("normalizes raw backend speaker labels to SPEAKER_NN in first-seen order", func() { + resp, data := postDiarize(nil) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + var got map[string]any + Expect(json.Unmarshal(data, &got)).To(Succeed()) + + Expect(got["task"]).To(Equal("diarize")) + Expect(got["num_speakers"]).To(BeEquivalentTo(2)) + // json (default) drops the heavy speakers summary + Expect(got).ToNot(HaveKey("speakers")) + + segs, ok := got["segments"].([]any) + Expect(ok).To(BeTrue()) + Expect(segs).To(HaveLen(3)) + + // Mock emits raw labels "5", "2", "5" — first-seen order maps: + // 5 → SPEAKER_00, 2 → SPEAKER_01. + seg0 := segs[0].(map[string]any) + seg1 := segs[1].(map[string]any) + seg2 := segs[2].(map[string]any) + Expect(seg0["speaker"]).To(Equal("SPEAKER_00")) + Expect(seg0["label"]).To(Equal("5")) + Expect(seg1["speaker"]).To(Equal("SPEAKER_01")) + Expect(seg2["speaker"]).To(Equal("SPEAKER_00")) + + // json default suppresses per-segment text even when the backend + // happened to emit some (here, IncludeText was not set so the + // backend already stripped — but the HTTP layer also gates). + _, hasText := seg0["text"].(string) + if hasText { + Expect(seg0["text"]).To(Equal("")) + } + }) + + It("verbose_json emits speakers summary and per-segment transcripts when include_text is set", func() { + resp, data := postDiarize(map[string]string{ + "response_format": "verbose_json", + "include_text": "true", + }) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + var got map[string]any + Expect(json.Unmarshal(data, &got)).To(Succeed()) + + speakers, ok := got["speakers"].([]any) + Expect(ok).To(BeTrue(), "verbose_json must include speakers summary") + Expect(speakers).To(HaveLen(2)) + + // SPEAKER_00 should reflect both 1.0s segments (1.0 + 1.5 = 2.5s, 2 segments) + byID := map[string]map[string]any{} + for _, sp := range speakers { + m := sp.(map[string]any) + byID[m["id"].(string)] = m + } + Expect(byID).To(HaveKey("SPEAKER_00")) + Expect(byID["SPEAKER_00"]["total_speech_duration"]).To(BeNumerically("~", 2.5, 0.001)) + Expect(byID["SPEAKER_00"]["segment_count"]).To(BeEquivalentTo(2)) + + segs := got["segments"].([]any) + Expect(segs[0].(map[string]any)["text"]).To(Equal("hello there")) + Expect(segs[1].(map[string]any)["text"]).To(Equal("general kenobi")) + }) + + It("rttm response_format returns NIST RTTM rows", func() { + resp, data := postDiarize(map[string]string{"response_format": "rttm"}) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + Expect(resp.Header.Get("Content-Type")).To(HavePrefix("text/plain")) + + body := string(data) + lines := strings.Split(strings.TrimSpace(body), "\n") + Expect(lines).To(HaveLen(3)) + // "SPEAKER stub 1 0.000 1.000 SPEAKER_00 " + Expect(lines[0]).To(HavePrefix("SPEAKER stub 1 ")) + Expect(lines[0]).To(ContainSubstring(" SPEAKER_00 ")) + Expect(lines[1]).To(ContainSubstring(" SPEAKER_01 ")) + Expect(lines[2]).To(ContainSubstring(" SPEAKER_00 ")) + }) + + It("rejects unknown response_format with 4xx/5xx", func() { + resp, _ := postDiarize(map[string]string{"response_format": "csv"}) + Expect(resp.StatusCode).To(BeNumerically(">=", 400)) + }) + }) + Describe("Rerank API", func() { It("should return mocked reranking results", func() { req, err := http.NewRequest("POST", apiURL+"/rerank", nil)