From a0252852569929b3e6a7e933741c073ff26e1cc4 Mon Sep 17 00:00:00 2001 From: xzeldon Date: Thu, 5 Oct 2023 21:05:30 +0300 Subject: [PATCH] Implement automatic model and Whisper.dll downloading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Additionally, made the following changes: ∙ - Allow listening only on local interfaces ∙ - Update project structure --- go.mod | 8 +++ go.sum | 14 +++++ handler.go => internal/api/handler.go | 4 +- state.go => internal/api/state.go | 2 +- utils.go => internal/api/utils.go | 2 +- internal/resources/download.go | 38 +++++++++++++ internal/resources/model.go | 29 ++++++++++ internal/resources/utils.go | 13 +++++ internal/resources/whisper.go | 78 +++++++++++++++++++++++++++ main.go | 21 +++++--- 10 files changed, 199 insertions(+), 10 deletions(-) rename handler.go => internal/api/handler.go (91%) rename state.go => internal/api/state.go (99%) rename utils.go => internal/api/utils.go (98%) create mode 100644 internal/resources/download.go create mode 100644 internal/resources/model.go create mode 100644 internal/resources/utils.go create mode 100644 internal/resources/whisper.go diff --git a/go.mod b/go.mod index b4f3c0d..7252b49 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,17 @@ go 1.21.1 require ( github.com/labstack/echo/v4 v4.11.1 + github.com/schollz/progressbar/v3 v3.13.1 golang.org/x/sys v0.12.0 ) +require ( + github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.2.0 // indirect + golang.org/x/term v0.10.0 // indirect +) + require ( github.com/labstack/gommon v0.4.0 github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 119ff10..659b594 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,7 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= @@ -10,11 +11,21 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE= +github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -34,6 +45,9 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handler.go b/internal/api/handler.go similarity index 91% rename from handler.go rename to internal/api/handler.go index 6b4ca39..965f31c 100644 --- a/handler.go +++ b/internal/api/handler.go @@ -1,4 +1,4 @@ -package main +package api import ( "net/http" @@ -11,7 +11,7 @@ type TranscribeResponse struct { Text string `json:"text"` } -func transcribe(c echo.Context, whisperState *WhisperState) error { +func Transcribe(c echo.Context, whisperState *WhisperState) error { audioPath, err := saveFormFile("file", c) if err != nil { c.Logger().Errorf("Error reading file: %s", err) diff --git a/state.go b/internal/api/state.go similarity index 99% rename from state.go rename to internal/api/state.go index 55e1d1c..39e9668 100644 --- a/state.go +++ b/internal/api/state.go @@ -1,4 +1,4 @@ -package main +package api import ( "fmt" diff --git a/utils.go b/internal/api/utils.go similarity index 98% rename from utils.go rename to internal/api/utils.go index e7f868f..bb994e7 100644 --- a/utils.go +++ b/internal/api/utils.go @@ -1,4 +1,4 @@ -package main +package api import ( "io" diff --git a/internal/resources/download.go b/internal/resources/download.go new file mode 100644 index 0000000..644d2dd --- /dev/null +++ b/internal/resources/download.go @@ -0,0 +1,38 @@ +package resources + +import ( + "io" + "net/http" + "os" + + "github.com/schollz/progressbar/v3" +) + +func DownloadFile(url string, filepath string) error { + out, err := os.Create(filepath) + if err != nil { + return err + } + defer out.Close() + + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + fileSize := resp.ContentLength + bar := progressbar.DefaultBytes( + fileSize, + "Downloading", + ) + + writer := io.MultiWriter(out, bar) + + _, err = io.Copy(writer, resp.Body) + if err != nil { + return err + } + + return nil +} diff --git a/internal/resources/model.go b/internal/resources/model.go new file mode 100644 index 0000000..18fa83d --- /dev/null +++ b/internal/resources/model.go @@ -0,0 +1,29 @@ +package resources + +import ( + "fmt" + "path/filepath" +) + +func GetModel(modelType string) (string, error) { + fileURL := fmt.Sprintf("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/%s", modelType) + filePath := modelType + + isModelFileExists := IsFileExists(filePath) + + if !isModelFileExists { + fmt.Println("Model not found.") + err := DownloadFile(fileURL, filePath) + if err != nil { + return "", err + } + } + + absPath, err := filepath.Abs(filePath) + if err != nil { + return "", err + } + + fmt.Printf("Model found: %s\n", absPath) + return filePath, nil +} diff --git a/internal/resources/utils.go b/internal/resources/utils.go new file mode 100644 index 0000000..9186bf6 --- /dev/null +++ b/internal/resources/utils.go @@ -0,0 +1,13 @@ +package resources + +import "os" + +func IsFileExists(filename string) bool { + _, err := os.Stat(filename) + if err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} diff --git a/internal/resources/whisper.go b/internal/resources/whisper.go new file mode 100644 index 0000000..1084780 --- /dev/null +++ b/internal/resources/whisper.go @@ -0,0 +1,78 @@ +package resources + +import ( + "archive/zip" + "fmt" + "io" + "os" + "path/filepath" +) + +func GetWhisperDll(version string) (string, error) { + fileUrl := fmt.Sprintf("https://github.com/Const-me/Whisper/releases/download/%s/Library.zip", version) + fileToExtract := "Binary/Whisper.dll" + + isWhisperDllExists := IsFileExists("Whisper.dll") + + if !isWhisperDllExists { + fmt.Println("Whisper DLL not found.") + archivePath, err := os.CreateTemp("", "WhisperLibrary-*.zip") + if err != nil { + return "", err + } + defer archivePath.Close() + + err = DownloadFile(fileUrl, archivePath.Name()) + if err != nil { + return "", err + } + + err = extractFile(archivePath.Name(), fileToExtract) + if err != nil { + return "", err + } + } + + absPath, err := filepath.Abs("Whisper.dll") + if err != nil { + return "", err + } + + fmt.Printf("Library found: %s\n", absPath) + return "Whisper.dll", nil +} + +func extractFile(archivePath string, fileToExtract string) error { + reader, err := zip.OpenReader(archivePath) + if err != nil { + return err + } + defer reader.Close() + + for _, file := range reader.File { + if file.Name == fileToExtract { + targetPath := filepath.Base(fileToExtract) + + writer, err := os.Create(targetPath) + if err != nil { + return err + } + defer writer.Close() + + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + _, err = io.Copy(writer, src) + if err != nil { + return err + } + + return nil + } + } + + return fmt.Errorf("File not found in the archive") +} diff --git a/main.go b/main.go index cf1328a..4f5bb62 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,10 @@ package main import ( "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" + "github.com/xzeldon/whisper-api-server/internal/api" + "github.com/xzeldon/whisper-api-server/internal/resources" ) -const MODEL_PATH = "./ggml-medium.bin" - func main() { e := echo.New() e.HideBanner = true @@ -15,15 +15,24 @@ func main() { l.SetHeader("${time_rfc3339} ${level}") } - whisperState, err := InitializeWhisperState(MODEL_PATH) + _, err := resources.GetWhisperDll("1.12.0") + if err != nil { + e.Logger.Error(err) + } + + model, err := resources.GetModel("ggml-medium.bin") + if err != nil { + e.Logger.Error(err) + } + + whisperState, err := api.InitializeWhisperState(model) if err != nil { e.Logger.Error(err) - return } e.POST("/v1/audio/transcriptions", func(c echo.Context) error { - return transcribe(c, whisperState) + return api.Transcribe(c, whisperState) }) - e.Logger.Fatal(e.Start(":3000")) + e.Logger.Fatal(e.Start("127.0.0.1:3000")) }