diff --git a/README.md b/README.md index 2375202..0bce83a 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ go build -ldflags "-s -w" -o server.exe main.go Make a request to the server using the following command: ```sh -curl http://localhost:3031/v1/audio/transcriptions \ +curl http://localhost:3000/v1/audio/transcriptions \ -H "Content-Type: multipart/form-data" \ -F file="@/path/to/file/audio.mp3" \ ``` @@ -61,7 +61,7 @@ Receive a response in JSON format: 2. Open the plugin's settings. 3. Set the following values: - API KEY: `sk-1` - - API URL: `http://localhost:3031/v1/audio/transcriptions` + - API URL: `http://localhost:3000/v1/audio/transcriptions` - Model: `whisper-1` # Roadmap @@ -70,9 +70,8 @@ Receive a response in JSON format: - [x] Implement automatic `Whisper.dll` downloading from [Guthub releases](https://github.com/Const-me/Whisper/releases) - [x] Provide prebuilt binaries for Windows - [ ] Include instructions for running on Linux with Wine (likely possible). -- [ ] Use flags to override the model path -- [ ] Use flags to override the model type (when downloading the model) -- [ ] Use flags to override the port +- [x] Use flags to override the model path +- [x] Use flags to override the port # Credits diff --git a/go.mod b/go.mod index 1c3c9ee..aec0bac 100644 --- a/go.mod +++ b/go.mod @@ -10,17 +10,20 @@ require ( require ( github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect golang.org/x/term v0.10.0 // indirect golang.org/x/time v0.3.0 // indirect ) require ( - github.com/labstack/gommon v0.4.0 + github.com/labstack/gommon v0.4.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/spf13/cobra v1.8.1 github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect golang.org/x/crypto v0.11.0 // indirect diff --git a/go.sum b/go.sum index 03296d9..500fa53 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,11 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= @@ -24,8 +27,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE= github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/internal/resources/cliArgs.go b/internal/resources/cliArgs.go new file mode 100644 index 0000000..ae4a31a --- /dev/null +++ b/internal/resources/cliArgs.go @@ -0,0 +1,103 @@ +package resources + +import ( + _ "embed" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + + "github.com/spf13/cobra" +) + +//go:embed languageMap.json +var languageMapData []byte // Embedded language map file as a byte slice + +// Arguments holds the parsed CLI arguments +type Arguments struct { + Language string + ModelPath string + Port int +} + +// ParsedArguments holds the processed arguments +type ParsedArguments struct { + Language int32 + ModelPath string + Port int +} + +// LanguageMap represents the mapping of languages to their hex codes +type LanguageMap map[string]string + +func processLanguageAndCode(language string) (int32, error) { + var languageMap LanguageMap + err := json.Unmarshal(languageMapData, &languageMap) + if err != nil { + return 0x6E65, fmt.Errorf("error parsing language map: %w", err) + } + + hexCode, ok := languageMap[strings.ToLower(language)] + if !ok { + return 0x6E65, fmt.Errorf("unsupported language") + } + + fmt.Printf("Hex Code Found: %s\n", hexCode) + + languageCode, err := strconv.ParseInt(hexCode, 0, 32) + if err != nil { + return 0x6E65, fmt.Errorf("error converting hex code: %w", err) + } + + return int32(languageCode), nil +} + +func ApplyExitOnHelp(c *cobra.Command, exitCode int) { + helpFunc := c.HelpFunc() + c.SetHelpFunc(func(c *cobra.Command, s []string) { + helpFunc(c, s) + os.Exit(exitCode) + }) +} + +func ParseFlags() (*ParsedArguments, error) { + args := &Arguments{} + + var parsedArgs *ParsedArguments + + rootCmd := &cobra.Command{ + Use: "whisper", + Short: "Audio transcription using the OpenAI Whisper models", + RunE: func(cmd *cobra.Command, _ []string) error { + // Process language code with fallback + languageCode, err := processLanguageAndCode(args.Language) + if err != nil { + fmt.Printf("Error setting language, defaulting to English") + // Default to English + languageCode = 0x6E65 + } + + parsedArgs = &ParsedArguments{ + Language: languageCode, + ModelPath: args.ModelPath, + Port: args.Port, + } + return nil + }, + } + + rootCmd.Flags().StringVarP(&args.Language, "language", "l", "", "Language to be processed") + rootCmd.Flags().StringVarP(&args.ModelPath, "modelPath", "m", "ggml-medium.bin", "Path to the model file (required)") + rootCmd.Flags().IntVarP(&args.Port, "port", "p", 3000, "Port to start the server on") + + ApplyExitOnHelp(rootCmd, 0) + + err := rootCmd.Execute() + if err != nil { + return nil, err + } + + return parsedArgs, nil +} + diff --git a/internal/resources/cli_arguments.go b/internal/resources/cli_arguments.go deleted file mode 100644 index 3a29a48..0000000 --- a/internal/resources/cli_arguments.go +++ /dev/null @@ -1,96 +0,0 @@ -package resources - -import ( - "encoding/json" - "flag" - "fmt" - "io" - "os" - "strconv" - "strings" -) - -// Arguments defines the structure to hold parsed arguments -type Arguments struct { - Language string - ModelPath string - Port int -} -type ParsedArguments struct { - Language int32 - ModelPath string - Port int -} - -type LanguageMap map[string]string - -func processLanguageAndCode(args *Arguments) (int32, error) { - // Read the language map from JSON file - jsonFile, err := os.Open("languageMap.json") - if err != nil { - return 0x6E65, fmt.Errorf("error opening language map: %w", err) // Wrap error for context - } - defer jsonFile.Close() - - byteData, err := io.ReadAll(jsonFile) - if err != nil { - return 0x6E65, fmt.Errorf("error reading language map: %w", err) - } - - var languageMap LanguageMap - err = json.Unmarshal(byteData, &languageMap) - if err != nil { - return 0x6E65, fmt.Errorf("error parsing language map: %w", err) - } - - hexCode, ok := languageMap[strings.ToLower(args.Language)] - if !ok { - return 0x6E65, fmt.Errorf("unsupported language: %s", args.Language) - } - - languageCode, err := strconv.ParseInt(hexCode, 0, 32) - if err != nil { - return 0x6E65, fmt.Errorf("error converting hex code: %w", err) - } - - return int32(languageCode), nil -} - -// ParseFlags parses command line arguments and returns an Arguments struct -func ParseFlags() (*ParsedArguments, error) { - args := &Arguments{} - - flag.StringVar(&args.Language, "l", "", "Language to be processed") - flag.StringVar(&args.Language, "language", "", "Language to be processed") // Optional: Redundant to demonstrate - flag.StringVar(&args.ModelPath, "m", "", "Path to the model file (required)") - flag.StringVar(&args.ModelPath, "modelPath", "", "Path to the model file (required)") // Optional: Redundant - flag.IntVar(&args.Port, "p", 3031, "Port to start the server on") - flag.IntVar(&args.Port, "port", 3031, "Port to start the server on") // Optional: Redundant - - flag.Usage = func() { - fmt.Println("Usage: your_program [OPTIONS]") - fmt.Println("Options:") - flag.PrintDefaults() // Print default values for all flags - } - - // Parsing flags - flag.Parse() - - args.Language = strings.ToLower(args.Language) - - if args.ModelPath == "" { - return nil, fmt.Errorf("modelPath argument is required") - } - - languageCode, err := processLanguageAndCode(args) - if err != nil { - fmt.Println("Error setting language, defaulting to English:", err) - // Use default language code directly as the result here - } - - return &ParsedArguments{ - Language: languageCode, - ModelPath: args.ModelPath, - Port: args.Port, - }, nil -} diff --git a/internal/resources/downloadResources.go b/internal/resources/downloadResources.go new file mode 100644 index 0000000..e06d019 --- /dev/null +++ b/internal/resources/downloadResources.go @@ -0,0 +1,143 @@ +package resources + +import ( + "archive/zip" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/schollz/progressbar/v3" +) + +func GetModel(modelType string) (string, error) { + fileURL := fmt.Sprintf("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/%s", modelType) + filePath := modelType + + isModelFileExists := IsFileExists(filePath) + + if !isModelFileExists { + fmt.Println("Model not found.") + err := DownloadFile(fileURL, filePath) + if err != nil { + return "", err + } + } + + absPath, err := filepath.Abs(filePath) + if err != nil { + return "", err + } + + fmt.Printf("Model found: %s\n", absPath) + return filePath, nil +} + +func DownloadFile(url string, filepath string) error { + out, err := os.Create(filepath) + if err != nil { + return err + } + defer out.Close() + + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + fileSize := resp.ContentLength + bar := progressbar.DefaultBytes( + fileSize, + "Downloading", + ) + + writer := io.MultiWriter(out, bar) + + _, err = io.Copy(writer, resp.Body) + if err != nil { + return err + } + + return nil +} + +func GetWhisperDll(version string) (string, error) { + fileUrl := fmt.Sprintf("https://github.com/Const-me/Whisper/releases/download/%s/Library.zip", version) + fileToExtract := "Binary/Whisper.dll" + + isWhisperDllExists := IsFileExists("Whisper.dll") + + if !isWhisperDllExists { + fmt.Println("Whisper DLL not found.") + archivePath, err := os.CreateTemp("", "WhisperLibrary-*.zip") + if err != nil { + return "", err + } + defer archivePath.Close() + + err = DownloadFile(fileUrl, archivePath.Name()) + if err != nil { + return "", err + } + + err = extractFile(archivePath.Name(), fileToExtract) + if err != nil { + return "", err + } + } + + absPath, err := filepath.Abs("Whisper.dll") + if err != nil { + return "", err + } + + fmt.Printf("Library found: %s\n", absPath) + return "Whisper.dll", nil +} + +func extractFile(archivePath string, fileToExtract string) error { + reader, err := zip.OpenReader(archivePath) + if err != nil { + return err + } + defer reader.Close() + + for _, file := range reader.File { + if file.Name == fileToExtract { + targetPath := filepath.Base(fileToExtract) + + writer, err := os.Create(targetPath) + if err != nil { + return err + } + defer writer.Close() + + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + _, err = io.Copy(writer, src) + if err != nil { + return err + } + + return nil + } + } + + return fmt.Errorf("File not found in the archive") +} + +func IsFileExists(filename string) bool { + _, err := os.Stat(filename) + if err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} \ No newline at end of file diff --git a/languageMap.json b/internal/resources/languageMap.json similarity index 100% rename from languageMap.json rename to internal/resources/languageMap.json diff --git a/internal/resources/promt.go b/internal/resources/promt.go new file mode 100644 index 0000000..94ee840 --- /dev/null +++ b/internal/resources/promt.go @@ -0,0 +1,76 @@ +package resources + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PromptUser prompts the user with a question and returns true if they agree +func PromptUser(question string) bool { + fmt.Printf("%s (y/n): ", question) + reader := bufio.NewReader(os.Stdin) + response, err := reader.ReadString('\n') + if err != nil { + fmt.Println("Error reading input:", err) + return false + } + response = strings.TrimSpace(strings.ToLower(response)) + return response == "y" || response == "yes" +} + +// HandleWhisperDll checks if Whisper.dll exists or prompts the user to download it +func HandleWhisperDll(version string) (string, error) { + if IsFileExists("Whisper.dll") { + absPath, err := filepath.Abs("Whisper.dll") + if err != nil { + return "", err + } + fmt.Printf("Library found: %s\n", absPath) + return "Whisper.dll", nil + } + + fmt.Println("Whisper DLL not found.") + if PromptUser("Do you want to download Whisper.dll automatically?") { + path, err := GetWhisperDll(version) + if err != nil { + return "", fmt.Errorf("failed to download Whisper.dll: %w", err) + } + return path, nil + } + + fmt.Println("To use Whisper, download the DLL manually:") + fmt.Printf("URL: https://github.com/Const-me/Whisper/releases/download/%s/Library.zip\n", version) + fmt.Println("Extract 'Binary/Whisper.dll' from the archive and place it in the executable's directory.") + fmt.Println("You can manually specify path to .dll file using cli arguments, use --help to print available cli flags") + return "", fmt.Errorf("whisper.dll not found and user chose not to download") +} + +// HandleDefaultModel checks if the default model exists or prompts the user to download it +func HandleDefaultModel(modelType string) (string, error) { + if IsFileExists(modelType) { + absPath, err := filepath.Abs(modelType) + if err != nil { + return "", err + } + fmt.Printf("Model found: %s\n", absPath) + return modelType, nil + } + + fmt.Println("Default model not found.") + if PromptUser("Do you want to download the default model (ggml-medium.bin) automatically?") { + path, err := GetModel(modelType) + if err != nil { + return "", fmt.Errorf("failed to download the default model: %w", err) + } + return path, nil + } + + fmt.Println("To use Whisper, download the model manually:") + fmt.Println("URL: https://huggingface.co/ggerganov/whisper.cpp/tree/main") + fmt.Println("Place the model file in the executable's directory or specify its path using cli arguments.") + fmt.Println("You can manually specify path to model file using cli arguments, use --help to print available cli flags") + return "", fmt.Errorf("default model not found and user chose not to download") +} diff --git a/main.go b/main.go index 5465df9..3c4f499 100644 --- a/main.go +++ b/main.go @@ -7,24 +7,25 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" - "github.com/labstack/gommon/log" "github.com/xzeldon/whisper-api-server/internal/api" "github.com/xzeldon/whisper-api-server/internal/resources" ) -func change_working_directory(e *echo.Echo) { - exePath, errs := os.Executable() - if errs != nil { - e.Logger.Error(errs) +const ( + defaultModelType = "ggml-medium.bin" + defaultWhisperVersion = "1.12.0" +) + +func changeWorkingDirectory(e *echo.Echo) { + exePath, err := os.Executable() + if err != nil { + e.Logger.Error("Error getting executable path: ", err) return } exeDir := filepath.Dir(exePath) - - // Change the working directory to the executable directory - errs = os.Chdir(exeDir) - if errs != nil { - e.Logger.Error(errs) + if err := os.Chdir(exeDir); err != nil { + e.Logger.Error("Error changing working directory: ", err) return } @@ -33,33 +34,38 @@ func change_working_directory(e *echo.Echo) { } func main() { - e := echo.New() e.HideBanner = true - change_working_directory(e) + changeWorkingDirectory(e) - args, errParsing := resources.ParseFlags() - if errParsing != nil { - e.Logger.Error("Error parsing flags: ", errParsing) + args, err := resources.ParseFlags() + if err != nil { + e.Logger.Error("Error parsing flags: ", err) + return + } + + if _, err := resources.HandleWhisperDll(defaultWhisperVersion); err != nil { + e.Logger.Error("Error handling Whisper.dll: ", err) + return + } + + if _, err := resources.HandleDefaultModel(defaultModelType); err != nil { + e.Logger.Error("Error handling model file: ", err) return } e.Use(middleware.CORS()) - if l, ok := e.Logger.(*log.Logger); ok { - l.SetHeader("${time_rfc3339} ${level}") - } - whisperState, err := api.InitializeWhisperState(args.ModelPath, args.Language) - if err != nil { - e.Logger.Error(err) + e.Logger.Error("Error initializing Whisper state: ", err) + return } e.POST("/v1/audio/transcriptions", func(c echo.Context) error { - return api.Transcribe(c, whisperState) }) - e.Logger.Fatal(e.Start(fmt.Sprintf("127.0.0.1:%d", args.Port))) + address := fmt.Sprintf("127.0.0.1:%d", args.Port) + e.Logger.Fatal(e.Start(address)) } diff --git a/pkg/whisper/winversion.go b/pkg/whisper/winversion.go index c51bb1d..605c086 100644 --- a/pkg/whisper/winversion.go +++ b/pkg/whisper/winversion.go @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD // license that can be found in the LICENSE file. // Adapted mainly from github.com/gonutz/w32 - //go:build windows // +build windows @@ -10,7 +9,7 @@ package whisper import ( "errors" - "syscall" + "fmt" "unsafe" "golang.org/x/sys/windows" @@ -51,34 +50,46 @@ func (fi VS_FIXEDFILEINFO) FileVersion() uint64 { return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS) } -func GetFileVersionInfoSize(path string) uint32 { +func GetFileVersionInfoSize(path string) (uint32, error) { + pathPtr, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, err + } ret, _, _ := getFileVersionInfoSize.Call( - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + uintptr(unsafe.Pointer(pathPtr)), 0, ) - return uint32(ret) + return uint32(ret), nil } -func GetFileVersionInfo(path string, data []byte) bool { +func GetFileVersionInfo(path string, data []byte) (bool, error) { + pathPtr, err := windows.UTF16PtrFromString(path) + if err != nil { + return false, err + } ret, _, _ := getFileVersionInfo.Call( - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + uintptr(unsafe.Pointer(pathPtr)), 0, uintptr(len(data)), uintptr(unsafe.Pointer(&data[0])), ) - return ret != 0 + return ret != 0, nil } // VerQueryValueRoot calls VerQueryValue // (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx) -// with `\` (root) to retieve the VS_FIXEDFILEINFO. +// with \ (root) to retrieve the VS_FIXEDFILEINFO. func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) { var offset uintptr var length uint blockStart := unsafe.Pointer(&block[0]) + rootPathPtr, err := windows.UTF16PtrFromString(`\`) + if err != nil { + return VS_FIXEDFILEINFO{}, err + } ret, _, _ := verQueryValue.Call( uintptr(blockStart), - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))), + uintptr(unsafe.Pointer(rootPathPtr)), uintptr(unsafe.Pointer(&offset)), uintptr(unsafe.Pointer(&length)), ) @@ -97,27 +108,24 @@ func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) { func GetFileVersion(path string) (WinVersion, error) { var result WinVersion - size := GetFileVersionInfoSize(path) - if size <= 0 { + size, err := GetFileVersionInfoSize(path) + fmt.Println(path) + if err != nil || size <= 0 { return result, errors.New("GetFileVersionInfoSize failed") } - info := make([]byte, size) - ok := GetFileVersionInfo(path, info) - if !ok { + ok, err := GetFileVersionInfo(path, info) + if err != nil || !ok { return result, errors.New("GetFileVersionInfo failed") } - fixed, err := VerQueryValueRoot(info) if err != nil { return result, err } version := fixed.FileVersion() - result.Major = uint32(version & 0xFFFF000000000000 >> 48) result.Minor = uint32(version & 0x0000FFFF00000000 >> 32) result.Patch = uint32(version & 0x00000000FFFF0000 >> 16) result.Build = uint32(version & 0x000000000000FFFF) - return result, nil }