Feat/accept cmd args (#2)

* make working dir as executable dir

* Add support for cli args
This commit is contained in:
Łukasz Kwiecień 2024-03-14 21:06:34 +01:00 committed by xzeldon
parent 1f8912eb6d
commit 0a281f79d6
Signed by: zeldon
GPG Key ID: 047886915281DD2A
5 changed files with 180 additions and 13 deletions

View File

@ -15,7 +15,7 @@ type WhisperState struct {
mutex sync.Mutex mutex sync.Mutex
} }
func InitializeWhisperState(modelPath string) (*WhisperState, error) { func InitializeWhisperState(modelPath string, lang int32) (*WhisperState, error) {
lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil) lib, err := whisper.New(whisper.LlDebug, whisper.LfUseStandardError, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -41,6 +41,8 @@ func InitializeWhisperState(modelPath string) (*WhisperState, error) {
return nil, err return nil, err
} }
params.SetLanguage(lang)
fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads()) fmt.Printf("Params CPU Threads : %d\n", params.CpuThreads())
return &WhisperState{ return &WhisperState{

View File

@ -0,0 +1,58 @@
package resources
import (
"flag"
"fmt"
"strings"
)
// Arguments defines the structure to hold parsed arguments
type Arguments struct {
Language string
ModelPath string
}
type ParsedArguments struct {
Language int32
ModelPath string
}
// ParseFlags parses command line arguments and returns an Arguments struct
func ParseFlags() (*ParsedArguments, error) {
args := &Arguments{}
flag.StringVar(&args.Language, "l", "", "Language to be processed")
flag.StringVar(&args.Language, "language", "", "Language to be processed") // Optional: Redundant to demonstrate
flag.StringVar(&args.ModelPath, "m", "", "Path to the model file (required)")
flag.StringVar(&args.ModelPath, "modelPath", "", "Path to the model file (required)") // Optional: Redundant
flag.Usage = func() {
fmt.Println("Usage: your_program [OPTIONS]")
fmt.Println("Options:")
flag.PrintDefaults() // Print default values for all flags
}
// Parsing flags
flag.Parse()
args.Language = strings.ToLower(args.Language)
var pickedCode int32
// Validate against LanguageMap and get associated code
if code, exists := LanguageMap[args.Language]; exists {
fmt.Println("Language code:", code) // Use the code as needed
pickedCode = code
} else {
fmt.Println("unsupported language: ", args.Language, " Defaulting to english")
pickedCode = 0x6E65 // Default to english
}
// Check for required flags
if args.ModelPath == "" {
return nil, fmt.Errorf("modelPath argument is required")
}
return &ParsedArguments{
Language: pickedCode,
ModelPath: args.ModelPath,
}, nil
}

View File

@ -0,0 +1,86 @@
package resources
var LanguageMap = map[string]int32{
"af": 0x6661, // Afrikaans
"sq": 0x7173, // Albanian
"am": 0x6D61, // Amharic
"ar": 0x7261, // Arabic
"hy": 0x7968, // Armenian
"as": 0x7361, // Assamese
"az": 0x7A61, // Azerbaijani
"ba": 0x6162, // Bashkir
"eu": 0x7565, // Basque
"be": 0x6562, // Belarusian
"bn": 0x6E62, // Bengali
"bs": 0x7362, // Bosnian
"br": 0x7262, // Breton
"bg": 0x6762, // Bulgarian
"ca": 0x6163, // Catalan
"zh": 0x687A, // Chinese
"hr": 0x7268, // Croatian
"cs": 0x7363, // Czech
"da": 0x6164, // Danish
"nl": 0x6C6E, // Dutch
"en": 0x6E65, // English
"et": 0x7465, // Estonian
"fo": 0x6F66, // Faroese
"fi": 0x6966, // Finnish
"fr": 0x7266, // French
"gl": 0x6C67, // Galician
"ka": 0x616B, // Georgian
"de": 0x7265, // German
"el": 0x6C61, // Greek
"gu": 0x7567, // Gujarati
"he": 0x6568, // Hebrew
"hi": 0x6968, // Hindi
"hu": 0x7568, // Hungarian
"is": 0x7369, // Icelandic
"id": 0x6469, // Indonesian
"it": 0x7469, // Italian
"ja": 0x616A, // Japanese
"kn": 0x6E6B, // Kannada
"kk": 0x6B6B, // Kazakh
"km": 0x6D6B, // Khmer
"ko": 0x6F6B, // Korean
"ky": 0x796B, // Kyrgyz
"lo": 0x6F6C, // Lao
"lv": 0x766C, // Latvian
"lt": 0x746C, // Lithuanian
"mk": 0x6B6D, // Macedonian
"ms": 0x736D, // Malay
"ml": 0x6C6D, // Malayalam
"mr": 0x726D, // Marathi
"mn": 0x6E6D, // Mongolian
"ne": 0x6570, // Nepali
"no": 0x6F6E, // Norwegian
"or": 0x726F, // Oriya
"ps": 0x7368, // Pashto
"fa": 0x6172, // Persian
"pl": 0x6C70, // Polish
"pt": 0x7470, // Portuguese
"pa": 0x6170, // Punjabi
"ro": 0x6F72, // Romanian
"ru": 0x7572, // Russian
"sa": 0x6173, // Sanskrit
"sr": 0x7273, // Serbian
"sd": 0x6473, // Sindhi
"si": 0x6973, // Sinhalese
"sk": 0x6B73, // Slovak
"sl": 0x6C73, // Slovenian
"es": 0x6573, // Spanish
"sw": 0x7773, // Swahili
"sv": 0x6576, // Swedish
"tg": 0x6769, // Tajik
"ta": 0x6174, // Tamil
"te": 0x6574, // Telugu
"th": 0x6874, // Thai
"tr": 0x7274, // Turkish
"uk": 0x6B75, // Ukrainian
"ur": 0x7275, // Urdu
"uz": 0x7A75, // Uzbek
"vi": 0x6976, // Vietnamese
"cy": 0x7963, // Welsh
"xh": 0x6877, // Xhosa
"yi": 0x6979, // Yiddish
"yo": 0x6F79, // Yoruba
}

41
main.go
View File

@ -1,6 +1,10 @@
package main package main
import ( import (
"fmt"
"os"
"path/filepath"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
@ -9,31 +13,48 @@ import (
) )
func main() { func main() {
e := echo.New() e := echo.New()
e.HideBanner = true e.HideBanner = true
args, errParsing := resources.ParseFlags()
if errParsing != nil {
e.Logger.Error("Error parsing flags: ", errParsing)
return
}
e.Use(middleware.CORS()) e.Use(middleware.CORS())
exePath, errs := os.Executable()
if errs != nil {
e.Logger.Error(errs)
return
}
exeDir := filepath.Dir(exePath)
// Change the working directory to the executable directory
errs = os.Chdir(exeDir)
if errs != nil {
e.Logger.Error(errs)
return
}
cwd, _ := os.Getwd()
fmt.Println("Current working directory:", cwd)
if l, ok := e.Logger.(*log.Logger); ok { if l, ok := e.Logger.(*log.Logger); ok {
l.SetHeader("${time_rfc3339} ${level}") l.SetHeader("${time_rfc3339} ${level}")
} }
_, err := resources.GetWhisperDll("1.12.0") whisperState, err := api.InitializeWhisperState(args.ModelPath, args.Language)
if err != nil {
e.Logger.Error(err)
}
model, err := resources.GetModel("ggml-medium.bin")
if err != nil {
e.Logger.Error(err)
}
whisperState, err := api.InitializeWhisperState(model)
if err != nil { if err != nil {
e.Logger.Error(err) e.Logger.Error(err)
} }
e.POST("/v1/audio/transcriptions", func(c echo.Context) error { e.POST("/v1/audio/transcriptions", func(c echo.Context) error {
return api.Transcribe(c, whisperState) return api.Transcribe(c, whisperState)
}) })

View File

@ -86,14 +86,14 @@ func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) {
this.cStruct.Flags = this.cStruct.Flags ^ newflag this.cStruct.Flags = this.cStruct.Flags ^ newflag
} }
func (this *FullParams) SetLanguage(language eLanguage) { func (this *FullParams) SetLanguage(language int32) {
if this == nil { if this == nil {
return return
} else if this.cStruct == nil { } else if this.cStruct == nil {
return return
} }
this.cStruct.Language = language this.cStruct.Language = eLanguage(language)
} }
/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/ /*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/