Implement automatic model and Whisper.dll downloading

Additionally, made the following changes:
∙ - Allow listening only on local interfaces
∙ - Update project structure
This commit is contained in:
2023-10-05 21:05:30 +03:00
parent 8e4cb72b50
commit a025285256
10 changed files with 199 additions and 10 deletions

45
internal/api/handler.go Normal file
View File

@ -0,0 +1,45 @@
package api
import (
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
type TranscribeResponse struct {
Text string `json:"text"`
}
func Transcribe(c echo.Context, whisperState *WhisperState) error {
audioPath, err := saveFormFile("file", c)
if err != nil {
c.Logger().Errorf("Error reading file: %s", err)
return err
}
whisperState.mutex.Lock()
buffer, err := whisperState.media.LoadAudioFile(audioPath, true)
if err != nil {
c.Logger().Errorf("Error loading audio file data: %s", err)
}
err = whisperState.context.RunFull(whisperState.params, buffer)
result, err := getResult(whisperState.context)
if err != nil {
c.Logger().Error(err)
}
defer whisperState.mutex.Unlock()
if len(result) == 0 {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"})
}
response := TranscribeResponse{
Text: strings.TrimLeft(result, " "),
}
return c.JSON(http.StatusOK, response)
}

75
internal/api/state.go Normal file
View File

@ -0,0 +1,75 @@
package api
import (
"fmt"
"sync"
"github.com/xzeldon/whisper-api-server/pkg/whisper"
)
type WhisperState struct {
model *whisper.Model
context *whisper.IContext
media *whisper.IMediaFoundation
params *whisper.FullParams
mutex sync.Mutex
}
func InitializeWhisperState(modelPath string) (*WhisperState, error) {
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
if err != nil {
return nil, err
}
model, err := lib.LoadModel(modelPath)
if err != nil {
return nil, err
}
context, err := model.CreateContext()
if err != nil {
return nil, err
}
media, err := lib.InitMediaFoundation()
if err != nil {
return nil, err
}
params, err := context.FullDefaultParams(whisper.SsBeamSearch)
if err != nil {
return nil, err
}
params.AddFlags(whisper.FlagNoContext)
params.AddFlags(whisper.FlagTokenTimestamps)
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
return &WhisperState{
model: model,
context: context,
media: media,
params: params,
}, nil
}
func getResult(ctx *whisper.IContext) (string, error) {
results := &whisper.ITranscribeResult{}
ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results)
length, err := results.GetSize()
if err != nil {
return "", err
}
segments := results.GetSegments(length.CountSegments)
var result string
for _, seg := range segments {
result += seg.Text()
}
return result, nil
}

48
internal/api/utils.go Normal file
View File

@ -0,0 +1,48 @@
package api
import (
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/labstack/echo/v4"
)
func saveFormFile(name string, c echo.Context) (string, error) {
file, err := c.FormFile(name)
if err != nil {
return "", err
}
src, err := file.Open()
if err != nil {
return "", err
}
defer src.Close()
ext := filepath.Ext(file.Filename)
filename := time.Now().Format(time.RFC3339)
filename = "./tmp/" + sanitizeFilename(filename) + ext
dst, err := os.Create(filename)
if err != nil {
return "", err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return "", err
}
return filename, nil
}
func sanitizeFilename(filename string) string {
invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`}
for _, char := range invalidChars {
filename = strings.ReplaceAll(filename, char, "-")
}
return filename
}

View File

@ -0,0 +1,38 @@
package resources
import (
"io"
"net/http"
"os"
"github.com/schollz/progressbar/v3"
)
func DownloadFile(url string, filepath string) error {
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
fileSize := resp.ContentLength
bar := progressbar.DefaultBytes(
fileSize,
"Downloading",
)
writer := io.MultiWriter(out, bar)
_, err = io.Copy(writer, resp.Body)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,29 @@
package resources
import (
"fmt"
"path/filepath"
)
func GetModel(modelType string) (string, error) {
fileURL := fmt.Sprintf("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/%s", modelType)
filePath := modelType
isModelFileExists := IsFileExists(filePath)
if !isModelFileExists {
fmt.Println("Model not found.")
err := DownloadFile(fileURL, filePath)
if err != nil {
return "", err
}
}
absPath, err := filepath.Abs(filePath)
if err != nil {
return "", err
}
fmt.Printf("Model found: %s\n", absPath)
return filePath, nil
}

View File

@ -0,0 +1,13 @@
package resources
import "os"
func IsFileExists(filename string) bool {
_, err := os.Stat(filename)
if err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}

View File

@ -0,0 +1,78 @@
package resources
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
)
func GetWhisperDll(version string) (string, error) {
fileUrl := fmt.Sprintf("https://github.com/Const-me/Whisper/releases/download/%s/Library.zip", version)
fileToExtract := "Binary/Whisper.dll"
isWhisperDllExists := IsFileExists("Whisper.dll")
if !isWhisperDllExists {
fmt.Println("Whisper DLL not found.")
archivePath, err := os.CreateTemp("", "WhisperLibrary-*.zip")
if err != nil {
return "", err
}
defer archivePath.Close()
err = DownloadFile(fileUrl, archivePath.Name())
if err != nil {
return "", err
}
err = extractFile(archivePath.Name(), fileToExtract)
if err != nil {
return "", err
}
}
absPath, err := filepath.Abs("Whisper.dll")
if err != nil {
return "", err
}
fmt.Printf("Library found: %s\n", absPath)
return "Whisper.dll", nil
}
func extractFile(archivePath string, fileToExtract string) error {
reader, err := zip.OpenReader(archivePath)
if err != nil {
return err
}
defer reader.Close()
for _, file := range reader.File {
if file.Name == fileToExtract {
targetPath := filepath.Base(fileToExtract)
writer, err := os.Create(targetPath)
if err != nil {
return err
}
defer writer.Close()
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
_, err = io.Copy(writer, src)
if err != nil {
return err
}
return nil
}
}
return fmt.Errorf("File not found in the archive")
}