mirror of
https://github.com/xzeldon/whisper-api-server.git
synced 2024-12-25 07:05:48 +00:00
initial commit
This commit is contained in:
commit
8e4cb72b50
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
Whisper.dll
|
||||
ggml-medium.bin
|
||||
|
||||
whisper-api-server.exe
|
||||
|
||||
tmp/*
|
||||
!tmp/.gitkeep
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Timofey Gelazoniya
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
43
README.md
Normal file
43
README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Whisper API Server (Go)
|
||||
|
||||
## ⚠️ This project is a work in progress (WIP).
|
||||
|
||||
This API server enables audio transcription using the OpenAI Whisper models.
|
||||
|
||||
# Setup
|
||||
|
||||
- Download the desired model from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main)
|
||||
- Update the model path in the `main.go` file
|
||||
- Download `Whisper.dll` from [github](https://github.com/Const-me/Whisper/releases/tag/1.12.0) (`Library.zip`) and place it in the project's root directory
|
||||
- Build project: `go build .` (you only need go compiler, without gcc)
|
||||
|
||||
# Usage example
|
||||
|
||||
Make a request to the server using the following command:
|
||||
|
||||
```sh
|
||||
curl http://localhost:3000/v1/audio/transcriptions \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F file="@/path/to/file/audio.mp3" \
|
||||
```
|
||||
|
||||
Receive a response in JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that."
|
||||
}
|
||||
```
|
||||
|
||||
# Roadmap
|
||||
|
||||
- [ ] Implement automatic model downloading from [huggingface](https://huggingface.co/ggerganov/whisper.cpp/tree/main)
|
||||
- [ ] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases)
|
||||
- [ ] Provide prebuilt binaries for Windows
|
||||
- [ ] Include instructions for running on Linux with Wine (likely possible).
|
||||
|
||||
# Credits
|
||||
|
||||
- [Const-me/Whisper](https://github.com/Const-me/Whisper) project
|
||||
- [goConstmeWhisper](https://github.com/jaybinks/goConstmeWhisper) for the remarkable Go bindings for [Const-me/Whisper](https://github.com/Const-me/Whisper)
|
||||
- [Georgi Gerganov](https://github.com/ggerganov) for GGML models
|
19
go.mod
Normal file
19
go.mod
Normal file
@ -0,0 +1,19 @@
|
||||
module github.com/xzeldon/whisper-api-server
|
||||
|
||||
go 1.21.1
|
||||
|
||||
require (
|
||||
github.com/labstack/echo/v4 v4.11.1
|
||||
golang.org/x/sys v0.12.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/labstack/gommon v0.4.0
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.org/x/net v0.12.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
)
|
43
go.sum
Normal file
43
go.sum
Normal file
@ -0,0 +1,43 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4=
|
||||
github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ=
|
||||
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
|
||||
github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
|
||||
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
||||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
|
||||
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
45
handler.go
Normal file
45
handler.go
Normal file
@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type TranscribeResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func transcribe(c echo.Context, whisperState *WhisperState) error {
|
||||
audioPath, err := saveFormFile("file", c)
|
||||
if err != nil {
|
||||
c.Logger().Errorf("Error reading file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
whisperState.mutex.Lock()
|
||||
buffer, err := whisperState.media.LoadAudioFile(audioPath, true)
|
||||
if err != nil {
|
||||
c.Logger().Errorf("Error loading audio file data: %s", err)
|
||||
}
|
||||
|
||||
err = whisperState.context.RunFull(whisperState.params, buffer)
|
||||
|
||||
result, err := getResult(whisperState.context)
|
||||
if err != nil {
|
||||
c.Logger().Error(err)
|
||||
}
|
||||
|
||||
defer whisperState.mutex.Unlock()
|
||||
|
||||
if len(result) == 0 {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"})
|
||||
}
|
||||
|
||||
response := TranscribeResponse{
|
||||
Text: strings.TrimLeft(result, " "),
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
29
main.go
Normal file
29
main.go
Normal file
@ -0,0 +1,29 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
)
|
||||
|
||||
const MODEL_PATH = "./ggml-medium.bin"
|
||||
|
||||
func main() {
|
||||
e := echo.New()
|
||||
e.HideBanner = true
|
||||
|
||||
if l, ok := e.Logger.(*log.Logger); ok {
|
||||
l.SetHeader("${time_rfc3339} ${level}")
|
||||
}
|
||||
|
||||
whisperState, err := InitializeWhisperState(MODEL_PATH)
|
||||
if err != nil {
|
||||
e.Logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
e.POST("/v1/audio/transcriptions", func(c echo.Context) error {
|
||||
return transcribe(c, whisperState)
|
||||
})
|
||||
|
||||
e.Logger.Fatal(e.Start(":3000"))
|
||||
}
|
248
pkg/whisper/FullParams.go
Normal file
248
pkg/whisper/FullParams.go
Normal file
@ -0,0 +1,248 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/sFullParams.h
|
||||
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/Parameters.cs
|
||||
|
||||
type eSamplingStrategy uint32
|
||||
|
||||
const (
|
||||
SsGreedy eSamplingStrategy = iota
|
||||
SsBeamSearch
|
||||
SsINVALIDARG
|
||||
)
|
||||
|
||||
type eFullParamsFlags uint32
|
||||
|
||||
const (
|
||||
FlagNone eFullParamsFlags = 0
|
||||
FlagTranslate = 1 << 0
|
||||
FlagNoContext = 1 << 1
|
||||
FlagSingleSegment = 1 << 2
|
||||
FlagPrintSpecial = 1 << 3
|
||||
FlagPrintProgress = 1 << 4
|
||||
FlagPrintRealtime = 1 << 5
|
||||
FlagPrintTimestamps = 1 << 6
|
||||
FlagTokenTimestamps = 1 << 7 // Experimental
|
||||
FlagSpeedupAudio = 1 << 8
|
||||
)
|
||||
|
||||
type EWhisperHWND uintptr
|
||||
|
||||
const (
|
||||
S_OK EWhisperHWND = 0
|
||||
S_FALSE EWhisperHWND = 1
|
||||
)
|
||||
|
||||
type FullParams struct {
|
||||
cStruct *_FullParams
|
||||
}
|
||||
|
||||
func (this *FullParams) CpuThreads() int32 {
|
||||
if this == nil {
|
||||
return 0
|
||||
} else if this.cStruct == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.cStruct.cpuThreads
|
||||
}
|
||||
|
||||
func (this *FullParams) setCpuThreads(val int32) {
|
||||
if this == nil {
|
||||
return
|
||||
} else if this.cStruct == nil {
|
||||
return
|
||||
}
|
||||
|
||||
this.cStruct.cpuThreads = val
|
||||
}
|
||||
|
||||
func (this *FullParams) SetMaxTextCTX(val int32) {
|
||||
this.cStruct.n_max_text_ctx = val
|
||||
}
|
||||
|
||||
func (this *FullParams) AddFlags(newflag eFullParamsFlags) {
|
||||
if this == nil {
|
||||
return
|
||||
} else if this.cStruct == nil {
|
||||
return
|
||||
}
|
||||
|
||||
this.cStruct.Flags = this.cStruct.Flags | newflag
|
||||
}
|
||||
|
||||
func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) {
|
||||
if this == nil {
|
||||
return
|
||||
} else if this.cStruct == nil {
|
||||
return
|
||||
}
|
||||
|
||||
this.cStruct.Flags = this.cStruct.Flags ^ newflag
|
||||
}
|
||||
|
||||
/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/
|
||||
type NewSegmentCallback_Type func(context *IContext, n_new uint32, user_data unsafe.Pointer) EWhisperHWND
|
||||
|
||||
func (this *FullParams) SetNewSegmentCallback(cb NewSegmentCallback_Type) {
|
||||
if this == nil {
|
||||
return
|
||||
} else if this.cStruct == nil {
|
||||
return
|
||||
}
|
||||
this.cStruct.new_segment_callback = syscall.NewCallback(cb)
|
||||
}
|
||||
|
||||
/*
|
||||
Return S_OK to proceed, or S_FALSE to stop the process
|
||||
*/
|
||||
type EncoderBeginCallback_Type func(context *IContext, user_data unsafe.Pointer) EWhisperHWND
|
||||
|
||||
func (this *FullParams) SetEncoderBeginCallback(cb EncoderBeginCallback_Type) {
|
||||
if this == nil {
|
||||
return
|
||||
} else if this.cStruct == nil {
|
||||
return
|
||||
}
|
||||
|
||||
this.cStruct.encoder_begin_callback = syscall.NewCallback(cb)
|
||||
}
|
||||
|
||||
func (this *FullParams) TestDefaultsOK() bool {
|
||||
if this == nil {
|
||||
return false
|
||||
} else if this.cStruct == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.cStruct.n_max_text_ctx != 16384 {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.cStruct.Flags != (FlagPrintProgress | FlagPrintTimestamps) {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.cStruct.thold_pt != 0.01 {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.cStruct.thold_ptsum != 0.01 {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.cStruct.Language != English {
|
||||
return false
|
||||
}
|
||||
|
||||
// Todo ... why do these not line up as expected.. is our struct out of alignment ?
|
||||
/*
|
||||
if this.cStruct.strategy == ssGreedy {
|
||||
if this.cStruct.beam_search.n_past != -1 ||
|
||||
this.cStruct.beam_search.beam_width != -1 ||
|
||||
this.cStruct.beam_search.n_best != -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
} else if this.cStruct.strategy == ssBeamSearch {
|
||||
if this.cStruct.greedy.n_past != -1 ||
|
||||
this.cStruct.beam_search.beam_width != 10 ||
|
||||
this.cStruct.beam_search.n_best != 5 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type _FullParams struct {
|
||||
strategy eSamplingStrategy
|
||||
cpuThreads int32
|
||||
n_max_text_ctx int32
|
||||
offset_ms int32
|
||||
duration_ms int32
|
||||
Flags eFullParamsFlags
|
||||
Language eLanguage
|
||||
|
||||
thold_pt float32
|
||||
thold_ptsum float32
|
||||
max_len int32
|
||||
max_tokens int32
|
||||
|
||||
greedy struct{ n_past int32 }
|
||||
beam_search struct {
|
||||
n_past int32
|
||||
beam_width int32
|
||||
n_best int32
|
||||
}
|
||||
|
||||
audio_ctx int32 // overwrite the audio context size (0 = use default)
|
||||
|
||||
prompt_tokens uintptr
|
||||
prompt_n_tokens int32
|
||||
|
||||
new_segment_callback uintptr
|
||||
new_segment_callback_user_data uintptr
|
||||
|
||||
encoder_begin_callback uintptr
|
||||
encoder_begin_callback_user_data uintptr
|
||||
|
||||
// Are these needed ?? Jay
|
||||
// setFlag uintptr
|
||||
}
|
||||
|
||||
func NewFullParams(cstruct *_FullParams) *FullParams {
|
||||
this := FullParams{}
|
||||
this.cStruct = cstruct
|
||||
return &this
|
||||
}
|
||||
|
||||
func _newFullParams_cStruct() *_FullParams {
|
||||
return &_FullParams{
|
||||
|
||||
strategy: 0,
|
||||
cpuThreads: 0,
|
||||
n_max_text_ctx: 0,
|
||||
offset_ms: 0,
|
||||
duration_ms: 0,
|
||||
|
||||
Flags: 0,
|
||||
Language: 0,
|
||||
|
||||
thold_pt: 0,
|
||||
thold_ptsum: 0,
|
||||
max_len: 0,
|
||||
max_tokens: 0,
|
||||
|
||||
// anonymous int32
|
||||
greedy: struct{ n_past int32 }{n_past: 0},
|
||||
|
||||
// anonymous struct
|
||||
beam_search: struct {
|
||||
n_past int32
|
||||
beam_width int32
|
||||
n_best int32
|
||||
}{
|
||||
n_past: 0,
|
||||
beam_width: 0,
|
||||
n_best: 0,
|
||||
},
|
||||
|
||||
audio_ctx: 0,
|
||||
|
||||
prompt_tokens: 0,
|
||||
prompt_n_tokens: 0,
|
||||
|
||||
new_segment_callback: 0,
|
||||
new_segment_callback_user_data: 0,
|
||||
|
||||
encoder_begin_callback: 0,
|
||||
encoder_begin_callback_user_data: 0,
|
||||
}
|
||||
}
|
222
pkg/whisper/MediaFoundation.go
Normal file
222
pkg/whisper/MediaFoundation.go
Normal file
@ -0,0 +1,222 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/IMediaFoundation.h
|
||||
|
||||
type IMediaFoundation struct {
|
||||
lpVtbl *IMediaFoundationVtbl
|
||||
}
|
||||
|
||||
type IMediaFoundationVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
loadAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
|
||||
openAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioReader** pp );
|
||||
loadAudioFileData uintptr // ( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); HRESULT
|
||||
listCaptureDevices uintptr // ( pfnFoundCaptureDevices pfn, void* pv );
|
||||
openCaptureDevice uintptr // ( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp );
|
||||
}
|
||||
|
||||
func (this *IMediaFoundation) AddRef() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.AddRef,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *IMediaFoundation) Release() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.Release,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
// ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
|
||||
func (this *IMediaFoundation) LoadAudioFile(file string, stereo bool) (*iAudioBuffer, error) {
|
||||
|
||||
var buffer *iAudioBuffer
|
||||
|
||||
UTFFileName, _ := windows.UTF16PtrFromString(file)
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.loadAudioFile,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
uintptr(unsafe.Pointer(UTFFileName)),
|
||||
uintptr(1), // Todo ... Stereo !
|
||||
uintptr(unsafe.Pointer(&buffer)))
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("loadAudioFile failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
func (this *IMediaFoundation) OpenAudioFile(file string, stereo bool) (*iAudioReader, error) {
|
||||
|
||||
var buffer *iAudioReader
|
||||
|
||||
UTFFileName, _ := windows.UTF16PtrFromString(file)
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.openAudioFile,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
uintptr(unsafe.Pointer(UTFFileName)),
|
||||
uintptr(1), // Todo ... Stereo !
|
||||
uintptr(unsafe.Pointer(&buffer)))
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("openAudioFile failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
func (this *IMediaFoundation) LoadAudioFileData(inbuffer *[]byte, stereo bool) (*iAudioReader, error) {
|
||||
|
||||
var reader *iAudioReader
|
||||
|
||||
// loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp );
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.loadAudioFileData,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
|
||||
uintptr(unsafe.Pointer(&(*inbuffer)[0])),
|
||||
uintptr(uint64(len(*inbuffer))),
|
||||
uintptr(1), // Todo ... Stereo !
|
||||
uintptr(unsafe.Pointer(&reader)))
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// ************************************************************
|
||||
|
||||
type iAudioBuffer struct {
|
||||
lpVtbl *iAudioBufferVtbl
|
||||
}
|
||||
|
||||
type iAudioBufferVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
countSamples uintptr // returns uint32_t
|
||||
getPcmMono uintptr // returns float*
|
||||
getPcmStereo uintptr // returns float*
|
||||
getTime uintptr // ( int64_t& rdi )
|
||||
}
|
||||
|
||||
func (this *iAudioBuffer) AddRef() int32 {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.AddRef,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *iAudioBuffer) Release() int32 {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.Release,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *iAudioBuffer) CountSamples() (uint32, error) {
|
||||
|
||||
ret, _, err := syscall.SyscallN(
|
||||
this.lpVtbl.countSamples,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return 0, errors.New(err.Error())
|
||||
}
|
||||
|
||||
return uint32(ret), nil
|
||||
}
|
||||
|
||||
// ************************************************************
|
||||
|
||||
type iAudioReader struct {
|
||||
lpVtbl *iAudioReaderVtbl
|
||||
}
|
||||
|
||||
type iAudioReaderVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
|
||||
getDuration uintptr // ( int64_t& rdi )
|
||||
getReader uintptr // ( IMFSourceReader** pp )
|
||||
requestedStereo uintptr // ()
|
||||
}
|
||||
|
||||
func (this *iAudioReader) AddRef() int32 {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.AddRef,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *iAudioReader) Release() int32 {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.Release,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *iAudioReader) GetDuration() (uint64, error) {
|
||||
|
||||
var rdi int64
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.getDuration,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
uintptr(unsafe.Pointer(&rdi)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
|
||||
return 0, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return uint64(rdi), nil
|
||||
}
|
||||
|
||||
// ************************************************************
|
||||
|
||||
type iAudioCapture struct {
|
||||
lpVtbl *iAudioCaptureVtbl
|
||||
}
|
||||
|
||||
type iAudioCaptureVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
getReader uintptr // ( IMFSourceReader** pp )
|
||||
getParams uintptr // returns sCaptureParams&
|
||||
}
|
119
pkg/whisper/Model.go
Normal file
119
pkg/whisper/Model.go
Normal file
@ -0,0 +1,119 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// External - Go version of the struct
|
||||
type Model struct {
|
||||
cStruct *_IModel
|
||||
setup *sModelSetup
|
||||
}
|
||||
|
||||
// Internal - C Version of the structs
|
||||
type _IModel struct {
|
||||
lpVtbl *IModelVtbl
|
||||
}
|
||||
|
||||
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/iContext.cl.h
|
||||
type IModelVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
|
||||
createContext uintptr //( iContext** pp ) = 0;
|
||||
tokenize uintptr /* HRESULT __stdcall tokenize( const char* text, pfnDecodedTokens pfn, void* pv ); */
|
||||
isMultilingual uintptr //() = 0;
|
||||
getSpecialTokens uintptr //( SpecialTokens& rdi ) = 0;
|
||||
stringFromToken uintptr //( whisper_token token ) = 0;
|
||||
clone uintptr //( iModel** rdi ) = 0;
|
||||
}
|
||||
|
||||
func NewModel(setup *sModelSetup, cstruct *_IModel) *Model {
|
||||
this := Model{}
|
||||
this.setup = setup
|
||||
this.cStruct = cstruct
|
||||
return &this
|
||||
}
|
||||
|
||||
func (this *Model) AddRef() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.cStruct.lpVtbl.AddRef,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *Model) Release() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.cStruct.lpVtbl.Release,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *Model) CreateContext() (*IContext, error) {
|
||||
var context *IContext
|
||||
|
||||
/*
|
||||
ret, _, err := syscall.Syscall(
|
||||
this.cStruct.lpVtbl.createContext,
|
||||
2, // Why was this 1, rather than 2 ?? 1 seemed to work fine
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
uintptr(unsafe.Pointer(&context)),
|
||||
0)*/
|
||||
ret, _, err := syscall.SyscallN(
|
||||
this.cStruct.lpVtbl.createContext,
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
uintptr(unsafe.Pointer(&context)))
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("createContext failed: %w", err.Error())
|
||||
}
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
return nil, fmt.Errorf("loadModel failed: %w", err)
|
||||
}
|
||||
|
||||
return context, nil
|
||||
}
|
||||
|
||||
func (this *Model) IsMultilingual() bool {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.cStruct.lpVtbl.isMultilingual,
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
)
|
||||
|
||||
return bool(windows.Handle(ret) == windows.S_OK)
|
||||
}
|
||||
|
||||
func (this *Model) Clone() (*_IModel, error) {
|
||||
|
||||
if this.setup.isFlagSet(gmf_Cloneable) {
|
||||
return nil, errors.New("Model is not cloneable")
|
||||
}
|
||||
//this.Cloneable ?
|
||||
|
||||
var modelptr *_IModel
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.cStruct.lpVtbl.clone,
|
||||
uintptr(unsafe.Pointer(this.cStruct)),
|
||||
uintptr(unsafe.Pointer(&modelptr)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) == windows.S_OK {
|
||||
return modelptr, nil
|
||||
} else {
|
||||
return nil, errors.New("Model.Clone() failed : " + syscall.Errno(ret).Error())
|
||||
}
|
||||
}
|
101
pkg/whisper/ModelSetup.go
Normal file
101
pkg/whisper/ModelSetup.go
Normal file
@ -0,0 +1,101 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Re-implemented sModelSetup.h
|
||||
|
||||
// enum struct eModelImplementation : uint32_t
|
||||
type eModelImplementation uint32
|
||||
|
||||
const (
|
||||
// GPGPU implementation based on Direct3D 11.0 compute shaders
|
||||
mi_GPU eModelImplementation = 1
|
||||
|
||||
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
|
||||
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
|
||||
mi_Hybrid eModelImplementation = 2
|
||||
|
||||
// A reference implementation which uses the original GGML CPU-running code
|
||||
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
|
||||
mi_Reference eModelImplementation = 3
|
||||
)
|
||||
|
||||
// enum struct eGpuModelFlags : uint32_t
|
||||
type eGpuModelFlags uint32
|
||||
|
||||
const (
|
||||
// <summary>Equivalent to <c>Wave32 | NoReshapedMatMul</c> on Intel and nVidia GPUs,<br/>
|
||||
// and <c>Wave64 | UseReshapedMatMul</c> on AMD GPUs</summary>
|
||||
gmf_None eGpuModelFlags = 0
|
||||
|
||||
// <summary>Use Wave32 version of compute shaders even on AMD GPUs</summary>
|
||||
// <remarks>Incompatible with <see cref="Wave64" /></remarks>
|
||||
gmf_Wave32 eGpuModelFlags = 1
|
||||
|
||||
// <summary>Use Wave64 version of compute shaders even on nVidia and Intel GPUs</summary>
|
||||
// <remarks>Incompatible with <see cref="Wave32" /></remarks>
|
||||
gmf_Wave64 eGpuModelFlags = 2
|
||||
|
||||
// <summary>Do not use reshaped matrix multiplication shaders on AMD GPUs</summary>
|
||||
// <remarks>Incompatible with <see cref="UseReshapedMatMul" /></remarks>
|
||||
gmf_NoReshapedMatMul eGpuModelFlags = 4
|
||||
|
||||
// <summary>Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs</summary>
|
||||
// <remarks>Incompatible with <see cref="NoReshapedMatMul" /></remarks>
|
||||
gmf_UseReshapedMatMul eGpuModelFlags = 8
|
||||
|
||||
// <summary>Create GPU tensors in a way which allows sharing across D3D devices</summary>
|
||||
gmf_Cloneable eGpuModelFlags = 0x10
|
||||
)
|
||||
|
||||
// struct sModelSetup
|
||||
type sModelSetup struct {
|
||||
impl eModelImplementation
|
||||
flags eGpuModelFlags
|
||||
adapter string
|
||||
}
|
||||
|
||||
type _sModelSetup struct {
|
||||
impl eModelImplementation
|
||||
flags eGpuModelFlags
|
||||
adapter uintptr
|
||||
}
|
||||
|
||||
func ModelSetup(flags eGpuModelFlags, GPU string) *sModelSetup {
|
||||
this := sModelSetup{}
|
||||
this.impl = mi_GPU
|
||||
this.flags = flags
|
||||
this.adapter = GPU
|
||||
|
||||
return &this
|
||||
}
|
||||
|
||||
func (this *sModelSetup) isFlagSet(flag eGpuModelFlags) bool {
|
||||
return (this.flags & flag) == 0
|
||||
}
|
||||
|
||||
func (this *sModelSetup) AsCType() *_sModelSetup {
|
||||
var err error
|
||||
|
||||
ctype := _sModelSetup{}
|
||||
ctype.impl = this.impl
|
||||
ctype.flags = this.flags
|
||||
ctype.adapter = 0
|
||||
|
||||
// Conver Go String to wchar_t, AKA UTF-16
|
||||
if this.adapter != "" {
|
||||
var UTF16str *uint16
|
||||
UTF16str, err = windows.UTF16PtrFromString(this.adapter)
|
||||
ctype.adapter = uintptr(unsafe.Pointer(UTF16str))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
} else {
|
||||
return &ctype
|
||||
}
|
||||
}
|
178
pkg/whisper/TranscribeResult.go
Normal file
178
pkg/whisper/TranscribeResult.go
Normal file
@ -0,0 +1,178 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"C"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type eTokenFlags uint32
|
||||
|
||||
const (
|
||||
TfNone eTokenFlags = 0
|
||||
TfSpecial = 1
|
||||
)
|
||||
|
||||
type sTranscribeLength struct {
|
||||
CountSegments uint32
|
||||
CountTokens uint32
|
||||
}
|
||||
|
||||
type sTimeSpan struct {
|
||||
|
||||
// The value is expressed in 100-nanoseconds ticks: compatible with System.Timespan, FILETIME, and many other things
|
||||
Ticks uint64
|
||||
|
||||
/*
|
||||
operator sTimeSpanFields() const
|
||||
{
|
||||
return sTimeSpanFields{ ticks };
|
||||
}
|
||||
void operator=( uint64_t tt )
|
||||
{
|
||||
ticks = tt;
|
||||
}
|
||||
void operator=( int64_t tt )
|
||||
{
|
||||
assert( tt >= 0 );
|
||||
ticks = (uint64_t)tt;
|
||||
} */
|
||||
}
|
||||
|
||||
type sTimeInterval struct {
|
||||
Begin sTimeSpan
|
||||
End sTimeSpan
|
||||
}
|
||||
|
||||
type sSegment struct {
|
||||
// Segment text, null-terminated, and probably UTF-8 encoded
|
||||
text *C.char
|
||||
|
||||
// Start and end times of the segment
|
||||
Time sTimeInterval
|
||||
|
||||
// These two integers define the slice of the tokens in this segment, in the array returned by iTranscribeResult.getTokens method
|
||||
FirstToken uint32
|
||||
CountTokens uint32
|
||||
}
|
||||
|
||||
func (this *sSegment) Text() string {
|
||||
return C.GoString(this.text)
|
||||
}
|
||||
|
||||
type sSegmentArray []sSegment
|
||||
|
||||
type SToken struct {
|
||||
// Token text, null-terminated, and usually UTF-8 encoded.
|
||||
// I think for Chinese language the models sometimes outputs invalid UTF8 strings here, Unicode code points can be split between adjacent tokens in the same segment
|
||||
// More info: https://github.com/ggerganov/whisper.cpp/issues/399
|
||||
text *C.char
|
||||
|
||||
// Start and end times of the token
|
||||
Time sTimeInterval
|
||||
// Probability of the token
|
||||
Probability float32
|
||||
|
||||
// Probability of the timestamp token
|
||||
ProbabilityTimestamp float32
|
||||
|
||||
// Sum of probabilities of all timestamp tokens
|
||||
Ptsum float32
|
||||
|
||||
// Voice length of the token
|
||||
Vlen float32
|
||||
|
||||
// Token id
|
||||
Id int32
|
||||
|
||||
Flags eTokenFlags
|
||||
}
|
||||
|
||||
func (this *SToken) Text() string {
|
||||
return C.GoString(this.text)
|
||||
}
|
||||
|
||||
type sTokenArray []SToken
|
||||
|
||||
type iTranscribeResultVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
|
||||
getSize uintptr // ( sTranscribeLength& rdi ) HRESULT
|
||||
getSegments uintptr // () getTokens
|
||||
getTokens uintptr // () getToken*
|
||||
}
|
||||
|
||||
type ITranscribeResult struct {
|
||||
lpVtbl *iTranscribeResultVtbl
|
||||
}
|
||||
|
||||
func (this *ITranscribeResult) AddRef() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.AddRef,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *ITranscribeResult) Release() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.Release,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *ITranscribeResult) GetSize() (*sTranscribeLength, error) {
|
||||
|
||||
var result sTranscribeLength
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.getSize,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
uintptr(unsafe.Pointer(&result)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("iTranscribeResult.GetSize failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, errors.New(syscall.Errno(ret).Error())
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
|
||||
}
|
||||
|
||||
func (this *ITranscribeResult) GetSegments(len uint32) []sSegment {
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.getSegments,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
|
||||
data := unsafe.Slice((*sSegment)(unsafe.Pointer(ret)), len)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (this *ITranscribeResult) GetTokens(len uint32) []SToken {
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
this.lpVtbl.getTokens,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
)
|
||||
|
||||
if unsafe.Pointer(ret) != nil {
|
||||
return unsafe.Slice((*SToken)(unsafe.Pointer(ret)), len)
|
||||
} else {
|
||||
return []SToken{}
|
||||
}
|
||||
}
|
281
pkg/whisper/context.go
Normal file
281
pkg/whisper/context.go
Normal file
@ -0,0 +1,281 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type uuid [16]byte
|
||||
|
||||
type eResultFlags uint32
|
||||
|
||||
const (
|
||||
RfNone eResultFlags = 0
|
||||
|
||||
// Return individual tokens in addition to the segments
|
||||
RfTokens = 1
|
||||
|
||||
// Return timestamps
|
||||
RfTimestamps = 2
|
||||
|
||||
// Create a new COM object for the results.
|
||||
// Without this flag, the context returns a pointer to the COM object stored in the context.
|
||||
// The content of that object is replaced every time you call IContext.getResults method
|
||||
RfNewObject = 0x100
|
||||
)
|
||||
|
||||
type IContextVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
|
||||
RunFull uintptr
|
||||
RunStreamed uintptr
|
||||
RunCapture uintptr
|
||||
GetResults uintptr
|
||||
DetectSpeaker uintptr
|
||||
GetModel uintptr
|
||||
FullDefaultParams uintptr
|
||||
TimingsPrint uintptr
|
||||
TimingsReset uintptr
|
||||
}
|
||||
|
||||
type IContext struct {
|
||||
lpVtbl *IContextVtbl
|
||||
}
|
||||
|
||||
//type sFullParams struct{}
|
||||
|
||||
// type iAudioBuffer struct{}
|
||||
type sProgressSink struct {
|
||||
pfn uintptr
|
||||
pv uintptr
|
||||
}
|
||||
|
||||
// type iAudioReader struct{}
|
||||
type sCaptureCallbacks struct{}
|
||||
|
||||
// type iAudioCapture struct{}
|
||||
// type eResultFlags int32
|
||||
// type iTranscribeResult struct{}
|
||||
// type sTimeInterval struct{}
|
||||
type eSpeakerChannel int32
|
||||
|
||||
//type eSamplingStrategy int32
|
||||
|
||||
// Create a new IContext instance
|
||||
func newIContext() *IContext {
|
||||
return &IContext{
|
||||
lpVtbl: &IContextVtbl{
|
||||
QueryInterface: 0,
|
||||
AddRef: 0,
|
||||
Release: 0,
|
||||
RunFull: 0,
|
||||
RunStreamed: 0,
|
||||
RunCapture: 0,
|
||||
GetResults: 0,
|
||||
DetectSpeaker: 0,
|
||||
GetModel: 0,
|
||||
FullDefaultParams: 0,
|
||||
TimingsPrint: 0,
|
||||
TimingsReset: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (context *IContext) TimingsPrint() error {
|
||||
|
||||
// TimingsPrint();
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.TimingsPrint,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
|
||||
return errors.New(syscall.Errno(ret).Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (context *IContext) RunFull(params *FullParams, buffer *iAudioBuffer) error {
|
||||
|
||||
// runFull( const sFullParams& params, const iAudioBuffer* buffer );
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.RunFull,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
|
||||
uintptr(unsafe.Pointer(params.cStruct)),
|
||||
uintptr(unsafe.Pointer(buffer)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
|
||||
return errors.New(syscall.Errno(ret).Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *IContext) RunStreamed(params *FullParams, reader *iAudioReader) error {
|
||||
|
||||
cb := sProgressSink{}
|
||||
|
||||
// runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader );
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.RunStreamed,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(unsafe.Pointer(params.cStruct)),
|
||||
uintptr(unsafe.Pointer(&cb)), // No progress cb yet
|
||||
uintptr(unsafe.Pointer(reader)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("RunStreamed failed: %s\n", syscall.Errno(ret).Error())
|
||||
return errors.New(syscall.Errno(ret).Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *IContext) AddRef() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.AddRef,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
func (this *IContext) Release() int32 {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
this.lpVtbl.Release,
|
||||
1,
|
||||
uintptr(unsafe.Pointer(this)),
|
||||
0,
|
||||
0)
|
||||
return int32(ret)
|
||||
}
|
||||
|
||||
/*
|
||||
https://github.com/Const-me/Whisper/blob/f6f743c7b3570b85ccf47f74b84e06a73667ef3e/Whisper/Whisper/ContextImpl.misc.cpp
|
||||
|
||||
Returns E_POINTER if null pointer provided in params
|
||||
Initialises params to all 0
|
||||
sets values in struct, does not malloc
|
||||
*/
|
||||
func (context *IContext) FullDefaultParams(strategy eSamplingStrategy) (*FullParams, error) {
|
||||
|
||||
/*
|
||||
ERR : unreadable Only part of a ReadProcessMemory or WriteProcessMemory request was completed
|
||||
* not related to stratergy ... tested 0, 1 and 2 ... 2 produced E_INVALIDARG as expected
|
||||
* not a nil ptr to params ... nil poitner produced E_POINTER as expected
|
||||
* params seems to return 0x4000
|
||||
* !!!!! FullParams is not a com interface !!!
|
||||
* so no lpVtbl *FullParamsVtbl , no queryinterface, addref etc
|
||||
*/
|
||||
|
||||
params := _newFullParams_cStruct()
|
||||
//params := &[160]byte{}
|
||||
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.FullDefaultParams,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(strategy),
|
||||
uintptr(unsafe.Pointer(params)),
|
||||
)
|
||||
|
||||
// nil ptr should be 0x80004003L
|
||||
// unsafe.Pointer(0xc00011dc28)
|
||||
// unsafe.Pointer(0x4000)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, syscall.Errno(ret)
|
||||
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return nil, errors.New("FullDefaultParams did not return params")
|
||||
}
|
||||
ParamObj := NewFullParams(params)
|
||||
|
||||
if ParamObj.TestDefaultsOK() {
|
||||
return ParamObj, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (context *IContext) GetModel() (*_IModel, error) {
|
||||
|
||||
var modelptr *_IModel
|
||||
|
||||
// getModel( iModel** pp );
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.GetModel,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(unsafe.Pointer(&modelptr)),
|
||||
)
|
||||
|
||||
if windows.Handle(ret) != windows.S_OK {
|
||||
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
|
||||
return nil, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
if modelptr == nil {
|
||||
return nil, errors.New("loadModel did not return a Model")
|
||||
}
|
||||
|
||||
if modelptr.lpVtbl == nil {
|
||||
return nil, errors.New("loadModel method table is nil")
|
||||
}
|
||||
|
||||
return modelptr, nil
|
||||
}
|
||||
|
||||
// ************************************************************************************************************************************************
|
||||
// Not really implemented / tested
|
||||
// ************************************************************************************************************************************************
|
||||
|
||||
func (context *IContext) RunCapture(params *FullParams, callbacks *sCaptureCallbacks, reader *iAudioCapture) uintptr {
|
||||
ret, _, _ := syscall.SyscallN(
|
||||
context.lpVtbl.RunCapture,
|
||||
//3,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(unsafe.Pointer(params)),
|
||||
uintptr(unsafe.Pointer(callbacks)),
|
||||
uintptr(unsafe.Pointer(reader)),
|
||||
)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (context *IContext) GetResults(flags eResultFlags, pp **ITranscribeResult) uintptr {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
context.lpVtbl.GetResults,
|
||||
3,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(flags),
|
||||
uintptr(unsafe.Pointer(pp)),
|
||||
)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (context *IContext) DetectSpeaker(time *sTimeInterval, result *eSpeakerChannel) uintptr {
|
||||
ret, _, _ := syscall.Syscall(
|
||||
context.lpVtbl.DetectSpeaker,
|
||||
3,
|
||||
uintptr(unsafe.Pointer(context)),
|
||||
uintptr(unsafe.Pointer(time)),
|
||||
uintptr(unsafe.Pointer(result)),
|
||||
)
|
||||
return ret
|
||||
}
|
5
pkg/whisper/doc.go
Normal file
5
pkg/whisper/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
/*
|
||||
github.com/jaybinks/goConstmeWhispers
|
||||
Go Bindings for https://github.com/Const-me/Whisper
|
||||
*/
|
||||
package whisper
|
207
pkg/whisper/language.go
Normal file
207
pkg/whisper/language.go
Normal file
@ -0,0 +1,207 @@
|
||||
package whisper
|
||||
|
||||
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/eLanguage.cs
|
||||
|
||||
type eLanguage int32
|
||||
|
||||
const (
|
||||
Auto eLanguage = -1 // "af"
|
||||
|
||||
Afrikaans = 0x6661 // "af"
|
||||
/// <summary>Albanian</summary>
|
||||
Albanian = 0x7173 // "sq"
|
||||
/// <summary>Amharic</summary>
|
||||
Amharic = 0x6D61 // "am"
|
||||
/// <summary>Arabic</summary>
|
||||
Arabic = 0x7261 // "ar"
|
||||
/// <summary>Armenian</summary>
|
||||
Armenian = 0x7968 // "hy"
|
||||
/// <summary>Assamese</summary>
|
||||
Assamese = 0x7361 // "as"
|
||||
/// <summary>Azerbaijani</summary>
|
||||
Azerbaijani = 0x7A61 // "az"
|
||||
/// <summary>Bashkir</summary>
|
||||
Bashkir = 0x6162 // "ba"
|
||||
/// <summary>Basque</summary>
|
||||
Basque = 0x7565 // "eu"
|
||||
/// <summary>Belarusian</summary>
|
||||
Belarusian = 0x6562 // "be"
|
||||
/// <summary>Bengali</summary>
|
||||
Bengali = 0x6E62 // "bn"
|
||||
/// <summary>Bosnian</summary>
|
||||
Bosnian = 0x7362 // "bs"
|
||||
/// <summary>Breton</summary>
|
||||
Breton = 0x7262 // "br"
|
||||
/// <summary>Bulgarian</summary>
|
||||
Bulgarian = 0x6762 // "bg"
|
||||
/// <summary>Catalan</summary>
|
||||
Catalan = 0x6163 // "ca"
|
||||
/// <summary>Chinese</summary>
|
||||
Chinese = 0x687A // "zh"
|
||||
/// <summary>Croatian</summary>
|
||||
Croatian = 0x7268 // "hr"
|
||||
/// <summary>Czech</summary>
|
||||
Czech = 0x7363 // "cs"
|
||||
/// <summary>Danish</summary>
|
||||
Danish = 0x6164 // "da"
|
||||
/// <summary>Dutch</summary>
|
||||
Dutch = 0x6C6E // "nl"
|
||||
/// <summary>English</summary>
|
||||
English = 0x6E65 // "en"
|
||||
/// <summary>Estonian</summary>
|
||||
Estonian = 0x7465 // "et"
|
||||
/// <summary>Faroese</summary>
|
||||
Faroese = 0x6F66 // "fo"
|
||||
/// <summary>Finnish</summary>
|
||||
Finnish = 0x6966 // "fi"
|
||||
/// <summary>French</summary>
|
||||
French = 0x7266 // "fr"
|
||||
/// <summary>Galician</summary>
|
||||
Galician = 0x6C67 // "gl"
|
||||
/// <summary>Georgian</summary>
|
||||
Georgian = 0x616B // "ka"
|
||||
/// <summary>German</summary>
|
||||
German = 0x6564 // "de"
|
||||
/// <summary>Greek</summary>
|
||||
Greek = 0x6C65 // "el"
|
||||
/// <summary>Gujarati</summary>
|
||||
Gujarati = 0x7567 // "gu"
|
||||
/// <summary>Haitian Creole</summary>
|
||||
HaitianCreole = 0x7468 // "ht"
|
||||
/// <summary>Hausa</summary>
|
||||
Hausa = 0x6168 // "ha"
|
||||
/// <summary>Hawaiian</summary>
|
||||
Hawaiian = 0x776168 // "haw"
|
||||
/// <summary>Hebrew</summary>
|
||||
Hebrew = 0x7769 // "iw"
|
||||
/// <summary>Hindi</summary>
|
||||
Hindi = 0x6968 // "hi"
|
||||
/// <summary>Hungarian</summary>
|
||||
Hungarian = 0x7568 // "hu"
|
||||
/// <summary>Icelandic</summary>
|
||||
Icelandic = 0x7369 // "is"
|
||||
/// <summary>Indonesian</summary>
|
||||
Indonesian = 0x6469 // "id"
|
||||
/// <summary>Italian</summary>
|
||||
Italian = 0x7469 // "it"
|
||||
/// <summary>Japanese</summary>
|
||||
Japanese = 0x616A // "ja"
|
||||
/// <summary>Javanese</summary>
|
||||
Javanese = 0x776A // "jw"
|
||||
/// <summary>Kannada</summary>
|
||||
Kannada = 0x6E6B // "kn"
|
||||
/// <summary>Kazakh</summary>
|
||||
Kazakh = 0x6B6B // "kk"
|
||||
/// <summary>Khmer</summary>
|
||||
Khmer = 0x6D6B // "km"
|
||||
/// <summary>Korean</summary>
|
||||
Korean = 0x6F6B // "ko"
|
||||
/// <summary>Lao</summary>
|
||||
Lao = 0x6F6C // "lo"
|
||||
/// <summary>Latin</summary>
|
||||
Latin = 0x616C // "la"
|
||||
/// <summary>Latvian</summary>
|
||||
Latvian = 0x766C // "lv"
|
||||
/// <summary>Lingala</summary>
|
||||
Lingala = 0x6E6C // "ln"
|
||||
/// <summary>Lithuanian</summary>
|
||||
Lithuanian = 0x746C // "lt"
|
||||
/// <summary>Luxembourgish</summary>
|
||||
Luxembourgish = 0x626C // "lb"
|
||||
/// <summary>Macedonian</summary>
|
||||
Macedonian = 0x6B6D // "mk"
|
||||
/// <summary>Malagasy</summary>
|
||||
Malagasy = 0x676D // "mg"
|
||||
/// <summary>Malay</summary>
|
||||
Malay = 0x736D // "ms"
|
||||
/// <summary>Malayalam</summary>
|
||||
Malayalam = 0x6C6D // "ml"
|
||||
/// <summary>Maltese</summary>
|
||||
Maltese = 0x746D // "mt"
|
||||
/// <summary>Maori</summary>
|
||||
Maori = 0x696D // "mi"
|
||||
/// <summary>Marathi</summary>
|
||||
Marathi = 0x726D // "mr"
|
||||
/// <summary>Mongolian</summary>
|
||||
Mongolian = 0x6E6D // "mn"
|
||||
/// <summary>Myanmar</summary>
|
||||
Myanmar = 0x796D // "my"
|
||||
/// <summary>Nepali</summary>
|
||||
Nepali = 0x656E // "ne"
|
||||
/// <summary>Norwegian</summary>
|
||||
Norwegian = 0x6F6E // "no"
|
||||
/// <summary>Nynorsk</summary>
|
||||
Nynorsk = 0x6E6E // "nn"
|
||||
/// <summary>Occitan</summary>
|
||||
Occitan = 0x636F // "oc"
|
||||
/// <summary>Pashto</summary>
|
||||
Pashto = 0x7370 // "ps"
|
||||
/// <summary>Persian</summary>
|
||||
Persian = 0x6166 // "fa"
|
||||
/// <summary>Polish</summary>
|
||||
Polish = 0x6C70 // "pl"
|
||||
/// <summary>Portuguese</summary>
|
||||
Portuguese = 0x7470 // "pt"
|
||||
/// <summary>Punjabi</summary>
|
||||
Punjabi = 0x6170 // "pa"
|
||||
/// <summary>Romanian</summary>
|
||||
Romanian = 0x6F72 // "ro"
|
||||
/// <summary>Russian</summary>
|
||||
Russian = 0x7572 // "ru"
|
||||
/// <summary>Sanskrit</summary>
|
||||
Sanskrit = 0x6173 // "sa"
|
||||
/// <summary>Serbian</summary>
|
||||
Serbian = 0x7273 // "sr"
|
||||
/// <summary>Shona</summary>
|
||||
Shona = 0x6E73 // "sn"
|
||||
/// <summary>Sindhi</summary>
|
||||
Sindhi = 0x6473 // "sd"
|
||||
/// <summary>Sinhala</summary>
|
||||
Sinhala = 0x6973 // "si"
|
||||
/// <summary>Slovak</summary>
|
||||
Slovak = 0x6B73 // "sk"
|
||||
/// <summary>Slovenian</summary>
|
||||
Slovenian = 0x6C73 // "sl"
|
||||
/// <summary>Somali</summary>
|
||||
Somali = 0x6F73 // "so"
|
||||
/// <summary>Spanish</summary>
|
||||
Spanish = 0x7365 // "es"
|
||||
/// <summary>Sundanese</summary>
|
||||
Sundanese = 0x7573 // "su"
|
||||
/// <summary>Swahili</summary>
|
||||
Swahili = 0x7773 // "sw"
|
||||
/// <summary>Swedish</summary>
|
||||
Swedish = 0x7673 // "sv"
|
||||
/// <summary>Tagalog</summary>
|
||||
Tagalog = 0x6C74 // "tl"
|
||||
/// <summary>Tajik</summary>
|
||||
Tajik = 0x6774 // "tg"
|
||||
/// <summary>Tamil</summary>
|
||||
Tamil = 0x6174 // "ta"
|
||||
/// <summary>Tatar</summary>
|
||||
Tatar = 0x7474 // "tt"
|
||||
/// <summary>Telugu</summary>
|
||||
Telugu = 0x6574 // "te"
|
||||
/// <summary>Thai</summary>
|
||||
Thai = 0x6874 // "th"
|
||||
/// <summary>Tibetan</summary>
|
||||
Tibetan = 0x6F62 // "bo"
|
||||
/// <summary>Turkish</summary>
|
||||
Turkish = 0x7274 // "tr"
|
||||
/// <summary>Turkmen</summary>
|
||||
Turkmen = 0x6B74 // "tk"
|
||||
/// <summary>Ukrainian</summary>
|
||||
Ukrainian = 0x6B75 // "uk"
|
||||
/// <summary>Urdu</summary>
|
||||
Urdu = 0x7275 // "ur"
|
||||
/// <summary>Uzbek</summary>
|
||||
Uzbek = 0x7A75 // "uz"
|
||||
/// <summary>Vietnamese</summary>
|
||||
Vietnamese = 0x6976 // "vi"
|
||||
/// <summary>Welsh</summary>
|
||||
Welsh = 0x7963 // "cy"
|
||||
/// <summary>Yiddish</summary>
|
||||
Yiddish = 0x6979 // "yi"
|
||||
/// <summary>Yoruba</summary>
|
||||
Yoruba = 0x6F79 // "yo"
|
||||
)
|
43
pkg/whisper/logger.go
Normal file
43
pkg/whisper/logger.go
Normal file
@ -0,0 +1,43 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"C"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
/*
|
||||
https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/loggerApi.h
|
||||
|
||||
*/
|
||||
|
||||
type eLogLevel uint8
|
||||
|
||||
const (
|
||||
LlError eLogLevel = 0
|
||||
LlWarning = 1
|
||||
LlInfo = 2
|
||||
LlDebug = 3
|
||||
)
|
||||
|
||||
type eLogFlags uint8
|
||||
|
||||
const (
|
||||
LfNone eLogFlags = 0
|
||||
LfUseStandardError = 1
|
||||
LfSkipFormatMessage = 2
|
||||
)
|
||||
|
||||
type sLoggerSetup struct {
|
||||
sink uintptr // pfnLoggerSink
|
||||
context uintptr // void*
|
||||
level eLogLevel // eLogLevel
|
||||
flags eLogFlags // eLoggerFlags
|
||||
}
|
||||
|
||||
func fnLoggerSink(context uintptr, lvl eLogLevel, message *C.char) uintptr {
|
||||
|
||||
strmessage := C.GoString(message)
|
||||
fmt.Printf("%d - %s\n", lvl, strmessage)
|
||||
|
||||
return 0
|
||||
}
|
188
pkg/whisper/whisper.go
Normal file
188
pkg/whisper/whisper.go
Normal file
@ -0,0 +1,188 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"C"
|
||||
"errors"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
// Using lxn/win because its COM functions expose raw HRESULTs
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
/*
|
||||
eModelImplementation - TranscribeStructs.h
|
||||
|
||||
// GPGPU implementation based on Direct3D 11.0 compute shaders
|
||||
GPU = 1,
|
||||
|
||||
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
|
||||
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
|
||||
Hybrid = 2,
|
||||
|
||||
// A reference implementation which uses the original GGML CPU-running code
|
||||
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
|
||||
Reference = 3,
|
||||
*/
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/seccrypto/common-hresult-values
|
||||
// https://pkg.go.dev/golang.org/x/sys/windows
|
||||
const (
|
||||
E_INVALIDARG = 0x80070057
|
||||
ERROR_HV_CPUID_FEATURE_VALIDATION = 0xC0350038
|
||||
|
||||
DLLName = "whisper.dll"
|
||||
)
|
||||
|
||||
type Libwhisper struct {
|
||||
dll *syscall.LazyDLL
|
||||
ver WinVersion
|
||||
existing_model map[string]*Model
|
||||
|
||||
proc_setupLogger *syscall.LazyProc
|
||||
proc_loadModel *syscall.LazyProc
|
||||
proc_initMediaFoundation *syscall.LazyProc
|
||||
// proc_findLanguageKeyW *syscall.LazyProc
|
||||
// proc_findLanguageKeyA *syscall.LazyProc
|
||||
// proc_getSupportedLanguages *syscall.LazyProc
|
||||
}
|
||||
|
||||
var singleton_whisper *Libwhisper = nil
|
||||
|
||||
func New(level eLogLevel, flags eLogFlags, cb *any) (*Libwhisper, error) {
|
||||
if singleton_whisper != nil {
|
||||
return singleton_whisper, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
this := &Libwhisper{}
|
||||
|
||||
this.ver, err = GetFileVersion(DLLName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if this.ver.Major < 1 && this.ver.Minor < 9 {
|
||||
return nil, errors.New("This library requires whisper.dll version 1.9 or higher.") // or less than 1.11 for now .. because the API changed
|
||||
}
|
||||
|
||||
this.dll = syscall.NewLazyDLL(DLLName) // Todo wrap this in a class, check file exists, handle errors ... you know, just a few things.. AKA Stop being lazy
|
||||
|
||||
this.proc_setupLogger = this.dll.NewProc("setupLogger")
|
||||
this.proc_loadModel = this.dll.NewProc("loadModel")
|
||||
this.proc_initMediaFoundation = this.dll.NewProc("initMediaFoundation")
|
||||
/*
|
||||
this.proc_findLanguageKeyW = this.dll.NewProc("findLanguageKeyW")
|
||||
this.proc_findLanguageKeyA = this.dll.NewProc("findLanguageKeyA")
|
||||
this.proc_getSupportedLanguages = this.dll.NewProc("getSupportedLanguages")
|
||||
*/
|
||||
|
||||
ok, err := this._setupLogger(level, flags, cb)
|
||||
if !ok {
|
||||
return nil, errors.New("Logger Error : " + err.Error())
|
||||
}
|
||||
|
||||
this.existing_model = make(map[string]*Model)
|
||||
singleton_whisper = this
|
||||
|
||||
return singleton_whisper, nil
|
||||
}
|
||||
|
||||
func (this *Libwhisper) Version() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d.", this.ver.Major, this.ver.Minor, this.ver.Patch, this.ver.Build)
|
||||
}
|
||||
|
||||
func (this *Libwhisper) SupportsMultiThread() bool {
|
||||
return this.ver.Major >= 1 && this.ver.Minor >= 10
|
||||
}
|
||||
|
||||
func (this *Libwhisper) _setupLogger(level eLogLevel, flags eLogFlags, cb *any) (bool, error) {
|
||||
|
||||
setup := sLoggerSetup{}
|
||||
setup.sink = 0
|
||||
setup.context = 0
|
||||
setup.level = level
|
||||
setup.flags = flags
|
||||
|
||||
if cb != nil {
|
||||
setup.sink = syscall.NewCallback(cb)
|
||||
}
|
||||
|
||||
res, _, err := this.proc_setupLogger.Call(uintptr(unsafe.Pointer(&setup)))
|
||||
|
||||
if windows.Handle(res) == windows.S_OK {
|
||||
return true, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Libwhisper) LoadModel(path string, aGPU ...string) (*Model, error) {
|
||||
var modelptr *_IModel
|
||||
|
||||
whisperpath, _ := windows.UTF16PtrFromString(path)
|
||||
|
||||
GPU := ""
|
||||
if len(aGPU) == 1 {
|
||||
GPU = aGPU[0]
|
||||
}
|
||||
|
||||
setup := ModelSetup(gmf_Cloneable, GPU)
|
||||
|
||||
// Construct our map hash
|
||||
singleton_hash := GPU + "|" + path
|
||||
if this.existing_model[singleton_hash] != nil {
|
||||
ClonedModel, err := this.existing_model[singleton_hash].Clone()
|
||||
if ClonedModel != nil {
|
||||
return NewModel(setup, ClonedModel), nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
obj, _, _ := this.proc_loadModel.Call(uintptr(unsafe.Pointer(whisperpath)), uintptr(unsafe.Pointer(setup.AsCType())), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&modelptr)))
|
||||
|
||||
if windows.Handle(obj) != windows.S_OK {
|
||||
fmt.Printf("loadModel failed: %s\n", syscall.Errno(obj).Error())
|
||||
return nil, fmt.Errorf("loadModel failed: %s", syscall.Errno(obj))
|
||||
}
|
||||
|
||||
if modelptr == nil {
|
||||
return nil, errors.New("loadModel did not return a Model")
|
||||
}
|
||||
|
||||
if modelptr.lpVtbl == nil {
|
||||
return nil, errors.New("loadModel method table is nil")
|
||||
}
|
||||
|
||||
model := NewModel(setup, modelptr)
|
||||
|
||||
this.existing_model[singleton_hash] = model
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (this *Libwhisper) InitMediaFoundation() (*IMediaFoundation, error) {
|
||||
|
||||
var mediafoundation *IMediaFoundation
|
||||
|
||||
// initMediaFoundation( iMediaFoundation** pp );
|
||||
obj, _, _ := this.proc_initMediaFoundation.Call(uintptr(unsafe.Pointer(&mediafoundation)))
|
||||
|
||||
if windows.Handle(obj) != windows.S_OK {
|
||||
fmt.Printf("initMediaFoundation failed: %s\n", syscall.Errno(obj).Error())
|
||||
return nil, fmt.Errorf("initMediaFoundation failed: %s", syscall.Errno(obj))
|
||||
}
|
||||
|
||||
if mediafoundation.lpVtbl == nil {
|
||||
return nil, errors.New("initMediaFoundation method table is nil")
|
||||
}
|
||||
|
||||
return mediafoundation, nil
|
||||
}
|
123
pkg/whisper/winversion.go
Normal file
123
pkg/whisper/winversion.go
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright 2018 Keybase Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD
|
||||
// license that can be found in the LICENSE file.
|
||||
// Adapted mainly from github.com/gonutz/w32
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
version = windows.NewLazySystemDLL("version.dll")
|
||||
getFileVersionInfoSize = version.NewProc("GetFileVersionInfoSizeW")
|
||||
getFileVersionInfo = version.NewProc("GetFileVersionInfoW")
|
||||
verQueryValue = version.NewProc("VerQueryValueW")
|
||||
)
|
||||
|
||||
type VS_FIXEDFILEINFO struct {
|
||||
Signature uint32
|
||||
StrucVersion uint32
|
||||
FileVersionMS uint32
|
||||
FileVersionLS uint32
|
||||
ProductVersionMS uint32
|
||||
ProductVersionLS uint32
|
||||
FileFlagsMask uint32
|
||||
FileFlags uint32
|
||||
FileOS uint32
|
||||
FileType uint32
|
||||
FileSubtype uint32
|
||||
FileDateMS uint32
|
||||
FileDateLS uint32
|
||||
}
|
||||
|
||||
type WinVersion struct {
|
||||
Major uint32
|
||||
Minor uint32
|
||||
Patch uint32
|
||||
Build uint32
|
||||
}
|
||||
|
||||
// FileVersion concatenates FileVersionMS and FileVersionLS to a uint64 value.
|
||||
func (fi VS_FIXEDFILEINFO) FileVersion() uint64 {
|
||||
return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS)
|
||||
}
|
||||
|
||||
func GetFileVersionInfoSize(path string) uint32 {
|
||||
ret, _, _ := getFileVersionInfoSize.Call(
|
||||
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
|
||||
0,
|
||||
)
|
||||
return uint32(ret)
|
||||
}
|
||||
|
||||
func GetFileVersionInfo(path string, data []byte) bool {
|
||||
ret, _, _ := getFileVersionInfo.Call(
|
||||
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
|
||||
0,
|
||||
uintptr(len(data)),
|
||||
uintptr(unsafe.Pointer(&data[0])),
|
||||
)
|
||||
return ret != 0
|
||||
}
|
||||
|
||||
// VerQueryValueRoot calls VerQueryValue
|
||||
// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx)
|
||||
// with `\` (root) to retieve the VS_FIXEDFILEINFO.
|
||||
func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) {
|
||||
var offset uintptr
|
||||
var length uint
|
||||
blockStart := unsafe.Pointer(&block[0])
|
||||
ret, _, _ := verQueryValue.Call(
|
||||
uintptr(blockStart),
|
||||
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))),
|
||||
uintptr(unsafe.Pointer(&offset)),
|
||||
uintptr(unsafe.Pointer(&length)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: verQueryValue failed")
|
||||
}
|
||||
start := int(offset) - int(uintptr(blockStart))
|
||||
end := start + int(length)
|
||||
if start < 0 || start >= len(block) || end < start || end > len(block) {
|
||||
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: find failed")
|
||||
}
|
||||
data := block[start:end]
|
||||
info := *((*VS_FIXEDFILEINFO)(unsafe.Pointer(&data[0])))
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func GetFileVersion(path string) (WinVersion, error) {
|
||||
var result WinVersion
|
||||
size := GetFileVersionInfoSize(path)
|
||||
if size <= 0 {
|
||||
return result, errors.New("GetFileVersionInfoSize failed")
|
||||
}
|
||||
|
||||
info := make([]byte, size)
|
||||
ok := GetFileVersionInfo(path, info)
|
||||
if !ok {
|
||||
return result, errors.New("GetFileVersionInfo failed")
|
||||
}
|
||||
|
||||
fixed, err := VerQueryValueRoot(info)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
version := fixed.FileVersion()
|
||||
|
||||
result.Major = uint32(version & 0xFFFF000000000000 >> 48)
|
||||
result.Minor = uint32(version & 0x0000FFFF00000000 >> 32)
|
||||
result.Patch = uint32(version & 0x00000000FFFF0000 >> 16)
|
||||
result.Build = uint32(version & 0x000000000000FFFF)
|
||||
|
||||
return result, nil
|
||||
}
|
75
state.go
Normal file
75
state.go
Normal file
@ -0,0 +1,75 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/xzeldon/whisper-api-server/pkg/whisper"
|
||||
)
|
||||
|
||||
type WhisperState struct {
|
||||
model *whisper.Model
|
||||
context *whisper.IContext
|
||||
media *whisper.IMediaFoundation
|
||||
params *whisper.FullParams
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func InitializeWhisperState(modelPath string) (*WhisperState, error) {
|
||||
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model, err := lib.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
context, err := model.CreateContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
media, err := lib.InitMediaFoundation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params, err := context.FullDefaultParams(whisper.SsBeamSearch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.AddFlags(whisper.FlagNoContext)
|
||||
params.AddFlags(whisper.FlagTokenTimestamps)
|
||||
|
||||
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
|
||||
|
||||
return &WhisperState{
|
||||
model: model,
|
||||
context: context,
|
||||
media: media,
|
||||
params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getResult(ctx *whisper.IContext) (string, error) {
|
||||
results := &whisper.ITranscribeResult{}
|
||||
ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results)
|
||||
|
||||
length, err := results.GetSize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
segments := results.GetSegments(length.CountSegments)
|
||||
|
||||
var result string
|
||||
|
||||
for _, seg := range segments {
|
||||
result += seg.Text()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
0
tmp/.gitkeep
Normal file
0
tmp/.gitkeep
Normal file
48
utils.go
Normal file
48
utils.go
Normal file
@ -0,0 +1,48 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func saveFormFile(name string, c echo.Context) (string, error) {
|
||||
file, err := c.FormFile(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
ext := filepath.Ext(file.Filename)
|
||||
filename := time.Now().Format(time.RFC3339)
|
||||
filename = "./tmp/" + sanitizeFilename(filename) + ext
|
||||
|
||||
dst, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err = io.Copy(dst, src); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func sanitizeFilename(filename string) string {
|
||||
invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`}
|
||||
for _, char := range invalidChars {
|
||||
filename = strings.ReplaceAll(filename, char, "-")
|
||||
}
|
||||
return filename
|
||||
}
|
Loading…
Reference in New Issue
Block a user