Implement automatic model and Whisper.dll downloading

Additionally, made the following changes:
∙ - Allow listening only on local interfaces
∙ - Update project structure
This commit is contained in:
2023-10-05 21:05:30 +03:00
parent 8e4cb72b50
commit a025285256
10 changed files with 199 additions and 10 deletions

45
internal/api/handler.go Normal file
View File

@ -0,0 +1,45 @@
package api
import (
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
type TranscribeResponse struct {
Text string `json:"text"`
}
func Transcribe(c echo.Context, whisperState *WhisperState) error {
audioPath, err := saveFormFile("file", c)
if err != nil {
c.Logger().Errorf("Error reading file: %s", err)
return err
}
whisperState.mutex.Lock()
buffer, err := whisperState.media.LoadAudioFile(audioPath, true)
if err != nil {
c.Logger().Errorf("Error loading audio file data: %s", err)
}
err = whisperState.context.RunFull(whisperState.params, buffer)
result, err := getResult(whisperState.context)
if err != nil {
c.Logger().Error(err)
}
defer whisperState.mutex.Unlock()
if len(result) == 0 {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"})
}
response := TranscribeResponse{
Text: strings.TrimLeft(result, " "),
}
return c.JSON(http.StatusOK, response)
}

75
internal/api/state.go Normal file
View File

@ -0,0 +1,75 @@
package api
import (
"fmt"
"sync"
"github.com/xzeldon/whisper-api-server/pkg/whisper"
)
type WhisperState struct {
model *whisper.Model
context *whisper.IContext
media *whisper.IMediaFoundation
params *whisper.FullParams
mutex sync.Mutex
}
func InitializeWhisperState(modelPath string) (*WhisperState, error) {
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
if err != nil {
return nil, err
}
model, err := lib.LoadModel(modelPath)
if err != nil {
return nil, err
}
context, err := model.CreateContext()
if err != nil {
return nil, err
}
media, err := lib.InitMediaFoundation()
if err != nil {
return nil, err
}
params, err := context.FullDefaultParams(whisper.SsBeamSearch)
if err != nil {
return nil, err
}
params.AddFlags(whisper.FlagNoContext)
params.AddFlags(whisper.FlagTokenTimestamps)
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
return &WhisperState{
model: model,
context: context,
media: media,
params: params,
}, nil
}
func getResult(ctx *whisper.IContext) (string, error) {
results := &whisper.ITranscribeResult{}
ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results)
length, err := results.GetSize()
if err != nil {
return "", err
}
segments := results.GetSegments(length.CountSegments)
var result string
for _, seg := range segments {
result += seg.Text()
}
return result, nil
}

48
internal/api/utils.go Normal file
View File

@ -0,0 +1,48 @@
package api
import (
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/labstack/echo/v4"
)
func saveFormFile(name string, c echo.Context) (string, error) {
file, err := c.FormFile(name)
if err != nil {
return "", err
}
src, err := file.Open()
if err != nil {
return "", err
}
defer src.Close()
ext := filepath.Ext(file.Filename)
filename := time.Now().Format(time.RFC3339)
filename = "./tmp/" + sanitizeFilename(filename) + ext
dst, err := os.Create(filename)
if err != nil {
return "", err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return "", err
}
return filename, nil
}
func sanitizeFilename(filename string) string {
invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`}
for _, char := range invalidChars {
filename = strings.ReplaceAll(filename, char, "-")
}
return filename
}