commit 8e4cb72b50785d4feeca431fc2bb2ec523126e64 Author: xzeldon Date: Wed Oct 4 01:09:38 2023 +0300 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5b27170 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +Whisper.dll +ggml-medium.bin + +whisper-api-server.exe + +tmp/* +!tmp/.gitkeep \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..675d1d2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Timofey Gelazoniya + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..619a737 --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# Whisper API Server (Go) + +## ⚠️ This project is a work in progress (WIP). + +This API server enables audio transcription using the OpenAI Whisper models. + +# Setup + +- Download the desired model from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main) +- Update the model path in the `main.go` file +- Download `Whisper.dll` from [github](https://github.com/Const-me/Whisper/releases/tag/1.12.0) (`Library.zip`) and place it in the project's root directory +- Build project: `go build .` (you only need go compiler, without gcc) + +# Usage example + +Make a request to the server using the following command: + +```sh +curl http://localhost:3000/v1/audio/transcriptions \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ +``` + +Receive a response in JSON format: + +```json +{ + "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that." +} +``` + +# Roadmap + +- [ ] Implement automatic model downloading from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main) +- [ ] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases) +- [ ] Provide prebuilt binaries for Windows +- [ ] Include instructions for running on Linux with Wine (likely possible). + +# Credits + +- [Const-me/Whisper](https://github.com/Const-me/Whisper) project +- [goConstmeWhisper](https://github.com/jaybinks/goConstmeWhisper) for the remarkable Go bindings for [Const-me/Whisper](https://github.com/Const-me/Whisper) +- [Georgi Gerganov](https://github.com/ggerganov) for GGML models diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b4f3c0d --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/xzeldon/whisper-api-server + +go 1.21.1 + +require ( + github.com/labstack/echo/v4 v4.11.1 + golang.org/x/sys v0.12.0 +) + +require ( + github.com/labstack/gommon v0.4.0 + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/text v0.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..119ff10 --- /dev/null +++ b/go.sum @@ -0,0 +1,43 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= +github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..6b4ca39 --- /dev/null +++ b/handler.go @@ -0,0 +1,45 @@ +package main + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +type TranscribeResponse struct { + Text string `json:"text"` +} + +func transcribe(c echo.Context, whisperState *WhisperState) error { + audioPath, err := saveFormFile("file", c) + if err != nil { + c.Logger().Errorf("Error reading file: %s", err) + return err + } + + whisperState.mutex.Lock() + buffer, err := whisperState.media.LoadAudioFile(audioPath, true) + if err != nil { + c.Logger().Errorf("Error loading audio file data: %s", err) + } + + err = whisperState.context.RunFull(whisperState.params, buffer) + + result, err := getResult(whisperState.context) + if err != nil { + c.Logger().Error(err) + } + + defer whisperState.mutex.Unlock() + + if len(result) == 0 { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"}) + } + + response := TranscribeResponse{ + Text: strings.TrimLeft(result, " "), + } + + return c.JSON(http.StatusOK, response) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..cf1328a --- /dev/null +++ b/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" +) + +const MODEL_PATH = "./ggml-medium.bin" + +func main() { + e := echo.New() + e.HideBanner = true + + if l, ok := e.Logger.(*log.Logger); ok { + l.SetHeader("${time_rfc3339} ${level}") + } + + whisperState, err := InitializeWhisperState(MODEL_PATH) + if err != nil { + e.Logger.Error(err) + return + } + + e.POST("/v1/audio/transcriptions", func(c echo.Context) error { + return transcribe(c, whisperState) + }) + + e.Logger.Fatal(e.Start(":3000")) +} diff --git a/pkg/whisper/FullParams.go b/pkg/whisper/FullParams.go new file mode 100644 index 0000000..cb49136 --- /dev/null +++ b/pkg/whisper/FullParams.go @@ -0,0 +1,248 @@ +package whisper + +import ( + "syscall" + "unsafe" +) + +// https://github.com/Const-me/Whisper/blob/master/Whisper/API/sFullParams.h +// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/Parameters.cs + +type eSamplingStrategy uint32 + +const ( + SsGreedy eSamplingStrategy = iota + SsBeamSearch + SsINVALIDARG +) + +type eFullParamsFlags uint32 + +const ( + FlagNone eFullParamsFlags = 0 + FlagTranslate = 1 << 0 + FlagNoContext = 1 << 1 + FlagSingleSegment = 1 << 2 + FlagPrintSpecial = 1 << 3 + FlagPrintProgress = 1 << 4 + FlagPrintRealtime = 1 << 5 + FlagPrintTimestamps = 1 << 6 + FlagTokenTimestamps = 1 << 7 // Experimental + FlagSpeedupAudio = 1 << 8 +) + +type EWhisperHWND uintptr + +const ( + S_OK EWhisperHWND = 0 + S_FALSE EWhisperHWND = 1 +) + +type FullParams struct { + cStruct *_FullParams +} + +func (this *FullParams) CpuThreads() int32 { + if this == nil { + return 0 + } else if this.cStruct == nil { + return 0 + } + + return this.cStruct.cpuThreads +} + +func (this *FullParams) setCpuThreads(val int32) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.cpuThreads = val +} + +func (this *FullParams) SetMaxTextCTX(val int32) { + this.cStruct.n_max_text_ctx = val +} + +func (this *FullParams) AddFlags(newflag eFullParamsFlags) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.Flags = this.cStruct.Flags | newflag +} + +func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.Flags = this.cStruct.Flags ^ newflag +} + +/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/ +type NewSegmentCallback_Type func(context *IContext, n_new uint32, user_data unsafe.Pointer) EWhisperHWND + +func (this *FullParams) SetNewSegmentCallback(cb NewSegmentCallback_Type) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + this.cStruct.new_segment_callback = syscall.NewCallback(cb) +} + +/* +Return S_OK to proceed, or S_FALSE to stop the process +*/ +type EncoderBeginCallback_Type func(context *IContext, user_data unsafe.Pointer) EWhisperHWND + +func (this *FullParams) SetEncoderBeginCallback(cb EncoderBeginCallback_Type) { + if this == nil { + return + } else if this.cStruct == nil { + return + } + + this.cStruct.encoder_begin_callback = syscall.NewCallback(cb) +} + +func (this *FullParams) TestDefaultsOK() bool { + if this == nil { + return false + } else if this.cStruct == nil { + return false + } + + if this.cStruct.n_max_text_ctx != 16384 { + return false + } + + if this.cStruct.Flags != (FlagPrintProgress | FlagPrintTimestamps) { + return false + } + + if this.cStruct.thold_pt != 0.01 { + return false + } + + if this.cStruct.thold_ptsum != 0.01 { + return false + } + + if this.cStruct.Language != English { + return false + } + + // Todo ... why do these not line up as expected.. is our struct out of alignment ? + /* + if this.cStruct.strategy == ssGreedy { + if this.cStruct.beam_search.n_past != -1 || + this.cStruct.beam_search.beam_width != -1 || + this.cStruct.beam_search.n_best != -1 { + return false + } + + } else if this.cStruct.strategy == ssBeamSearch { + if this.cStruct.greedy.n_past != -1 || + this.cStruct.beam_search.beam_width != 10 || + this.cStruct.beam_search.n_best != 5 { + return false + } + } + */ + + return true +} + +type _FullParams struct { + strategy eSamplingStrategy + cpuThreads int32 + n_max_text_ctx int32 + offset_ms int32 + duration_ms int32 + Flags eFullParamsFlags + Language eLanguage + + thold_pt float32 + thold_ptsum float32 + max_len int32 + max_tokens int32 + + greedy struct{ n_past int32 } + beam_search struct { + n_past int32 + beam_width int32 + n_best int32 + } + + audio_ctx int32 // overwrite the audio context size (0 = use default) + + prompt_tokens uintptr + prompt_n_tokens int32 + + new_segment_callback uintptr + new_segment_callback_user_data uintptr + + encoder_begin_callback uintptr + encoder_begin_callback_user_data uintptr + + // Are these needed ?? Jay + // setFlag uintptr +} + +func NewFullParams(cstruct *_FullParams) *FullParams { + this := FullParams{} + this.cStruct = cstruct + return &this +} + +func _newFullParams_cStruct() *_FullParams { + return &_FullParams{ + + strategy: 0, + cpuThreads: 0, + n_max_text_ctx: 0, + offset_ms: 0, + duration_ms: 0, + + Flags: 0, + Language: 0, + + thold_pt: 0, + thold_ptsum: 0, + max_len: 0, + max_tokens: 0, + + // anonymous int32 + greedy: struct{ n_past int32 }{n_past: 0}, + + // anonymous struct + beam_search: struct { + n_past int32 + beam_width int32 + n_best int32 + }{ + n_past: 0, + beam_width: 0, + n_best: 0, + }, + + audio_ctx: 0, + + prompt_tokens: 0, + prompt_n_tokens: 0, + + new_segment_callback: 0, + new_segment_callback_user_data: 0, + + encoder_begin_callback: 0, + encoder_begin_callback_user_data: 0, + } +} diff --git a/pkg/whisper/MediaFoundation.go b/pkg/whisper/MediaFoundation.go new file mode 100644 index 0000000..b5d665e --- /dev/null +++ b/pkg/whisper/MediaFoundation.go @@ -0,0 +1,222 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/IMediaFoundation.h + +type IMediaFoundation struct { + lpVtbl *IMediaFoundationVtbl +} + +type IMediaFoundationVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + loadAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const; + openAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioReader** pp ); + loadAudioFileData uintptr // ( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); HRESULT + listCaptureDevices uintptr // ( pfnFoundCaptureDevices pfn, void* pv ); + openCaptureDevice uintptr // ( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ); +} + +func (this *IMediaFoundation) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *IMediaFoundation) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +// ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const; +func (this *IMediaFoundation) LoadAudioFile(file string, stereo bool) (*iAudioBuffer, error) { + + var buffer *iAudioBuffer + + UTFFileName, _ := windows.UTF16PtrFromString(file) + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.loadAudioFile, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(UTFFileName)), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&buffer))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("loadAudioFile failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return buffer, nil +} + +func (this *IMediaFoundation) OpenAudioFile(file string, stereo bool) (*iAudioReader, error) { + + var buffer *iAudioReader + + UTFFileName, _ := windows.UTF16PtrFromString(file) + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.openAudioFile, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(UTFFileName)), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&buffer))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("openAudioFile failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return buffer, nil +} + +func (this *IMediaFoundation) LoadAudioFileData(inbuffer *[]byte, stereo bool) (*iAudioReader, error) { + + var reader *iAudioReader + + // loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); + ret, _, _ := syscall.SyscallN( + this.lpVtbl.loadAudioFileData, + uintptr(unsafe.Pointer(this)), + + uintptr(unsafe.Pointer(&(*inbuffer)[0])), + uintptr(uint64(len(*inbuffer))), + uintptr(1), // Todo ... Stereo ! + uintptr(unsafe.Pointer(&reader))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + return reader, nil +} + +// ************************************************************ + +type iAudioBuffer struct { + lpVtbl *iAudioBufferVtbl +} + +type iAudioBufferVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + countSamples uintptr // returns uint32_t + getPcmMono uintptr // returns float* + getPcmStereo uintptr // returns float* + getTime uintptr // ( int64_t& rdi ) +} + +func (this *iAudioBuffer) AddRef() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.AddRef, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioBuffer) Release() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.Release, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioBuffer) CountSamples() (uint32, error) { + + ret, _, err := syscall.SyscallN( + this.lpVtbl.countSamples, + uintptr(unsafe.Pointer(this)), + ) + + if err != 0 { + return 0, errors.New(err.Error()) + } + + return uint32(ret), nil +} + +// ************************************************************ + +type iAudioReader struct { + lpVtbl *iAudioReaderVtbl +} + +type iAudioReaderVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + getDuration uintptr // ( int64_t& rdi ) + getReader uintptr // ( IMFSourceReader** pp ) + requestedStereo uintptr // () +} + +func (this *iAudioReader) AddRef() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.AddRef, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioReader) Release() int32 { + ret, _, _ := syscall.SyscallN( + this.lpVtbl.Release, + uintptr(unsafe.Pointer(this)), + ) + return int32(ret) +} + +func (this *iAudioReader) GetDuration() (uint64, error) { + + var rdi int64 + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getDuration, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(&rdi)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error()) + return 0, syscall.Errno(ret) + } + + return uint64(rdi), nil +} + +// ************************************************************ + +type iAudioCapture struct { + lpVtbl *iAudioCaptureVtbl +} + +type iAudioCaptureVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + getReader uintptr // ( IMFSourceReader** pp ) + getParams uintptr // returns sCaptureParams& +} diff --git a/pkg/whisper/Model.go b/pkg/whisper/Model.go new file mode 100644 index 0000000..1635325 --- /dev/null +++ b/pkg/whisper/Model.go @@ -0,0 +1,119 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// External - Go version of the struct +type Model struct { + cStruct *_IModel + setup *sModelSetup +} + +// Internal - C Version of the structs +type _IModel struct { + lpVtbl *IModelVtbl +} + +// https://github.com/Const-me/Whisper/blob/master/Whisper/API/iContext.cl.h +type IModelVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + createContext uintptr //( iContext** pp ) = 0; + tokenize uintptr /* HRESULT __stdcall tokenize( const char* text, pfnDecodedTokens pfn, void* pv ); */ + isMultilingual uintptr //() = 0; + getSpecialTokens uintptr //( SpecialTokens& rdi ) = 0; + stringFromToken uintptr //( whisper_token token ) = 0; + clone uintptr //( iModel** rdi ) = 0; +} + +func NewModel(setup *sModelSetup, cstruct *_IModel) *Model { + this := Model{} + this.setup = setup + this.cStruct = cstruct + return &this +} + +func (this *Model) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.cStruct.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this.cStruct)), + 0, + 0) + return int32(ret) +} + +func (this *Model) Release() int32 { + ret, _, _ := syscall.Syscall( + this.cStruct.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this.cStruct)), + 0, + 0) + return int32(ret) +} + +func (this *Model) CreateContext() (*IContext, error) { + var context *IContext + + /* + ret, _, err := syscall.Syscall( + this.cStruct.lpVtbl.createContext, + 2, // Why was this 1, rather than 2 ?? 1 seemed to work fine + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&context)), + 0)*/ + ret, _, err := syscall.SyscallN( + this.cStruct.lpVtbl.createContext, + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&context))) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("createContext failed: %w", err.Error()) + } + + if windows.Handle(ret) != windows.S_OK { + return nil, fmt.Errorf("loadModel failed: %w", err) + } + + return context, nil +} + +func (this *Model) IsMultilingual() bool { + ret, _, _ := syscall.SyscallN( + this.cStruct.lpVtbl.isMultilingual, + uintptr(unsafe.Pointer(this.cStruct)), + ) + + return bool(windows.Handle(ret) == windows.S_OK) +} + +func (this *Model) Clone() (*_IModel, error) { + + if this.setup.isFlagSet(gmf_Cloneable) { + return nil, errors.New("Model is not cloneable") + } + //this.Cloneable ? + + var modelptr *_IModel + + ret, _, _ := syscall.SyscallN( + this.cStruct.lpVtbl.clone, + uintptr(unsafe.Pointer(this.cStruct)), + uintptr(unsafe.Pointer(&modelptr)), + ) + + if windows.Handle(ret) == windows.S_OK { + return modelptr, nil + } else { + return nil, errors.New("Model.Clone() failed : " + syscall.Errno(ret).Error()) + } +} diff --git a/pkg/whisper/ModelSetup.go b/pkg/whisper/ModelSetup.go new file mode 100644 index 0000000..08c82ee --- /dev/null +++ b/pkg/whisper/ModelSetup.go @@ -0,0 +1,101 @@ +package whisper + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// Re-implemented sModelSetup.h + +// enum struct eModelImplementation : uint32_t +type eModelImplementation uint32 + +const ( + // GPGPU implementation based on Direct3D 11.0 compute shaders + mi_GPU eModelImplementation = 1 + + // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU + // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1 + mi_Hybrid eModelImplementation = 2 + + // A reference implementation which uses the original GGML CPU-running code + // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1 + mi_Reference eModelImplementation = 3 +) + +// enum struct eGpuModelFlags : uint32_t +type eGpuModelFlags uint32 + +const ( + // Equivalent to Wave32 | NoReshapedMatMul on Intel and nVidia GPUs,
+ // and Wave64 | UseReshapedMatMul on AMD GPUs
+ gmf_None eGpuModelFlags = 0 + + // Use Wave32 version of compute shaders even on AMD GPUs + // Incompatible with + gmf_Wave32 eGpuModelFlags = 1 + + // Use Wave64 version of compute shaders even on nVidia and Intel GPUs + // Incompatible with + gmf_Wave64 eGpuModelFlags = 2 + + // Do not use reshaped matrix multiplication shaders on AMD GPUs + // Incompatible with + gmf_NoReshapedMatMul eGpuModelFlags = 4 + + // Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs + // Incompatible with + gmf_UseReshapedMatMul eGpuModelFlags = 8 + + // Create GPU tensors in a way which allows sharing across D3D devices + gmf_Cloneable eGpuModelFlags = 0x10 +) + +// struct sModelSetup +type sModelSetup struct { + impl eModelImplementation + flags eGpuModelFlags + adapter string +} + +type _sModelSetup struct { + impl eModelImplementation + flags eGpuModelFlags + adapter uintptr +} + +func ModelSetup(flags eGpuModelFlags, GPU string) *sModelSetup { + this := sModelSetup{} + this.impl = mi_GPU + this.flags = flags + this.adapter = GPU + + return &this +} + +func (this *sModelSetup) isFlagSet(flag eGpuModelFlags) bool { + return (this.flags & flag) == 0 +} + +func (this *sModelSetup) AsCType() *_sModelSetup { + var err error + + ctype := _sModelSetup{} + ctype.impl = this.impl + ctype.flags = this.flags + ctype.adapter = 0 + + // Conver Go String to wchar_t, AKA UTF-16 + if this.adapter != "" { + var UTF16str *uint16 + UTF16str, err = windows.UTF16PtrFromString(this.adapter) + ctype.adapter = uintptr(unsafe.Pointer(UTF16str)) + } + + if err != nil { + return nil + } else { + return &ctype + } +} diff --git a/pkg/whisper/TranscribeResult.go b/pkg/whisper/TranscribeResult.go new file mode 100644 index 0000000..8ea0367 --- /dev/null +++ b/pkg/whisper/TranscribeResult.go @@ -0,0 +1,178 @@ +package whisper + +import ( + "C" + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type eTokenFlags uint32 + +const ( + TfNone eTokenFlags = 0 + TfSpecial = 1 +) + +type sTranscribeLength struct { + CountSegments uint32 + CountTokens uint32 +} + +type sTimeSpan struct { + + // The value is expressed in 100-nanoseconds ticks: compatible with System.Timespan, FILETIME, and many other things + Ticks uint64 + + /* + operator sTimeSpanFields() const + { + return sTimeSpanFields{ ticks }; + } + void operator=( uint64_t tt ) + { + ticks = tt; + } + void operator=( int64_t tt ) + { + assert( tt >= 0 ); + ticks = (uint64_t)tt; + } */ +} + +type sTimeInterval struct { + Begin sTimeSpan + End sTimeSpan +} + +type sSegment struct { + // Segment text, null-terminated, and probably UTF-8 encoded + text *C.char + + // Start and end times of the segment + Time sTimeInterval + + // These two integers define the slice of the tokens in this segment, in the array returned by iTranscribeResult.getTokens method + FirstToken uint32 + CountTokens uint32 +} + +func (this *sSegment) Text() string { + return C.GoString(this.text) +} + +type sSegmentArray []sSegment + +type SToken struct { + // Token text, null-terminated, and usually UTF-8 encoded. + // I think for Chinese language the models sometimes outputs invalid UTF8 strings here, Unicode code points can be split between adjacent tokens in the same segment + // More info: https://github.com/ggerganov/whisper.cpp/issues/399 + text *C.char + + // Start and end times of the token + Time sTimeInterval + // Probability of the token + Probability float32 + + // Probability of the timestamp token + ProbabilityTimestamp float32 + + // Sum of probabilities of all timestamp tokens + Ptsum float32 + + // Voice length of the token + Vlen float32 + + // Token id + Id int32 + + Flags eTokenFlags +} + +func (this *SToken) Text() string { + return C.GoString(this.text) +} + +type sTokenArray []SToken + +type iTranscribeResultVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + getSize uintptr // ( sTranscribeLength& rdi ) HRESULT + getSegments uintptr // () getTokens + getTokens uintptr // () getToken* +} + +type ITranscribeResult struct { + lpVtbl *iTranscribeResultVtbl +} + +func (this *ITranscribeResult) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *ITranscribeResult) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *ITranscribeResult) GetSize() (*sTranscribeLength, error) { + + var result sTranscribeLength + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getSize, + uintptr(unsafe.Pointer(this)), + uintptr(unsafe.Pointer(&result)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("iTranscribeResult.GetSize failed: %s\n", syscall.Errno(ret).Error()) + return nil, errors.New(syscall.Errno(ret).Error()) + } + + return &result, nil + +} + +func (this *ITranscribeResult) GetSegments(len uint32) []sSegment { + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getSegments, + uintptr(unsafe.Pointer(this)), + ) + + data := unsafe.Slice((*sSegment)(unsafe.Pointer(ret)), len) + + return data +} + +func (this *ITranscribeResult) GetTokens(len uint32) []SToken { + + ret, _, _ := syscall.SyscallN( + this.lpVtbl.getTokens, + uintptr(unsafe.Pointer(this)), + ) + + if unsafe.Pointer(ret) != nil { + return unsafe.Slice((*SToken)(unsafe.Pointer(ret)), len) + } else { + return []SToken{} + } +} diff --git a/pkg/whisper/context.go b/pkg/whisper/context.go new file mode 100644 index 0000000..8c469f6 --- /dev/null +++ b/pkg/whisper/context.go @@ -0,0 +1,281 @@ +package whisper + +import ( + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type uuid [16]byte + +type eResultFlags uint32 + +const ( + RfNone eResultFlags = 0 + + // Return individual tokens in addition to the segments + RfTokens = 1 + + // Return timestamps + RfTimestamps = 2 + + // Create a new COM object for the results. + // Without this flag, the context returns a pointer to the COM object stored in the context. + // The content of that object is replaced every time you call IContext.getResults method + RfNewObject = 0x100 +) + +type IContextVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr + + RunFull uintptr + RunStreamed uintptr + RunCapture uintptr + GetResults uintptr + DetectSpeaker uintptr + GetModel uintptr + FullDefaultParams uintptr + TimingsPrint uintptr + TimingsReset uintptr +} + +type IContext struct { + lpVtbl *IContextVtbl +} + +//type sFullParams struct{} + +// type iAudioBuffer struct{} +type sProgressSink struct { + pfn uintptr + pv uintptr +} + +// type iAudioReader struct{} +type sCaptureCallbacks struct{} + +// type iAudioCapture struct{} +// type eResultFlags int32 +// type iTranscribeResult struct{} +// type sTimeInterval struct{} +type eSpeakerChannel int32 + +//type eSamplingStrategy int32 + +// Create a new IContext instance +func newIContext() *IContext { + return &IContext{ + lpVtbl: &IContextVtbl{ + QueryInterface: 0, + AddRef: 0, + Release: 0, + RunFull: 0, + RunStreamed: 0, + RunCapture: 0, + GetResults: 0, + DetectSpeaker: 0, + GetModel: 0, + FullDefaultParams: 0, + TimingsPrint: 0, + TimingsReset: 0, + }, + } +} + +func (context *IContext) TimingsPrint() error { + + // TimingsPrint(); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.TimingsPrint, + uintptr(unsafe.Pointer(context)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text +// Uses the specified decoding strategy to obtain the text. +func (context *IContext) RunFull(params *FullParams, buffer *iAudioBuffer) error { + + // runFull( const sFullParams& params, const iAudioBuffer* buffer ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunFull, + uintptr(unsafe.Pointer(context)), + + uintptr(unsafe.Pointer(params.cStruct)), + uintptr(unsafe.Pointer(buffer)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +func (context *IContext) RunStreamed(params *FullParams, reader *iAudioReader) error { + + cb := sProgressSink{} + + // runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunStreamed, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(params.cStruct)), + uintptr(unsafe.Pointer(&cb)), // No progress cb yet + uintptr(unsafe.Pointer(reader)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("RunStreamed failed: %s\n", syscall.Errno(ret).Error()) + return errors.New(syscall.Errno(ret).Error()) + } + + return nil +} + +func (this *IContext) AddRef() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.AddRef, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +func (this *IContext) Release() int32 { + ret, _, _ := syscall.Syscall( + this.lpVtbl.Release, + 1, + uintptr(unsafe.Pointer(this)), + 0, + 0) + return int32(ret) +} + +/* +https://github.com/Const-me/Whisper/blob/f6f743c7b3570b85ccf47f74b84e06a73667ef3e/Whisper/Whisper/ContextImpl.misc.cpp + +Returns E_POINTER if null pointer provided in params +Initialises params to all 0 +sets values in struct, does not malloc +*/ +func (context *IContext) FullDefaultParams(strategy eSamplingStrategy) (*FullParams, error) { + + /* + ERR : unreadable Only part of a ReadProcessMemory or WriteProcessMemory request was completed + * not related to stratergy ... tested 0, 1 and 2 ... 2 produced E_INVALIDARG as expected + * not a nil ptr to params ... nil poitner produced E_POINTER as expected + * params seems to return 0x4000 + * !!!!! FullParams is not a com interface !!! + * so no lpVtbl *FullParamsVtbl , no queryinterface, addref etc + */ + + params := _newFullParams_cStruct() + //params := &[160]byte{} + + ret, _, _ := syscall.SyscallN( + context.lpVtbl.FullDefaultParams, + uintptr(unsafe.Pointer(context)), + uintptr(strategy), + uintptr(unsafe.Pointer(params)), + ) + + // nil ptr should be 0x80004003L + // unsafe.Pointer(0xc00011dc28) + // unsafe.Pointer(0x4000) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + + } + + if params == nil { + return nil, errors.New("FullDefaultParams did not return params") + } + ParamObj := NewFullParams(params) + + if ParamObj.TestDefaultsOK() { + return ParamObj, nil + } + + return nil, nil +} + +func (context *IContext) GetModel() (*_IModel, error) { + + var modelptr *_IModel + + // getModel( iModel** pp ); + ret, _, _ := syscall.SyscallN( + context.lpVtbl.GetModel, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(&modelptr)), + ) + + if windows.Handle(ret) != windows.S_OK { + fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error()) + return nil, syscall.Errno(ret) + } + + if modelptr == nil { + return nil, errors.New("loadModel did not return a Model") + } + + if modelptr.lpVtbl == nil { + return nil, errors.New("loadModel method table is nil") + } + + return modelptr, nil +} + +// ************************************************************************************************************************************************ +// Not really implemented / tested +// ************************************************************************************************************************************************ + +func (context *IContext) RunCapture(params *FullParams, callbacks *sCaptureCallbacks, reader *iAudioCapture) uintptr { + ret, _, _ := syscall.SyscallN( + context.lpVtbl.RunCapture, + //3, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(params)), + uintptr(unsafe.Pointer(callbacks)), + uintptr(unsafe.Pointer(reader)), + ) + return ret +} + +func (context *IContext) GetResults(flags eResultFlags, pp **ITranscribeResult) uintptr { + ret, _, _ := syscall.Syscall( + context.lpVtbl.GetResults, + 3, + uintptr(unsafe.Pointer(context)), + uintptr(flags), + uintptr(unsafe.Pointer(pp)), + ) + return ret +} + +func (context *IContext) DetectSpeaker(time *sTimeInterval, result *eSpeakerChannel) uintptr { + ret, _, _ := syscall.Syscall( + context.lpVtbl.DetectSpeaker, + 3, + uintptr(unsafe.Pointer(context)), + uintptr(unsafe.Pointer(time)), + uintptr(unsafe.Pointer(result)), + ) + return ret +} diff --git a/pkg/whisper/doc.go b/pkg/whisper/doc.go new file mode 100644 index 0000000..dd1f9f9 --- /dev/null +++ b/pkg/whisper/doc.go @@ -0,0 +1,5 @@ +/* +github.com/jaybinks/goConstmeWhispers +Go Bindings for https://github.com/Const-me/Whisper +*/ +package whisper diff --git a/pkg/whisper/language.go b/pkg/whisper/language.go new file mode 100644 index 0000000..4e6a042 --- /dev/null +++ b/pkg/whisper/language.go @@ -0,0 +1,207 @@ +package whisper + +// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/eLanguage.cs + +type eLanguage int32 + +const ( + Auto eLanguage = -1 // "af" + + Afrikaans = 0x6661 // "af" + /// Albanian + Albanian = 0x7173 // "sq" + /// Amharic + Amharic = 0x6D61 // "am" + /// Arabic + Arabic = 0x7261 // "ar" + /// Armenian + Armenian = 0x7968 // "hy" + /// Assamese + Assamese = 0x7361 // "as" + /// Azerbaijani + Azerbaijani = 0x7A61 // "az" + /// Bashkir + Bashkir = 0x6162 // "ba" + /// Basque + Basque = 0x7565 // "eu" + /// Belarusian + Belarusian = 0x6562 // "be" + /// Bengali + Bengali = 0x6E62 // "bn" + /// Bosnian + Bosnian = 0x7362 // "bs" + /// Breton + Breton = 0x7262 // "br" + /// Bulgarian + Bulgarian = 0x6762 // "bg" + /// Catalan + Catalan = 0x6163 // "ca" + /// Chinese + Chinese = 0x687A // "zh" + /// Croatian + Croatian = 0x7268 // "hr" + /// Czech + Czech = 0x7363 // "cs" + /// Danish + Danish = 0x6164 // "da" + /// Dutch + Dutch = 0x6C6E // "nl" + /// English + English = 0x6E65 // "en" + /// Estonian + Estonian = 0x7465 // "et" + /// Faroese + Faroese = 0x6F66 // "fo" + /// Finnish + Finnish = 0x6966 // "fi" + /// French + French = 0x7266 // "fr" + /// Galician + Galician = 0x6C67 // "gl" + /// Georgian + Georgian = 0x616B // "ka" + /// German + German = 0x6564 // "de" + /// Greek + Greek = 0x6C65 // "el" + /// Gujarati + Gujarati = 0x7567 // "gu" + /// Haitian Creole + HaitianCreole = 0x7468 // "ht" + /// Hausa + Hausa = 0x6168 // "ha" + /// Hawaiian + Hawaiian = 0x776168 // "haw" + /// Hebrew + Hebrew = 0x7769 // "iw" + /// Hindi + Hindi = 0x6968 // "hi" + /// Hungarian + Hungarian = 0x7568 // "hu" + /// Icelandic + Icelandic = 0x7369 // "is" + /// Indonesian + Indonesian = 0x6469 // "id" + /// Italian + Italian = 0x7469 // "it" + /// Japanese + Japanese = 0x616A // "ja" + /// Javanese + Javanese = 0x776A // "jw" + /// Kannada + Kannada = 0x6E6B // "kn" + /// Kazakh + Kazakh = 0x6B6B // "kk" + /// Khmer + Khmer = 0x6D6B // "km" + /// Korean + Korean = 0x6F6B // "ko" + /// Lao + Lao = 0x6F6C // "lo" + /// Latin + Latin = 0x616C // "la" + /// Latvian + Latvian = 0x766C // "lv" + /// Lingala + Lingala = 0x6E6C // "ln" + /// Lithuanian + Lithuanian = 0x746C // "lt" + /// Luxembourgish + Luxembourgish = 0x626C // "lb" + /// Macedonian + Macedonian = 0x6B6D // "mk" + /// Malagasy + Malagasy = 0x676D // "mg" + /// Malay + Malay = 0x736D // "ms" + /// Malayalam + Malayalam = 0x6C6D // "ml" + /// Maltese + Maltese = 0x746D // "mt" + /// Maori + Maori = 0x696D // "mi" + /// Marathi + Marathi = 0x726D // "mr" + /// Mongolian + Mongolian = 0x6E6D // "mn" + /// Myanmar + Myanmar = 0x796D // "my" + /// Nepali + Nepali = 0x656E // "ne" + /// Norwegian + Norwegian = 0x6F6E // "no" + /// Nynorsk + Nynorsk = 0x6E6E // "nn" + /// Occitan + Occitan = 0x636F // "oc" + /// Pashto + Pashto = 0x7370 // "ps" + /// Persian + Persian = 0x6166 // "fa" + /// Polish + Polish = 0x6C70 // "pl" + /// Portuguese + Portuguese = 0x7470 // "pt" + /// Punjabi + Punjabi = 0x6170 // "pa" + /// Romanian + Romanian = 0x6F72 // "ro" + /// Russian + Russian = 0x7572 // "ru" + /// Sanskrit + Sanskrit = 0x6173 // "sa" + /// Serbian + Serbian = 0x7273 // "sr" + /// Shona + Shona = 0x6E73 // "sn" + /// Sindhi + Sindhi = 0x6473 // "sd" + /// Sinhala + Sinhala = 0x6973 // "si" + /// Slovak + Slovak = 0x6B73 // "sk" + /// Slovenian + Slovenian = 0x6C73 // "sl" + /// Somali + Somali = 0x6F73 // "so" + /// Spanish + Spanish = 0x7365 // "es" + /// Sundanese + Sundanese = 0x7573 // "su" + /// Swahili + Swahili = 0x7773 // "sw" + /// Swedish + Swedish = 0x7673 // "sv" + /// Tagalog + Tagalog = 0x6C74 // "tl" + /// Tajik + Tajik = 0x6774 // "tg" + /// Tamil + Tamil = 0x6174 // "ta" + /// Tatar + Tatar = 0x7474 // "tt" + /// Telugu + Telugu = 0x6574 // "te" + /// Thai + Thai = 0x6874 // "th" + /// Tibetan + Tibetan = 0x6F62 // "bo" + /// Turkish + Turkish = 0x7274 // "tr" + /// Turkmen + Turkmen = 0x6B74 // "tk" + /// Ukrainian + Ukrainian = 0x6B75 // "uk" + /// Urdu + Urdu = 0x7275 // "ur" + /// Uzbek + Uzbek = 0x7A75 // "uz" + /// Vietnamese + Vietnamese = 0x6976 // "vi" + /// Welsh + Welsh = 0x7963 // "cy" + /// Yiddish + Yiddish = 0x6979 // "yi" + /// Yoruba + Yoruba = 0x6F79 // "yo" +) diff --git a/pkg/whisper/logger.go b/pkg/whisper/logger.go new file mode 100644 index 0000000..5c9aacb --- /dev/null +++ b/pkg/whisper/logger.go @@ -0,0 +1,43 @@ +package whisper + +import ( + "C" + "fmt" +) + +/* + https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/loggerApi.h + +*/ + +type eLogLevel uint8 + +const ( + LlError eLogLevel = 0 + LlWarning = 1 + LlInfo = 2 + LlDebug = 3 +) + +type eLogFlags uint8 + +const ( + LfNone eLogFlags = 0 + LfUseStandardError = 1 + LfSkipFormatMessage = 2 +) + +type sLoggerSetup struct { + sink uintptr // pfnLoggerSink + context uintptr // void* + level eLogLevel // eLogLevel + flags eLogFlags // eLoggerFlags +} + +func fnLoggerSink(context uintptr, lvl eLogLevel, message *C.char) uintptr { + + strmessage := C.GoString(message) + fmt.Printf("%d - %s\n", lvl, strmessage) + + return 0 +} diff --git a/pkg/whisper/whisper.go b/pkg/whisper/whisper.go new file mode 100644 index 0000000..04b327f --- /dev/null +++ b/pkg/whisper/whisper.go @@ -0,0 +1,188 @@ +//go:build windows +// +build windows + +package whisper + +import ( + "C" + "errors" + "syscall" + "unsafe" + + // Using lxn/win because its COM functions expose raw HRESULTs + "golang.org/x/sys/windows" +) +import ( + "fmt" +) + +/* + eModelImplementation - TranscribeStructs.h + + // GPGPU implementation based on Direct3D 11.0 compute shaders + GPU = 1, + + // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU + // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1 + Hybrid = 2, + + // A reference implementation which uses the original GGML CPU-running code + // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1 + Reference = 3, +*/ + +// https://learn.microsoft.com/en-us/windows/win32/seccrypto/common-hresult-values +// https://pkg.go.dev/golang.org/x/sys/windows +const ( + E_INVALIDARG = 0x80070057 + ERROR_HV_CPUID_FEATURE_VALIDATION = 0xC0350038 + + DLLName = "whisper.dll" +) + +type Libwhisper struct { + dll *syscall.LazyDLL + ver WinVersion + existing_model map[string]*Model + + proc_setupLogger *syscall.LazyProc + proc_loadModel *syscall.LazyProc + proc_initMediaFoundation *syscall.LazyProc + // proc_findLanguageKeyW *syscall.LazyProc + // proc_findLanguageKeyA *syscall.LazyProc + // proc_getSupportedLanguages *syscall.LazyProc +} + +var singleton_whisper *Libwhisper = nil + +func New(level eLogLevel, flags eLogFlags, cb *any) (*Libwhisper, error) { + if singleton_whisper != nil { + return singleton_whisper, nil + } + + var err error + this := &Libwhisper{} + + this.ver, err = GetFileVersion(DLLName) + if err != nil { + return nil, err + } + + if this.ver.Major < 1 && this.ver.Minor < 9 { + return nil, errors.New("This library requires whisper.dll version 1.9 or higher.") // or less than 1.11 for now .. because the API changed + } + + this.dll = syscall.NewLazyDLL(DLLName) // Todo wrap this in a class, check file exists, handle errors ... you know, just a few things.. AKA Stop being lazy + + this.proc_setupLogger = this.dll.NewProc("setupLogger") + this.proc_loadModel = this.dll.NewProc("loadModel") + this.proc_initMediaFoundation = this.dll.NewProc("initMediaFoundation") + /* + this.proc_findLanguageKeyW = this.dll.NewProc("findLanguageKeyW") + this.proc_findLanguageKeyA = this.dll.NewProc("findLanguageKeyA") + this.proc_getSupportedLanguages = this.dll.NewProc("getSupportedLanguages") + */ + + ok, err := this._setupLogger(level, flags, cb) + if !ok { + return nil, errors.New("Logger Error : " + err.Error()) + } + + this.existing_model = make(map[string]*Model) + singleton_whisper = this + + return singleton_whisper, nil +} + +func (this *Libwhisper) Version() string { + return fmt.Sprintf("%d.%d.%d.%d.", this.ver.Major, this.ver.Minor, this.ver.Patch, this.ver.Build) +} + +func (this *Libwhisper) SupportsMultiThread() bool { + return this.ver.Major >= 1 && this.ver.Minor >= 10 +} + +func (this *Libwhisper) _setupLogger(level eLogLevel, flags eLogFlags, cb *any) (bool, error) { + + setup := sLoggerSetup{} + setup.sink = 0 + setup.context = 0 + setup.level = level + setup.flags = flags + + if cb != nil { + setup.sink = syscall.NewCallback(cb) + } + + res, _, err := this.proc_setupLogger.Call(uintptr(unsafe.Pointer(&setup))) + + if windows.Handle(res) == windows.S_OK { + return true, nil + } else { + return false, err + } +} + +func (this *Libwhisper) LoadModel(path string, aGPU ...string) (*Model, error) { + var modelptr *_IModel + + whisperpath, _ := windows.UTF16PtrFromString(path) + + GPU := "" + if len(aGPU) == 1 { + GPU = aGPU[0] + } + + setup := ModelSetup(gmf_Cloneable, GPU) + + // Construct our map hash + singleton_hash := GPU + "|" + path + if this.existing_model[singleton_hash] != nil { + ClonedModel, err := this.existing_model[singleton_hash].Clone() + if ClonedModel != nil { + return NewModel(setup, ClonedModel), nil + } else { + return nil, err + } + } + + obj, _, _ := this.proc_loadModel.Call(uintptr(unsafe.Pointer(whisperpath)), uintptr(unsafe.Pointer(setup.AsCType())), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&modelptr))) + + if windows.Handle(obj) != windows.S_OK { + fmt.Printf("loadModel failed: %s\n", syscall.Errno(obj).Error()) + return nil, fmt.Errorf("loadModel failed: %s", syscall.Errno(obj)) + } + + if modelptr == nil { + return nil, errors.New("loadModel did not return a Model") + } + + if modelptr.lpVtbl == nil { + return nil, errors.New("loadModel method table is nil") + } + + model := NewModel(setup, modelptr) + + this.existing_model[singleton_hash] = model + + return model, nil +} + +func (this *Libwhisper) InitMediaFoundation() (*IMediaFoundation, error) { + + var mediafoundation *IMediaFoundation + + // initMediaFoundation( iMediaFoundation** pp ); + obj, _, _ := this.proc_initMediaFoundation.Call(uintptr(unsafe.Pointer(&mediafoundation))) + + if windows.Handle(obj) != windows.S_OK { + fmt.Printf("initMediaFoundation failed: %s\n", syscall.Errno(obj).Error()) + return nil, fmt.Errorf("initMediaFoundation failed: %s", syscall.Errno(obj)) + } + + if mediafoundation.lpVtbl == nil { + return nil, errors.New("initMediaFoundation method table is nil") + } + + return mediafoundation, nil +} diff --git a/pkg/whisper/winversion.go b/pkg/whisper/winversion.go new file mode 100644 index 0000000..c51bb1d --- /dev/null +++ b/pkg/whisper/winversion.go @@ -0,0 +1,123 @@ +// Copyright 2018 Keybase Inc. All rights reserved. +// Use of this source code is governed by a BSD +// license that can be found in the LICENSE file. +// Adapted mainly from github.com/gonutz/w32 + +//go:build windows +// +build windows + +package whisper + +import ( + "errors" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + version = windows.NewLazySystemDLL("version.dll") + getFileVersionInfoSize = version.NewProc("GetFileVersionInfoSizeW") + getFileVersionInfo = version.NewProc("GetFileVersionInfoW") + verQueryValue = version.NewProc("VerQueryValueW") +) + +type VS_FIXEDFILEINFO struct { + Signature uint32 + StrucVersion uint32 + FileVersionMS uint32 + FileVersionLS uint32 + ProductVersionMS uint32 + ProductVersionLS uint32 + FileFlagsMask uint32 + FileFlags uint32 + FileOS uint32 + FileType uint32 + FileSubtype uint32 + FileDateMS uint32 + FileDateLS uint32 +} + +type WinVersion struct { + Major uint32 + Minor uint32 + Patch uint32 + Build uint32 +} + +// FileVersion concatenates FileVersionMS and FileVersionLS to a uint64 value. +func (fi VS_FIXEDFILEINFO) FileVersion() uint64 { + return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS) +} + +func GetFileVersionInfoSize(path string) uint32 { + ret, _, _ := getFileVersionInfoSize.Call( + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + 0, + ) + return uint32(ret) +} + +func GetFileVersionInfo(path string, data []byte) bool { + ret, _, _ := getFileVersionInfo.Call( + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + 0, + uintptr(len(data)), + uintptr(unsafe.Pointer(&data[0])), + ) + return ret != 0 +} + +// VerQueryValueRoot calls VerQueryValue +// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx) +// with `\` (root) to retieve the VS_FIXEDFILEINFO. +func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) { + var offset uintptr + var length uint + blockStart := unsafe.Pointer(&block[0]) + ret, _, _ := verQueryValue.Call( + uintptr(blockStart), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))), + uintptr(unsafe.Pointer(&offset)), + uintptr(unsafe.Pointer(&length)), + ) + if ret == 0 { + return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: verQueryValue failed") + } + start := int(offset) - int(uintptr(blockStart)) + end := start + int(length) + if start < 0 || start >= len(block) || end < start || end > len(block) { + return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: find failed") + } + data := block[start:end] + info := *((*VS_FIXEDFILEINFO)(unsafe.Pointer(&data[0]))) + return info, nil +} + +func GetFileVersion(path string) (WinVersion, error) { + var result WinVersion + size := GetFileVersionInfoSize(path) + if size <= 0 { + return result, errors.New("GetFileVersionInfoSize failed") + } + + info := make([]byte, size) + ok := GetFileVersionInfo(path, info) + if !ok { + return result, errors.New("GetFileVersionInfo failed") + } + + fixed, err := VerQueryValueRoot(info) + if err != nil { + return result, err + } + version := fixed.FileVersion() + + result.Major = uint32(version & 0xFFFF000000000000 >> 48) + result.Minor = uint32(version & 0x0000FFFF00000000 >> 32) + result.Patch = uint32(version & 0x00000000FFFF0000 >> 16) + result.Build = uint32(version & 0x000000000000FFFF) + + return result, nil +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..55e1d1c --- /dev/null +++ b/state.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/xzeldon/whisper-api-server/pkg/whisper" +) + +type WhisperState struct { + model *whisper.Model + context *whisper.IContext + media *whisper.IMediaFoundation + params *whisper.FullParams + mutex sync.Mutex +} + +func InitializeWhisperState(modelPath string) (*WhisperState, error) { + lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil) + if err != nil { + return nil, err + } + + model, err := lib.LoadModel(modelPath) + if err != nil { + return nil, err + } + + context, err := model.CreateContext() + if err != nil { + return nil, err + } + + media, err := lib.InitMediaFoundation() + if err != nil { + return nil, err + } + + params, err := context.FullDefaultParams(whisper.SsBeamSearch) + if err != nil { + return nil, err + } + + params.AddFlags(whisper.FlagNoContext) + params.AddFlags(whisper.FlagTokenTimestamps) + + fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads()) + + return &WhisperState{ + model: model, + context: context, + media: media, + params: params, + }, nil +} + +func getResult(ctx *whisper.IContext) (string, error) { + results := &whisper.ITranscribeResult{} + ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results) + + length, err := results.GetSize() + if err != nil { + return "", err + } + + segments := results.GetSegments(length.CountSegments) + + var result string + + for _, seg := range segments { + result += seg.Text() + } + + return result, nil +} diff --git a/tmp/.gitkeep b/tmp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..e7f868f --- /dev/null +++ b/utils.go @@ -0,0 +1,48 @@ +package main + +import ( + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/labstack/echo/v4" +) + +func saveFormFile(name string, c echo.Context) (string, error) { + file, err := c.FormFile(name) + if err != nil { + return "", err + } + + src, err := file.Open() + if err != nil { + return "", err + } + defer src.Close() + + ext := filepath.Ext(file.Filename) + filename := time.Now().Format(time.RFC3339) + filename = "./tmp/" + sanitizeFilename(filename) + ext + + dst, err := os.Create(filename) + if err != nil { + return "", err + } + defer dst.Close() + + if _, err = io.Copy(dst, src); err != nil { + return "", err + } + + return filename, nil +} + +func sanitizeFilename(filename string) string { + invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`} + for _, char := range invalidChars { + filename = strings.ReplaceAll(filename, char, "-") + } + return filename +}