mirror of
https://github.com/xzeldon/whisper-api-server.git
synced 2025-07-13 13:34:39 +03:00
Implement automatic model and Whisper.dll downloading
Additionally, made the following changes: ∙ - Allow listening only on local interfaces ∙ - Update project structure
This commit is contained in:
45
internal/api/handler.go
Normal file
45
internal/api/handler.go
Normal file
@ -0,0 +1,45 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type TranscribeResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func Transcribe(c echo.Context, whisperState *WhisperState) error {
|
||||
audioPath, err := saveFormFile("file", c)
|
||||
if err != nil {
|
||||
c.Logger().Errorf("Error reading file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
whisperState.mutex.Lock()
|
||||
buffer, err := whisperState.media.LoadAudioFile(audioPath, true)
|
||||
if err != nil {
|
||||
c.Logger().Errorf("Error loading audio file data: %s", err)
|
||||
}
|
||||
|
||||
err = whisperState.context.RunFull(whisperState.params, buffer)
|
||||
|
||||
result, err := getResult(whisperState.context)
|
||||
if err != nil {
|
||||
c.Logger().Error(err)
|
||||
}
|
||||
|
||||
defer whisperState.mutex.Unlock()
|
||||
|
||||
if len(result) == 0 {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Internal server error"})
|
||||
}
|
||||
|
||||
response := TranscribeResponse{
|
||||
Text: strings.TrimLeft(result, " "),
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
75
internal/api/state.go
Normal file
75
internal/api/state.go
Normal file
@ -0,0 +1,75 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/xzeldon/whisper-api-server/pkg/whisper"
|
||||
)
|
||||
|
||||
type WhisperState struct {
|
||||
model *whisper.Model
|
||||
context *whisper.IContext
|
||||
media *whisper.IMediaFoundation
|
||||
params *whisper.FullParams
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func InitializeWhisperState(modelPath string) (*WhisperState, error) {
|
||||
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model, err := lib.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
context, err := model.CreateContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
media, err := lib.InitMediaFoundation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params, err := context.FullDefaultParams(whisper.SsBeamSearch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.AddFlags(whisper.FlagNoContext)
|
||||
params.AddFlags(whisper.FlagTokenTimestamps)
|
||||
|
||||
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
|
||||
|
||||
return &WhisperState{
|
||||
model: model,
|
||||
context: context,
|
||||
media: media,
|
||||
params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getResult(ctx *whisper.IContext) (string, error) {
|
||||
results := &whisper.ITranscribeResult{}
|
||||
ctx.GetResults(whisper.RfTokens|whisper.RfTimestamps, &results)
|
||||
|
||||
length, err := results.GetSize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
segments := results.GetSegments(length.CountSegments)
|
||||
|
||||
var result string
|
||||
|
||||
for _, seg := range segments {
|
||||
result += seg.Text()
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
48
internal/api/utils.go
Normal file
48
internal/api/utils.go
Normal file
@ -0,0 +1,48 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func saveFormFile(name string, c echo.Context) (string, error) {
|
||||
file, err := c.FormFile(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
ext := filepath.Ext(file.Filename)
|
||||
filename := time.Now().Format(time.RFC3339)
|
||||
filename = "./tmp/" + sanitizeFilename(filename) + ext
|
||||
|
||||
dst, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err = io.Copy(dst, src); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func sanitizeFilename(filename string) string {
|
||||
invalidChars := []string{`\`, `/`, `:`, `*`, `?`, `"`, `<`, `>`, `|`}
|
||||
for _, char := range invalidChars {
|
||||
filename = strings.ReplaceAll(filename, char, "-")
|
||||
}
|
||||
return filename
|
||||
}
|
38
internal/resources/download.go
Normal file
38
internal/resources/download.go
Normal file
@ -0,0 +1,38 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
func DownloadFile(url string, filepath string) error {
|
||||
out, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
fileSize := resp.ContentLength
|
||||
bar := progressbar.DefaultBytes(
|
||||
fileSize,
|
||||
"Downloading",
|
||||
)
|
||||
|
||||
writer := io.MultiWriter(out, bar)
|
||||
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
29
internal/resources/model.go
Normal file
29
internal/resources/model.go
Normal file
@ -0,0 +1,29 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func GetModel(modelType string) (string, error) {
|
||||
fileURL := fmt.Sprintf("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/%s", modelType)
|
||||
filePath := modelType
|
||||
|
||||
isModelFileExists := IsFileExists(filePath)
|
||||
|
||||
if !isModelFileExists {
|
||||
fmt.Println("Model not found.")
|
||||
err := DownloadFile(fileURL, filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Printf("Model found: %s\n", absPath)
|
||||
return filePath, nil
|
||||
}
|
13
internal/resources/utils.go
Normal file
13
internal/resources/utils.go
Normal file
@ -0,0 +1,13 @@
|
||||
package resources
|
||||
|
||||
import "os"
|
||||
|
||||
func IsFileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
78
internal/resources/whisper.go
Normal file
78
internal/resources/whisper.go
Normal file
@ -0,0 +1,78 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func GetWhisperDll(version string) (string, error) {
|
||||
fileUrl := fmt.Sprintf("https://github.com/Const-me/Whisper/releases/download/%s/Library.zip", version)
|
||||
fileToExtract := "Binary/Whisper.dll"
|
||||
|
||||
isWhisperDllExists := IsFileExists("Whisper.dll")
|
||||
|
||||
if !isWhisperDllExists {
|
||||
fmt.Println("Whisper DLL not found.")
|
||||
archivePath, err := os.CreateTemp("", "WhisperLibrary-*.zip")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer archivePath.Close()
|
||||
|
||||
err = DownloadFile(fileUrl, archivePath.Name())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = extractFile(archivePath.Name(), fileToExtract)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs("Whisper.dll")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Printf("Library found: %s\n", absPath)
|
||||
return "Whisper.dll", nil
|
||||
}
|
||||
|
||||
func extractFile(archivePath string, fileToExtract string) error {
|
||||
reader, err := zip.OpenReader(archivePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.Name == fileToExtract {
|
||||
targetPath := filepath.Base(fileToExtract)
|
||||
|
||||
writer, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
_, err = io.Copy(writer, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("File not found in the archive")
|
||||
}
|
Reference in New Issue
Block a user