diff --git a/src/windows.rs b/src/windows.rs index fc876e8..a009ecd 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -1,4 +1,5 @@ use std::ffi::CString; +use std::mem::transmute; use std::ptr::null_mut; use widestring::U16CString; @@ -20,6 +21,8 @@ use winapi::um::securitybaseapi::{AdjustTokenPrivileges}; use winapi::um::winbase::{LookupPrivilegeValueA}; use winapi::um::winreg::{RegCreateKeyExA, RegSetValueExA, RegCloseKey, HKEY_LOCAL_MACHINE}; +const STR_DRIVER_REGISTRY_PATH: &str = "\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa"; + #[allow(dead_code)] #[derive(Debug)] pub enum WindowsVersion { @@ -37,12 +40,11 @@ pub enum WindowsVersion { pub struct WindowsFFI { pub version_info: OSVERSIONINFOW, pub short_version: WindowsVersion, - // driver_registry_string: UNICODE_STRING, driver_handle: HANDLE, ntdll: HMODULE, - nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS, - nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS, - rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR), + nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS, + nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS, + rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR), rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS, } @@ -62,8 +64,6 @@ impl WindowsFFI { let str_start = CString::new("Start").unwrap(); let str_image_path = CString::new("ImagePath").unwrap(); - // let mut str_driver_reg_unicode = UNICODE_STRING::default(); - let str_driver_reg_unicode: UNICODE_STRING; let mut version_info = OSVERSIONINFOW { dwOSVersionInfoSize: 0u32, dwMajorVersion: 0u32, @@ -74,9 +74,9 @@ impl WindowsFFI { }; let ntdll: HMODULE; - let nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS; - let nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS; - let rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR); + let nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS; + let nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS; + let rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR); let rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS; // some pointer unsafe C code @@ -87,10 +87,10 @@ impl WindowsFFI { let rtl_init_unicode_str_ = GetProcAddress(ntdll, str_rtl_init_unicode_str.as_ptr()); let rtl_get_version_ = GetProcAddress(ntdll, str_rtl_get_version.as_ptr()); - nt_load_driver = std::mem::transmute(nt_load_driver_); - nt_unload_driver = std::mem::transmute(nt_unload_driver_); - rtl_init_unicode_str = std::mem::transmute(rtl_init_unicode_str_); - rtl_get_version = std::mem::transmute(rtl_get_version_); + nt_load_driver = transmute(nt_load_driver_); + nt_unload_driver = transmute(nt_unload_driver_); + rtl_init_unicode_str = transmute(rtl_init_unicode_str_); + rtl_get_version = transmute(rtl_get_version_); // setup registry let mut registry_key: HKEY = null_mut(); @@ -133,15 +133,6 @@ impl WindowsFFI { AdjustTokenPrivileges( token_handle, 0, &mut new_token_state, 16, null_mut(), null_mut()); CloseHandle(token_handle); - - // init string for loading and unloading driver routine - // rtl_init_unicode_str(&mut str_driver_reg_unicode, str_driver_reg.as_ptr()); - // - // let unicode_str = - // U16CString::from_ptr_unchecked( - // str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize); - // - // println!("unicode string created: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy()); } rtl_get_version(&mut version_info); @@ -157,7 +148,6 @@ impl WindowsFFI { Self { version_info, short_version, - // driver_registry_string: str_driver_reg_unicode, driver_handle: null_mut(), ntdll, nt_load_driver, @@ -168,68 +158,32 @@ impl WindowsFFI { } pub fn load_driver(&mut self) -> NTSTATUS { - let mut str_driver_reg_unicode: UNICODE_STRING = UNICODE_STRING::default(); - unsafe { - let str_driver_reg = - U16CString::from_str( - "\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa") - .expect(""); - (self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr()); - // str_driver_reg_unicode = UNICODE_STRING { - // Length: (str_driver_reg.len() * 2) as u16, - // MaximumLength: (str_driver_reg.len() * 2) as u16, - // Buffer: str_driver_reg.as_ptr() as *mut u16 - // }; - let unicode_str = - U16CString::from_ptr_unchecked( - str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize); - - println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy()); - println!("unicode string called: {:?}", unicode_str.into_vec_with_nul()); - } + let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap(); + let mut str_driver_reg_unicode = UNICODE_STRING::default(); + (self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr() as *const u16); let status = (self.nt_load_driver)(&mut str_driver_reg_unicode); - // Create a device handle to loaded driver - let driver_system_path = CString::new("\\Device\\poolscanner").unwrap(); - let driver_handle; - unsafe { - driver_handle = CreateFileA(driver_system_path.as_ptr(), - GENERIC_READ | GENERIC_WRITE, - 0, - null_mut(), - CREATE_ALWAYS, - FILE_ATTRIBUTE_NORMAL, - null_mut()); - } - // TODO: check driver_handle return status - self.driver_handle = driver_handle; - if driver_handle == INVALID_HANDLE_VALUE { - println!("Driver create failed"); - status + + let filename = CString::new("\\Device\\poolscanner").unwrap(); + let driver_file_handle: HANDLE = unsafe { + CreateFileA(filename.as_ptr(), + GENERIC_READ | GENERIC_WRITE, + 0, null_mut(), CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, null_mut()) + }; + + if driver_file_handle == INVALID_HANDLE_VALUE { + println!("Driver CreateFileA failed"); } else { - status + self.driver_handle = driver_file_handle; } + status } pub fn unload_driver(&mut self) -> NTSTATUS { - let mut str_driver_reg_unicode: UNICODE_STRING; - unsafe { - let str_driver_reg = - U16CString::from_str( - "\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa") - .expect(""); - str_driver_reg_unicode = UNICODE_STRING { - Length: (str_driver_reg.len() * 2) as u16, - MaximumLength: (str_driver_reg.len() * 2) as u16, - Buffer: str_driver_reg.as_ptr() as *mut u16 - }; - let unicode_str = - U16CString::from_ptr_unchecked( - str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize); - - println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy()); - println!("unicode string called: {:?}", unicode_str.into_vec_with_nul()); - } + let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap(); + let mut str_driver_reg_unicode = UNICODE_STRING::default(); + (self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr()); (self.nt_unload_driver)(&mut str_driver_reg_unicode) }