initial commit

This commit is contained in:
2023-10-04 01:09:38 +03:00
commit 8e4cb72b50
21 changed files with 2045 additions and 0 deletions

248
pkg/whisper/FullParams.go Normal file
View File

@ -0,0 +1,248 @@
package whisper
import (
"syscall"
"unsafe"
)
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/sFullParams.h
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/Parameters.cs
type eSamplingStrategy uint32
const (
SsGreedy eSamplingStrategy = iota
SsBeamSearch
SsINVALIDARG
)
type eFullParamsFlags uint32
const (
FlagNone eFullParamsFlags = 0
FlagTranslate = 1 << 0
FlagNoContext = 1 << 1
FlagSingleSegment = 1 << 2
FlagPrintSpecial = 1 << 3
FlagPrintProgress = 1 << 4
FlagPrintRealtime = 1 << 5
FlagPrintTimestamps = 1 << 6
FlagTokenTimestamps = 1 << 7 // Experimental
FlagSpeedupAudio = 1 << 8
)
type EWhisperHWND uintptr
const (
S_OK EWhisperHWND = 0
S_FALSE EWhisperHWND = 1
)
type FullParams struct {
cStruct *_FullParams
}
func (this *FullParams) CpuThreads() int32 {
if this == nil {
return 0
} else if this.cStruct == nil {
return 0
}
return this.cStruct.cpuThreads
}
func (this *FullParams) setCpuThreads(val int32) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.cpuThreads = val
}
func (this *FullParams) SetMaxTextCTX(val int32) {
this.cStruct.n_max_text_ctx = val
}
func (this *FullParams) AddFlags(newflag eFullParamsFlags) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.Flags = this.cStruct.Flags | newflag
}
func (this *FullParams) RemoveFlags(newflag eFullParamsFlags) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.Flags = this.cStruct.Flags ^ newflag
}
/*using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;*/
type NewSegmentCallback_Type func(context *IContext, n_new uint32, user_data unsafe.Pointer) EWhisperHWND
func (this *FullParams) SetNewSegmentCallback(cb NewSegmentCallback_Type) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.new_segment_callback = syscall.NewCallback(cb)
}
/*
Return S_OK to proceed, or S_FALSE to stop the process
*/
type EncoderBeginCallback_Type func(context *IContext, user_data unsafe.Pointer) EWhisperHWND
func (this *FullParams) SetEncoderBeginCallback(cb EncoderBeginCallback_Type) {
if this == nil {
return
} else if this.cStruct == nil {
return
}
this.cStruct.encoder_begin_callback = syscall.NewCallback(cb)
}
func (this *FullParams) TestDefaultsOK() bool {
if this == nil {
return false
} else if this.cStruct == nil {
return false
}
if this.cStruct.n_max_text_ctx != 16384 {
return false
}
if this.cStruct.Flags != (FlagPrintProgress | FlagPrintTimestamps) {
return false
}
if this.cStruct.thold_pt != 0.01 {
return false
}
if this.cStruct.thold_ptsum != 0.01 {
return false
}
if this.cStruct.Language != English {
return false
}
// Todo ... why do these not line up as expected.. is our struct out of alignment ?
/*
if this.cStruct.strategy == ssGreedy {
if this.cStruct.beam_search.n_past != -1 ||
this.cStruct.beam_search.beam_width != -1 ||
this.cStruct.beam_search.n_best != -1 {
return false
}
} else if this.cStruct.strategy == ssBeamSearch {
if this.cStruct.greedy.n_past != -1 ||
this.cStruct.beam_search.beam_width != 10 ||
this.cStruct.beam_search.n_best != 5 {
return false
}
}
*/
return true
}
type _FullParams struct {
strategy eSamplingStrategy
cpuThreads int32
n_max_text_ctx int32
offset_ms int32
duration_ms int32
Flags eFullParamsFlags
Language eLanguage
thold_pt float32
thold_ptsum float32
max_len int32
max_tokens int32
greedy struct{ n_past int32 }
beam_search struct {
n_past int32
beam_width int32
n_best int32
}
audio_ctx int32 // overwrite the audio context size (0 = use default)
prompt_tokens uintptr
prompt_n_tokens int32
new_segment_callback uintptr
new_segment_callback_user_data uintptr
encoder_begin_callback uintptr
encoder_begin_callback_user_data uintptr
// Are these needed ?? Jay
// setFlag uintptr
}
func NewFullParams(cstruct *_FullParams) *FullParams {
this := FullParams{}
this.cStruct = cstruct
return &this
}
func _newFullParams_cStruct() *_FullParams {
return &_FullParams{
strategy: 0,
cpuThreads: 0,
n_max_text_ctx: 0,
offset_ms: 0,
duration_ms: 0,
Flags: 0,
Language: 0,
thold_pt: 0,
thold_ptsum: 0,
max_len: 0,
max_tokens: 0,
// anonymous int32
greedy: struct{ n_past int32 }{n_past: 0},
// anonymous struct
beam_search: struct {
n_past int32
beam_width int32
n_best int32
}{
n_past: 0,
beam_width: 0,
n_best: 0,
},
audio_ctx: 0,
prompt_tokens: 0,
prompt_n_tokens: 0,
new_segment_callback: 0,
new_segment_callback_user_data: 0,
encoder_begin_callback: 0,
encoder_begin_callback_user_data: 0,
}
}

View File

@ -0,0 +1,222 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
// https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/IMediaFoundation.h
type IMediaFoundation struct {
lpVtbl *IMediaFoundationVtbl
}
type IMediaFoundationVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
loadAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
openAudioFile uintptr // ( LPCTSTR path, bool stereo, iAudioReader** pp );
loadAudioFileData uintptr // ( const void* data, uint64_t size, bool stereo, iAudioReader** pp ); HRESULT
listCaptureDevices uintptr // ( pfnFoundCaptureDevices pfn, void* pv );
openCaptureDevice uintptr // ( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp );
}
func (this *IMediaFoundation) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *IMediaFoundation) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
// ( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
func (this *IMediaFoundation) LoadAudioFile(file string, stereo bool) (*iAudioBuffer, error) {
var buffer *iAudioBuffer
UTFFileName, _ := windows.UTF16PtrFromString(file)
ret, _, _ := syscall.SyscallN(
this.lpVtbl.loadAudioFile,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(UTFFileName)),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&buffer)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("loadAudioFile failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return buffer, nil
}
func (this *IMediaFoundation) OpenAudioFile(file string, stereo bool) (*iAudioReader, error) {
var buffer *iAudioReader
UTFFileName, _ := windows.UTF16PtrFromString(file)
ret, _, _ := syscall.SyscallN(
this.lpVtbl.openAudioFile,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(UTFFileName)),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&buffer)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("openAudioFile failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return buffer, nil
}
func (this *IMediaFoundation) LoadAudioFileData(inbuffer *[]byte, stereo bool) (*iAudioReader, error) {
var reader *iAudioReader
// loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp );
ret, _, _ := syscall.SyscallN(
this.lpVtbl.loadAudioFileData,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&(*inbuffer)[0])),
uintptr(uint64(len(*inbuffer))),
uintptr(1), // Todo ... Stereo !
uintptr(unsafe.Pointer(&reader)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
return reader, nil
}
// ************************************************************
type iAudioBuffer struct {
lpVtbl *iAudioBufferVtbl
}
type iAudioBufferVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
countSamples uintptr // returns uint32_t
getPcmMono uintptr // returns float*
getPcmStereo uintptr // returns float*
getTime uintptr // ( int64_t& rdi )
}
func (this *iAudioBuffer) AddRef() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.AddRef,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioBuffer) Release() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.Release,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioBuffer) CountSamples() (uint32, error) {
ret, _, err := syscall.SyscallN(
this.lpVtbl.countSamples,
uintptr(unsafe.Pointer(this)),
)
if err != 0 {
return 0, errors.New(err.Error())
}
return uint32(ret), nil
}
// ************************************************************
type iAudioReader struct {
lpVtbl *iAudioReaderVtbl
}
type iAudioReaderVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getDuration uintptr // ( int64_t& rdi )
getReader uintptr // ( IMFSourceReader** pp )
requestedStereo uintptr // ()
}
func (this *iAudioReader) AddRef() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.AddRef,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioReader) Release() int32 {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.Release,
uintptr(unsafe.Pointer(this)),
)
return int32(ret)
}
func (this *iAudioReader) GetDuration() (uint64, error) {
var rdi int64
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getDuration,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&rdi)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("LoadAudioFileData failed: %s\n", syscall.Errno(ret).Error())
return 0, syscall.Errno(ret)
}
return uint64(rdi), nil
}
// ************************************************************
type iAudioCapture struct {
lpVtbl *iAudioCaptureVtbl
}
type iAudioCaptureVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getReader uintptr // ( IMFSourceReader** pp )
getParams uintptr // returns sCaptureParams&
}

119
pkg/whisper/Model.go Normal file
View File

@ -0,0 +1,119 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
// External - Go version of the struct
type Model struct {
cStruct *_IModel
setup *sModelSetup
}
// Internal - C Version of the structs
type _IModel struct {
lpVtbl *IModelVtbl
}
// https://github.com/Const-me/Whisper/blob/master/Whisper/API/iContext.cl.h
type IModelVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
createContext uintptr //( iContext** pp ) = 0;
tokenize uintptr /* HRESULT __stdcall tokenize( const char* text, pfnDecodedTokens pfn, void* pv ); */
isMultilingual uintptr //() = 0;
getSpecialTokens uintptr //( SpecialTokens& rdi ) = 0;
stringFromToken uintptr //( whisper_token token ) = 0;
clone uintptr //( iModel** rdi ) = 0;
}
func NewModel(setup *sModelSetup, cstruct *_IModel) *Model {
this := Model{}
this.setup = setup
this.cStruct = cstruct
return &this
}
func (this *Model) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.cStruct.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this.cStruct)),
0,
0)
return int32(ret)
}
func (this *Model) Release() int32 {
ret, _, _ := syscall.Syscall(
this.cStruct.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this.cStruct)),
0,
0)
return int32(ret)
}
func (this *Model) CreateContext() (*IContext, error) {
var context *IContext
/*
ret, _, err := syscall.Syscall(
this.cStruct.lpVtbl.createContext,
2, // Why was this 1, rather than 2 ?? 1 seemed to work fine
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&context)),
0)*/
ret, _, err := syscall.SyscallN(
this.cStruct.lpVtbl.createContext,
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&context)))
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("createContext failed: %w", err.Error())
}
if windows.Handle(ret) != windows.S_OK {
return nil, fmt.Errorf("loadModel failed: %w", err)
}
return context, nil
}
func (this *Model) IsMultilingual() bool {
ret, _, _ := syscall.SyscallN(
this.cStruct.lpVtbl.isMultilingual,
uintptr(unsafe.Pointer(this.cStruct)),
)
return bool(windows.Handle(ret) == windows.S_OK)
}
func (this *Model) Clone() (*_IModel, error) {
if this.setup.isFlagSet(gmf_Cloneable) {
return nil, errors.New("Model is not cloneable")
}
//this.Cloneable ?
var modelptr *_IModel
ret, _, _ := syscall.SyscallN(
this.cStruct.lpVtbl.clone,
uintptr(unsafe.Pointer(this.cStruct)),
uintptr(unsafe.Pointer(&modelptr)),
)
if windows.Handle(ret) == windows.S_OK {
return modelptr, nil
} else {
return nil, errors.New("Model.Clone() failed : " + syscall.Errno(ret).Error())
}
}

101
pkg/whisper/ModelSetup.go Normal file
View File

@ -0,0 +1,101 @@
package whisper
import (
"unsafe"
"golang.org/x/sys/windows"
)
// Re-implemented sModelSetup.h
// enum struct eModelImplementation : uint32_t
type eModelImplementation uint32
const (
// GPGPU implementation based on Direct3D 11.0 compute shaders
mi_GPU eModelImplementation = 1
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
mi_Hybrid eModelImplementation = 2
// A reference implementation which uses the original GGML CPU-running code
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
mi_Reference eModelImplementation = 3
)
// enum struct eGpuModelFlags : uint32_t
type eGpuModelFlags uint32
const (
// <summary>Equivalent to <c>Wave32 | NoReshapedMatMul</c> on Intel and nVidia GPUs,<br/>
// and <c>Wave64 | UseReshapedMatMul</c> on AMD GPUs</summary>
gmf_None eGpuModelFlags = 0
// <summary>Use Wave32 version of compute shaders even on AMD GPUs</summary>
// <remarks>Incompatible with <see cref="Wave64" /></remarks>
gmf_Wave32 eGpuModelFlags = 1
// <summary>Use Wave64 version of compute shaders even on nVidia and Intel GPUs</summary>
// <remarks>Incompatible with <see cref="Wave32" /></remarks>
gmf_Wave64 eGpuModelFlags = 2
// <summary>Do not use reshaped matrix multiplication shaders on AMD GPUs</summary>
// <remarks>Incompatible with <see cref="UseReshapedMatMul" /></remarks>
gmf_NoReshapedMatMul eGpuModelFlags = 4
// <summary>Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs</summary>
// <remarks>Incompatible with <see cref="NoReshapedMatMul" /></remarks>
gmf_UseReshapedMatMul eGpuModelFlags = 8
// <summary>Create GPU tensors in a way which allows sharing across D3D devices</summary>
gmf_Cloneable eGpuModelFlags = 0x10
)
// struct sModelSetup
type sModelSetup struct {
impl eModelImplementation
flags eGpuModelFlags
adapter string
}
type _sModelSetup struct {
impl eModelImplementation
flags eGpuModelFlags
adapter uintptr
}
func ModelSetup(flags eGpuModelFlags, GPU string) *sModelSetup {
this := sModelSetup{}
this.impl = mi_GPU
this.flags = flags
this.adapter = GPU
return &this
}
func (this *sModelSetup) isFlagSet(flag eGpuModelFlags) bool {
return (this.flags & flag) == 0
}
func (this *sModelSetup) AsCType() *_sModelSetup {
var err error
ctype := _sModelSetup{}
ctype.impl = this.impl
ctype.flags = this.flags
ctype.adapter = 0
// Conver Go String to wchar_t, AKA UTF-16
if this.adapter != "" {
var UTF16str *uint16
UTF16str, err = windows.UTF16PtrFromString(this.adapter)
ctype.adapter = uintptr(unsafe.Pointer(UTF16str))
}
if err != nil {
return nil
} else {
return &ctype
}
}

View File

@ -0,0 +1,178 @@
package whisper
import (
"C"
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type eTokenFlags uint32
const (
TfNone eTokenFlags = 0
TfSpecial = 1
)
type sTranscribeLength struct {
CountSegments uint32
CountTokens uint32
}
type sTimeSpan struct {
// The value is expressed in 100-nanoseconds ticks: compatible with System.Timespan, FILETIME, and many other things
Ticks uint64
/*
operator sTimeSpanFields() const
{
return sTimeSpanFields{ ticks };
}
void operator=( uint64_t tt )
{
ticks = tt;
}
void operator=( int64_t tt )
{
assert( tt >= 0 );
ticks = (uint64_t)tt;
} */
}
type sTimeInterval struct {
Begin sTimeSpan
End sTimeSpan
}
type sSegment struct {
// Segment text, null-terminated, and probably UTF-8 encoded
text *C.char
// Start and end times of the segment
Time sTimeInterval
// These two integers define the slice of the tokens in this segment, in the array returned by iTranscribeResult.getTokens method
FirstToken uint32
CountTokens uint32
}
func (this *sSegment) Text() string {
return C.GoString(this.text)
}
type sSegmentArray []sSegment
type SToken struct {
// Token text, null-terminated, and usually UTF-8 encoded.
// I think for Chinese language the models sometimes outputs invalid UTF8 strings here, Unicode code points can be split between adjacent tokens in the same segment
// More info: https://github.com/ggerganov/whisper.cpp/issues/399
text *C.char
// Start and end times of the token
Time sTimeInterval
// Probability of the token
Probability float32
// Probability of the timestamp token
ProbabilityTimestamp float32
// Sum of probabilities of all timestamp tokens
Ptsum float32
// Voice length of the token
Vlen float32
// Token id
Id int32
Flags eTokenFlags
}
func (this *SToken) Text() string {
return C.GoString(this.text)
}
type sTokenArray []SToken
type iTranscribeResultVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
getSize uintptr // ( sTranscribeLength& rdi ) HRESULT
getSegments uintptr // () getTokens
getTokens uintptr // () getToken*
}
type ITranscribeResult struct {
lpVtbl *iTranscribeResultVtbl
}
func (this *ITranscribeResult) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *ITranscribeResult) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *ITranscribeResult) GetSize() (*sTranscribeLength, error) {
var result sTranscribeLength
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getSize,
uintptr(unsafe.Pointer(this)),
uintptr(unsafe.Pointer(&result)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("iTranscribeResult.GetSize failed: %s\n", syscall.Errno(ret).Error())
return nil, errors.New(syscall.Errno(ret).Error())
}
return &result, nil
}
func (this *ITranscribeResult) GetSegments(len uint32) []sSegment {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getSegments,
uintptr(unsafe.Pointer(this)),
)
data := unsafe.Slice((*sSegment)(unsafe.Pointer(ret)), len)
return data
}
func (this *ITranscribeResult) GetTokens(len uint32) []SToken {
ret, _, _ := syscall.SyscallN(
this.lpVtbl.getTokens,
uintptr(unsafe.Pointer(this)),
)
if unsafe.Pointer(ret) != nil {
return unsafe.Slice((*SToken)(unsafe.Pointer(ret)), len)
} else {
return []SToken{}
}
}

281
pkg/whisper/context.go Normal file
View File

@ -0,0 +1,281 @@
package whisper
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type uuid [16]byte
type eResultFlags uint32
const (
RfNone eResultFlags = 0
// Return individual tokens in addition to the segments
RfTokens = 1
// Return timestamps
RfTimestamps = 2
// Create a new COM object for the results.
// Without this flag, the context returns a pointer to the COM object stored in the context.
// The content of that object is replaced every time you call IContext.getResults method
RfNewObject = 0x100
)
type IContextVtbl struct {
QueryInterface uintptr
AddRef uintptr
Release uintptr
RunFull uintptr
RunStreamed uintptr
RunCapture uintptr
GetResults uintptr
DetectSpeaker uintptr
GetModel uintptr
FullDefaultParams uintptr
TimingsPrint uintptr
TimingsReset uintptr
}
type IContext struct {
lpVtbl *IContextVtbl
}
//type sFullParams struct{}
// type iAudioBuffer struct{}
type sProgressSink struct {
pfn uintptr
pv uintptr
}
// type iAudioReader struct{}
type sCaptureCallbacks struct{}
// type iAudioCapture struct{}
// type eResultFlags int32
// type iTranscribeResult struct{}
// type sTimeInterval struct{}
type eSpeakerChannel int32
//type eSamplingStrategy int32
// Create a new IContext instance
func newIContext() *IContext {
return &IContext{
lpVtbl: &IContextVtbl{
QueryInterface: 0,
AddRef: 0,
Release: 0,
RunFull: 0,
RunStreamed: 0,
RunCapture: 0,
GetResults: 0,
DetectSpeaker: 0,
GetModel: 0,
FullDefaultParams: 0,
TimingsPrint: 0,
TimingsReset: 0,
},
}
}
func (context *IContext) TimingsPrint() error {
// TimingsPrint();
ret, _, _ := syscall.SyscallN(
context.lpVtbl.TimingsPrint,
uintptr(unsafe.Pointer(context)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (context *IContext) RunFull(params *FullParams, buffer *iAudioBuffer) error {
// runFull( const sFullParams& params, const iAudioBuffer* buffer );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunFull,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params.cStruct)),
uintptr(unsafe.Pointer(buffer)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunFull failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
func (context *IContext) RunStreamed(params *FullParams, reader *iAudioReader) error {
cb := sProgressSink{}
// runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunStreamed,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params.cStruct)),
uintptr(unsafe.Pointer(&cb)), // No progress cb yet
uintptr(unsafe.Pointer(reader)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("RunStreamed failed: %s\n", syscall.Errno(ret).Error())
return errors.New(syscall.Errno(ret).Error())
}
return nil
}
func (this *IContext) AddRef() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.AddRef,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
func (this *IContext) Release() int32 {
ret, _, _ := syscall.Syscall(
this.lpVtbl.Release,
1,
uintptr(unsafe.Pointer(this)),
0,
0)
return int32(ret)
}
/*
https://github.com/Const-me/Whisper/blob/f6f743c7b3570b85ccf47f74b84e06a73667ef3e/Whisper/Whisper/ContextImpl.misc.cpp
Returns E_POINTER if null pointer provided in params
Initialises params to all 0
sets values in struct, does not malloc
*/
func (context *IContext) FullDefaultParams(strategy eSamplingStrategy) (*FullParams, error) {
/*
ERR : unreadable Only part of a ReadProcessMemory or WriteProcessMemory request was completed
* not related to stratergy ... tested 0, 1 and 2 ... 2 produced E_INVALIDARG as expected
* not a nil ptr to params ... nil poitner produced E_POINTER as expected
* params seems to return 0x4000
* !!!!! FullParams is not a com interface !!!
* so no lpVtbl *FullParamsVtbl , no queryinterface, addref etc
*/
params := _newFullParams_cStruct()
//params := &[160]byte{}
ret, _, _ := syscall.SyscallN(
context.lpVtbl.FullDefaultParams,
uintptr(unsafe.Pointer(context)),
uintptr(strategy),
uintptr(unsafe.Pointer(params)),
)
// nil ptr should be 0x80004003L
// unsafe.Pointer(0xc00011dc28)
// unsafe.Pointer(0x4000)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
if params == nil {
return nil, errors.New("FullDefaultParams did not return params")
}
ParamObj := NewFullParams(params)
if ParamObj.TestDefaultsOK() {
return ParamObj, nil
}
return nil, nil
}
func (context *IContext) GetModel() (*_IModel, error) {
var modelptr *_IModel
// getModel( iModel** pp );
ret, _, _ := syscall.SyscallN(
context.lpVtbl.GetModel,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(&modelptr)),
)
if windows.Handle(ret) != windows.S_OK {
fmt.Printf("FullDefaultParams failed: %s\n", syscall.Errno(ret).Error())
return nil, syscall.Errno(ret)
}
if modelptr == nil {
return nil, errors.New("loadModel did not return a Model")
}
if modelptr.lpVtbl == nil {
return nil, errors.New("loadModel method table is nil")
}
return modelptr, nil
}
// ************************************************************************************************************************************************
// Not really implemented / tested
// ************************************************************************************************************************************************
func (context *IContext) RunCapture(params *FullParams, callbacks *sCaptureCallbacks, reader *iAudioCapture) uintptr {
ret, _, _ := syscall.SyscallN(
context.lpVtbl.RunCapture,
//3,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(params)),
uintptr(unsafe.Pointer(callbacks)),
uintptr(unsafe.Pointer(reader)),
)
return ret
}
func (context *IContext) GetResults(flags eResultFlags, pp **ITranscribeResult) uintptr {
ret, _, _ := syscall.Syscall(
context.lpVtbl.GetResults,
3,
uintptr(unsafe.Pointer(context)),
uintptr(flags),
uintptr(unsafe.Pointer(pp)),
)
return ret
}
func (context *IContext) DetectSpeaker(time *sTimeInterval, result *eSpeakerChannel) uintptr {
ret, _, _ := syscall.Syscall(
context.lpVtbl.DetectSpeaker,
3,
uintptr(unsafe.Pointer(context)),
uintptr(unsafe.Pointer(time)),
uintptr(unsafe.Pointer(result)),
)
return ret
}

5
pkg/whisper/doc.go Normal file
View File

@ -0,0 +1,5 @@
/*
github.com/jaybinks/goConstmeWhispers
Go Bindings for https://github.com/Const-me/Whisper
*/
package whisper

207
pkg/whisper/language.go Normal file
View File

@ -0,0 +1,207 @@
package whisper
// https://github.com/Const-me/Whisper/blob/master/WhisperNet/API/eLanguage.cs
type eLanguage int32
const (
Auto eLanguage = -1 // "af"
Afrikaans = 0x6661 // "af"
/// <summary>Albanian</summary>
Albanian = 0x7173 // "sq"
/// <summary>Amharic</summary>
Amharic = 0x6D61 // "am"
/// <summary>Arabic</summary>
Arabic = 0x7261 // "ar"
/// <summary>Armenian</summary>
Armenian = 0x7968 // "hy"
/// <summary>Assamese</summary>
Assamese = 0x7361 // "as"
/// <summary>Azerbaijani</summary>
Azerbaijani = 0x7A61 // "az"
/// <summary>Bashkir</summary>
Bashkir = 0x6162 // "ba"
/// <summary>Basque</summary>
Basque = 0x7565 // "eu"
/// <summary>Belarusian</summary>
Belarusian = 0x6562 // "be"
/// <summary>Bengali</summary>
Bengali = 0x6E62 // "bn"
/// <summary>Bosnian</summary>
Bosnian = 0x7362 // "bs"
/// <summary>Breton</summary>
Breton = 0x7262 // "br"
/// <summary>Bulgarian</summary>
Bulgarian = 0x6762 // "bg"
/// <summary>Catalan</summary>
Catalan = 0x6163 // "ca"
/// <summary>Chinese</summary>
Chinese = 0x687A // "zh"
/// <summary>Croatian</summary>
Croatian = 0x7268 // "hr"
/// <summary>Czech</summary>
Czech = 0x7363 // "cs"
/// <summary>Danish</summary>
Danish = 0x6164 // "da"
/// <summary>Dutch</summary>
Dutch = 0x6C6E // "nl"
/// <summary>English</summary>
English = 0x6E65 // "en"
/// <summary>Estonian</summary>
Estonian = 0x7465 // "et"
/// <summary>Faroese</summary>
Faroese = 0x6F66 // "fo"
/// <summary>Finnish</summary>
Finnish = 0x6966 // "fi"
/// <summary>French</summary>
French = 0x7266 // "fr"
/// <summary>Galician</summary>
Galician = 0x6C67 // "gl"
/// <summary>Georgian</summary>
Georgian = 0x616B // "ka"
/// <summary>German</summary>
German = 0x6564 // "de"
/// <summary>Greek</summary>
Greek = 0x6C65 // "el"
/// <summary>Gujarati</summary>
Gujarati = 0x7567 // "gu"
/// <summary>Haitian Creole</summary>
HaitianCreole = 0x7468 // "ht"
/// <summary>Hausa</summary>
Hausa = 0x6168 // "ha"
/// <summary>Hawaiian</summary>
Hawaiian = 0x776168 // "haw"
/// <summary>Hebrew</summary>
Hebrew = 0x7769 // "iw"
/// <summary>Hindi</summary>
Hindi = 0x6968 // "hi"
/// <summary>Hungarian</summary>
Hungarian = 0x7568 // "hu"
/// <summary>Icelandic</summary>
Icelandic = 0x7369 // "is"
/// <summary>Indonesian</summary>
Indonesian = 0x6469 // "id"
/// <summary>Italian</summary>
Italian = 0x7469 // "it"
/// <summary>Japanese</summary>
Japanese = 0x616A // "ja"
/// <summary>Javanese</summary>
Javanese = 0x776A // "jw"
/// <summary>Kannada</summary>
Kannada = 0x6E6B // "kn"
/// <summary>Kazakh</summary>
Kazakh = 0x6B6B // "kk"
/// <summary>Khmer</summary>
Khmer = 0x6D6B // "km"
/// <summary>Korean</summary>
Korean = 0x6F6B // "ko"
/// <summary>Lao</summary>
Lao = 0x6F6C // "lo"
/// <summary>Latin</summary>
Latin = 0x616C // "la"
/// <summary>Latvian</summary>
Latvian = 0x766C // "lv"
/// <summary>Lingala</summary>
Lingala = 0x6E6C // "ln"
/// <summary>Lithuanian</summary>
Lithuanian = 0x746C // "lt"
/// <summary>Luxembourgish</summary>
Luxembourgish = 0x626C // "lb"
/// <summary>Macedonian</summary>
Macedonian = 0x6B6D // "mk"
/// <summary>Malagasy</summary>
Malagasy = 0x676D // "mg"
/// <summary>Malay</summary>
Malay = 0x736D // "ms"
/// <summary>Malayalam</summary>
Malayalam = 0x6C6D // "ml"
/// <summary>Maltese</summary>
Maltese = 0x746D // "mt"
/// <summary>Maori</summary>
Maori = 0x696D // "mi"
/// <summary>Marathi</summary>
Marathi = 0x726D // "mr"
/// <summary>Mongolian</summary>
Mongolian = 0x6E6D // "mn"
/// <summary>Myanmar</summary>
Myanmar = 0x796D // "my"
/// <summary>Nepali</summary>
Nepali = 0x656E // "ne"
/// <summary>Norwegian</summary>
Norwegian = 0x6F6E // "no"
/// <summary>Nynorsk</summary>
Nynorsk = 0x6E6E // "nn"
/// <summary>Occitan</summary>
Occitan = 0x636F // "oc"
/// <summary>Pashto</summary>
Pashto = 0x7370 // "ps"
/// <summary>Persian</summary>
Persian = 0x6166 // "fa"
/// <summary>Polish</summary>
Polish = 0x6C70 // "pl"
/// <summary>Portuguese</summary>
Portuguese = 0x7470 // "pt"
/// <summary>Punjabi</summary>
Punjabi = 0x6170 // "pa"
/// <summary>Romanian</summary>
Romanian = 0x6F72 // "ro"
/// <summary>Russian</summary>
Russian = 0x7572 // "ru"
/// <summary>Sanskrit</summary>
Sanskrit = 0x6173 // "sa"
/// <summary>Serbian</summary>
Serbian = 0x7273 // "sr"
/// <summary>Shona</summary>
Shona = 0x6E73 // "sn"
/// <summary>Sindhi</summary>
Sindhi = 0x6473 // "sd"
/// <summary>Sinhala</summary>
Sinhala = 0x6973 // "si"
/// <summary>Slovak</summary>
Slovak = 0x6B73 // "sk"
/// <summary>Slovenian</summary>
Slovenian = 0x6C73 // "sl"
/// <summary>Somali</summary>
Somali = 0x6F73 // "so"
/// <summary>Spanish</summary>
Spanish = 0x7365 // "es"
/// <summary>Sundanese</summary>
Sundanese = 0x7573 // "su"
/// <summary>Swahili</summary>
Swahili = 0x7773 // "sw"
/// <summary>Swedish</summary>
Swedish = 0x7673 // "sv"
/// <summary>Tagalog</summary>
Tagalog = 0x6C74 // "tl"
/// <summary>Tajik</summary>
Tajik = 0x6774 // "tg"
/// <summary>Tamil</summary>
Tamil = 0x6174 // "ta"
/// <summary>Tatar</summary>
Tatar = 0x7474 // "tt"
/// <summary>Telugu</summary>
Telugu = 0x6574 // "te"
/// <summary>Thai</summary>
Thai = 0x6874 // "th"
/// <summary>Tibetan</summary>
Tibetan = 0x6F62 // "bo"
/// <summary>Turkish</summary>
Turkish = 0x7274 // "tr"
/// <summary>Turkmen</summary>
Turkmen = 0x6B74 // "tk"
/// <summary>Ukrainian</summary>
Ukrainian = 0x6B75 // "uk"
/// <summary>Urdu</summary>
Urdu = 0x7275 // "ur"
/// <summary>Uzbek</summary>
Uzbek = 0x7A75 // "uz"
/// <summary>Vietnamese</summary>
Vietnamese = 0x6976 // "vi"
/// <summary>Welsh</summary>
Welsh = 0x7963 // "cy"
/// <summary>Yiddish</summary>
Yiddish = 0x6979 // "yi"
/// <summary>Yoruba</summary>
Yoruba = 0x6F79 // "yo"
)

43
pkg/whisper/logger.go Normal file
View File

@ -0,0 +1,43 @@
package whisper
import (
"C"
"fmt"
)
/*
https://github.com/Const-me/Whisper/blob/843a2a6ca6ea47c5ac4889a281badfc808d0ea01/Whisper/API/loggerApi.h
*/
type eLogLevel uint8
const (
LlError eLogLevel = 0
LlWarning = 1
LlInfo = 2
LlDebug = 3
)
type eLogFlags uint8
const (
LfNone eLogFlags = 0
LfUseStandardError = 1
LfSkipFormatMessage = 2
)
type sLoggerSetup struct {
sink uintptr // pfnLoggerSink
context uintptr // void*
level eLogLevel // eLogLevel
flags eLogFlags // eLoggerFlags
}
func fnLoggerSink(context uintptr, lvl eLogLevel, message *C.char) uintptr {
strmessage := C.GoString(message)
fmt.Printf("%d - %s\n", lvl, strmessage)
return 0
}

188
pkg/whisper/whisper.go Normal file
View File

@ -0,0 +1,188 @@
//go:build windows
// +build windows
package whisper
import (
"C"
"errors"
"syscall"
"unsafe"
// Using lxn/win because its COM functions expose raw HRESULTs
"golang.org/x/sys/windows"
)
import (
"fmt"
)
/*
eModelImplementation - TranscribeStructs.h
// GPGPU implementation based on Direct3D 11.0 compute shaders
GPU = 1,
// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU
// Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1
Hybrid = 2,
// A reference implementation which uses the original GGML CPU-running code
// Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1
Reference = 3,
*/
// https://learn.microsoft.com/en-us/windows/win32/seccrypto/common-hresult-values
// https://pkg.go.dev/golang.org/x/sys/windows
const (
E_INVALIDARG = 0x80070057
ERROR_HV_CPUID_FEATURE_VALIDATION = 0xC0350038
DLLName = "whisper.dll"
)
type Libwhisper struct {
dll *syscall.LazyDLL
ver WinVersion
existing_model map[string]*Model
proc_setupLogger *syscall.LazyProc
proc_loadModel *syscall.LazyProc
proc_initMediaFoundation *syscall.LazyProc
// proc_findLanguageKeyW *syscall.LazyProc
// proc_findLanguageKeyA *syscall.LazyProc
// proc_getSupportedLanguages *syscall.LazyProc
}
var singleton_whisper *Libwhisper = nil
func New(level eLogLevel, flags eLogFlags, cb *any) (*Libwhisper, error) {
if singleton_whisper != nil {
return singleton_whisper, nil
}
var err error
this := &Libwhisper{}
this.ver, err = GetFileVersion(DLLName)
if err != nil {
return nil, err
}
if this.ver.Major < 1 && this.ver.Minor < 9 {
return nil, errors.New("This library requires whisper.dll version 1.9 or higher.") // or less than 1.11 for now .. because the API changed
}
this.dll = syscall.NewLazyDLL(DLLName) // Todo wrap this in a class, check file exists, handle errors ... you know, just a few things.. AKA Stop being lazy
this.proc_setupLogger = this.dll.NewProc("setupLogger")
this.proc_loadModel = this.dll.NewProc("loadModel")
this.proc_initMediaFoundation = this.dll.NewProc("initMediaFoundation")
/*
this.proc_findLanguageKeyW = this.dll.NewProc("findLanguageKeyW")
this.proc_findLanguageKeyA = this.dll.NewProc("findLanguageKeyA")
this.proc_getSupportedLanguages = this.dll.NewProc("getSupportedLanguages")
*/
ok, err := this._setupLogger(level, flags, cb)
if !ok {
return nil, errors.New("Logger Error : " + err.Error())
}
this.existing_model = make(map[string]*Model)
singleton_whisper = this
return singleton_whisper, nil
}
func (this *Libwhisper) Version() string {
return fmt.Sprintf("%d.%d.%d.%d.", this.ver.Major, this.ver.Minor, this.ver.Patch, this.ver.Build)
}
func (this *Libwhisper) SupportsMultiThread() bool {
return this.ver.Major >= 1 && this.ver.Minor >= 10
}
func (this *Libwhisper) _setupLogger(level eLogLevel, flags eLogFlags, cb *any) (bool, error) {
setup := sLoggerSetup{}
setup.sink = 0
setup.context = 0
setup.level = level
setup.flags = flags
if cb != nil {
setup.sink = syscall.NewCallback(cb)
}
res, _, err := this.proc_setupLogger.Call(uintptr(unsafe.Pointer(&setup)))
if windows.Handle(res) == windows.S_OK {
return true, nil
} else {
return false, err
}
}
func (this *Libwhisper) LoadModel(path string, aGPU ...string) (*Model, error) {
var modelptr *_IModel
whisperpath, _ := windows.UTF16PtrFromString(path)
GPU := ""
if len(aGPU) == 1 {
GPU = aGPU[0]
}
setup := ModelSetup(gmf_Cloneable, GPU)
// Construct our map hash
singleton_hash := GPU + "|" + path
if this.existing_model[singleton_hash] != nil {
ClonedModel, err := this.existing_model[singleton_hash].Clone()
if ClonedModel != nil {
return NewModel(setup, ClonedModel), nil
} else {
return nil, err
}
}
obj, _, _ := this.proc_loadModel.Call(uintptr(unsafe.Pointer(whisperpath)), uintptr(unsafe.Pointer(setup.AsCType())), uintptr(unsafe.Pointer(nil)), uintptr(unsafe.Pointer(&modelptr)))
if windows.Handle(obj) != windows.S_OK {
fmt.Printf("loadModel failed: %s\n", syscall.Errno(obj).Error())
return nil, fmt.Errorf("loadModel failed: %s", syscall.Errno(obj))
}
if modelptr == nil {
return nil, errors.New("loadModel did not return a Model")
}
if modelptr.lpVtbl == nil {
return nil, errors.New("loadModel method table is nil")
}
model := NewModel(setup, modelptr)
this.existing_model[singleton_hash] = model
return model, nil
}
func (this *Libwhisper) InitMediaFoundation() (*IMediaFoundation, error) {
var mediafoundation *IMediaFoundation
// initMediaFoundation( iMediaFoundation** pp );
obj, _, _ := this.proc_initMediaFoundation.Call(uintptr(unsafe.Pointer(&mediafoundation)))
if windows.Handle(obj) != windows.S_OK {
fmt.Printf("initMediaFoundation failed: %s\n", syscall.Errno(obj).Error())
return nil, fmt.Errorf("initMediaFoundation failed: %s", syscall.Errno(obj))
}
if mediafoundation.lpVtbl == nil {
return nil, errors.New("initMediaFoundation method table is nil")
}
return mediafoundation, nil
}

123
pkg/whisper/winversion.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright 2018 Keybase Inc. All rights reserved.
// Use of this source code is governed by a BSD
// license that can be found in the LICENSE file.
// Adapted mainly from github.com/gonutz/w32
//go:build windows
// +build windows
package whisper
import (
"errors"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var (
version = windows.NewLazySystemDLL("version.dll")
getFileVersionInfoSize = version.NewProc("GetFileVersionInfoSizeW")
getFileVersionInfo = version.NewProc("GetFileVersionInfoW")
verQueryValue = version.NewProc("VerQueryValueW")
)
type VS_FIXEDFILEINFO struct {
Signature uint32
StrucVersion uint32
FileVersionMS uint32
FileVersionLS uint32
ProductVersionMS uint32
ProductVersionLS uint32
FileFlagsMask uint32
FileFlags uint32
FileOS uint32
FileType uint32
FileSubtype uint32
FileDateMS uint32
FileDateLS uint32
}
type WinVersion struct {
Major uint32
Minor uint32
Patch uint32
Build uint32
}
// FileVersion concatenates FileVersionMS and FileVersionLS to a uint64 value.
func (fi VS_FIXEDFILEINFO) FileVersion() uint64 {
return uint64(fi.FileVersionMS)<<32 | uint64(fi.FileVersionLS)
}
func GetFileVersionInfoSize(path string) uint32 {
ret, _, _ := getFileVersionInfoSize.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
0,
)
return uint32(ret)
}
func GetFileVersionInfo(path string, data []byte) bool {
ret, _, _ := getFileVersionInfo.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))),
0,
uintptr(len(data)),
uintptr(unsafe.Pointer(&data[0])),
)
return ret != 0
}
// VerQueryValueRoot calls VerQueryValue
// (https://msdn.microsoft.com/en-us/library/windows/desktop/ms647464(v=vs.85).aspx)
// with `\` (root) to retieve the VS_FIXEDFILEINFO.
func VerQueryValueRoot(block []byte) (VS_FIXEDFILEINFO, error) {
var offset uintptr
var length uint
blockStart := unsafe.Pointer(&block[0])
ret, _, _ := verQueryValue.Call(
uintptr(blockStart),
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(`\`))),
uintptr(unsafe.Pointer(&offset)),
uintptr(unsafe.Pointer(&length)),
)
if ret == 0 {
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: verQueryValue failed")
}
start := int(offset) - int(uintptr(blockStart))
end := start + int(length)
if start < 0 || start >= len(block) || end < start || end > len(block) {
return VS_FIXEDFILEINFO{}, errors.New("VerQueryValueRoot: find failed")
}
data := block[start:end]
info := *((*VS_FIXEDFILEINFO)(unsafe.Pointer(&data[0])))
return info, nil
}
func GetFileVersion(path string) (WinVersion, error) {
var result WinVersion
size := GetFileVersionInfoSize(path)
if size <= 0 {
return result, errors.New("GetFileVersionInfoSize failed")
}
info := make([]byte, size)
ok := GetFileVersionInfo(path, info)
if !ok {
return result, errors.New("GetFileVersionInfo failed")
}
fixed, err := VerQueryValueRoot(info)
if err != nil {
return result, err
}
version := fixed.FileVersion()
result.Major = uint32(version & 0xFFFF000000000000 >> 48)
result.Minor = uint32(version & 0x0000FFFF00000000 >> 32)
result.Patch = uint32(version & 0x00000000FFFF0000 >> 16)
result.Build = uint32(version & 0x000000000000FFFF)
return result, nil
}