initial commit

This commit is contained in:
Timofey Gelazoniya 2023-10-04 01:09:38 +03:00
commit 8e4cb72b50
Signed by: zeldon
GPG Key ID: 047886915281DD2A
21 changed files with 2045 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
Whisper.dll
ggml-medium.bin
whisper-api-server.exe
tmp/*
!tmp/.gitkeep

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Timofey Gelazoniya
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

43
README.md Normal file
View File

@ -0,0 +1,43 @@
# Whisper API Server (Go)
## ⚠️ This project is a work in progress (WIP).
This API server enables audio transcription using the OpenAI Whisper models.
# Setup
- Download the desired model from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main)
- Update the model path in the `main.go` file
- Download `Whisper.dll` from [github](https://github.com/Const-me/Whisper/releases/tag/1.12.0) (`Library.zip`) and place it in the project's root directory
- Build project: `go build .` (you only need go compiler, without gcc)
# Usage example
Make a request to the server using the following command:
```sh
curl http://localhost:3000/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@/path/to/file/audio.mp3" \
```
Receive a response in JSON format:
```json
{
"text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that."
}
```
# Roadmap
- [ ] Implement automatic model downloading from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main)
- [ ] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases)
- [ ] Provide prebuilt binaries for Windows
- [ ] Include instructions for running on Linux with Wine (likely possible).
# Credits
- [Const-me/Whisper](https://github.com/Const-me/Whisper) project
- [goConstmeWhisper](https://github.com/jaybinks/goConstmeWhisper) for the remarkable Go bindings for [Const-me/Whisper](https://github.com/Const-me/Whisper)
- [Georgi Gerganov](https://github.com/ggerganov) for GGML models

19
go.mod Normal file
View File

@ -0,0 +1,19 @@
module github.com/xzeldon/whisper-api-server
go 1.21.1
require (
github.com/labstack/echo/v4 v4.11.1
golang.org/x/sys v0.12.0
)
require (
github.com/labstack/gommon v0.4.0
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/text v0.11.0 // indirect
)

43
go.sum Normal file
View File

@ -0,0 +1,43 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4=
github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ=
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

45
handler.go Normal file
View File

@ -0,0 +1,45 @@
package main
import (
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
type TranscribeResponse struct {
Text string `json:"text"`
}
func transcribe(c echo.Context, whisperState *WhisperState) error {
audioPath, err := saveFormFile("file", c)
if err != nil {
c.Logger().Errorf("Error reading file: %s", err)
return err
}
whisperState.mutex.Lock()
buffer, err := whisperState.media.LoadAudioFile(audioPath, true)
if err != nil {
c.Logger().Errorf("Error loading audio file data: %s", err)
}
err = whisperState.context.RunFull(whisperState.params, buffer)
result, err := getResult(whisperState.context)
if err != nil {
c.Logger().Error(err)
}
defer whisperState.mutex.Unlock()
if len(result) == 0 {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"})
}
response := TranscribeResponse{
Text: strings.TrimLeft(result, " "),
}
return c.JSON(http.StatusOK, response)
}

29
main.go Normal file
View File

@ -0,0 +1,29 @@
package main
import (
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
)
const MODEL_PATH = "./ggml-medium.bin"
func main() {
e := echo.New()
e.HideBanner = true
if l, ok := e.Logger.(*log.Logger); ok {
l.SetHeader("${time_rfc3339} ${level}")
}
whisperState, err := InitializeWhisperState(MODEL_PATH)
if err != nil {
e.Logger.Error(err)
return
}
e.POST("/v1/audio/transcriptions", func(c echo.Context) error {
return transcribe(c, whisperState)
})
e.Logger.Fatal(e.Start(":3000"))
}

248
pkg/whisper/FullParams.go Normal file
View File

@ -0,0 +1,248 @@
package whisper
import (
"syscall"
"unsafe"
)
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/sFullParams.h
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/Parameters.cs
type eSamplingStrategy uint32
const (
SsGreedy eSamplingStrategy = iota
SsBeamSearch
SsINVALIDARG
)
type eFullParamsFlags uint32
const (
FlagNone eFullParamsFlags = 0
FlagTranslate = 1 << 0
FlagNoContext = 1 << 1
FlagSingleSegment = 1 << 2
FlagPrintSpecial = 1 << 3
FlagPrintProgress = 1 << 4
FlagPrintRealtime = 1 << 5
FlagPrintTimestamps = 1 << 6
FlagTokenTimestamps = 1 << 7 // Experimental
FlagSpeedupAudio = 1 << 8
)
type EWhisperHWND uintptr
const (
S_OK EWhisperHWND = 0
S_FALSE EWhisperHWND = 1
)
type FullParams struct {
cStruct *_FullParams
}
func (this *FullParams) CpuThreads() int32 {
if this == nil {
return 0
} else if this.cStruct == nil {
return 0
}
return this.cStruct.cpuThreads
}
func (this *FullParams) setCpuThreads(val int32) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.cpuThreads = val
}
func (this *FullParams) SetMaxTextCTX(val int32) {
this.cStruct.n_max_text_ctx = val
}
func (this *FullParams) AddFlags(newflag eFullParamsFlags) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.Flags = this.cStruct.Flags | newflag
}
func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.Flags = this.cStruct.Flags ^ newflag
}
/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/
type NewSegmentCallback_Type func(context *IContext, n_new uint32, user_data unsafe.Pointer) EWhisperHWND
func (this *FullParams) SetNewSegmentCallback(cb NewSegmentCallback_Type) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.new_segment_callback = syscall.NewCallback(cb)
}
/*
Return S_OK to proceed, or S_FALSE to stop the process
*/
type EncoderBeginCallback_Type func(context *IContext, user_data unsafe.Pointer) EWhisperHWND
func (this *FullParams) SetEncoderBeginCallback(cb EncoderBeginCallback_Type) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.encoder_begin_callback = syscall.NewCallback(cb)
}
func (this *FullParams) TestDefaultsOK() bool {
if this == nil {
return false
} else if this.cStruct == nil {
return false
}
if this.cStruct.n_max_text_ctx != 16384 {
return false
}
if this.cStruct.Flags != (FlagPrintProgress | FlagPrintTimestamps) {
return false
}
if this.cStruct.thold_pt != 0.01 {
return false
}
if this.cStruct.thold_ptsum != 0.01 {
return false
}
if this.cStruct.Language != English {
return false
}
// Todo ... why do these not line up as expected.. is our struct out of alignment ?
/*
if this.cStruct.strategy == ssGreedy {
if this.cStruct.beam_search.n_past != -1 ||
this.cStruct.beam_search.beam_width != -1 ||
this.cStruct.beam_search.n_best != -1 {
return false
}
} else if this.cStruct.strategy == ssBeamSearch {
if this.cStruct.greedy.n_past != -1 ||
this.cStruct.beam_search.beam_width != 10 ||
this.cStruct.beam_search.n_best != 5 {
return false
}
}
*/
return true
}
type _FullParams struct {
strategy eSamplingStrategy
cpuThreads int32
n_max_text_ctx int32
offset_ms int32
duration_ms int32
Flags eFullParamsFlags
Language eLanguage
thold_pt float32
thold_ptsum float32
max_len int32
max_tokens int32
greedy struct{ n_past int32 }
beam_search struct {
n_past int32
beam_width int32
n_best int32
}
audio_ctx int32 // overwrite the audio context size (0 = use default)
prompt_tokens uintptr
prompt_n_tokens int32
new_segment_callback uintptr
new_segment_callback_user_data uintptr
encoder_begin_callback uintptr
encoder_begin_callback_user_data uintptr
// Are these needed ?? Jay
// setFlag uintptr
}
func NewFullParams(cstruct *_FullParams) *FullParams {
this := FullParams{}
this.cStruct = cstruct
return &this
}
func _newFullParams_cStruct() *_FullParams {
return &_FullParams{
strategy: 0,
cpuThreads: 0,
n_max_text_ctx: 0,
offset_ms: 0,
duration_ms: 0,
Flags: 0,
Language: 0,
thold_pt: 0,
thold_ptsum: 0,
max_len: 0,
max_tokens: 0,
// anonymous int32
greedy: struct{ n_past int32 }{n_past: 0},
// anonymous struct
beam_search: struct {
n_past int32
beam_width int32
n_best int32
}{
n_past: 0,
beam_width: 0,
n_best: 0,
},
audio_ctx: 0,
prompt_tokens: 0,
prompt_n_tokens: 0,
new_segment_callback: 0,
new_segment_callback_user_data: 0,
encoder_begin_callback: 0,
encoder_begin_callback_user_data: 0,
}
}

View File

@ -0,0 +1,222 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
// https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/IMediaFoundation.h
type IMediaFoundation struct {
lpVtbl *IMediaFoundationVtbl
}
type IMediaFoundationVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
loadAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
openAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioReader** pp );
loadAudioFileData uintptr // ( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); HRESULT
listCaptureDevices uintptr // ( pfnFoundCaptureDevices pfn, void* pv );
openCaptureDevice uintptr // ( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp );
}
func (this *IMediaFoundation) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *IMediaFoundation) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
// ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
func (this *IMediaFoundation) LoadAudioFile(file string, stereo bool) (*iAudioBuffer, error) {
var buffer *iAudioBuffer
UTFFileName, _ := windows.UTF16PtrFromString(file)
ret, _, _ := syscall.SyscallN(
this.lpVtbl.loadAudioFile,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(UTFFileName)),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&buffer)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("loadAudioFile failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return buffer, nil
}
func (this *IMediaFoundation) OpenAudioFile(file string, stereo bool) (*iAudioReader, error) {
var buffer *iAudioReader
UTFFileName, _ := windows.UTF16PtrFromString(file)
ret, _, _ := syscall.SyscallN(
this.lpVtbl.openAudioFile,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(UTFFileName)),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&buffer)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("openAudioFile failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return buffer, nil
}
func (this *IMediaFoundation) LoadAudioFileData(inbuffer *[]byte, stereo bool) (*iAudioReader, error) {
var reader *iAudioReader
// loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp );
ret, _, _ := syscall.SyscallN(
this.lpVtbl.loadAudioFileData,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&(*inbuffer)[0])),
uintptr(uint64(len(*inbuffer))),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&reader)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return reader, nil
}
// ************************************************************
type iAudioBuffer struct {
lpVtbl *iAudioBufferVtbl
}
type iAudioBufferVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
countSamples uintptr // returns uint32_t
getPcmMono uintptr // returns float*
getPcmStereo uintptr // returns float*
getTime uintptr // ( int64_t& rdi )
}
func (this *iAudioBuffer) AddRef() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.AddRef,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioBuffer) Release() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.Release,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioBuffer) CountSamples() (uint32, error) {
ret, _, err := syscall.SyscallN(
this.lpVtbl.countSamples,
uintptr(unsafe.Pointer(this)),
)
if err != 0 {
return 0, errors.New(err.Error())
}
return uint32(ret), nil
}
// ************************************************************
type iAudioReader struct {
lpVtbl *iAudioReaderVtbl
}
type iAudioReaderVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getDuration uintptr // ( int64_t& rdi )
getReader uintptr // ( IMFSourceReader** pp )
requestedStereo uintptr // ()
}
func (this *iAudioReader) AddRef() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.AddRef,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioReader) Release() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.Release,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioReader) GetDuration() (uint64, error) {
var rdi int64
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getDuration,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&rdi)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
return 0, syscall.Errno(ret)
}
return uint64(rdi), nil
}
// ************************************************************
type iAudioCapture struct {
lpVtbl *iAudioCaptureVtbl
}
type iAudioCaptureVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getReader uintptr // ( IMFSourceReader** pp )
getParams uintptr // returns sCaptureParams&
}

119
pkg/whisper/Model.go Normal file
View File

@ -0,0 +1,119 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
// External - Go version of the struct
type Model struct {
cStruct *_IModel
setup *sModelSetup
}
// Internal - C Version of the structs
type _IModel struct {
lpVtbl *IModelVtbl
}
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/iContext.cl.h
type IModelVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
createContext uintptr //( iContext** pp ) = 0;
tokenize uintptr /* HRESULT __stdcall tokenize( const char* text, pfnDecodedTokens pfn, void* pv ); */
isMultilingual uintptr //() = 0;
getSpecialTokens uintptr //( SpecialTokens& rdi ) = 0;
stringFromToken uintptr //( whisper_token token ) = 0;
clone uintptr //( iModel** rdi ) = 0;
}
func NewModel(setup *sModelSetup, cstruct *_IModel) *Model {
this := Model{}
this.setup = setup
this.cStruct = cstruct
return &this
}
func (this *Model) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.cStruct.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this.cStruct)),
0,
0)
return int32(ret)
}
func (this *Model) Release() int32 {
ret, _, _ := syscall.Syscall(
this.cStruct.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this.cStruct)),
0,
0)
return int32(ret)
}
func (this *Model) CreateContext() (*IContext, error) {
var context *IContext
/*
ret, _, err := syscall.Syscall(
this.cStruct.lpVtbl.createContext,
2, // Why was this 1, rather than 2 ?? 1 seemed to work fine
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&context)),
0)*/
ret, _, err := syscall.SyscallN(
this.cStruct.lpVtbl.createContext,
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&context)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("createContext failed: %w", err.Error())
}
if windows.Handle(ret) != windows.S_OK {
return nil, fmt.Errorf("loadModel failed: %w", err)
}
return context, nil
}
func (this *Model) IsMultilingual() bool {
ret, _, _ := syscall.SyscallN(
this.cStruct.lpVtbl.isMultilingual,
uintptr(unsafe.Pointer(this.cStruct)),
)
return bool(windows.Handle(ret) == windows.S_OK)
}
func (this *Model) Clone() (*_IModel, error) {
if this.setup.isFlagSet(gmf_Cloneable) {
return nil, errors.New("Model is not cloneable")
}
//this.Cloneable ?
var modelptr *_IModel
ret, _, _ := syscall.SyscallN(
this.cStruct.lpVtbl.clone,
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&modelptr)),
)
if windows.Handle(ret) == windows.S_OK {
return modelptr, nil
} else {
return nil, errors.New("Model.Clone() failed : " + syscall.Errno(ret).Error())
}
}

101
pkg/whisper/ModelSetup.go Normal file
View File

@ -0,0 +1,101 @@
package whisper
import (
"unsafe"
"golang.org/x/sys/windows"
)
// Re-implemented sModelSetup.h
// enum struct eModelImplementation : uint32_t
type eModelImplementation uint32
const (
// GPGPU implementation based on Direct3D 11.0 compute shaders
mi_GPU eModelImplementation = 1
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
mi_Hybrid eModelImplementation = 2
// A reference implementation which uses the original GGML CPU-running code
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
mi_Reference eModelImplementation = 3
)
// enum struct eGpuModelFlags : uint32_t
type eGpuModelFlags uint32
const (
// <summary>Equivalent to <c>Wave32 | NoReshapedMatMul</c> on Intel and nVidia GPUs,<br/>
// and <c>Wave64 | UseReshapedMatMul</c> on AMD GPUs</summary>
gmf_None eGpuModelFlags = 0
// <summary>Use Wave32 version of compute shaders even on AMD GPUs</summary>
// <remarks>Incompatible with <see cref="Wave64" /></remarks>
gmf_Wave32 eGpuModelFlags = 1
// <summary>Use Wave64 version of compute shaders even on nVidia and Intel GPUs</summary>
// <remarks>Incompatible with <see cref="Wave32" /></remarks>
gmf_Wave64 eGpuModelFlags = 2
// <summary>Do not use reshaped matrix multiplication shaders on AMD GPUs</summary>
// <remarks>Incompatible with <see cref="UseReshapedMatMul" /></remarks>
gmf_NoReshapedMatMul eGpuModelFlags = 4
// <summary>Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs</summary>
// <remarks>Incompatible with <see cref="NoReshapedMatMul" /></remarks>
gmf_UseReshapedMatMul eGpuModelFlags = 8
// <summary>Create GPU tensors in a way which allows sharing across D3D devices</summary>
gmf_Cloneable eGpuModelFlags = 0x10
)
// struct sModelSetup
type sModelSetup struct {
impl eModelImplementation
flags eGpuModelFlags
adapter string
}
type _sModelSetup struct {
impl eModelImplementation
flags eGpuModelFlags
adapter uintptr
}
func ModelSetup(flags eGpuModelFlags, GPU string) *sModelSetup {
this := sModelSetup{}
this.impl = mi_GPU
this.flags = flags
this.adapter = GPU
return &this
}
func (this *sModelSetup) isFlagSet(flag eGpuModelFlags) bool {
return (this.flags & flag) == 0
}
func (this *sModelSetup) AsCType() *_sModelSetup {
var err error
ctype := _sModelSetup{}
ctype.impl = this.impl
ctype.flags = this.flags
ctype.adapter = 0
// Conver Go String to wchar_t, AKA UTF-16
if this.adapter != "" {
var UTF16str *uint16
UTF16str, err = windows.UTF16PtrFromString(this.adapter)
ctype.adapter = uintptr(unsafe.Pointer(UTF16str))
}
if err != nil {
return nil
} else {
return &ctype
}
}

View File

@ -0,0 +1,178 @@
package whisper
import (
"C"
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type eTokenFlags uint32
const (
TfNone eTokenFlags = 0
TfSpecial = 1
)
type sTranscribeLength struct {
CountSegments uint32
CountTokens uint32
}
type sTimeSpan struct {
// The value is expressed in 100-nanoseconds ticks: compatible with System.Timespan, FILETIME, and many other things
Ticks uint64
/*
operator sTimeSpanFields() const
{
return sTimeSpanFields{ ticks };
}
void operator=( uint64_t tt )
{
ticks = tt;
}
void operator=( int64_t tt )
{
assert( tt >= 0 );
ticks = (uint64_t)tt;
} */
}
type sTimeInterval struct {
Begin sTimeSpan
End sTimeSpan
}
type sSegment struct {
// Segment text, null-terminated, and probably UTF-8 encoded
text *C.char
// Start and end times of the segment
Time sTimeInterval
// These two integers define the slice of the tokens in this segment, in the array returned by iTranscribeResult.getTokens method
FirstToken uint32
CountTokens uint32
}
func (this *sSegment) Text() string {
return C.GoString(this.text)
}
type sSegmentArray []sSegment
type SToken struct {
// Token text, null-terminated, and usually UTF-8 encoded.
// I think for Chinese language the models sometimes outputs invalid UTF8 strings here, Unicode code points can be split between adjacent tokens in the same segment
// More info: https://github.com/ggerganov/whisper.cpp/issues/399
text *C.char
// Start and end times of the token
Time sTimeInterval
// Probability of the token
Probability float32
// Probability of the timestamp token
ProbabilityTimestamp float32
// Sum of probabilities of all timestamp tokens
Ptsum float32
// Voice length of the token
Vlen float32
// Token id
Id int32
Flags eTokenFlags
}
func (this *SToken) Text() string {
return C.GoString(this.text)
}
type sTokenArray []SToken
type iTranscribeResultVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getSize uintptr // ( sTranscribeLength& rdi ) HRESULT
getSegments uintptr // () getTokens
getTokens uintptr // () getToken*
}
type ITranscribeResult struct {
lpVtbl *iTranscribeResultVtbl
}
func (this *ITranscribeResult) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *ITranscribeResult) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *ITranscribeResult) GetSize() (*sTranscribeLength, error) {
var result sTranscribeLength
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getSize,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&result)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("iTranscribeResult.GetSize failed: %s\n", syscall.Errno(ret).Error())
return nil, errors.New(syscall.Errno(ret).Error())
}
return &result, nil
}
func (this *ITranscribeResult) GetSegments(len uint32) []sSegment {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getSegments,
uintptr(unsafe.Pointer(this)),
)
data := unsafe.Slice((*sSegment)(unsafe.Pointer(ret)), len)
return data
}
func (this *ITranscribeResult) GetTokens(len uint32) []SToken {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getTokens,
uintptr(unsafe.Pointer(this)),
)
if unsafe.Pointer(ret) != nil {
return unsafe.Slice((*SToken)(unsafe.Pointer(ret)), len)
} else {
return []SToken{}
}
}

281
pkg/whisper/context.go Normal file
View File

@ -0,0 +1,281 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type uuid [16]byte
type eResultFlags uint32
const (
RfNone eResultFlags = 0
// Return individual tokens in addition to the segments
RfTokens = 1
// Return timestamps
RfTimestamps = 2
// Create a new COM object for the results.
// Without this flag, the context returns a pointer to the COM object stored in the context.
// The content of that object is replaced every time you call IContext.getResults method
RfNewObject = 0x100
)
type IContextVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
RunFull uintptr
RunStreamed uintptr
RunCapture uintptr
GetResults uintptr
DetectSpeaker uintptr
GetModel uintptr
FullDefaultParams uintptr
TimingsPrint uintptr
TimingsReset uintptr
}
type IContext struct {
lpVtbl *IContextVtbl
}
//type sFullParams struct{}
// type iAudioBuffer struct{}
type sProgressSink struct {
pfn uintptr
pv uintptr
}
// type iAudioReader struct{}
type sCaptureCallbacks struct{}
// type iAudioCapture struct{}
// type eResultFlags int32
// type iTranscribeResult struct{}
// type sTimeInterval struct{}
type eSpeakerChannel int32
//type eSamplingStrategy int32
// Create a new IContext instance
func newIContext() *IContext {
return &IContext{
lpVtbl: &IContextVtbl{
QueryInterface: 0,
AddRef: 0,
Release: 0,
RunFull: 0,
RunStreamed: 0,
RunCapture: 0,
GetResults: 0,
DetectSpeaker: 0,
GetModel: 0,
FullDefaultParams: 0,
TimingsPrint: 0,
TimingsReset: 0,
},
}
}
func (context *IContext) TimingsPrint() error {
// TimingsPrint();
ret, _, _ := syscall.SyscallN(
context.lpVtbl.TimingsPrint,
uintptr(unsafe.Pointer(context)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (context *IContext) RunFull(params *FullParams, buffer *iAudioBuffer) error {
// runFull( const sFullParams& params, const iAudioBuffer* buffer );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunFull,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params.cStruct)),
uintptr(unsafe.Pointer(buffer)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
func (context *IContext) RunStreamed(params *FullParams, reader *iAudioReader) error {
cb := sProgressSink{}
// runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunStreamed,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params.cStruct)),
uintptr(unsafe.Pointer(&cb)), // No progress cb yet
uintptr(unsafe.Pointer(reader)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunStreamed failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
func (this *IContext) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *IContext) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
/*
https://github.com/Const-me/Whisper/blob/f6f743c7b3570b85ccf47f74b84e06a73667ef3e/Whisper/Whisper/ContextImpl.misc.cpp
Returns E_POINTER if null pointer provided in params
Initialises params to all 0
sets values in struct, does not malloc
*/
func (context *IContext) FullDefaultParams(strategy eSamplingStrategy) (*FullParams, error) {
/*
ERR : unreadable Only part of a ReadProcessMemory or WriteProcessMemory request was completed
* not related to stratergy ... tested 0, 1 and 2 ... 2 produced E_INVALIDARG as expected
* not a nil ptr to params ... nil poitner produced E_POINTER as expected
* params seems to return 0x4000
* !!!!! FullParams is not a com interface !!!
* so no lpVtbl *FullParamsVtbl , no queryinterface, addref etc
*/
params := _newFullParams_cStruct()
//params := &[160]byte{}
ret, _, _ := syscall.SyscallN(
context.lpVtbl.FullDefaultParams,
uintptr(unsafe.Pointer(context)),
uintptr(strategy),
uintptr(unsafe.Pointer(params)),
)
// nil ptr should be 0x80004003L
// unsafe.Pointer(0xc00011dc28)
// unsafe.Pointer(0x4000)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
if params == nil {
return nil, errors.New("FullDefaultParams did not return params")
}
ParamObj := NewFullParams(params)
if ParamObj.TestDefaultsOK() {
return ParamObj, nil
}
return nil, nil
}
func (context *IContext) GetModel() (*_IModel, error) {
var modelptr *_IModel
// getModel( iModel** pp );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.GetModel,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(&modelptr)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
if modelptr == nil {
return nil, errors.New("loadModel did not return a Model")
}
if modelptr.lpVtbl == nil {
return nil, errors.New("loadModel method table is nil")
}
return modelptr, nil
}
// ************************************************************************************************************************************************
// Not really implemented / tested
// ************************************************************************************************************************************************
func (context *IContext) RunCapture(params *FullParams, callbacks *sCaptureCallbacks, reader *iAudioCapture) uintptr {
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunCapture,
//3,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params)),
uintptr(unsafe.Pointer(callbacks)),
uintptr(unsafe.Pointer(reader)),
)
return ret
}
func (context *IContext) GetResults(flags eResultFlags, pp **ITranscribeResult) uintptr {
ret, _, _ := syscall.Syscall(
context.lpVtbl.GetResults,
3,
uintptr(unsafe.Pointer(context)),
uintptr(flags),
uintptr(unsafe.Pointer(pp)),
)
return ret
}
func (context *IContext) DetectSpeaker(time *sTimeInterval, result *eSpeakerChannel) uintptr {
ret, _, _ := syscall.Syscall(
context.lpVtbl.DetectSpeaker,
3,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(time)),
uintptr(unsafe.Pointer(result)),
)
return ret
}

5
pkg/whisper/doc.go Normal file
View File

@ -0,0 +1,5 @@
/*
github.com/jaybinks/goConstmeWhispers
Go Bindings for https://github.com/Const-me/Whisper
*/
package whisper

207
pkg/whisper/language.go Normal file
View File

@ -0,0 +1,207 @@
package whisper
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/eLanguage.cs
type eLanguage int32
const (
Auto eLanguage = -1 // "af"
Afrikaans = 0x6661 // "af"
/// <summary>Albanian</summary>
Albanian = 0x7173 // "sq"
/// <summary>Amharic</summary>
Amharic = 0x6D61 // "am"
/// <summary>Arabic</summary>
Arabic = 0x7261 // "ar"
/// <summary>Armenian</summary>
Armenian = 0x7968 // "hy"
/// <summary>Assamese</summary>
Assamese = 0x7361 // "as"
/// <summary>Azerbaijani</summary>
Azerbaijani = 0x7A61 // "az"
/// <summary>Bashkir</summary>
Bashkir = 0x6162 // "ba"
/// <summary>Basque</summary>
Basque = 0x7565 // "eu"
/// <summary>Belarusian</summary>
Belarusian = 0x6562 // "be"
/// <summary>Bengali</summary>
Bengali = 0x6E62 // "bn"
/// <summary>Bosnian</summary>
Bosnian = 0x7362 // "bs"
/// <summary>Breton</summary>
Breton = 0x7262 // "br"
/// <summary>Bulgarian</summary>
Bulgarian = 0x6762 // "bg"
/// <summary>Catalan</summary>
Catalan = 0x6163 // "ca"
/// <summary>Chinese</summary>
Chinese = 0x687A // "zh"
/// <summary>Croatian</summary>
Croatian = 0x7268 // "hr"
/// <summary>Czech</summary>
Czech = 0x7363 // "cs"
/// <summary>Danish</summary>
Danish = 0x6164 // "da"
/// <summary>Dutch</summary>
Dutch = 0x6C6E // "nl"
/// <summary>English</summary>
English = 0x6E65 // "en"
/// <summary>Estonian</summary>
Estonian = 0x7465 // "et"
/// <summary>Faroese</summary>
Faroese = 0x6F66 // "fo"
/// <summary>Finnish</summary>
Finnish = 0x6966 // "fi"
/// <summary>French</summary>
French = 0x7266 // "fr"
/// <summary>Galician</summary>
Galician = 0x6C67 // "gl"
/// <summary>Georgian</summary>
Georgian = 0x616B // "ka"
/// <summary>German</summary>
German = 0x6564 // "de"
/// <summary>Greek</summary>
Greek = 0x6C65 // "el"
/// <summary>Gujarati</summary>
Gujarati = 0x7567 // "gu"
/// <summary>Haitian Creole</summary>
HaitianCreole = 0x7468 // "ht"
/// <summary>Hausa</summary>
Hausa = 0x6168 // "ha"
/// <summary>Hawaiian</summary>
Hawaiian = 0x776168 // "haw"
/// <summary>Hebrew</summary>
Hebrew = 0x7769 // "iw"
/// <summary>Hindi</summary>
Hindi = 0x6968 // "hi"
/// <summary>Hungarian</summary>
Hungarian = 0x7568 // "hu"
/// <summary>Icelandic</summary>
Icelandic = 0x7369 // "is"
/// <summary>Indonesian</summary>
Indonesian = 0x6469 // "id"
/// <summary>Italian</summary>
Italian = 0x7469 // "it"
/// <summary>Japanese</summary>
Japanese = 0x616A // "ja"
/// <summary>Javanese</summary>
Javanese = 0x776A // "jw"
/// <summary>Kannada</summary>
Kannada = 0x6E6B // "kn"
/// <summary>Kazakh</summary>
Kazakh = 0x6B6B // "kk"
/// <summary>Khmer</summary>
Khmer = 0x6D6B // "km"
/// <summary>Korean</summary>
Korean = 0x6F6B // "ko"
/// <summary>Lao</summary>
Lao = 0x6F6C // "lo"
/// <summary>Latin</summary>
Latin = 0x616C // "la"
/// <summary>Latvian</summary>
Latvian = 0x766C // "lv"
/// <summary>Lingala</summary>
Lingala = 0x6E6C // "ln"
/// <summary>Lithuanian</summary>
Lithuanian = 0x746C // "lt"
/// <summary>Luxembourgish</summary>
Luxembourgish = 0x626C // "lb"
/// <summary>Macedonian</summary>
Macedonian = 0x6B6D // "mk"
/// <summary>Malagasy</summary>
Malagasy = 0x676D // "mg"
/// <summary>Malay</summary>
Malay = 0x736D // "ms"
/// <summary>Malayalam</summary>
Malayalam = 0x6C6D // "ml"
/// <summary>Maltese</summary>
Maltese = 0x746D // "mt"
/// <summary>Maori</summary>
Maori = 0x696D // "mi"
/// <summary>Marathi</summary>
Marathi = 0x726D // "mr"
/// <summary>Mongolian</summary>
Mongolian = 0x6E6D // "mn"
/// <summary>Myanmar</summary>
Myanmar = 0x796D // "my"
/// <summary>Nepali</summary>
Nepali = 0x656E // "ne"
/// <summary>Norwegian</summary>
Norwegian = 0x6F6E // "no"
/// <summary>Nynorsk</summary>
Nynorsk = 0x6E6E // "nn"
/// <summary>Occitan</summary>
Occitan = 0x636F // "oc"
/// <summary>Pashto</summary>
Pashto = 0x7370 // "ps"
/// <summary>Persian</summary>
Persian = 0x6166 // "fa"
/// <summary>Polish</summary>
Polish = 0x6C70 // "pl"
/// <summary>Portuguese</summary>
Portuguese = 0x7470 // "pt"
/// <summary>Punjabi</summary>
Punjabi = 0x6170 // "pa"
/// <summary>Romanian</summary>
Romanian = 0x6F72 // "ro"
/// <summary>Russian</summary>
Russian = 0x7572 // "ru"
/// <summary>Sanskrit</summary>
Sanskrit = 0x6173 // "sa"
/// <summary>Serbian</summary>
Serbian = 0x7273 // "sr"
/// <summary>Shona</summary>
Shona = 0x6E73 // "sn"
/// <summary>Sindhi</summary>
Sindhi = 0x6473 // "sd"
/// <summary>Sinhala</summary>
Sinhala = 0x6973 // "si"
/// <summary>Slovak</summary>
Slovak = 0x6B73 // "sk"
/// <summary>Slovenian</summary>
Slovenian = 0x6C73 // "sl"
/// <summary>Somali</summary>
Somali = 0x6F73 // "so"
/// <summary>Spanish</summary>
Spanish = 0x7365 // "es"
/// <summary>Sundanese</summary>
Sundanese = 0x7573 // "su"
/// <summary>Swahili</summary>
Swahili = 0x7773 // "sw"
/// <summary>Swedish</summary>
Swedish = 0x7673 // "sv"
/// <summary>Tagalog</summary>
Tagalog = 0x6C74 // "tl"
/// <summary>Tajik</summary>
Tajik = 0x6774 // "tg"
/// <summary>Tamil</summary>
Tamil = 0x6174 // "ta"
/// <summary>Tatar</summary>
Tatar = 0x7474 // "tt"
/// <summary>Telugu</summary>
Telugu = 0x6574 // "te"
/// <summary>Thai</summary>
Thai = 0x6874 // "th"
/// <summary>Tibetan</summary>
Tibetan = 0x6F62 // "bo"
/// <summary>Turkish</summary>
Turkish = 0x7274 // "tr"
/// <summary>Turkmen</summary>
Turkmen = 0x6B74 // "tk"
/// <summary>Ukrainian</summary>
Ukrainian = 0x6B75 // "uk"
/// <summary>Urdu</summary>
Urdu = 0x7275 // "ur"
/// <summary>Uzbek</summary>
Uzbek = 0x7A75 // "uz"
/// <summary>Vietnamese</summary>
Vietnamese = 0x6976 // "vi"
/// <summary>Welsh</summary>
Welsh = 0x7963 // "cy"
/// <summary>Yiddish</summary>
Yiddish = 0x6979 // "yi"
/// <summary>Yoruba</summary>
Yoruba = 0x6F79 // "yo"
)

43
pkg/whisper/logger.go Normal file
View File

@ -0,0 +1,43 @@
package whisper
import (
"C"
"fmt"
)
/*
https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/loggerApi.h
*/
type eLogLevel uint8
const (
LlError eLogLevel = 0
LlWarning = 1
LlInfo = 2
LlDebug = 3
)
type eLogFlags uint8
const (
LfNone eLogFlags = 0
LfUseStandardError = 1
LfSkipFormatMessage = 2
)
type sLoggerSetup struct {
sink uintptr // pfnLoggerSink
context uintptr // void*
level eLogLevel // eLogLevel
flags eLogFlags // eLoggerFlags
}
func fnLoggerSink(context uintptr, lvl eLogLevel, message *C.char) uintptr {
strmessage := C.GoString(message)
fmt.Printf("%d - %s\n", lvl, strmessage)
return 0
}

188
pkg/whisper/whisper.go Normal file
View File

@ -0,0 +1,188 @@
//go:build windows
// +build windows
package whisper
import (
"C"
"errors"
"syscall"
"unsafe"
// Using lxn/win because its COM functions expose raw HRESULTs
"golang.org/x/sys/windows"
)
import (
"fmt"
)
/*
eModelImplementation - TranscribeStructs.h
// GPGPU implementation based on Direct3D 11.0 compute shaders
GPU = 1,
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
Hybrid = 2,
// A reference implementation which uses the original GGML CPU-running code
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
Reference = 3,
*/
// https://learn.microsoft.com/en-us/windows/win32/seccrypto/common-hresult-values
// https://pkg.go.dev/golang.org/x/sys/windows
const (
E_INVALIDARG = 0x80070057
ERROR_HV_CPUID_FEATURE_VALIDATION = 0xC0350038
DLLName = "whisper.dll"
)
type Libwhisper struct {
dll *syscall.LazyDLL
ver WinVersion
existing_model map[string]*Model
proc_setupLogger *syscall.LazyProc
proc_loadModel *syscall.LazyProc
proc_initMediaFoundation *syscall.LazyProc
// proc_findLanguageKeyW *syscall.LazyProc
// proc_findLanguageKeyA *syscall.LazyProc
// proc_getSupportedLanguages *syscall.LazyProc
}
var singleton_whisper *Libwhisper = nil
func New(level eLogLevel, flags eLogFlags, cb *any) (*Libwhisper, error) {
if singleton_whisper != nil {
return singleton_whisper, nil
}
var err error
this := &Libwhisper{}
this.ver, err = GetFileVersion(DLLName)
if err != nil {
return nil, err
}
if this.ver.Major < 1 && this.ver.Minor < 9 {
return nil, errors.New("This library requires whisper.dll version 1.9 or higher.") // or less than 1.11 for now .. because the API changed
}
this.dll = syscall.NewLazyDLL(DLLName) // Todo wrap this in a class, check file exists, handle errors ... you know, just a few things.. AKA Stop being lazy
this.proc_setupLogger = this.dll.NewProc("setupLogger")
this.proc_loadModel = this.dll.NewProc("loadModel")
this.proc_initMediaFoundation = this.dll.NewProc("initMediaFoundation")
/*
this.proc_findLanguageKeyW = this.dll.NewProc("findLanguageKeyW")
this.proc_findLanguageKeyA = this.dll.NewProc("findLanguageKeyA")
this.proc_getSupportedLanguages = this.dll.NewProc("getSupportedLanguages")
*/
ok, err := this._setupLogger(level, flags, cb)
if !ok {
return nil, errors.New("Logger Error : " + err.Error())
}
this.existing_model = make(map[string]*Model)
singleton_whisper = this
return singleton_whisper, nil
}
func (this *Libwhisper) Version() string {
return fmt.Sprintf("%d.%d.%d.%d.", this.ver.Major, this.ver.Minor, this.ver.Patch, this.ver.Build)
}
func (this *Libwhisper) SupportsMultiThread() bool {
return this.ver.Major >= 1 && this.ver.Minor >= 10
}
func (this *Libwhisper) _setupLogger(level eLogLevel, flags eLogFlags, cb *any) (bool, error) {
setup := sLoggerSetup{}
setup.sink = 0
setup.context = 0
setup.level = level
setup.flags = flags
if cb != nil {
setup.sink = syscall.NewCallback(cb)
}
res, _, err := this.proc_setupLogger.Call(uintptr(unsafe.Pointer(&setup)))
if windows.Handle(res) == windows.S_OK {
return true, nil
} else {
return false, err
}
}
func (this *Libwhisper) LoadModel(path string, aGPU ...string) (*Model, error) {
var modelptr *_IModel
whisperpath, _ := windows.UTF16PtrFromString(path)
GPU := ""
if len(aGPU) == 1 {
GPU = aGPU[0]
}
setup := ModelSetup(gmf_Cloneable, GPU)
// Construct our map hash
singleton_hash := GPU + "|" + path
if this.existing_model[singleton_hash] != nil {
ClonedModel, err := this.existing_model[singleton_hash].Clone()
if ClonedModel != nil {
return NewModel(setup, ClonedModel), nil
} else {
return nil, err
}
}
obj, _, _ := this.proc_loadModel.Call(uintptr(unsafe.Pointer(whisperpath)), uintptr(unsafe.Pointer(setup.AsCType())), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&modelptr)))
if windows.Handle(obj) != windows.S_OK {
fmt.Printf("loadModel failed: %s\n", syscall.Errno(obj).Error())
return nil, fmt.Errorf("loadModel failed: %s", syscall.Errno(obj))
}
if modelptr == nil {
return nil, errors.New("loadModel did not return a Model")
}
if modelptr.lpVtbl == nil {
return nil, errors.New("loadModel method table is nil")
}
model := NewModel(setup, modelptr)
this.existing_model[singleton_hash] = model
return model, nil
}
func (this *Libwhisper) InitMediaFoundation() (*IMediaFoundation, error) {
var mediafoundation *IMediaFoundation
// initMediaFoundation( iMediaFoundation** pp );
obj, _, _ := this.proc_initMediaFoundation.Call(uintptr(unsafe.Pointer(&mediafoundation)))
if windows.Handle(obj) != windows.S_OK {
fmt.Printf("initMediaFoundation failed: %s\n", syscall.Errno(obj).Error())
return nil, fmt.Errorf("initMediaFoundation failed: %s", syscall.Errno(obj))
}
if mediafoundation.lpVtbl == nil {
return nil, errors.New("initMediaFoundation method table is nil")
}
return mediafoundation, nil
}

123
pkg/whisper/winversion.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright 2018 Keybase Inc. All rights reserved.
// Use of this source code is governed by a BSD
// license that can be found in the LICENSE file.
// Adapted mainly from github.com/gonutz/w32
//go:build windows
// +build windows
package whisper
import (
"errors"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var (
version = windows.NewLazySystemDLL("version.dll")
getFileVersionInfoSize = version.NewProc("GetFileVersionInfoSizeW")
getFileVersionInfo = version.NewProc("GetFileVersionInfoW")
verQueryValue = version.NewProc("VerQueryValueW")
)
type VS_FIXEDFILEINFO struct {
Signature uint32
StrucVersion uint32
FileVersionMS uint32
FileVersionLS uint32
ProductVersionMS uint32
ProductVersionLS uint32
FileFlagsMask uint32
FileFlags uint32
FileOS uint32
FileType uint32
FileSubtype uint32
FileDateMS uint32
FileDateLS uint32
}
type WinVersion struct {
Major uint32
Minor uint32
Patch uint32
Build uint32
}
// FileVersion concatenates FileVersionMS and FileVersionLS to a uint64 value.
func (fi VS_FIXEDFILEINFO) FileVersion() uint64 {
return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS)
}
func GetFileVersionInfoSize(path string) uint32 {
ret, _, _ := getFileVersionInfoSize.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
0,
)
return uint32(ret)
}
func GetFileVersionInfo(path string, data []byte) bool {
ret, _, _ := getFileVersionInfo.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
0,
uintptr(len(data)),
uintptr(unsafe.Pointer(&data[0])),
)
return ret != 0
}
// VerQueryValueRoot calls VerQueryValue
// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx)
// with `\` (root) to retieve the VS_FIXEDFILEINFO.
func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) {
var offset uintptr
var length uint
blockStart := unsafe.Pointer(&block[0])
ret, _, _ := verQueryValue.Call(
uintptr(blockStart),
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))),
uintptr(unsafe.Pointer(&offset)),
uintptr(unsafe.Pointer(&length)),
)
if ret == 0 {
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: verQueryValue failed")
}
start := int(offset) - int(uintptr(blockStart))
end := start + int(length)
if start < 0 || start >= len(block) || end < start || end > len(block) {
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: find failed")
}
data := block[start:end]
info := *((*VS_FIXEDFILEINFO)(unsafe.Pointer(&data[0])))
return info, nil
}
func GetFileVersion(path string) (WinVersion, error) {
var result WinVersion
size := GetFileVersionInfoSize(path)
if size <= 0 {
return result, errors.New("GetFileVersionInfoSize failed")
}
info := make([]byte, size)
ok := GetFileVersionInfo(path, info)
if !ok {
return result, errors.New("GetFileVersionInfo failed")
}
fixed, err := VerQueryValueRoot(info)
if err != nil {
return result, err
}
version := fixed.FileVersion()
result.Major = uint32(version & 0xFFFF000000000000 >> 48)
result.Minor = uint32(version & 0x0000FFFF00000000 >> 32)
result.Patch = uint32(version & 0x00000000FFFF0000 >> 16)
result.Build = uint32(version & 0x000000000000FFFF)
return result, nil
}

75
state.go Normal file
View File

@ -0,0 +1,75 @@
package main
import (
"fmt"
"sync"
"github.com/xzeldon/whisper-api-server/pkg/whisper"
)
type WhisperState struct {
model *whisper.Model
context *whisper.IContext
media *whisper.IMediaFoundation
params *whisper.FullParams
mutex sync.Mutex
}
func InitializeWhisperState(modelPath string) (*WhisperState, error) {
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
if err != nil {
return nil, err
}
model, err := lib.LoadModel(modelPath)
if err != nil {
return nil, err
}
context, err := model.CreateContext()
if err != nil {
return nil, err
}
media, err := lib.InitMediaFoundation()
if err != nil {
return nil, err
}
params, err := context.FullDefaultParams(whisper.SsBeamSearch)
if err != nil {
return nil, err
}
params.AddFlags(whisper.FlagNoContext)
params.AddFlags(whisper.FlagTokenTimestamps)
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
return &WhisperState{
model: model,
context: context,
media: media,
params: params,
}, nil
}
func getResult(ctx *whisper.IContext) (string, error) {
results := &whisper.ITranscribeResult{}
ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results)
length, err := results.GetSize()
if err != nil {
return "", err
}
segments := results.GetSegments(length.CountSegments)
var result string
for _, seg := range segments {
result += seg.Text()
}
return result, nil
}

0
tmp/.gitkeep Normal file
View File

48
utils.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/labstack/echo/v4"
)
func saveFormFile(name string, c echo.Context) (string, error) {
file, err := c.FormFile(name)
if err != nil {
return "", err
}
src, err := file.Open()
if err != nil {
return "", err
}
defer src.Close()
ext := filepath.Ext(file.Filename)
filename := time.Now().Format(time.RFC3339)
filename = "./tmp/" + sanitizeFilename(filename) + ext
dst, err := os.Create(filename)
if err != nil {
return "", err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return "", err
}
return filename, nil
}
func sanitizeFilename(filename string) string {
invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`}
for _, char := range invalidChars {
filename = strings.ReplaceAll(filename, char, "-")
}
return filename
}