Fix load driver issue

The Buffer pointer of UNICODE_STRING seems to be cleaned up after
routine, so we cannot store the string, but have to init the string when
needed.
This commit is contained in:
nganhkhoa 2020-02-25 01:20:54 +07:00
parent 8928e4e4cb
commit 2ee77d16c7

View File

@ -1,4 +1,5 @@
use std::ffi::CString; use std::ffi::CString;
use std::mem::transmute;
use std::ptr::null_mut; use std::ptr::null_mut;
use widestring::U16CString; use widestring::U16CString;
@ -20,6 +21,8 @@ use winapi::um::securitybaseapi::{AdjustTokenPrivileges};
use winapi::um::winbase::{LookupPrivilegeValueA}; use winapi::um::winbase::{LookupPrivilegeValueA};
use winapi::um::winreg::{RegCreateKeyExA, RegSetValueExA, RegCloseKey, HKEY_LOCAL_MACHINE}; use winapi::um::winreg::{RegCreateKeyExA, RegSetValueExA, RegCloseKey, HKEY_LOCAL_MACHINE};
const STR_DRIVER_REGISTRY_PATH: &str = "\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa";
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
pub enum WindowsVersion { pub enum WindowsVersion {
@ -37,12 +40,11 @@ pub enum WindowsVersion {
pub struct WindowsFFI { pub struct WindowsFFI {
pub version_info: OSVERSIONINFOW, pub version_info: OSVERSIONINFOW,
pub short_version: WindowsVersion, pub short_version: WindowsVersion,
// driver_registry_string: UNICODE_STRING,
driver_handle: HANDLE, driver_handle: HANDLE,
ntdll: HMODULE, ntdll: HMODULE,
nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS, nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS,
nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS, nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS,
rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR), rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR),
rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS, rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS,
} }
@ -62,8 +64,6 @@ impl WindowsFFI {
let str_start = CString::new("Start").unwrap(); let str_start = CString::new("Start").unwrap();
let str_image_path = CString::new("ImagePath").unwrap(); let str_image_path = CString::new("ImagePath").unwrap();
// let mut str_driver_reg_unicode = UNICODE_STRING::default();
let str_driver_reg_unicode: UNICODE_STRING;
let mut version_info = OSVERSIONINFOW { let mut version_info = OSVERSIONINFOW {
dwOSVersionInfoSize: 0u32, dwOSVersionInfoSize: 0u32,
dwMajorVersion: 0u32, dwMajorVersion: 0u32,
@ -74,9 +74,9 @@ impl WindowsFFI {
}; };
let ntdll: HMODULE; let ntdll: HMODULE;
let nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS; let nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS;
let nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS; let nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS;
let rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR); let rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR);
let rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS; let rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS;
// some pointer unsafe C code // some pointer unsafe C code
@ -87,10 +87,10 @@ impl WindowsFFI {
let rtl_init_unicode_str_ = GetProcAddress(ntdll, str_rtl_init_unicode_str.as_ptr()); let rtl_init_unicode_str_ = GetProcAddress(ntdll, str_rtl_init_unicode_str.as_ptr());
let rtl_get_version_ = GetProcAddress(ntdll, str_rtl_get_version.as_ptr()); let rtl_get_version_ = GetProcAddress(ntdll, str_rtl_get_version.as_ptr());
nt_load_driver = std::mem::transmute(nt_load_driver_); nt_load_driver = transmute(nt_load_driver_);
nt_unload_driver = std::mem::transmute(nt_unload_driver_); nt_unload_driver = transmute(nt_unload_driver_);
rtl_init_unicode_str = std::mem::transmute(rtl_init_unicode_str_); rtl_init_unicode_str = transmute(rtl_init_unicode_str_);
rtl_get_version = std::mem::transmute(rtl_get_version_); rtl_get_version = transmute(rtl_get_version_);
// setup registry // setup registry
let mut registry_key: HKEY = null_mut(); let mut registry_key: HKEY = null_mut();
@ -133,15 +133,6 @@ impl WindowsFFI {
AdjustTokenPrivileges( AdjustTokenPrivileges(
token_handle, 0, &mut new_token_state, 16, null_mut(), null_mut()); token_handle, 0, &mut new_token_state, 16, null_mut(), null_mut());
CloseHandle(token_handle); CloseHandle(token_handle);
// init string for loading and unloading driver routine
// rtl_init_unicode_str(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
//
// let unicode_str =
// U16CString::from_ptr_unchecked(
// str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
//
// println!("unicode string created: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
} }
rtl_get_version(&mut version_info); rtl_get_version(&mut version_info);
@ -157,7 +148,6 @@ impl WindowsFFI {
Self { Self {
version_info, version_info,
short_version, short_version,
// driver_registry_string: str_driver_reg_unicode,
driver_handle: null_mut(), driver_handle: null_mut(),
ntdll, ntdll,
nt_load_driver, nt_load_driver,
@ -168,68 +158,32 @@ impl WindowsFFI {
} }
pub fn load_driver(&mut self) -> NTSTATUS { pub fn load_driver(&mut self) -> NTSTATUS {
let mut str_driver_reg_unicode: UNICODE_STRING = UNICODE_STRING::default(); let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap();
unsafe { let mut str_driver_reg_unicode = UNICODE_STRING::default();
let str_driver_reg = (self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr() as *const u16);
U16CString::from_str(
"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa")
.expect("");
(self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
// str_driver_reg_unicode = UNICODE_STRING {
// Length: (str_driver_reg.len() * 2) as u16,
// MaximumLength: (str_driver_reg.len() * 2) as u16,
// Buffer: str_driver_reg.as_ptr() as *mut u16
// };
let unicode_str =
U16CString::from_ptr_unchecked(
str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
println!("unicode string called: {:?}", unicode_str.into_vec_with_nul());
}
let status = (self.nt_load_driver)(&mut str_driver_reg_unicode); let status = (self.nt_load_driver)(&mut str_driver_reg_unicode);
// Create a device handle to loaded driver
let driver_system_path = CString::new("\\Device\\poolscanner").unwrap(); let filename = CString::new("\\Device\\poolscanner").unwrap();
let driver_handle; let driver_file_handle: HANDLE = unsafe {
unsafe { CreateFileA(filename.as_ptr(),
driver_handle = CreateFileA(driver_system_path.as_ptr(),
GENERIC_READ | GENERIC_WRITE, GENERIC_READ | GENERIC_WRITE,
0, 0, null_mut(), CREATE_ALWAYS,
null_mut(), FILE_ATTRIBUTE_NORMAL, null_mut())
CREATE_ALWAYS, };
FILE_ATTRIBUTE_NORMAL,
null_mut()); if driver_file_handle == INVALID_HANDLE_VALUE {
} println!("Driver CreateFileA failed");
// TODO: check driver_handle return status
self.driver_handle = driver_handle;
if driver_handle == INVALID_HANDLE_VALUE {
println!("Driver create failed");
status
} }
else { else {
status self.driver_handle = driver_file_handle;
} }
status
} }
pub fn unload_driver(&mut self) -> NTSTATUS { pub fn unload_driver(&mut self) -> NTSTATUS {
let mut str_driver_reg_unicode: UNICODE_STRING; let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap();
unsafe { let mut str_driver_reg_unicode = UNICODE_STRING::default();
let str_driver_reg = (self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
U16CString::from_str(
"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa")
.expect("");
str_driver_reg_unicode = UNICODE_STRING {
Length: (str_driver_reg.len() * 2) as u16,
MaximumLength: (str_driver_reg.len() * 2) as u16,
Buffer: str_driver_reg.as_ptr() as *mut u16
};
let unicode_str =
U16CString::from_ptr_unchecked(
str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
println!("unicode string called: {:?}", unicode_str.into_vec_with_nul());
}
(self.nt_unload_driver)(&mut str_driver_reg_unicode) (self.nt_unload_driver)(&mut str_driver_reg_unicode)
} }