From 0a281f79d6a7d6cee203c32a71b3050560d93def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Kwiecie=C5=84?= Date: Thu, 14 Mar 2024 21:06:34 +0100 Subject: [PATCH] Feat/accept cmd args (#2) * make working dir as executable dir * Add support for cli args --- internal/api/state.go | 4 +- internal/resources/cli_arguments.go | 58 +++++++++++++++++++ internal/resources/lang_mapper.go | 86 +++++++++++++++++++++++++++++ main.go | 41 ++++++++++---- pkg/whisper/FullParams.go | 4 +- 5 files changed, 180 insertions(+), 13 deletions(-) create mode 100644 internal/resources/cli_arguments.go create mode 100644 internal/resources/lang_mapper.go diff --git a/internal/api/state.go b/internal/api/state.go index 8389a52..5958c8e 100644 --- a/internal/api/state.go +++ b/internal/api/state.go @@ -15,7 +15,7 @@ type WhisperState struct { mutex sync.Mutex } -func InitializeWhisperState(modelPath string) (*WhisperState, error) { +func InitializeWhisperState(modelPath string, lang int32) (*WhisperState, error) { lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil) if err != nil { return nil, err @@ -41,6 +41,8 @@ func InitializeWhisperState(modelPath string) (*WhisperState, error) { return nil, err } + params.SetLanguage(lang) + fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads()) return &WhisperState{ diff --git a/internal/resources/cli_arguments.go b/internal/resources/cli_arguments.go new file mode 100644 index 0000000..27dd6df --- /dev/null +++ b/internal/resources/cli_arguments.go @@ -0,0 +1,58 @@ +package resources + +import ( + "flag" + "fmt" + "strings" +) + +// Arguments defines the structure to hold parsed arguments +type Arguments struct { + Language string + ModelPath string +} +type ParsedArguments struct { + Language int32 + ModelPath string +} + +// ParseFlags parses command line arguments and returns an Arguments struct +func ParseFlags() (*ParsedArguments, error) { + args := &Arguments{} + + flag.StringVar(&args.Language, "l", "", "Language to be processed") + flag.StringVar(&args.Language, "language", "", "Language to be processed") // Optional: Redundant to demonstrate + flag.StringVar(&args.ModelPath, "m", "", "Path to the model file (required)") + flag.StringVar(&args.ModelPath, "modelPath", "", "Path to the model file (required)") // Optional: Redundant + + flag.Usage = func() { + fmt.Println("Usage: your_program [OPTIONS]") + fmt.Println("Options:") + flag.PrintDefaults() // Print default values for all flags + } + + // Parsing flags + flag.Parse() + + args.Language = strings.ToLower(args.Language) + + var pickedCode int32 + // Validate against LanguageMap and get associated code + if code, exists := LanguageMap[args.Language]; exists { + fmt.Println("Language code:", code) // Use the code as needed + pickedCode = code + } else { + fmt.Println("unsupported language: ", args.Language, " Defaulting to english") + pickedCode = 0x6E65 // Default to english + } + // Check for required flags + + if args.ModelPath == "" { + return nil, fmt.Errorf("modelPath argument is required") + } + + return &ParsedArguments{ + Language: pickedCode, + ModelPath: args.ModelPath, + }, nil +} diff --git a/internal/resources/lang_mapper.go b/internal/resources/lang_mapper.go new file mode 100644 index 0000000..1e0126c --- /dev/null +++ b/internal/resources/lang_mapper.go @@ -0,0 +1,86 @@ +package resources + +var LanguageMap = map[string]int32{ + "af": 0x6661, // Afrikaans + "sq": 0x7173, // Albanian + "am": 0x6D61, // Amharic + "ar": 0x7261, // Arabic + "hy": 0x7968, // Armenian + "as": 0x7361, // Assamese + "az": 0x7A61, // Azerbaijani + "ba": 0x6162, // Bashkir + "eu": 0x7565, // Basque + "be": 0x6562, // Belarusian + "bn": 0x6E62, // Bengali + "bs": 0x7362, // Bosnian + "br": 0x7262, // Breton + "bg": 0x6762, // Bulgarian + "ca": 0x6163, // Catalan + "zh": 0x687A, // Chinese + "hr": 0x7268, // Croatian + "cs": 0x7363, // Czech + "da": 0x6164, // Danish + "nl": 0x6C6E, // Dutch + "en": 0x6E65, // English + "et": 0x7465, // Estonian + "fo": 0x6F66, // Faroese + "fi": 0x6966, // Finnish + "fr": 0x7266, // French + "gl": 0x6C67, // Galician + "ka": 0x616B, // Georgian + "de": 0x7265, // German + "el": 0x6C61, // Greek + "gu": 0x7567, // Gujarati + "he": 0x6568, // Hebrew + "hi": 0x6968, // Hindi + "hu": 0x7568, // Hungarian + "is": 0x7369, // Icelandic + "id": 0x6469, // Indonesian + "it": 0x7469, // Italian + "ja": 0x616A, // Japanese + "kn": 0x6E6B, // Kannada + "kk": 0x6B6B, // Kazakh + "km": 0x6D6B, // Khmer + "ko": 0x6F6B, // Korean + "ky": 0x796B, // Kyrgyz + "lo": 0x6F6C, // Lao + "lv": 0x766C, // Latvian + "lt": 0x746C, // Lithuanian + "mk": 0x6B6D, // Macedonian + "ms": 0x736D, // Malay + "ml": 0x6C6D, // Malayalam + "mr": 0x726D, // Marathi + "mn": 0x6E6D, // Mongolian + "ne": 0x6570, // Nepali + "no": 0x6F6E, // Norwegian + "or": 0x726F, // Oriya + "ps": 0x7368, // Pashto + "fa": 0x6172, // Persian + "pl": 0x6C70, // Polish + "pt": 0x7470, // Portuguese + "pa": 0x6170, // Punjabi + "ro": 0x6F72, // Romanian + "ru": 0x7572, // Russian + "sa": 0x6173, // Sanskrit + "sr": 0x7273, // Serbian + "sd": 0x6473, // Sindhi + "si": 0x6973, // Sinhalese + "sk": 0x6B73, // Slovak + "sl": 0x6C73, // Slovenian + "es": 0x6573, // Spanish + "sw": 0x7773, // Swahili + "sv": 0x6576, // Swedish + "tg": 0x6769, // Tajik + "ta": 0x6174, // Tamil + "te": 0x6574, // Telugu + "th": 0x6874, // Thai + "tr": 0x7274, // Turkish + "uk": 0x6B75, // Ukrainian + "ur": 0x7275, // Urdu + "uz": 0x7A75, // Uzbek + "vi": 0x6976, // Vietnamese + "cy": 0x7963, // Welsh + "xh": 0x6877, // Xhosa + "yi": 0x6979, // Yiddish + "yo": 0x6F79, // Yoruba +} diff --git a/main.go b/main.go index 12ef89f..a138a33 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,10 @@ package main import ( + "fmt" + "os" + "path/filepath" + "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/labstack/gommon/log" @@ -9,31 +13,48 @@ import ( ) func main() { + e := echo.New() e.HideBanner = true + args, errParsing := resources.ParseFlags() + if errParsing != nil { + e.Logger.Error("Error parsing flags: ", errParsing) + return + } + e.Use(middleware.CORS()) + exePath, errs := os.Executable() + if errs != nil { + e.Logger.Error(errs) + return + } + + exeDir := filepath.Dir(exePath) + + // Change the working directory to the executable directory + errs = os.Chdir(exeDir) + if errs != nil { + e.Logger.Error(errs) + return + } + + cwd, _ := os.Getwd() + fmt.Println("Current working directory:", cwd) + if l, ok := e.Logger.(*log.Logger); ok { l.SetHeader("${time_rfc3339} ${level}") } - _, err := resources.GetWhisperDll("1.12.0") - if err != nil { - e.Logger.Error(err) - } + whisperState, err := api.InitializeWhisperState(args.ModelPath, args.Language) - model, err := resources.GetModel("ggml-medium.bin") - if err != nil { - e.Logger.Error(err) - } - - whisperState, err := api.InitializeWhisperState(model) if err != nil { e.Logger.Error(err) } e.POST("/v1/audio/transcriptions", func(c echo.Context) error { + return api.Transcribe(c, whisperState) }) diff --git a/pkg/whisper/FullParams.go b/pkg/whisper/FullParams.go index cda84e2..4bbdd0a 100644 --- a/pkg/whisper/FullParams.go +++ b/pkg/whisper/FullParams.go @@ -86,14 +86,14 @@ func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) { this.cStruct.Flags = this.cStruct.Flags ^ newflag } -func (this *FullParams) SetLanguage(language eLanguage) { +func (this *FullParams) SetLanguage(language int32) { if this == nil { return } else if this.cStruct == nil { return } - this.cStruct.Language = language + this.cStruct.Language = eLanguage(language) } /*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/