feat: big update

add cli parsing, model and .dll download prompt and so on
This commit is contained in:
Timofey Gelazoniya 2024-12-26 01:35:38 +03:00
parent 1924c01353
commit b228b436b9
Signed by: zeldon
GPG Key ID: 047886915281DD2A
10 changed files with 393 additions and 143 deletions

View File

@ -42,7 +42,7 @@ go build -ldflags "-s -w" -o server.exe main.go
Make a request to the server using the following command:
```sh
curl http://localhost:3031/v1/audio/transcriptions \
curl http://localhost:3000/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@/path/to/file/audio.mp3" \
```
@ -61,7 +61,7 @@ Receive a response in JSON format:
2. Open the plugin's settings.
3. Set the following values:
- API KEY: `sk-1`
- API URL: `http://localhost:3031/v1/audio/transcriptions`
- API URL: `http://localhost:3000/v1/audio/transcriptions`
- Model: `whisper-1`
# Roadmap
@ -70,9 +70,8 @@ Receive a response in JSON format:
- [x] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases)
- [x] Provide prebuilt binaries for Windows
- [ ] Include instructions for running on Linux with Wine (likely possible).
- [ ] Use flags to override the model path
- [ ] Use flags to override the model type (when downloading the model)
- [ ] Use flags to override the port
- [x] Use flags to override the model path
- [x] Use flags to override the port
# Credits

5
go.mod
View File

@ -10,17 +10,20 @@ require (
require (
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/term v0.10.0 // indirect
golang.org/x/time v0.3.0 // indirect
)
require (
github.com/labstack/gommon v0.4.0
github.com/labstack/gommon v0.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/spf13/cobra v1.8.1
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/crypto v0.11.0 // indirect

8
go.sum
View File

@ -1,8 +1,11 @@
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4=
github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ=
@ -24,8 +27,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE=
github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

View File

@ -0,0 +1,103 @@
package resources
import (
_ "embed"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"github.com/spf13/cobra"
)
//go:embed languageMap.json
var languageMapData []byte // Embedded language map file as a byte slice
// Arguments holds the parsed CLI arguments
type Arguments struct {
Language string
ModelPath string
Port int
}
// ParsedArguments holds the processed arguments
type ParsedArguments struct {
Language int32
ModelPath string
Port int
}
// LanguageMap represents the mapping of languages to their hex codes
type LanguageMap map[string]string
func processLanguageAndCode(language string) (int32, error) {
var languageMap LanguageMap
err := json.Unmarshal(languageMapData, &languageMap)
if err != nil {
return 0x6E65, fmt.Errorf("error parsing language map: %w", err)
}
hexCode, ok := languageMap[strings.ToLower(language)]
if !ok {
return 0x6E65, fmt.Errorf("unsupported language")
}
fmt.Printf("Hex Code Found: %s\n", hexCode)
languageCode, err := strconv.ParseInt(hexCode, 0, 32)
if err != nil {
return 0x6E65, fmt.Errorf("error converting hex code: %w", err)
}
return int32(languageCode), nil
}
func ApplyExitOnHelp(c *cobra.Command, exitCode int) {
helpFunc := c.HelpFunc()
c.SetHelpFunc(func(c *cobra.Command, s []string) {
helpFunc(c, s)
os.Exit(exitCode)
})
}
func ParseFlags() (*ParsedArguments, error) {
args := &Arguments{}
var parsedArgs *ParsedArguments
rootCmd := &cobra.Command{
Use: "whisper",
Short: "Audio transcription using the OpenAI Whisper models",
RunE: func(cmd *cobra.Command, _ []string) error {
// Process language code with fallback
languageCode, err := processLanguageAndCode(args.Language)
if err != nil {
fmt.Printf("Error setting language, defaulting to English")
// Default to English
languageCode = 0x6E65
}
parsedArgs = &ParsedArguments{
Language: languageCode,
ModelPath: args.ModelPath,
Port: args.Port,
}
return nil
},
}
rootCmd.Flags().StringVarP(&args.Language, "language", "l", "", "Language to be processed")
rootCmd.Flags().StringVarP(&args.ModelPath, "modelPath", "m", "ggml-medium.bin", "Path to the model file (required)")
rootCmd.Flags().IntVarP(&args.Port, "port", "p", 3000, "Port to start the server on")
ApplyExitOnHelp(rootCmd, 0)
err := rootCmd.Execute()
if err != nil {
return nil, err
}
return parsedArgs, nil
}

View File

@ -1,96 +0,0 @@
package resources
import (
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strconv"
"strings"
)
// Arguments defines the structure to hold parsed arguments
type Arguments struct {
Language string
ModelPath string
Port int
}
type ParsedArguments struct {
Language int32
ModelPath string
Port int
}
type LanguageMap map[string]string
func processLanguageAndCode(args *Arguments) (int32, error) {
// Read the language map from JSON file
jsonFile, err := os.Open("languageMap.json")
if err != nil {
return 0x6E65, fmt.Errorf("error opening language map: %w", err) // Wrap error for context
}
defer jsonFile.Close()
byteData, err := io.ReadAll(jsonFile)
if err != nil {
return 0x6E65, fmt.Errorf("error reading language map: %w", err)
}
var languageMap LanguageMap
err = json.Unmarshal(byteData, &languageMap)
if err != nil {
return 0x6E65, fmt.Errorf("error parsing language map: %w", err)
}
hexCode, ok := languageMap[strings.ToLower(args.Language)]
if !ok {
return 0x6E65, fmt.Errorf("unsupported language: %s", args.Language)
}
languageCode, err := strconv.ParseInt(hexCode, 0, 32)
if err != nil {
return 0x6E65, fmt.Errorf("error converting hex code: %w", err)
}
return int32(languageCode), nil
}
// ParseFlags parses command line arguments and returns an Arguments struct
func ParseFlags() (*ParsedArguments, error) {
args := &Arguments{}
flag.StringVar(&args.Language, "l", "", "Language to be processed")
flag.StringVar(&args.Language, "language", "", "Language to be processed") // Optional: Redundant to demonstrate
flag.StringVar(&args.ModelPath, "m", "", "Path to the model file (required)")
flag.StringVar(&args.ModelPath, "modelPath", "", "Path to the model file (required)") // Optional: Redundant
flag.IntVar(&args.Port, "p", 3031, "Port to start the server on")
flag.IntVar(&args.Port, "port", 3031, "Port to start the server on") // Optional: Redundant
flag.Usage = func() {
fmt.Println("Usage: your_program [OPTIONS]")
fmt.Println("Options:")
flag.PrintDefaults() // Print default values for all flags
}
// Parsing flags
flag.Parse()
args.Language = strings.ToLower(args.Language)
if args.ModelPath == "" {
return nil, fmt.Errorf("modelPath argument is required")
}
languageCode, err := processLanguageAndCode(args)
if err != nil {
fmt.Println("Error setting language, defaulting to English:", err)
// Use default language code directly as the result here
}
return &ParsedArguments{
Language: languageCode,
ModelPath: args.ModelPath,
Port: args.Port,
}, nil
}

View File

@ -0,0 +1,143 @@
package resources
import (
"archive/zip"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"github.com/schollz/progressbar/v3"
)
func GetModel(modelType string) (string, error) {
fileURL := fmt.Sprintf("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/%s", modelType)
filePath := modelType
isModelFileExists := IsFileExists(filePath)
if !isModelFileExists {
fmt.Println("Model not found.")
err := DownloadFile(fileURL, filePath)
if err != nil {
return "", err
}
}
absPath, err := filepath.Abs(filePath)
if err != nil {
return "", err
}
fmt.Printf("Model found: %s\n", absPath)
return filePath, nil
}
func DownloadFile(url string, filepath string) error {
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
fileSize := resp.ContentLength
bar := progressbar.DefaultBytes(
fileSize,
"Downloading",
)
writer := io.MultiWriter(out, bar)
_, err = io.Copy(writer, resp.Body)
if err != nil {
return err
}
return nil
}
func GetWhisperDll(version string) (string, error) {
fileUrl := fmt.Sprintf("https://github.com/Const-me/Whisper/releases/download/%s/Library.zip", version)
fileToExtract := "Binary/Whisper.dll"
isWhisperDllExists := IsFileExists("Whisper.dll")
if !isWhisperDllExists {
fmt.Println("Whisper DLL not found.")
archivePath, err := os.CreateTemp("", "WhisperLibrary-*.zip")
if err != nil {
return "", err
}
defer archivePath.Close()
err = DownloadFile(fileUrl, archivePath.Name())
if err != nil {
return "", err
}
err = extractFile(archivePath.Name(), fileToExtract)
if err != nil {
return "", err
}
}
absPath, err := filepath.Abs("Whisper.dll")
if err != nil {
return "", err
}
fmt.Printf("Library found: %s\n", absPath)
return "Whisper.dll", nil
}
func extractFile(archivePath string, fileToExtract string) error {
reader, err := zip.OpenReader(archivePath)
if err != nil {
return err
}
defer reader.Close()
for _, file := range reader.File {
if file.Name == fileToExtract {
targetPath := filepath.Base(fileToExtract)
writer, err := os.Create(targetPath)
if err != nil {
return err
}
defer writer.Close()
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
_, err = io.Copy(writer, src)
if err != nil {
return err
}
return nil
}
}
return fmt.Errorf("File not found in the archive")
}
func IsFileExists(filename string) bool {
_, err := os.Stat(filename)
if err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}

View File

@ -0,0 +1,76 @@
package resources
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
)
// PromptUser prompts the user with a question and returns true if they agree
func PromptUser(question string) bool {
fmt.Printf("%s (y/n): ", question)
reader := bufio.NewReader(os.Stdin)
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("Error reading input:", err)
return false
}
response = strings.TrimSpace(strings.ToLower(response))
return response == "y" || response == "yes"
}
// HandleWhisperDll checks if Whisper.dll exists or prompts the user to download it
func HandleWhisperDll(version string) (string, error) {
if IsFileExists("Whisper.dll") {
absPath, err := filepath.Abs("Whisper.dll")
if err != nil {
return "", err
}
fmt.Printf("Library found: %s\n", absPath)
return "Whisper.dll", nil
}
fmt.Println("Whisper DLL not found.")
if PromptUser("Do you want to download Whisper.dll automatically?") {
path, err := GetWhisperDll(version)
if err != nil {
return "", fmt.Errorf("failed to download Whisper.dll: %w", err)
}
return path, nil
}
fmt.Println("To use Whisper, download the DLL manually:")
fmt.Printf("URL: https://github.com/Const-me/Whisper/releases/download/%s/Library.zip\n", version)
fmt.Println("Extract 'Binary/Whisper.dll' from the archive and place it in the executable's directory.")
fmt.Println("You can manually specify path to .dll file using cli arguments, use --help to print available cli flags")
return "", fmt.Errorf("whisper.dll not found and user chose not to download")
}
// HandleDefaultModel checks if the default model exists or prompts the user to download it
func HandleDefaultModel(modelType string) (string, error) {
if IsFileExists(modelType) {
absPath, err := filepath.Abs(modelType)
if err != nil {
return "", err
}
fmt.Printf("Model found: %s\n", absPath)
return modelType, nil
}
fmt.Println("Default model not found.")
if PromptUser("Do you want to download the default model (ggml-medium.bin) automatically?") {
path, err := GetModel(modelType)
if err != nil {
return "", fmt.Errorf("failed to download the default model: %w", err)
}
return path, nil
}
fmt.Println("To use Whisper, download the model manually:")
fmt.Println("URL: https://huggingface.co/ggerganov/whisper.cpp/tree/main")
fmt.Println("Place the model file in the executable's directory or specify its path using cli arguments.")
fmt.Println("You can manually specify path to model file using cli arguments, use --help to print available cli flags")
return "", fmt.Errorf("default model not found and user chose not to download")
}

52
main.go
View File

@ -7,24 +7,25 @@ import (
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/labstack/gommon/log"
"github.com/xzeldon/whisper-api-server/internal/api"
"github.com/xzeldon/whisper-api-server/internal/resources"
)
func change_working_directory(e *echo.Echo) {
exePath, errs := os.Executable()
if errs != nil {
e.Logger.Error(errs)
const (
defaultModelType = "ggml-medium.bin"
defaultWhisperVersion = "1.12.0"
)
func changeWorkingDirectory(e *echo.Echo) {
exePath, err := os.Executable()
if err != nil {
e.Logger.Error("Error getting executable path: ", err)
return
}
exeDir := filepath.Dir(exePath)
// Change the working directory to the executable directory
errs = os.Chdir(exeDir)
if errs != nil {
e.Logger.Error(errs)
if err := os.Chdir(exeDir); err != nil {
e.Logger.Error("Error changing working directory: ", err)
return
}
@ -33,33 +34,38 @@ func change_working_directory(e *echo.Echo) {
}
func main() {
e := echo.New()
e.HideBanner = true
change_working_directory(e)
changeWorkingDirectory(e)
args, errParsing := resources.ParseFlags()
if errParsing != nil {
e.Logger.Error("Error parsing flags: ", errParsing)
args, err := resources.ParseFlags()
if err != nil {
e.Logger.Error("Error parsing flags: ", err)
return
}
if _, err := resources.HandleWhisperDll(defaultWhisperVersion); err != nil {
e.Logger.Error("Error handling Whisper.dll: ", err)
return
}
if _, err := resources.HandleDefaultModel(defaultModelType); err != nil {
e.Logger.Error("Error handling model file: ", err)
return
}
e.Use(middleware.CORS())
if l, ok := e.Logger.(*log.Logger); ok {
l.SetHeader("${time_rfc3339} ${level}")
}
whisperState, err := api.InitializeWhisperState(args.ModelPath, args.Language)
if err != nil {
e.Logger.Error(err)
e.Logger.Error("Error initializing Whisper state: ", err)
return
}
e.POST("/v1/audio/transcriptions", func(c echo.Context) error {
return api.Transcribe(c, whisperState)
})
e.Logger.Fatal(e.Start(fmt.Sprintf("127.0.0.1:%d", args.Port)))
address := fmt.Sprintf("127.0.0.1:%d", args.Port)
e.Logger.Fatal(e.Start(address))
}

View File

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD
// license that can be found in the LICENSE file.
// Adapted mainly from github.com/gonutz/w32
//go:build windows
// +build windows
@ -10,7 +9,7 @@ package whisper
import (
"errors"
"syscall"
"fmt"
"unsafe"
"golang.org/x/sys/windows"
@ -51,34 +50,46 @@ func (fi VS_FIXEDFILEINFO) FileVersion() uint64 {
return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS)
}
func GetFileVersionInfoSize(path string) uint32 {
func GetFileVersionInfoSize(path string) (uint32, error) {
pathPtr, err := windows.UTF16PtrFromString(path)
if err != nil {
return 0, err
}
ret, _, _ := getFileVersionInfoSize.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
uintptr(unsafe.Pointer(pathPtr)),
0,
)
return uint32(ret)
return uint32(ret), nil
}
func GetFileVersionInfo(path string, data []byte) bool {
func GetFileVersionInfo(path string, data []byte) (bool, error) {
pathPtr, err := windows.UTF16PtrFromString(path)
if err != nil {
return false, err
}
ret, _, _ := getFileVersionInfo.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
uintptr(unsafe.Pointer(pathPtr)),
0,
uintptr(len(data)),
uintptr(unsafe.Pointer(&data[0])),
)
return ret != 0
return ret != 0, nil
}
// VerQueryValueRoot calls VerQueryValue
// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx)
// with `\` (root) to retieve the VS_FIXEDFILEINFO.
// with \ (root) to retrieve the VS_FIXEDFILEINFO.
func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) {
var offset uintptr
var length uint
blockStart := unsafe.Pointer(&block[0])
rootPathPtr, err := windows.UTF16PtrFromString(`\`)
if err != nil {
return VS_FIXEDFILEINFO{}, err
}
ret, _, _ := verQueryValue.Call(
uintptr(blockStart),
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))),
uintptr(unsafe.Pointer(rootPathPtr)),
uintptr(unsafe.Pointer(&offset)),
uintptr(unsafe.Pointer(&length)),
)
@ -97,27 +108,24 @@ func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) {
func GetFileVersion(path string) (WinVersion, error) {
var result WinVersion
size := GetFileVersionInfoSize(path)
if size <= 0 {
size, err := GetFileVersionInfoSize(path)
fmt.Println(path)
if err != nil || size <= 0 {
return result, errors.New("GetFileVersionInfoSize failed")
}
info := make([]byte, size)
ok := GetFileVersionInfo(path, info)
if !ok {
ok, err := GetFileVersionInfo(path, info)
if err != nil || !ok {
return result, errors.New("GetFileVersionInfo failed")
}
fixed, err := VerQueryValueRoot(info)
if err != nil {
return result, err
}
version := fixed.FileVersion()
result.Major = uint32(version & 0xFFFF000000000000 >> 48)
result.Minor = uint32(version & 0x0000FFFF00000000 >> 32)
result.Patch = uint32(version & 0x00000000FFFF0000 >> 16)
result.Build = uint32(version & 0x000000000000FFFF)
return result, nil
}