kernel loader: don't try to reload modules already loaded by init

Change-Id: I58fb2ee369bdb9f5c5f9e6c8ea3f15781ae60ece
diff --git a/kernel_module_loader.cpp b/kernel_module_loader.cpp
index 850807e..7baf1f2 100644
--- a/kernel_module_loader.cpp
+++ b/kernel_module_loader.cpp
@@ -2,78 +2,108 @@
 
 const std::vector<std::string> kernel_modules_requested = TWFunc::split_string(EXPAND(TW_LOAD_VENDOR_MODULES), ' ', true);
 
-bool KernelModuleLoader::Load_Vendor_Modules(BOOT_MODE mode) {
-    // check /lib/modules (ramdisk vendor_boot)
-    // check /lib/modules/N.N (ramdisk vendor_boot)
-    // check /lib/modules/N.N-gki (ramdisk vendor_boot)
-    // check /vendor/lib/modules (ramdisk)
-    // check /vendor/lib/modules/1.1 (ramdisk prebuilt modules)
-    // check /vendor/lib/modules/N.N (vendor mounted)
-    // check /vendor/lib/modules/N.N-gki (vendor mounted)
-    int modules_loaded = 0;
+bool KernelModuleLoader::Load_Vendor_Modules() {
+	// check /lib/modules (ramdisk vendor_boot)
+	// check /lib/modules/N.N (ramdisk vendor_boot)
+	// check /lib/modules/N.N-gki (ramdisk vendor_boot)
+	// check /vendor/lib/modules (ramdisk)
+	// check /vendor/lib/modules/1.1 (ramdisk prebuilt modules)
+	// check /vendor/lib/modules/N.N (vendor mounted)
+	// check /vendor/lib/modules/N.N-gki (vendor mounted)
+	int modules_loaded = 0;
 
-    LOGINFO("Attempting to load modules\n");
-    std::string vendor_base_dir(VENDOR_MODULE_DIR);
-    std::string base_dir(VENDOR_BOOT_MODULE_DIR);
-    std::vector<std::string> module_dirs;
-    std::vector<std::string> vendor_module_dirs;
+	LOGINFO("Attempting to load modules\n");
+	std::string vendor_base_dir(VENDOR_MODULE_DIR);
+	std::string base_dir(VENDOR_BOOT_MODULE_DIR);
+	std::vector<std::string> module_dirs;
+	std::vector<std::string> vendor_module_dirs;
 
-    TWPartition* ven = PartitionManager.Find_Partition_By_Path("/vendor");
-    vendor_module_dirs.push_back(VENDOR_MODULE_DIR);
-    vendor_module_dirs.push_back(vendor_base_dir + "/1.1");
+	TWPartition* ven = PartitionManager.Find_Partition_By_Path("/vendor");
+	vendor_module_dirs.push_back(VENDOR_MODULE_DIR);
+	vendor_module_dirs.push_back(vendor_base_dir + "/1.1");
 
-    module_dirs.push_back(base_dir);
+	module_dirs.push_back(base_dir);
 
-    struct utsname uts;
-    if (uname(&uts)) {
-        LOGERR("Unable to query kernel for version info\n");
-    }
+	struct utsname uts;
+	if (uname(&uts)) {
+		LOGERR("Unable to query kernel for version info\n");
+	}
 
-    std::string rls(uts.release);
-    std::vector<std::string> release = TWFunc::split_string(rls, '.', true);
-    int expected_module_count = kernel_modules_requested.size();
-    module_dirs.push_back(base_dir + "/" + release[0] + "." + release[1]);
+	std::string rls(uts.release);
+	std::vector<std::string> release = TWFunc::split_string(rls, '.', true);
+	int expected_module_count = kernel_modules_requested.size();
+	module_dirs.push_back(base_dir + "/" + release[0] + "." + release[1]);
 
-    for (auto&& module_dir:module_dirs) {
-        modules_loaded += Try_And_Load_Modules(module_dir);
-        if (modules_loaded >= expected_module_count) goto exit;
-    }
+	for (auto&& module_dir:module_dirs) {
+		modules_loaded += Try_And_Load_Modules(module_dir, false);
+		if (modules_loaded >= expected_module_count) goto exit;
+	}
 
-    for (auto&& module_dir:vendor_module_dirs) {
-        modules_loaded += Try_And_Load_Modules(module_dir);
-        if (modules_loaded >= expected_module_count) goto exit;
-    }
+	for (auto&& module_dir:vendor_module_dirs) {
+		modules_loaded += Try_And_Load_Modules(module_dir, false);
+		if (modules_loaded >= expected_module_count) goto exit;
+	}
 
-    if (ven) {
-        LOGINFO("Checking mounted /vendor\n");
-        ven->Mount(true);
-    }
+	if (ven) {
+		LOGINFO("Checking mounted /vendor\n");
+		ven->Mount(true);
+	}
 
-    for (auto&& module_dir:vendor_module_dirs) {
-        modules_loaded += Try_And_Load_Modules(module_dir);
-        if (modules_loaded >= expected_module_count) goto exit;
-    }
+	for (auto&& module_dir:vendor_module_dirs) {
+		modules_loaded += Try_And_Load_Modules(module_dir, true);
+		if (modules_loaded >= expected_module_count) goto exit;
+	}
 
 exit:
-    if (ven)
-        ven->UnMount(false);
+	if (ven)
+		ven->UnMount(false);
 
 	return true;
 }
 
-int KernelModuleLoader::Try_And_Load_Modules(std::string module_dir) {
-        LOGINFO("Checking directory: %s\n", module_dir.c_str());
-        std::string dest_module_dir;
-        dest_module_dir = "/tmp" + module_dir;
-        TWFunc::Recursive_Mkdir(dest_module_dir);
-        Copy_Modules_To_Tmpfs(module_dir);
-        Write_Module_List(dest_module_dir);
-        Modprobe m({dest_module_dir}, "modules.load.twrp");
-        m.EnableVerbose(true);
-        m.LoadListedModules(false);
-        int modules_loaded = m.GetModuleCount();
-        LOGINFO("Modules Loaded: %d\n", modules_loaded);
-        return modules_loaded;
+int KernelModuleLoader::Try_And_Load_Modules(std::string module_dir, bool vendor_is_mounted) {
+		LOGINFO("Checking directory: %s\n", module_dir.c_str());
+		int modules_loaded = 0;
+		std::string dest_module_dir;
+		dest_module_dir = "/tmp" + module_dir;
+		TWFunc::Recursive_Mkdir(dest_module_dir);
+		Copy_Modules_To_Tmpfs(module_dir);
+		if (!Write_Module_List(dest_module_dir))
+			return kernel_modules_requested.size();
+		if (!vendor_is_mounted && module_dir == "/vendor/lib/modules") {
+			module_dir = "/lib/modules";
+		}
+		LOGINFO("mounting %s on %s\n", dest_module_dir.c_str(), module_dir.c_str());
+		if (mount(dest_module_dir.c_str(), module_dir.c_str(), "", MS_BIND, NULL) == 0) {
+			Modprobe m({module_dir}, "modules.load.twrp");
+			m.EnableVerbose(true);
+			m.LoadListedModules(false);
+			modules_loaded = m.GetModuleCount();
+			umount2(module_dir.c_str(), MNT_DETACH);
+			LOGINFO("Modules Loaded: %d\n", modules_loaded);
+		}
+		return modules_loaded;
+}
+
+std::vector<string> KernelModuleLoader::Skip_Loaded_Kernel_Modules() {
+	std::vector<string> kernel_modules = kernel_modules_requested;
+	std::vector<string> loaded_modules;
+	std::string kernel_module_file = "/proc/modules";
+	if (TWFunc::read_file(kernel_module_file, loaded_modules) < 0)
+		LOGINFO("failed to get loaded kernel modules\n");
+	LOGINFO("number of modules loaded by init: %lu\n", loaded_modules.size());
+	if (loaded_modules.size() == 0)
+		return kernel_modules;
+	for (auto&& module_line:loaded_modules) {
+		auto module = TWFunc::Split_String(module_line, " ")[0];
+		std::string full_module_name = module + ".ko";
+		auto found = std::find(kernel_modules.begin(), kernel_modules.end(), full_module_name);
+		if (found != kernel_modules.end()) {
+			LOGINFO("found module to dedupe: %s\n", (*found).c_str());
+			kernel_modules.erase(found);
+		}
+	}
+	return kernel_modules;
 }
 
 bool KernelModuleLoader::Write_Module_List(std::string module_dir) {
@@ -81,6 +111,11 @@
 	struct dirent* de;
 	std::vector<std::string> kernel_modules;
 	d = opendir(module_dir.c_str());
+	auto deduped_modules = Skip_Loaded_Kernel_Modules();
+	if (deduped_modules.size() == 0) {
+		LOGINFO("Requested modules are loaded\n");
+		return false;
+	}
 	if (d != nullptr) {
 		while ((de = readdir(d)) != nullptr) {
 			std::string kernel_module = de->d_name;
@@ -89,40 +124,40 @@
 					for (auto&& requested:kernel_modules_requested) {
 						if (kernel_module == requested) {
 							kernel_modules.push_back(kernel_module);
-                            continue;
-                        }
+							continue;
+						}
 					}
 					continue;
 				}
 			}
 		}
-        std::string module_file = module_dir + "/modules.load.twrp";
+		std::string module_file = module_dir + "/modules.load.twrp";
 		TWFunc::write_to_file(module_file, kernel_modules);
-        closedir(d);
+		closedir(d);
 	}
 	return true;
 }
 
 bool KernelModuleLoader::Copy_Modules_To_Tmpfs(std::string module_dir) {
-    std::string ramdisk_dir = "/tmp" + module_dir;
-    DIR* d;
+	std::string ramdisk_dir = "/tmp" + module_dir;
+	DIR* d;
 	struct dirent* de;
-    d = opendir(module_dir.c_str());
-    if (d != nullptr) {
-        while ((de = readdir(d)) != nullptr) {
-            std::string kernel_module = de->d_name;
-            if (de->d_type == DT_REG) {
-                std::string src =  module_dir + "/" + de->d_name;
-                std::string dest = ramdisk_dir + "/" + de->d_name;
-                if (TWFunc::copy_file(src, dest, 0700, false) != 0) {
-                    return false;
-                }
-            }
-        }
-        closedir(d);
-    } else {
-        LOGINFO("Unable to open module directory: %s. Skipping\n", module_dir.c_str());
-        return false;
-    }
-    return true;
-}
\ No newline at end of file
+	d = opendir(module_dir.c_str());
+	if (d != nullptr) {
+		while ((de = readdir(d)) != nullptr) {
+			std::string kernel_module = de->d_name;
+			if (de->d_type == DT_REG) {
+				std::string src =  module_dir + "/" + de->d_name;
+				std::string dest = ramdisk_dir + "/" + de->d_name;
+				if (TWFunc::copy_file(src, dest, 0700, false) != 0) {
+					return false;
+				}
+			}
+		}
+		closedir(d);
+	} else {
+		LOGINFO("Unable to open module directory: %s. Skipping\n", module_dir.c_str());
+		return false;
+	}
+	return true;
+}
diff --git a/kernel_module_loader.hpp b/kernel_module_loader.hpp
index 0115ce4..de28bd2 100644
--- a/kernel_module_loader.hpp
+++ b/kernel_module_loader.hpp
@@ -6,6 +6,7 @@
 #include <vector>
 #include <android-base/strings.h>
 #include <modprobe/modprobe.h>
+#include <sys/mount.h>
 #include <sys/utsname.h>
 
 #include "twcommon.h"
@@ -22,12 +23,13 @@
 class KernelModuleLoader
 {
 public:
-    static bool Load_Vendor_Modules(BOOT_MODE mode); // Load specific maintainer defined kernel modules in TWRP
+    static bool Load_Vendor_Modules(); // Load specific maintainer defined kernel modules in TWRP
 
 private:
-	static int Try_And_Load_Modules(std::string module_dir); // Use libmodprobe to attempt loading kernel modules
+	static int Try_And_Load_Modules(std::string module_dir, bool vendor_is_mounted); // Use libmodprobe to attempt loading kernel modules
 	static bool Write_Module_List(std::string module_dir); // Write list of modules to load from TW_LOAD_VENDOR_MODULES
     static bool Copy_Modules_To_Tmpfs(std::string module_dir); // Copy modules to ramdisk for loading
+	static std::vector<string> Skip_Loaded_Kernel_Modules(); // return list of loaded kernel modules already done by init
 };
 
 #endif // _KERNELMODULELOADER_HPP
\ No newline at end of file
diff --git a/twrp-functions.cpp b/twrp-functions.cpp
index 0f7ebb3..782afe1 100755
--- a/twrp-functions.cpp
+++ b/twrp-functions.cpp
@@ -709,10 +709,7 @@
 	std::ifstream srcfile(src.c_str(), ios::binary);
 	std::ofstream dstfile(dst.c_str(), ios::binary);
 	dstfile << srcfile.rdbuf();
-	if (!dstfile.bad()) {
-		LOGINFO("Copied file %s to %s\n", src.c_str(), dst.c_str());
-	}
-	else {
+	if (dstfile.bad()) {
 		LOGINFO("Unable to copy file %s to %s\n", src.c_str(), dst.c_str());
 		return -1;
 	}
diff --git a/twrp.cpp b/twrp.cpp
index 1b0e337..96446b2 100644
--- a/twrp.cpp
+++ b/twrp.cpp
@@ -119,8 +119,12 @@
 		PartitionManager.Setup_Super_Devices();
 		PartitionManager.Prepare_Super_Volume(ven);
 #ifdef TW_LOAD_VENDOR_MODULES
-		KernelModuleLoader::Load_Vendor_Modules(FASTBOOTD_MODE);
+		KernelModuleLoader::Load_Vendor_Modules();
+		if (android::base::GetBoolProperty("ro.virtual_ab.enabled", false)) {
+			PartitionManager.Unmap_Super_Devices();
+		}
 #endif
+
 		gui_msg(Msg("fastboot_console_msg=Entered Fastboot mode..."));
 		if (gui_startPage("fastboot", 1, 1) != 0) {
 			LOGERR("Failed to start fastbootd page.\n");
@@ -162,9 +166,9 @@
 #ifdef TW_LOAD_VENDOR_MODULES
 	bool fastboot_mode = cmdline.find("twrpfastboot=1") != std::string::npos;
 	if (fastboot_mode)
-		KernelModuleLoader::Load_Vendor_Modules(RECOVERY_FASTBOOT_MODE);
+		KernelModuleLoader::Load_Vendor_Modules();
 	else
-		KernelModuleLoader::Load_Vendor_Modules(RECOVERY_IN_BOOT_MODE);
+		KernelModuleLoader::Load_Vendor_Modules();
 #endif
 
 // We are doing this here to allow super partition to be set up prior to overriding properties