From 8e4cb72b50785d4feeca431fc2bb2ec523126e64 Mon Sep 17 00:00:00 2001 From: xzeldon Date: Wed, 4 Oct 2023 01:09:38 +0300 Subject: [PATCH] initial commit --- .gitignore | 7 + LICENSE | 21 +++ README.md | 43 +++++ go.mod | 19 +++ go.sum | 43 +++++ handler.go | 45 +++++ main.go | 29 ++++ pkg/whisper/FullParams.go | 248 ++++++++++++++++++++++++++++ pkg/whisper/MediaFoundation.go | 222 +++++++++++++++++++++++++ pkg/whisper/Model.go | 119 ++++++++++++++ pkg/whisper/ModelSetup.go | 101 ++++++++++++ pkg/whisper/TranscribeResult.go | 178 ++++++++++++++++++++ pkg/whisper/context.go | 281 ++++++++++++++++++++++++++++++++ pkg/whisper/doc.go | 5 + pkg/whisper/language.go | 207 +++++++++++++++++++++++ pkg/whisper/logger.go | 43 +++++ pkg/whisper/whisper.go | 188 +++++++++++++++++++++ pkg/whisper/winversion.go | 123 ++++++++++++++ state.go | 75 +++++++++ tmp/.gitkeep | 0 utils.go | 48 ++++++ 21 files changed, 2045 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 handler.go create mode 100644 main.go create mode 100644 pkg/whisper/FullParams.go create mode 100644 pkg/whisper/MediaFoundation.go create mode 100644 pkg/whisper/Model.go create mode 100644 pkg/whisper/ModelSetup.go create mode 100644 pkg/whisper/TranscribeResult.go create mode 100644 pkg/whisper/context.go create mode 100644 pkg/whisper/doc.go create mode 100644 pkg/whisper/language.go create mode 100644 pkg/whisper/logger.go create mode 100644 pkg/whisper/whisper.go create mode 100644 pkg/whisper/winversion.go create mode 100644 state.go create mode 100644 tmp/.gitkeep create mode 100644 utils.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5b27170 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +Whisper.dll +ggml-medium.bin + +whisper-api-server.exe + +tmp/* +!tmp/.gitkeep \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..675d1d2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Timofey Gelazoniya + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..619a737 --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# Whisper API Server (Go) + +## ⚠️ This project is a work in progress (WIP). + +This API server enables audio transcription using the OpenAI Whisper models. + +# Setup + +- Download the desired model from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main) +- Update the model path in the `main.go` file +- Download `Whisper.dll` from [github](https://github.com/Const-me/Whisper/releases/tag/1.12.0) (`Library.zip`) and place it in the project's root directory +- Build project: `go build .` (you only need go compiler, without gcc) + +# Usage example + +Make a request to the server using the following command: + +```sh +curl http://localhost:3000/v1/audio/transcriptions \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ +``` + +Receive a response in JSON format: + +```json +{ + "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that." +} +``` + +# Roadmap + +- [ ] Implement automatic model downloading from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main) +- [ ] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases) +- [ ] Provide prebuilt binaries for Windows +- [ ] Include instructions for running on Linux with Wine (likely possible). + +# Credits + +- [Const-me/Whisper](https://github.com/Const-me/Whisper) project +- [goConstmeWhisper](https://github.com/jaybinks/goConstmeWhisper) for the remarkable Go bindings for [Const-me/Whisper](https://github.com/Const-me/Whisper) +- [Georgi Gerganov](https://github.com/ggerganov) for GGML models diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b4f3c0d --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/xzeldon/whisper-api-server + +go 1.21.1 + +require ( + github.com/labstack/echo/v4 v4.11.1 + golang.org/x/sys v0.12.0 +) + +require ( + github.com/labstack/gommon v0.4.0 + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/text v0.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..119ff10 --- /dev/null +++ b/go.sum @@ -0,0 +1,43 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= +github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..6b4ca39 --- /dev/null +++ b/handler.go @@ -0,0 +1,45 @@ +package main + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +type TranscribeResponse struct { + Text string `json:"text"` +} + +func transcribe(c echo.Context, whisperState *WhisperState) error { + audioPath, err := saveFormFile("file", c) + if err != nil { + c.Logger().Errorf("Error reading file: %s", err) + return err + } + + whisperState.mutex.Lock() + buffer, err := whisperState.media.LoadAudioFile(audioPath, true) + if err != nil { + c.Logger().Errorf("Error loading audio file data: %s", err) + } + + err = whisperState.context.RunFull(whisperState.params, buffer) + + result, err := getResult(whisperState.context) + if err != nil { + c.Logger().Error(err) + } + + defer whisperState.mutex.Unlock() + + if len(result) == 0 { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"}) + } + + response := TranscribeResponse{ + Text: strings.TrimLeft(result, " "), + } + + return c.JSON(http.StatusOK, response) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..cf1328a --- /dev/null +++ b/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" +) + +const MODEL_PATH = "./ggml-medium.bin" + +func main() { + e := echo.New() + e.HideBanner = true + + if l, ok := e.Logger.(*log.Logger); ok { + l.SetHeader("${time_rfc3339} ${level}") + } + + whisperState, err := InitializeWhisperState(MODEL_PATH) + if err != nil { + e.Logger.Error(err) + return + } + + e.POST("/v1/audio/transcriptions", func(c echo.Context) error { + return transcribe(c, whisperState) + }) + + e.Logger.Fatal(e.Start(":3000")) +} diff --git a/pkg/whisper/FullParams.go b/pkg/whisper/FullParams.go new file mode 100644 index 0000000..cb49136 --- /dev/null +++ b/pkg/whisper/FullParams.go @@ -0,0 +1,248 @@ +package whisper + +import ( + "syscall" + "unsafe" +) + +// https://github.com/Const-me/Whisper/blob/master/Whisper/API/sFullParams.h +// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/Parameters.cs + +type eSamplingStrategy uint32 + +const ( + SsGreedy eSamplingStrategy = iota + SsBeamSearch + SsINVALIDARG +) + +type eFullParamsFlags uint32 + +const ( + FlagNone eFullParamsFlags = 0 + FlagTranslate = 1 << 0 + FlagNoContext = 1 << 1 + FlagSingleSegment = 1 << 2 + FlagPrintSpecial = 1 << 3 + FlagPrintProgress = 1 << 4 + FlagPrintRealtime = 1 << 5 + FlagPrintTimestamps = 1 << 6 + FlagTokenTimestamps = 1 << 7 // Experimental + FlagSpeedupAudio = 1 << 8 +) + +type EWhisperHWND uintptr + +const ( + S_OK EWhisperHWND = 0 + S_FALSE EWhisperHWND = 1 +) + +type FullParams struct { + cStruct *_FullParams +} + +func (this *FullParams) CpuThreads() int32 { + if this == nil { + return 0 + } else if this.cStruct == nil { + return 0 + } + + return this.cStruct.cpuThreads +} + +func (this *FullParams) setCpuThreads(val int32) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.cpuThreads = val +} + +func (this *FullParams) SetMaxTextCTX(val int32) { + this.cStruct.n_max_text_ctx = val +} + +func (this *FullParams) AddFlags(newflag eFullParamsFlags) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.Flags = this.cStruct.Flags | newflag +} + +func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.Flags = this.cStruct.Flags ^ newflag +} + +/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/ +type NewSegmentCallback_Type func(context *IContext, n_new uint32, user_data unsafe.Pointer) EWhisperHWND + +func (this *FullParams) SetNewSegmentCallback(cb NewSegmentCallback_Type) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + this.cStruct.new_segment_callback = syscall.NewCallback(cb) +} + +/* +Return S_OK to proceed, or S_FALSE to stop the process +*/ +type EncoderBeginCallback_Type func(context *IContext, user_data unsafe.Pointer) EWhisperHWND + +func (this *FullParams) SetEncoderBeginCallback(cb EncoderBeginCallback_Type) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.encoder_begin_callback = syscall.NewCallback(cb) +} + +func (this *FullParams) TestDefaultsOK() bool { + if this == nil { + return false + } else if this.cStruct == nil { + return false + } + + if this.cStruct.n_max_text_ctx != 16384 { + return false + } + + if this.cStruct.Flags != (FlagPrintProgress | FlagPrintTimestamps) { + return false + } + + if this.cStruct.thold_pt != 0.01 { + return false + } + + if this.cStruct.thold_ptsum != 0.01 { + return false + } + + if this.cStruct.Language != English { + return false + } + + // Todo ... why do these not line up as expected.. is our struct out of alignment ? + /* + if this.cStruct.strategy == ssGreedy { + if this.cStruct.beam_search.n_past != -1 || + this.cStruct.beam_search.beam_width != -1 || + this.cStruct.beam_search.n_best != -1 { + return false + } + + } else if this.cStruct.strategy == ssBeamSearch { + if this.cStruct.greedy.n_past != -1 || + this.cStruct.beam_search.beam_width != 10 || + this.cStruct.beam_search.n_best != 5 { + return false + } + } + */ + + return true +} + +type _FullParams struct { + strategy eSamplingStrategy + cpuThreads int32 + n_max_text_ctx int32 + offset_ms int32 + duration_ms int32 + Flags eFullParamsFlags + Language eLanguage + + thold_pt float32 + thold_ptsum float32 + max_len int32 + max_tokens int32 + + greedy struct{ n_past int32 } + beam_search struct { + n_past int32 + beam_width int32 + n_best int32 + } + + audio_ctx int32 // overwrite the audio context size (0 = use default) + + prompt_tokens uintptr + prompt_n_tokens int32 + + new_segment_callback uintptr + new_segment_callback_user_data uintptr + + encoder_begin_callback uintptr + encoder_begin_callback_user_data uintptr + + // Are these needed ?? Jay + // setFlag uintptr +} + +func NewFullParams(cstruct *_FullParams) *FullParams { + this := FullParams{} + this.cStruct = cstruct + return &this +} + +func _newFullParams_cStruct() *_FullParams { + return &_FullParams{ + + strategy: 0, + cpuThreads: 0, + n_max_text_ctx: 0, + offset_ms: 0, + duration_ms: 0, + + Flags: 0, + Language: 0, + + thold_pt: 0, + thold_ptsum: 0, + max_len: 0, + max_tokens: 0, + + // anonymous int32 + greedy: struct{ n_past int32 }{n_past: 0}, + + // anonymous struct + beam_search: struct { + n_past int32 + beam_width int32 + n_best int32 + }{ + n_past: 0, + beam_width: 0, + n_best: 0, + }, + + audio_ctx: 0, + + prompt_tokens: 0, + prompt_n_tokens: 0, + + new_segment_callback: 0, + new_segment_callback_user_data: 0, + + encoder_begin_callback: 0, + encoder_begin_callback_user_data: 0, + } +} diff --git a/pkg/whisper/MediaFoundation.go b/pkg/whisper/MediaFoundation.go new file mode 100644 index 0000000..b5d665e --- /dev/null +++ b/pkg/whisper/MediaFoundation.go @@ -0,0 +1,222 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/IMediaFoundation.h + +type IMediaFoundation struct { + lpVtbl *IMediaFoundationVtbl +} + +type IMediaFoundationVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + loadAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const; + openAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioReader** pp ); + loadAudioFileData uintptr // ( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); HRESULT + listCaptureDevices uintptr // ( pfnFoundCaptureDevices pfn, void* pv ); + openCaptureDevice uintptr // ( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ); +} + +func (this *IMediaFoundation) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *IMediaFoundation) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +// ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const; +func (this *IMediaFoundation) LoadAudioFile(file string, stereo bool) (*iAudioBuffer, error) { + + var buffer *iAudioBuffer + + UTFFileName, _ := windows.UTF16PtrFromString(file) + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.loadAudioFile, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(UTFFileName)), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&buffer))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("loadAudioFile failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return buffer, nil +} + +func (this *IMediaFoundation) OpenAudioFile(file string, stereo bool) (*iAudioReader, error) { + + var buffer *iAudioReader + + UTFFileName, _ := windows.UTF16PtrFromString(file) + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.openAudioFile, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(UTFFileName)), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&buffer))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("openAudioFile failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return buffer, nil +} + +func (this *IMediaFoundation) LoadAudioFileData(inbuffer *[]byte, stereo bool) (*iAudioReader, error) { + + var reader *iAudioReader + + // loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); + ret, _, _ := syscall.SyscallN( + this.lpVtbl.loadAudioFileData, + uintptr(unsafe.Pointer(this)), + + uintptr(unsafe.Pointer(&(*inbuffer)[0])), + uintptr(uint64(len(*inbuffer))), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&reader))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return reader, nil +} + +// ************************************************************ + +type iAudioBuffer struct { + lpVtbl *iAudioBufferVtbl +} + +type iAudioBufferVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + countSamples uintptr // returns uint32_t + getPcmMono uintptr // returns float* + getPcmStereo uintptr // returns float* + getTime uintptr // ( int64_t& rdi ) +} + +func (this *iAudioBuffer) AddRef() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.AddRef, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioBuffer) Release() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.Release, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioBuffer) CountSamples() (uint32, error) { + + ret, _, err := syscall.SyscallN( + this.lpVtbl.countSamples, + uintptr(unsafe.Pointer(this)), + ) + + if err != 0 { + return 0, errors.New(err.Error()) + } + + return uint32(ret), nil +} + +// ************************************************************ + +type iAudioReader struct { + lpVtbl *iAudioReaderVtbl +} + +type iAudioReaderVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + getDuration uintptr // ( int64_t& rdi ) + getReader uintptr // ( IMFSourceReader** pp ) + requestedStereo uintptr // () +} + +func (this *iAudioReader) AddRef() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.AddRef, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioReader) Release() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.Release, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioReader) GetDuration() (uint64, error) { + + var rdi int64 + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getDuration, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(&rdi)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error()) + return 0, syscall.Errno(ret) + } + + return uint64(rdi), nil +} + +// ************************************************************ + +type iAudioCapture struct { + lpVtbl *iAudioCaptureVtbl +} + +type iAudioCaptureVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + getReader uintptr // ( IMFSourceReader** pp ) + getParams uintptr // returns sCaptureParams& +} diff --git a/pkg/whisper/Model.go b/pkg/whisper/Model.go new file mode 100644 index 0000000..1635325 --- /dev/null +++ b/pkg/whisper/Model.go @@ -0,0 +1,119 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// External - Go version of the struct +type Model struct { + cStruct *_IModel + setup *sModelSetup +} + +// Internal - C Version of the structs +type _IModel struct { + lpVtbl *IModelVtbl +} + +// https://github.com/Const-me/Whisper/blob/master/Whisper/API/iContext.cl.h +type IModelVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + createContext uintptr //( iContext** pp ) = 0; + tokenize uintptr /* HRESULT __stdcall tokenize( const char* text, pfnDecodedTokens pfn, void* pv ); */ + isMultilingual uintptr //() = 0; + getSpecialTokens uintptr //( SpecialTokens& rdi ) = 0; + stringFromToken uintptr //( whisper_token token ) = 0; + clone uintptr //( iModel** rdi ) = 0; +} + +func NewModel(setup *sModelSetup, cstruct *_IModel) *Model { + this := Model{} + this.setup = setup + this.cStruct = cstruct + return &this +} + +func (this *Model) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.cStruct.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this.cStruct)), + 0, + 0) + return int32(ret) +} + +func (this *Model) Release() int32 { + ret, _, _ := syscall.Syscall( + this.cStruct.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this.cStruct)), + 0, + 0) + return int32(ret) +} + +func (this *Model) CreateContext() (*IContext, error) { + var context *IContext + + /* + ret, _, err := syscall.Syscall( + this.cStruct.lpVtbl.createContext, + 2, // Why was this 1, rather than 2 ?? 1 seemed to work fine + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&context)), + 0)*/ + ret, _, err := syscall.SyscallN( + this.cStruct.lpVtbl.createContext, + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&context))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("createContext failed: %w", err.Error()) + } + + if windows.Handle(ret) != windows.S_OK { + return nil, fmt.Errorf("loadModel failed: %w", err) + } + + return context, nil +} + +func (this *Model) IsMultilingual() bool { + ret, _, _ := syscall.SyscallN( + this.cStruct.lpVtbl.isMultilingual, + uintptr(unsafe.Pointer(this.cStruct)), + ) + + return bool(windows.Handle(ret) == windows.S_OK) +} + +func (this *Model) Clone() (*_IModel, error) { + + if this.setup.isFlagSet(gmf_Cloneable) { + return nil, errors.New("Model is not cloneable") + } + //this.Cloneable ? + + var modelptr *_IModel + + ret, _, _ := syscall.SyscallN( + this.cStruct.lpVtbl.clone, + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&modelptr)), + ) + + if windows.Handle(ret) == windows.S_OK { + return modelptr, nil + } else { + return nil, errors.New("Model.Clone() failed : " + syscall.Errno(ret).Error()) + } +} diff --git a/pkg/whisper/ModelSetup.go b/pkg/whisper/ModelSetup.go new file mode 100644 index 0000000..08c82ee --- /dev/null +++ b/pkg/whisper/ModelSetup.go @@ -0,0 +1,101 @@ +package whisper + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// Re-implemented sModelSetup.h + +// enum struct eModelImplementation : uint32_t +type eModelImplementation uint32 + +const ( + // GPGPU implementation based on Direct3D 11.0 compute shaders + mi_GPU eModelImplementation = 1 + + // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU + // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1 + mi_Hybrid eModelImplementation = 2 + + // A reference implementation which uses the original GGML CPU-running code + // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1 + mi_Reference eModelImplementation = 3 +) + +// enum struct eGpuModelFlags : uint32_t +type eGpuModelFlags uint32 + +const ( + // Equivalent to Wave32 | NoReshapedMatMul on Intel and nVidia GPUs,
+ // and Wave64 | UseReshapedMatMul on AMD GPUs
+ gmf_None eGpuModelFlags = 0 + + // Use Wave32 version of compute shaders even on AMD GPUs + // Incompatible with + gmf_Wave32 eGpuModelFlags = 1 + + // Use Wave64 version of compute shaders even on nVidia and Intel GPUs + // Incompatible with + gmf_Wave64 eGpuModelFlags = 2 + + // Do not use reshaped matrix multiplication shaders on AMD GPUs + // Incompatible with + gmf_NoReshapedMatMul eGpuModelFlags = 4 + + // Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs + // Incompatible with + gmf_UseReshapedMatMul eGpuModelFlags = 8 + + // Create GPU tensors in a way which allows sharing across D3D devices + gmf_Cloneable eGpuModelFlags = 0x10 +) + +// struct sModelSetup +type sModelSetup struct { + impl eModelImplementation + flags eGpuModelFlags + adapter string +} + +type _sModelSetup struct { + impl eModelImplementation + flags eGpuModelFlags + adapter uintptr +} + +func ModelSetup(flags eGpuModelFlags, GPU string) *sModelSetup { + this := sModelSetup{} + this.impl = mi_GPU + this.flags = flags + this.adapter = GPU + + return &this +} + +func (this *sModelSetup) isFlagSet(flag eGpuModelFlags) bool { + return (this.flags & flag) == 0 +} + +func (this *sModelSetup) AsCType() *_sModelSetup { + var err error + + ctype := _sModelSetup{} + ctype.impl = this.impl + ctype.flags = this.flags + ctype.adapter = 0 + + // Conver Go String to wchar_t, AKA UTF-16 + if this.adapter != "" { + var UTF16str *uint16 + UTF16str, err = windows.UTF16PtrFromString(this.adapter) + ctype.adapter = uintptr(unsafe.Pointer(UTF16str)) + } + + if err != nil { + return nil + } else { + return &ctype + } +} diff --git a/pkg/whisper/TranscribeResult.go b/pkg/whisper/TranscribeResult.go new file mode 100644 index 0000000..8ea0367 --- /dev/null +++ b/pkg/whisper/TranscribeResult.go @@ -0,0 +1,178 @@ +package whisper + +import ( + "C" + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type eTokenFlags uint32 + +const ( + TfNone eTokenFlags = 0 + TfSpecial = 1 +) + +type sTranscribeLength struct { + CountSegments uint32 + CountTokens uint32 +} + +type sTimeSpan struct { + + // The value is expressed in 100-nanoseconds ticks: compatible with System.Timespan, FILETIME, and many other things + Ticks uint64 + + /* + operator sTimeSpanFields() const + { + return sTimeSpanFields{ ticks }; + } + void operator=( uint64_t tt ) + { + ticks = tt; + } + void operator=( int64_t tt ) + { + assert( tt >= 0 ); + ticks = (uint64_t)tt; + } */ +} + +type sTimeInterval struct { + Begin sTimeSpan + End sTimeSpan +} + +type sSegment struct { + // Segment text, null-terminated, and probably UTF-8 encoded + text *C.char + + // Start and end times of the segment + Time sTimeInterval + + // These two integers define the slice of the tokens in this segment, in the array returned by iTranscribeResult.getTokens method + FirstToken uint32 + CountTokens uint32 +} + +func (this *sSegment) Text() string { + return C.GoString(this.text) +} + +type sSegmentArray []sSegment + +type SToken struct { + // Token text, null-terminated, and usually UTF-8 encoded. + // I think for Chinese language the models sometimes outputs invalid UTF8 strings here, Unicode code points can be split between adjacent tokens in the same segment + // More info: https://github.com/ggerganov/whisper.cpp/issues/399 + text *C.char + + // Start and end times of the token + Time sTimeInterval + // Probability of the token + Probability float32 + + // Probability of the timestamp token + ProbabilityTimestamp float32 + + // Sum of probabilities of all timestamp tokens + Ptsum float32 + + // Voice length of the token + Vlen float32 + + // Token id + Id int32 + + Flags eTokenFlags +} + +func (this *SToken) Text() string { + return C.GoString(this.text) +} + +type sTokenArray []SToken + +type iTranscribeResultVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + getSize uintptr // ( sTranscribeLength& rdi ) HRESULT + getSegments uintptr // () getTokens + getTokens uintptr // () getToken* +} + +type ITranscribeResult struct { + lpVtbl *iTranscribeResultVtbl +} + +func (this *ITranscribeResult) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *ITranscribeResult) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *ITranscribeResult) GetSize() (*sTranscribeLength, error) { + + var result sTranscribeLength + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getSize, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(&result)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("iTranscribeResult.GetSize failed: %s\n", syscall.Errno(ret).Error()) + return nil, errors.New(syscall.Errno(ret).Error()) + } + + return &result, nil + +} + +func (this *ITranscribeResult) GetSegments(len uint32) []sSegment { + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getSegments, + uintptr(unsafe.Pointer(this)), + ) + + data := unsafe.Slice((*sSegment)(unsafe.Pointer(ret)), len) + + return data +} + +func (this *ITranscribeResult) GetTokens(len uint32) []SToken { + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getTokens, + uintptr(unsafe.Pointer(this)), + ) + + if unsafe.Pointer(ret) != nil { + return unsafe.Slice((*SToken)(unsafe.Pointer(ret)), len) + } else { + return []SToken{} + } +} diff --git a/pkg/whisper/context.go b/pkg/whisper/context.go new file mode 100644 index 0000000..8c469f6 --- /dev/null +++ b/pkg/whisper/context.go @@ -0,0 +1,281 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type uuid [16]byte + +type eResultFlags uint32 + +const ( + RfNone eResultFlags = 0 + + // Return individual tokens in addition to the segments + RfTokens = 1 + + // Return timestamps + RfTimestamps = 2 + + // Create a new COM object for the results. + // Without this flag, the context returns a pointer to the COM object stored in the context. + // The content of that object is replaced every time you call IContext.getResults method + RfNewObject = 0x100 +) + +type IContextVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + RunFull uintptr + RunStreamed uintptr + RunCapture uintptr + GetResults uintptr + DetectSpeaker uintptr + GetModel uintptr + FullDefaultParams uintptr + TimingsPrint uintptr + TimingsReset uintptr +} + +type IContext struct { + lpVtbl *IContextVtbl +} + +//type sFullParams struct{} + +// type iAudioBuffer struct{} +type sProgressSink struct { + pfn uintptr + pv uintptr +} + +// type iAudioReader struct{} +type sCaptureCallbacks struct{} + +// type iAudioCapture struct{} +// type eResultFlags int32 +// type iTranscribeResult struct{} +// type sTimeInterval struct{} +type eSpeakerChannel int32 + +//type eSamplingStrategy int32 + +// Create a new IContext instance +func newIContext() *IContext { + return &IContext{ + lpVtbl: &IContextVtbl{ + QueryInterface: 0, + AddRef: 0, + Release: 0, + RunFull: 0, + RunStreamed: 0, + RunCapture: 0, + GetResults: 0, + DetectSpeaker: 0, + GetModel: 0, + FullDefaultParams: 0, + TimingsPrint: 0, + TimingsReset: 0, + }, + } +} + +func (context *IContext) TimingsPrint() error { + + // TimingsPrint(); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.TimingsPrint, + uintptr(unsafe.Pointer(context)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text +// Uses the specified decoding strategy to obtain the text. +func (context *IContext) RunFull(params *FullParams, buffer *iAudioBuffer) error { + + // runFull( const sFullParams& params, const iAudioBuffer* buffer ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunFull, + uintptr(unsafe.Pointer(context)), + + uintptr(unsafe.Pointer(params.cStruct)), + uintptr(unsafe.Pointer(buffer)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +func (context *IContext) RunStreamed(params *FullParams, reader *iAudioReader) error { + + cb := sProgressSink{} + + // runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunStreamed, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(params.cStruct)), + uintptr(unsafe.Pointer(&cb)), // No progress cb yet + uintptr(unsafe.Pointer(reader)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunStreamed failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +func (this *IContext) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *IContext) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +/* +https://github.com/Const-me/Whisper/blob/f6f743c7b3570b85ccf47f74b84e06a73667ef3e/Whisper/Whisper/ContextImpl.misc.cpp + +Returns E_POINTER if null pointer provided in params +Initialises params to all 0 +sets values in struct, does not malloc +*/ +func (context *IContext) FullDefaultParams(strategy eSamplingStrategy) (*FullParams, error) { + + /* + ERR : unreadable Only part of a ReadProcessMemory or WriteProcessMemory request was completed + * not related to stratergy ... tested 0, 1 and 2 ... 2 produced E_INVALIDARG as expected + * not a nil ptr to params ... nil poitner produced E_POINTER as expected + * params seems to return 0x4000 + * !!!!! FullParams is not a com interface !!! + * so no lpVtbl *FullParamsVtbl , no queryinterface, addref etc + */ + + params := _newFullParams_cStruct() + //params := &[160]byte{} + + ret, _, _ := syscall.SyscallN( + context.lpVtbl.FullDefaultParams, + uintptr(unsafe.Pointer(context)), + uintptr(strategy), + uintptr(unsafe.Pointer(params)), + ) + + // nil ptr should be 0x80004003L + // unsafe.Pointer(0xc00011dc28) + // unsafe.Pointer(0x4000) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + + } + + if params == nil { + return nil, errors.New("FullDefaultParams did not return params") + } + ParamObj := NewFullParams(params) + + if ParamObj.TestDefaultsOK() { + return ParamObj, nil + } + + return nil, nil +} + +func (context *IContext) GetModel() (*_IModel, error) { + + var modelptr *_IModel + + // getModel( iModel** pp ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.GetModel, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(&modelptr)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + if modelptr == nil { + return nil, errors.New("loadModel did not return a Model") + } + + if modelptr.lpVtbl == nil { + return nil, errors.New("loadModel method table is nil") + } + + return modelptr, nil +} + +// ************************************************************************************************************************************************ +// Not really implemented / tested +// ************************************************************************************************************************************************ + +func (context *IContext) RunCapture(params *FullParams, callbacks *sCaptureCallbacks, reader *iAudioCapture) uintptr { + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunCapture, + //3, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(params)), + uintptr(unsafe.Pointer(callbacks)), + uintptr(unsafe.Pointer(reader)), + ) + return ret +} + +func (context *IContext) GetResults(flags eResultFlags, pp **ITranscribeResult) uintptr { + ret, _, _ := syscall.Syscall( + context.lpVtbl.GetResults, + 3, + uintptr(unsafe.Pointer(context)), + uintptr(flags), + uintptr(unsafe.Pointer(pp)), + ) + return ret +} + +func (context *IContext) DetectSpeaker(time *sTimeInterval, result *eSpeakerChannel) uintptr { + ret, _, _ := syscall.Syscall( + context.lpVtbl.DetectSpeaker, + 3, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(time)), + uintptr(unsafe.Pointer(result)), + ) + return ret +} diff --git a/pkg/whisper/doc.go b/pkg/whisper/doc.go new file mode 100644 index 0000000..dd1f9f9 --- /dev/null +++ b/pkg/whisper/doc.go @@ -0,0 +1,5 @@ +/* +github.com/jaybinks/goConstmeWhispers +Go Bindings for https://github.com/Const-me/Whisper +*/ +package whisper diff --git a/pkg/whisper/language.go b/pkg/whisper/language.go new file mode 100644 index 0000000..4e6a042 --- /dev/null +++ b/pkg/whisper/language.go @@ -0,0 +1,207 @@ +package whisper + +// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/eLanguage.cs + +type eLanguage int32 + +const ( + Auto eLanguage = -1 // "af" + + Afrikaans = 0x6661 // "af" + /// Albanian + Albanian = 0x7173 // "sq" + /// Amharic + Amharic = 0x6D61 // "am" + /// Arabic + Arabic = 0x7261 // "ar" + /// Armenian + Armenian = 0x7968 // "hy" + /// Assamese + Assamese = 0x7361 // "as" + /// Azerbaijani + Azerbaijani = 0x7A61 // "az" + /// Bashkir + Bashkir = 0x6162 // "ba" + /// Basque + Basque = 0x7565 // "eu" + /// Belarusian + Belarusian = 0x6562 // "be" + /// Bengali + Bengali = 0x6E62 // "bn" + /// Bosnian + Bosnian = 0x7362 // "bs" + /// Breton + Breton = 0x7262 // "br" + /// Bulgarian + Bulgarian = 0x6762 // "bg" + /// Catalan + Catalan = 0x6163 // "ca" + /// Chinese + Chinese = 0x687A // "zh" + /// Croatian + Croatian = 0x7268 // "hr" + /// Czech + Czech = 0x7363 // "cs" + /// Danish + Danish = 0x6164 // "da" + /// Dutch + Dutch = 0x6C6E // "nl" + /// English + English = 0x6E65 // "en" + /// Estonian + Estonian = 0x7465 // "et" + /// Faroese + Faroese = 0x6F66 // "fo" + /// Finnish + Finnish = 0x6966 // "fi" + /// French + French = 0x7266 // "fr" + /// Galician + Galician = 0x6C67 // "gl" + /// Georgian + Georgian = 0x616B // "ka" + /// German + German = 0x6564 // "de" + /// Greek + Greek = 0x6C65 // "el" + /// Gujarati + Gujarati = 0x7567 // "gu" + /// Haitian Creole + HaitianCreole = 0x7468 // "ht" + /// Hausa + Hausa = 0x6168 // "ha" + /// Hawaiian + Hawaiian = 0x776168 // "haw" + /// Hebrew + Hebrew = 0x7769 // "iw" + /// Hindi + Hindi = 0x6968 // "hi" + /// Hungarian + Hungarian = 0x7568 // "hu" + /// Icelandic + Icelandic = 0x7369 // "is" + /// Indonesian + Indonesian = 0x6469 // "id" + /// Italian + Italian = 0x7469 // "it" + /// Japanese + Japanese = 0x616A // "ja" + /// Javanese + Javanese = 0x776A // "jw" + /// Kannada + Kannada = 0x6E6B // "kn" + /// Kazakh + Kazakh = 0x6B6B // "kk" + /// Khmer + Khmer = 0x6D6B // "km" + /// Korean + Korean = 0x6F6B // "ko" + /// Lao + Lao = 0x6F6C // "lo" + /// Latin + Latin = 0x616C // "la" + /// Latvian + Latvian = 0x766C // "lv" + /// Lingala + Lingala = 0x6E6C // "ln" + /// Lithuanian + Lithuanian = 0x746C // "lt" + /// Luxembourgish + Luxembourgish = 0x626C // "lb" + /// Macedonian + Macedonian = 0x6B6D // "mk" + /// Malagasy + Malagasy = 0x676D // "mg" + /// Malay + Malay = 0x736D // "ms" + /// Malayalam + Malayalam = 0x6C6D // "ml" + /// Maltese + Maltese = 0x746D // "mt" + /// Maori + Maori = 0x696D // "mi" + /// Marathi + Marathi = 0x726D // "mr" + /// Mongolian + Mongolian = 0x6E6D // "mn" + /// Myanmar + Myanmar = 0x796D // "my" + /// Nepali + Nepali = 0x656E // "ne" + /// Norwegian + Norwegian = 0x6F6E // "no" + /// Nynorsk + Nynorsk = 0x6E6E // "nn" + /// Occitan + Occitan = 0x636F // "oc" + /// Pashto + Pashto = 0x7370 // "ps" + /// Persian + Persian = 0x6166 // "fa" + /// Polish + Polish = 0x6C70 // "pl" + /// Portuguese + Portuguese = 0x7470 // "pt" + /// Punjabi + Punjabi = 0x6170 // "pa" + /// Romanian + Romanian = 0x6F72 // "ro" + /// Russian + Russian = 0x7572 // "ru" + /// Sanskrit + Sanskrit = 0x6173 // "sa" + /// Serbian + Serbian = 0x7273 // "sr" + /// Shona + Shona = 0x6E73 // "sn" + /// Sindhi + Sindhi = 0x6473 // "sd" + /// Sinhala + Sinhala = 0x6973 // "si" + /// Slovak + Slovak = 0x6B73 // "sk" + /// Slovenian + Slovenian = 0x6C73 // "sl" + /// Somali + Somali = 0x6F73 // "so" + /// Spanish + Spanish = 0x7365 // "es" + /// Sundanese + Sundanese = 0x7573 // "su" + /// Swahili + Swahili = 0x7773 // "sw" + /// Swedish + Swedish = 0x7673 // "sv" + /// Tagalog + Tagalog = 0x6C74 // "tl" + /// Tajik + Tajik = 0x6774 // "tg" + /// Tamil + Tamil = 0x6174 // "ta" + /// Tatar + Tatar = 0x7474 // "tt" + /// Telugu + Telugu = 0x6574 // "te" + /// Thai + Thai = 0x6874 // "th" + /// Tibetan + Tibetan = 0x6F62 // "bo" + /// Turkish + Turkish = 0x7274 // "tr" + /// Turkmen + Turkmen = 0x6B74 // "tk" + /// Ukrainian + Ukrainian = 0x6B75 // "uk" + /// Urdu + Urdu = 0x7275 // "ur" + /// Uzbek + Uzbek = 0x7A75 // "uz" + /// Vietnamese + Vietnamese = 0x6976 // "vi" + /// Welsh + Welsh = 0x7963 // "cy" + /// Yiddish + Yiddish = 0x6979 // "yi" + /// Yoruba + Yoruba = 0x6F79 // "yo" +) diff --git a/pkg/whisper/logger.go b/pkg/whisper/logger.go new file mode 100644 index 0000000..5c9aacb --- /dev/null +++ b/pkg/whisper/logger.go @@ -0,0 +1,43 @@ +package whisper + +import ( + "C" + "fmt" +) + +/* + https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/loggerApi.h + +*/ + +type eLogLevel uint8 + +const ( + LlError eLogLevel = 0 + LlWarning = 1 + LlInfo = 2 + LlDebug = 3 +) + +type eLogFlags uint8 + +const ( + LfNone eLogFlags = 0 + LfUseStandardError = 1 + LfSkipFormatMessage = 2 +) + +type sLoggerSetup struct { + sink uintptr // pfnLoggerSink + context uintptr // void* + level eLogLevel // eLogLevel + flags eLogFlags // eLoggerFlags +} + +func fnLoggerSink(context uintptr, lvl eLogLevel, message *C.char) uintptr { + + strmessage := C.GoString(message) + fmt.Printf("%d - %s\n", lvl, strmessage) + + return 0 +} diff --git a/pkg/whisper/whisper.go b/pkg/whisper/whisper.go new file mode 100644 index 0000000..04b327f --- /dev/null +++ b/pkg/whisper/whisper.go @@ -0,0 +1,188 @@ +//go:build windows +// +build windows + +package whisper + +import ( + "C" + "errors" + "syscall" + "unsafe" + + // Using lxn/win because its COM functions expose raw HRESULTs + "golang.org/x/sys/windows" +) +import ( + "fmt" +) + +/* + eModelImplementation - TranscribeStructs.h + + // GPGPU implementation based on Direct3D 11.0 compute shaders + GPU = 1, + + // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU + // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1 + Hybrid = 2, + + // A reference implementation which uses the original GGML CPU-running code + // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1 + Reference = 3, +*/ + +// https://learn.microsoft.com/en-us/windows/win32/seccrypto/common-hresult-values +// https://pkg.go.dev/golang.org/x/sys/windows +const ( + E_INVALIDARG = 0x80070057 + ERROR_HV_CPUID_FEATURE_VALIDATION = 0xC0350038 + + DLLName = "whisper.dll" +) + +type Libwhisper struct { + dll *syscall.LazyDLL + ver WinVersion + existing_model map[string]*Model + + proc_setupLogger *syscall.LazyProc + proc_loadModel *syscall.LazyProc + proc_initMediaFoundation *syscall.LazyProc + // proc_findLanguageKeyW *syscall.LazyProc + // proc_findLanguageKeyA *syscall.LazyProc + // proc_getSupportedLanguages *syscall.LazyProc +} + +var singleton_whisper *Libwhisper = nil + +func New(level eLogLevel, flags eLogFlags, cb *any) (*Libwhisper, error) { + if singleton_whisper != nil { + return singleton_whisper, nil + } + + var err error + this := &Libwhisper{} + + this.ver, err = GetFileVersion(DLLName) + if err != nil { + return nil, err + } + + if this.ver.Major < 1 && this.ver.Minor < 9 { + return nil, errors.New("This library requires whisper.dll version 1.9 or higher.") // or less than 1.11 for now .. because the API changed + } + + this.dll = syscall.NewLazyDLL(DLLName) // Todo wrap this in a class, check file exists, handle errors ... you know, just a few things.. AKA Stop being lazy + + this.proc_setupLogger = this.dll.NewProc("setupLogger") + this.proc_loadModel = this.dll.NewProc("loadModel") + this.proc_initMediaFoundation = this.dll.NewProc("initMediaFoundation") + /* + this.proc_findLanguageKeyW = this.dll.NewProc("findLanguageKeyW") + this.proc_findLanguageKeyA = this.dll.NewProc("findLanguageKeyA") + this.proc_getSupportedLanguages = this.dll.NewProc("getSupportedLanguages") + */ + + ok, err := this._setupLogger(level, flags, cb) + if !ok { + return nil, errors.New("Logger Error : " + err.Error()) + } + + this.existing_model = make(map[string]*Model) + singleton_whisper = this + + return singleton_whisper, nil +} + +func (this *Libwhisper) Version() string { + return fmt.Sprintf("%d.%d.%d.%d.", this.ver.Major, this.ver.Minor, this.ver.Patch, this.ver.Build) +} + +func (this *Libwhisper) SupportsMultiThread() bool { + return this.ver.Major >= 1 && this.ver.Minor >= 10 +} + +func (this *Libwhisper) _setupLogger(level eLogLevel, flags eLogFlags, cb *any) (bool, error) { + + setup := sLoggerSetup{} + setup.sink = 0 + setup.context = 0 + setup.level = level + setup.flags = flags + + if cb != nil { + setup.sink = syscall.NewCallback(cb) + } + + res, _, err := this.proc_setupLogger.Call(uintptr(unsafe.Pointer(&setup))) + + if windows.Handle(res) == windows.S_OK { + return true, nil + } else { + return false, err + } +} + +func (this *Libwhisper) LoadModel(path string, aGPU ...string) (*Model, error) { + var modelptr *_IModel + + whisperpath, _ := windows.UTF16PtrFromString(path) + + GPU := "" + if len(aGPU) == 1 { + GPU = aGPU[0] + } + + setup := ModelSetup(gmf_Cloneable, GPU) + + // Construct our map hash + singleton_hash := GPU + "|" + path + if this.existing_model[singleton_hash] != nil { + ClonedModel, err := this.existing_model[singleton_hash].Clone() + if ClonedModel != nil { + return NewModel(setup, ClonedModel), nil + } else { + return nil, err + } + } + + obj, _, _ := this.proc_loadModel.Call(uintptr(unsafe.Pointer(whisperpath)), uintptr(unsafe.Pointer(setup.AsCType())), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&modelptr))) + + if windows.Handle(obj) != windows.S_OK { + fmt.Printf("loadModel failed: %s\n", syscall.Errno(obj).Error()) + return nil, fmt.Errorf("loadModel failed: %s", syscall.Errno(obj)) + } + + if modelptr == nil { + return nil, errors.New("loadModel did not return a Model") + } + + if modelptr.lpVtbl == nil { + return nil, errors.New("loadModel method table is nil") + } + + model := NewModel(setup, modelptr) + + this.existing_model[singleton_hash] = model + + return model, nil +} + +func (this *Libwhisper) InitMediaFoundation() (*IMediaFoundation, error) { + + var mediafoundation *IMediaFoundation + + // initMediaFoundation( iMediaFoundation** pp ); + obj, _, _ := this.proc_initMediaFoundation.Call(uintptr(unsafe.Pointer(&mediafoundation))) + + if windows.Handle(obj) != windows.S_OK { + fmt.Printf("initMediaFoundation failed: %s\n", syscall.Errno(obj).Error()) + return nil, fmt.Errorf("initMediaFoundation failed: %s", syscall.Errno(obj)) + } + + if mediafoundation.lpVtbl == nil { + return nil, errors.New("initMediaFoundation method table is nil") + } + + return mediafoundation, nil +} diff --git a/pkg/whisper/winversion.go b/pkg/whisper/winversion.go new file mode 100644 index 0000000..c51bb1d --- /dev/null +++ b/pkg/whisper/winversion.go @@ -0,0 +1,123 @@ +// Copyright 2018 Keybase Inc. All rights reserved. +// Use of this source code is governed by a BSD +// license that can be found in the LICENSE file. +// Adapted mainly from github.com/gonutz/w32 + +//go:build windows +// +build windows + +package whisper + +import ( + "errors" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + version = windows.NewLazySystemDLL("version.dll") + getFileVersionInfoSize = version.NewProc("GetFileVersionInfoSizeW") + getFileVersionInfo = version.NewProc("GetFileVersionInfoW") + verQueryValue = version.NewProc("VerQueryValueW") +) + +type VS_FIXEDFILEINFO struct { + Signature uint32 + StrucVersion uint32 + FileVersionMS uint32 + FileVersionLS uint32 + ProductVersionMS uint32 + ProductVersionLS uint32 + FileFlagsMask uint32 + FileFlags uint32 + FileOS uint32 + FileType uint32 + FileSubtype uint32 + FileDateMS uint32 + FileDateLS uint32 +} + +type WinVersion struct { + Major uint32 + Minor uint32 + Patch uint32 + Build uint32 +} + +// FileVersion concatenates FileVersionMS and FileVersionLS to a uint64 value. +func (fi VS_FIXEDFILEINFO) FileVersion() uint64 { + return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS) +} + +func GetFileVersionInfoSize(path string) uint32 { + ret, _, _ := getFileVersionInfoSize.Call( + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + 0, + ) + return uint32(ret) +} + +func GetFileVersionInfo(path string, data []byte) bool { + ret, _, _ := getFileVersionInfo.Call( + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + 0, + uintptr(len(data)), + uintptr(unsafe.Pointer(&data[0])), + ) + return ret != 0 +} + +// VerQueryValueRoot calls VerQueryValue +// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx) +// with `\` (root) to retieve the VS_FIXEDFILEINFO. +func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) { + var offset uintptr + var length uint + blockStart := unsafe.Pointer(&block[0]) + ret, _, _ := verQueryValue.Call( + uintptr(blockStart), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))), + uintptr(unsafe.Pointer(&offset)), + uintptr(unsafe.Pointer(&length)), + ) + if ret == 0 { + return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: verQueryValue failed") + } + start := int(offset) - int(uintptr(blockStart)) + end := start + int(length) + if start < 0 || start >= len(block) || end < start || end > len(block) { + return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: find failed") + } + data := block[start:end] + info := *((*VS_FIXEDFILEINFO)(unsafe.Pointer(&data[0]))) + return info, nil +} + +func GetFileVersion(path string) (WinVersion, error) { + var result WinVersion + size := GetFileVersionInfoSize(path) + if size <= 0 { + return result, errors.New("GetFileVersionInfoSize failed") + } + + info := make([]byte, size) + ok := GetFileVersionInfo(path, info) + if !ok { + return result, errors.New("GetFileVersionInfo failed") + } + + fixed, err := VerQueryValueRoot(info) + if err != nil { + return result, err + } + version := fixed.FileVersion() + + result.Major = uint32(version & 0xFFFF000000000000 >> 48) + result.Minor = uint32(version & 0x0000FFFF00000000 >> 32) + result.Patch = uint32(version & 0x00000000FFFF0000 >> 16) + result.Build = uint32(version & 0x000000000000FFFF) + + return result, nil +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..55e1d1c --- /dev/null +++ b/state.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/xzeldon/whisper-api-server/pkg/whisper" +) + +type WhisperState struct { + model *whisper.Model + context *whisper.IContext + media *whisper.IMediaFoundation + params *whisper.FullParams + mutex sync.Mutex +} + +func InitializeWhisperState(modelPath string) (*WhisperState, error) { + lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil) + if err != nil { + return nil, err + } + + model, err := lib.LoadModel(modelPath) + if err != nil { + return nil, err + } + + context, err := model.CreateContext() + if err != nil { + return nil, err + } + + media, err := lib.InitMediaFoundation() + if err != nil { + return nil, err + } + + params, err := context.FullDefaultParams(whisper.SsBeamSearch) + if err != nil { + return nil, err + } + + params.AddFlags(whisper.FlagNoContext) + params.AddFlags(whisper.FlagTokenTimestamps) + + fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads()) + + return &WhisperState{ + model: model, + context: context, + media: media, + params: params, + }, nil +} + +func getResult(ctx *whisper.IContext) (string, error) { + results := &whisper.ITranscribeResult{} + ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results) + + length, err := results.GetSize() + if err != nil { + return "", err + } + + segments := results.GetSegments(length.CountSegments) + + var result string + + for _, seg := range segments { + result += seg.Text() + } + + return result, nil +} diff --git a/tmp/.gitkeep b/tmp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..e7f868f --- /dev/null +++ b/utils.go @@ -0,0 +1,48 @@ +package main + +import ( + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/labstack/echo/v4" +) + +func saveFormFile(name string, c echo.Context) (string, error) { + file, err := c.FormFile(name) + if err != nil { + return "", err + } + + src, err := file.Open() + if err != nil { + return "", err + } + defer src.Close() + + ext := filepath.Ext(file.Filename) + filename := time.Now().Format(time.RFC3339) + filename = "./tmp/" + sanitizeFilename(filename) + ext + + dst, err := os.Create(filename) + if err != nil { + return "", err + } + defer dst.Close() + + if _, err = io.Copy(dst, src); err != nil { + return "", err + } + + return filename, nil +} + +func sanitizeFilename(filename string) string { + invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`} + for _, char := range invalidChars { + filename = strings.ReplaceAll(filename, char, "-") + } + return filename +}