Fix load driver issue

The Buffer pointer of UNICODE_STRING seems to be cleaned up after
routine, so we cannot store the string, but have to init the string when
needed.
This commit is contained in:
nganhkhoa 2020-02-25 01:20:54 +07:00
parent 8928e4e4cb
commit 2ee77d16c7

View File

@ -1,4 +1,5 @@
use std::ffi::CString;
use std::mem::transmute;
use std::ptr::null_mut;
use widestring::U16CString;
@ -20,6 +21,8 @@ use winapi::um::securitybaseapi::{AdjustTokenPrivileges};
use winapi::um::winbase::{LookupPrivilegeValueA};
use winapi::um::winreg::{RegCreateKeyExA, RegSetValueExA, RegCloseKey, HKEY_LOCAL_MACHINE};
const STR_DRIVER_REGISTRY_PATH: &str = "\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa";
#[allow(dead_code)]
#[derive(Debug)]
pub enum WindowsVersion {
@ -37,12 +40,11 @@ pub enum WindowsVersion {
pub struct WindowsFFI {
pub version_info: OSVERSIONINFOW,
pub short_version: WindowsVersion,
// driver_registry_string: UNICODE_STRING,
driver_handle: HANDLE,
ntdll: HMODULE,
nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS,
nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS,
rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR),
nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS,
nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS,
rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR),
rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS,
}
@ -62,8 +64,6 @@ impl WindowsFFI {
let str_start = CString::new("Start").unwrap();
let str_image_path = CString::new("ImagePath").unwrap();
// let mut str_driver_reg_unicode = UNICODE_STRING::default();
let str_driver_reg_unicode: UNICODE_STRING;
let mut version_info = OSVERSIONINFOW {
dwOSVersionInfoSize: 0u32,
dwMajorVersion: 0u32,
@ -74,9 +74,9 @@ impl WindowsFFI {
};
let ntdll: HMODULE;
let nt_load_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS;
let nt_unload_driver: extern "stdcall" fn(PUNICODE_STRING) -> NTSTATUS;
let rtl_init_unicode_str: extern "stdcall" fn(PUNICODE_STRING, PCWSTR);
let nt_load_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS;
let nt_unload_driver: extern "system" fn(PUNICODE_STRING) -> NTSTATUS;
let rtl_init_unicode_str: extern "system" fn(PUNICODE_STRING, PCWSTR);
let rtl_get_version: extern "system" fn(PRTL_OSVERSIONINFOW) -> NTSTATUS;
// some pointer unsafe C code
@ -87,10 +87,10 @@ impl WindowsFFI {
let rtl_init_unicode_str_ = GetProcAddress(ntdll, str_rtl_init_unicode_str.as_ptr());
let rtl_get_version_ = GetProcAddress(ntdll, str_rtl_get_version.as_ptr());
nt_load_driver = std::mem::transmute(nt_load_driver_);
nt_unload_driver = std::mem::transmute(nt_unload_driver_);
rtl_init_unicode_str = std::mem::transmute(rtl_init_unicode_str_);
rtl_get_version = std::mem::transmute(rtl_get_version_);
nt_load_driver = transmute(nt_load_driver_);
nt_unload_driver = transmute(nt_unload_driver_);
rtl_init_unicode_str = transmute(rtl_init_unicode_str_);
rtl_get_version = transmute(rtl_get_version_);
// setup registry
let mut registry_key: HKEY = null_mut();
@ -133,15 +133,6 @@ impl WindowsFFI {
AdjustTokenPrivileges(
token_handle, 0, &mut new_token_state, 16, null_mut(), null_mut());
CloseHandle(token_handle);
// init string for loading and unloading driver routine
// rtl_init_unicode_str(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
//
// let unicode_str =
// U16CString::from_ptr_unchecked(
// str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
//
// println!("unicode string created: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
}
rtl_get_version(&mut version_info);
@ -157,7 +148,6 @@ impl WindowsFFI {
Self {
version_info,
short_version,
// driver_registry_string: str_driver_reg_unicode,
driver_handle: null_mut(),
ntdll,
nt_load_driver,
@ -168,68 +158,32 @@ impl WindowsFFI {
}
pub fn load_driver(&mut self) -> NTSTATUS {
let mut str_driver_reg_unicode: UNICODE_STRING = UNICODE_STRING::default();
unsafe {
let str_driver_reg =
U16CString::from_str(
"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa")
.expect("");
(self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
// str_driver_reg_unicode = UNICODE_STRING {
// Length: (str_driver_reg.len() * 2) as u16,
// MaximumLength: (str_driver_reg.len() * 2) as u16,
// Buffer: str_driver_reg.as_ptr() as *mut u16
// };
let unicode_str =
U16CString::from_ptr_unchecked(
str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
println!("unicode string called: {:?}", unicode_str.into_vec_with_nul());
}
let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap();
let mut str_driver_reg_unicode = UNICODE_STRING::default();
(self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr() as *const u16);
let status = (self.nt_load_driver)(&mut str_driver_reg_unicode);
// Create a device handle to loaded driver
let driver_system_path = CString::new("\\Device\\poolscanner").unwrap();
let driver_handle;
unsafe {
driver_handle = CreateFileA(driver_system_path.as_ptr(),
let filename = CString::new("\\Device\\poolscanner").unwrap();
let driver_file_handle: HANDLE = unsafe {
CreateFileA(filename.as_ptr(),
GENERIC_READ | GENERIC_WRITE,
0,
null_mut(),
CREATE_ALWAYS,
FILE_ATTRIBUTE_NORMAL,
null_mut());
}
// TODO: check driver_handle return status
self.driver_handle = driver_handle;
if driver_handle == INVALID_HANDLE_VALUE {
println!("Driver create failed");
status
0, null_mut(), CREATE_ALWAYS,
FILE_ATTRIBUTE_NORMAL, null_mut())
};
if driver_file_handle == INVALID_HANDLE_VALUE {
println!("Driver CreateFileA failed");
}
else {
status
self.driver_handle = driver_file_handle;
}
status
}
pub fn unload_driver(&mut self) -> NTSTATUS {
let mut str_driver_reg_unicode: UNICODE_STRING;
unsafe {
let str_driver_reg =
U16CString::from_str(
"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\nganhkhoa")
.expect("");
str_driver_reg_unicode = UNICODE_STRING {
Length: (str_driver_reg.len() * 2) as u16,
MaximumLength: (str_driver_reg.len() * 2) as u16,
Buffer: str_driver_reg.as_ptr() as *mut u16
};
let unicode_str =
U16CString::from_ptr_unchecked(
str_driver_reg_unicode.Buffer, (str_driver_reg_unicode.Length / 2) as usize);
println!("unicode string called: {:p} {}", str_driver_reg_unicode.Buffer, unicode_str.to_string_lossy());
println!("unicode string called: {:?}", unicode_str.into_vec_with_nul());
}
let str_driver_reg = U16CString::from_str(STR_DRIVER_REGISTRY_PATH).unwrap();
let mut str_driver_reg_unicode = UNICODE_STRING::default();
(self.rtl_init_unicode_str)(&mut str_driver_reg_unicode, str_driver_reg.as_ptr());
(self.nt_unload_driver)(&mut str_driver_reg_unicode)
}