记录一下java native agent的学习过程,也顺便造一个检测log4j漏洞的轮子:
java native agent相比java agent最大的好处是快,C++写的,快的一笔,但是最大的坏处是非常麻烦,毕竟你拿个面过程的语言怼面对象的肯定是比较麻烦的。
本次学习的目的是做个加载器,动态加载agent,然后再实现检测log4j。
初始化JVM
首先我们需要给自己初始化一个jvm(进程附加到其他进程)或者用现成的jvm(dll注入到jvm里面),不管是加载jvm还是用现成的jvm,都需要获取jvm.dll的handle解析导出函数,只不过一个是”JNI_CreateJavaVM”另外一个是”JNI_GetCreatedJavaVMs”
//如果不指定路径,当是dll注入
if (path_ == nullptr) {
const auto jvmHandle = GetModuleHandleA("jvm.dll");
if (jvmHandle == nullptr) {
isVm.store(false);
throw InitializationException("Could not load JVM library");
}
typedef jint(JNICALL* GetCreatedJavaVMs_t)(JavaVM**, jsize, jsize*);
GetCreatedJavaVMs_t JNI_GetCreatedJavaVMs = (GetCreatedJavaVMs_t)GetProcAddress(jvmHandle, "JNI_GetCreatedJavaVMs");
if (JNI_GetCreatedJavaVMs == NULL || JNI_GetCreatedJavaVMs(0, 0, &vmNums) != 0)
{
isVm.store(false);
throw InitializationException("Java Virtual Machine failed during creation");
}
JavaVM** buffer = new JavaVM*[vmNums];
if (JNI_GetCreatedJavaVMs(buffer, vmNums, &vmNums) != 0) {
isVm.store(false);
throw InitializationException("Java Virtual Machine failed during creation #2");
}
//windows每个线程只能有一个jvm
javaVm = buffer[0];
//delete buffer[vmNums];
}
else {
HMODULE lib = ::LoadLibraryA(path_);
if (lib == NULL)
{
isVm.store(false);
throw InitializationException("Could not load JVM library");
}
CreateVm_t JNI_CreateJavaVM = (CreateVm_t) ::GetProcAddress(lib, "JNI_CreateJavaVM");
/**
Is your debugger catching an error here? This is normal. Just continue. The JVM
intentionally does this to test how the OS handles memory-reference exceptions.
*/
if (JNI_CreateJavaVM == NULL || JNI_CreateJavaVM(&javaVm, (void**)&env, &args) != 0)
{
isVm.store(false);
::FreeLibrary(lib);
throw InitializationException("Java Virtual Machine failed during creation");
}
一旦拿到javaVm 后,就可以拿到env了:
if (vm->GetEnv((void**)&_env, JNI_VERSION_1_2) != JNI_OK)
{
#ifdef __ANDROID__
if (vm->AttachCurrentThread(&_env, nullptr) != 0)
#else
if (vm->AttachCurrentThread((void**)&_env, nullptr) != 0)
#endif
throw InitializationException("Could not attach JNI to thread");
_attached = true;
}
这样我们就有当前进程的jvm的权限了
定位函数
然后我们要定位VirtualMachine这个类,用里面的attach方法加载我们的native
首先定位这个api的class:
static jclass findClass(const char* name)
{
jclass ref = env()->FindClass(name);
if (ref == nullptr)
{
env()->ExceptionClear();
throw NameResolutionException(name);
}
return ref;
}
使用方法:
jni::Class virtualMachineClass = jni::Class("com/sun/tools/attach/VirtualMachine");
native api使用java的API要预先设置所谓的signature,signature有一套命名规范,最简单的方法是用javap.exe看jar包的signature.比如VirtualMachine里面的attach就是(Ljava/lang/String;)Lcom/sun/tools/attach/VirtualMachine;
封装
简单的封装一下:
template <class TReturn, class... TArgs>
TReturn cs_dynamic_call(const Object& obj, const char* name, const char* sig, const TArgs&... args) const {
method_t method = getMethod(name, sig);
return call<TReturn>(obj, method, args...);
}
就可以随心所欲的call了
jni::Object virtualMachineObject = virtualMachineClass.cs_call<jni::Object>("attach","(Ljava/lang/String;)Lcom/sun/tools/attach/VirtualMachine;", argv[1]); //目标pid
virtualMachineClass.cs_dynamic_call<void>(virtualMachineObject, "loadAgentPath", "(Ljava/lang/String;)V", "E:\\agent.dll");
virtualMachineClass.cs_dynamic_call<void>(virtualMachineObject, "detach", "()V");
至此,就能动态附加一个agent了这是loader的完整代码:
#include "pch.h"
#include <assert.h>
auto main(int argc, char** argv) -> int {
jni::Vm vm("C:\\Program Files\\Java\\jdk-16.0.2\\bin\\server\\jvm.dll");
jni::Class virtualMachineClass = jni::Class("com/sun/tools/attach/VirtualMachine");
jni::Object virtualMachineObject = virtualMachineClass.cs_call<jni::Object>("attach","(Ljava/lang/String;)Lcom/sun/tools/attach/VirtualMachine;", argv[1]); //目标pid
virtualMachineClass.cs_dynamic_call<void>(virtualMachineObject, "loadAgentPath", "(Ljava/lang/String;)V", "E:\\agent.dll");
virtualMachineClass.cs_dynamic_call<void>(virtualMachineObject, "detach", "()V");
return 0;
}
Log4j检测
要让agent动态加载成功,需要给agent导出两个函数:
Agent_OnAttach
Agent_OnUnload
这样就行了:
JNIEXPORT auto __stdcall Agent_OnAttach(JavaVM* vm, char* options,
void* reserved) -> jint {
return Agent::Init(vm);
}
JNIEXPORT auto __stdcall Agent_OnUnload(JavaVM* vm) -> void {
Tools::DbgPrint("Agent_OnUnload");
}
初始化回调
agent加载进来后咋办呢,看这个:
https://docs.oracle.com/javase/8/docs/platform/jvmti/jvmti.htm
这里面是所有agent能触发的回调与设置列表,让我们把目光看到:
https://docs.oracle.com/javase/8/docs/platform/jvmti/jvmti.html#ClassFileLoadHook
This event is sent when the VM obtains class file data, but before it constructs the in-memory representation for that class. This event is also sent when the class is being modified by the RetransformClasses function or the RedefineClasses function, called in any JVM TI environment. The agent can instrument the existing class file data sent by the VM to include profiling/debugging hooks. See the description of bytecode instrumentation for usage information.
按手册的说法,首先需要设置
Capability
jvmtiCapabilities capabilities = {0};
capabilities.can_generate_all_class_hook_events = 1;
jvmti->AddCapabilities(&capabilities);
然后设置
Event
jvmtiEventCallbacks callbacks = {0};
callbacks.ClassFileLoadHook = Callback::ClassFileLoadHook;
最后设置
EventNotificationMode
jvmti->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, NULL);
然后坑来了,你用vs的cl编译器会崩溃:
我这里是换clang编译器才通过的,原因不明.不理他了
别忘了回调的定义:
auto __stdcall ClassFileLoadHook(jvmtiEnv* jvmti_env, JNIEnv* jni_env,
jclass class_being_redefined, jobject loader,
const char* name, jobject protection_domain,
jint class_data_len,
const unsigned char* class_data,
jint* new_class_data_len,
unsigned char** new_class_data) -> void
检测log4j
为了实现log4j的检测,我们设置一些符合的条件:
- 被加载的class是不是不在项目文件中(内存加载)
- 加载者是不是可疑加载者比如反序化常用的那些库
- 栈上是否有可疑的库,比如log4j
让我们一步一步实现:
被加载的class是不是不在项目文件中(内存加载)
auto findJavaClass(JNIEnv* jni_env, const char* name) -> jclass {
jclass ref = jni_env->FindClass(name);
if (ref == nullptr) {
jni_env->ExceptionClear();
}
return ref;
}
找到为nullptr就是内存加载的东西了
加载者是不是可疑加载者比如反序化常用的那些库
为了拿到加载者的信息,我封装了一个方法…
原理跟java的
System.out.println(obj.getClass().getResource(obj.getClass().getSimpleName()+".class"));
相同,但是是面向过程
auto getClassFullPath(JNIEnv* jni_env, jobject javaObj) -> std::string { std::string result{}; // getSimpleName jclass javaClass = jni_env->GetObjectClass(javaObj); jmethodID method_GetClass = jni_env->GetMethodID(javaClass, "getClass", "()Ljava/lang/Class;"); jobject getClass_CallObj = jni_env->CallObjectMethod(javaObj, method_GetClass); jclass getClass_CallClass = jni_env->GetObjectClass(getClass_CallObj); jmethodID method_GetSimpleName = jni_env->GetMethodID( getClass_CallClass, "getSimpleName", "()Ljava/lang/String;"); jstring jstring_GetSimpleName = (jstring)jni_env->CallObjectMethod( getClass_CallObj, method_GetSimpleName); const char* simpleName_StringBuffer = jni_env->GetStringUTFChars(jstring_GetSimpleName, nullptr); std::string simpleName = simpleName_StringBuffer; simpleName += ".class"; jni_env->ReleaseStringUTFChars(jstring_GetSimpleName, simpleName_StringBuffer); // getResource jmethodID method_GetResource = jni_env->GetMethodID(getClass_CallClass, "getResource", "(Ljava/lang/String;)Ljava/net/URL;"); jstring jstring_SimpleNameBuffer = jni_env->NewStringUTF(simpleName.c_str()); jobject urlObj = (jstring)jni_env->CallObjectMethod( getClass_CallObj, method_GetResource, jstring_SimpleNameBuffer); if (urlObj == nullptr) { // fix me return result; } // url_obj to string jclass javaUrlClass = jni_env->GetObjectClass(urlObj); jmethodID method_ToString = jni_env->GetMethodID(javaUrlClass, "toString", "()Ljava/lang/String;"); jstring jstring_urlStringBuffer = (jstring)jni_env->CallObjectMethod(urlObj, method_ToString); const char* url = jni_env->GetStringUTFChars(jstring_urlStringBuffer, nullptr); result = url; jni_env->ReleaseStringUTFChars(jstring_urlStringBuffer, url); return result; }
栈上是否有可疑的库
我这里给可疑库定的是:
org.apache.logging.log4j.core.lookup.JndiLookup.lookup
实现起来也是花了点功夫
auto getStackPackageList(JNIEnv* jni_env) -> std::vector<std::string> {
std::vector<std::string> packageList{};
// get StackTraceElement by Thread.currentThread().getStackTrace()
jclass javaThreadClass = findJavaClass(jni_env, "java/lang/Thread");
jmethodID method_CurrentThread = jni_env->GetStaticMethodID(
javaThreadClass, "currentThread", "()Ljava/lang/Thread;");
jobject javaThreadObj =
jni_env->CallStaticObjectMethod(javaThreadClass, method_CurrentThread);
jclass javaThreadObjClass = jni_env->GetObjectClass(javaThreadObj);
jmethodID method_GetStackTrace =
jni_env->GetMethodID(javaThreadObjClass, "getStackTrace",
"()[Ljava/lang/StackTraceElement;");
jobjectArray javaStackTraceElementArray =
(jobjectArray)jni_env->CallObjectMethod(javaThreadObj,
method_GetStackTrace);
// get StackTraceElement
jclass javaStackTraceElementClass =
findJavaClass(jni_env, "java/lang/StackTraceElement");
jmethodID method_GetClassName = jni_env->GetMethodID(
javaStackTraceElementClass, "getClassName", "()Ljava/lang/String;");
jmethodID method_GetMethodName = jni_env->GetMethodID(
javaStackTraceElementClass, "getMethodName", "()Ljava/lang/String;");
/*
jmethodID method_GetLineNumber = jni_env->GetMethodID(
javaStackTraceElementClass, "getLineNumber", "()I");
*/
// get StackTraceElement
const auto javaStackTraceElementArrayLength =
jni_env->GetArrayLength(javaStackTraceElementArray);
for (auto i = 0; i < javaStackTraceElementArrayLength; i++) {
jobject javaStackTraceElementObj =
jni_env->GetObjectArrayElement(javaStackTraceElementArray, i);
jstring jstring_ClassNameBuffer = (jstring)jni_env->CallObjectMethod(
javaStackTraceElementObj, method_GetClassName);
const char* className =
jni_env->GetStringUTFChars(jstring_ClassNameBuffer, nullptr);
jstring jstring_MethodNameBuffer = (jstring)jni_env->CallObjectMethod(
javaStackTraceElementObj, method_GetMethodName);
const char* methodName =
jni_env->GetStringUTFChars(jstring_MethodNameBuffer, nullptr);
std::string fullPackageName{};
fullPackageName += className;
fullPackageName += ".";
fullPackageName += methodName;
packageList.push_back(fullPackageName);
// Tools::DbgPrint("fullPackageName: %s \n", fullPackageName.c_str());
/*
jint lineNumber =
jni_env->CallIntMethod(javaStackTraceElementObj,
method_GetLineNumber);
Tools::DbgPrint("StackTraceElement: %s.%s:%d \n", className,
methodName, lineNumber);
*/
jni_env->ReleaseStringUTFChars(jstring_ClassNameBuffer, className);
jni_env->ReleaseStringUTFChars(jstring_MethodNameBuffer, methodName);
}
return packageList;
}
让我们把这些组合在一起:
// dllmain.cpp : 定义 DLL 应用程序的入口点。
#include "pch.h"
namespace Agent {
namespace Callback {
auto findJavaClass(JNIEnv* jni_env, const char* name) -> jclass {
jclass ref = jni_env->FindClass(name);
if (ref == nullptr) {
jni_env->ExceptionClear();
}
return ref;
}
auto getClassFullPath(JNIEnv* jni_env, jobject javaObj) -> std::string {
std::string result{};
// getSimpleName
jclass javaClass = jni_env->GetObjectClass(javaObj);
jmethodID method_GetClass =
jni_env->GetMethodID(javaClass, "getClass", "()Ljava/lang/Class;");
jobject getClass_CallObj =
jni_env->CallObjectMethod(javaObj, method_GetClass);
jclass getClass_CallClass = jni_env->GetObjectClass(getClass_CallObj);
jmethodID method_GetSimpleName = jni_env->GetMethodID(
getClass_CallClass, "getSimpleName", "()Ljava/lang/String;");
jstring jstring_GetSimpleName = (jstring)jni_env->CallObjectMethod(
getClass_CallObj, method_GetSimpleName);
const char* simpleName_StringBuffer =
jni_env->GetStringUTFChars(jstring_GetSimpleName, nullptr);
std::string simpleName = simpleName_StringBuffer;
simpleName += ".class";
jni_env->ReleaseStringUTFChars(jstring_GetSimpleName,
simpleName_StringBuffer);
// getResource
jmethodID method_GetResource =
jni_env->GetMethodID(getClass_CallClass, "getResource",
"(Ljava/lang/String;)Ljava/net/URL;");
jstring jstring_SimpleNameBuffer =
jni_env->NewStringUTF(simpleName.c_str());
jobject urlObj = (jstring)jni_env->CallObjectMethod(
getClass_CallObj, method_GetResource, jstring_SimpleNameBuffer);
if (urlObj == nullptr) {
// fix me
return result;
}
// url_obj to string
jclass javaUrlClass = jni_env->GetObjectClass(urlObj);
jmethodID method_ToString =
jni_env->GetMethodID(javaUrlClass, "toString", "()Ljava/lang/String;");
jstring jstring_urlStringBuffer =
(jstring)jni_env->CallObjectMethod(urlObj, method_ToString);
const char* url =
jni_env->GetStringUTFChars(jstring_urlStringBuffer, nullptr);
result = url;
jni_env->ReleaseStringUTFChars(jstring_urlStringBuffer, url);
return result;
}
auto getStackPackageList(JNIEnv* jni_env) -> std::vector<std::string> {
std::vector<std::string> packageList{};
// get StackTraceElement by Thread.currentThread().getStackTrace()
jclass javaThreadClass = findJavaClass(jni_env, "java/lang/Thread");
jmethodID method_CurrentThread = jni_env->GetStaticMethodID(
javaThreadClass, "currentThread", "()Ljava/lang/Thread;");
jobject javaThreadObj =
jni_env->CallStaticObjectMethod(javaThreadClass, method_CurrentThread);
jclass javaThreadObjClass = jni_env->GetObjectClass(javaThreadObj);
jmethodID method_GetStackTrace =
jni_env->GetMethodID(javaThreadObjClass, "getStackTrace",
"()[Ljava/lang/StackTraceElement;");
jobjectArray javaStackTraceElementArray =
(jobjectArray)jni_env->CallObjectMethod(javaThreadObj,
method_GetStackTrace);
// get StackTraceElement
jclass javaStackTraceElementClass =
findJavaClass(jni_env, "java/lang/StackTraceElement");
jmethodID method_GetClassName = jni_env->GetMethodID(
javaStackTraceElementClass, "getClassName", "()Ljava/lang/String;");
jmethodID method_GetMethodName = jni_env->GetMethodID(
javaStackTraceElementClass, "getMethodName", "()Ljava/lang/String;");
/*
jmethodID method_GetLineNumber = jni_env->GetMethodID(
javaStackTraceElementClass, "getLineNumber", "()I");
*/
// get StackTraceElement
const auto javaStackTraceElementArrayLength =
jni_env->GetArrayLength(javaStackTraceElementArray);
for (auto i = 0; i < javaStackTraceElementArrayLength; i++) {
jobject javaStackTraceElementObj =
jni_env->GetObjectArrayElement(javaStackTraceElementArray, i);
jstring jstring_ClassNameBuffer = (jstring)jni_env->CallObjectMethod(
javaStackTraceElementObj, method_GetClassName);
const char* className =
jni_env->GetStringUTFChars(jstring_ClassNameBuffer, nullptr);
jstring jstring_MethodNameBuffer = (jstring)jni_env->CallObjectMethod(
javaStackTraceElementObj, method_GetMethodName);
const char* methodName =
jni_env->GetStringUTFChars(jstring_MethodNameBuffer, nullptr);
std::string fullPackageName{};
fullPackageName += className;
fullPackageName += ".";
fullPackageName += methodName;
packageList.push_back(fullPackageName);
// Tools::DbgPrint("fullPackageName: %s \n", fullPackageName.c_str());
/*
jint lineNumber =
jni_env->CallIntMethod(javaStackTraceElementObj,
method_GetLineNumber);
Tools::DbgPrint("StackTraceElement: %s.%s:%d \n", className,
methodName, lineNumber);
*/
jni_env->ReleaseStringUTFChars(jstring_ClassNameBuffer, className);
jni_env->ReleaseStringUTFChars(jstring_MethodNameBuffer, methodName);
}
return packageList;
}
auto __stdcall ClassFileLoadHook(jvmtiEnv* jvmti_env, JNIEnv* jni_env,
jclass class_being_redefined, jobject loader,
const char* name, jobject protection_domain,
jint class_data_len,
const unsigned char* class_data,
jint* new_class_data_len,
unsigned char** new_class_data) -> void {
if (loader == nullptr || jni_env == nullptr) {
return;
}
clock_t startTime, endTime;
startTime = clock();
// 是否在原有文件中
const auto javaClass = findJavaClass(jni_env, name);
// 检查loader的类 todo:黑名单类
const auto loaderPath = getClassFullPath(jni_env, loader);
if (loaderPath.size() == 0) {
return;
}
// 检查栈
const auto packageList = getStackPackageList(jni_env);
static std::string log4jJniPackageName =
"org.apache.logging.log4j.core.lookup.JndiLookup.lookup";
bool lookUpJniManager = false;
for (const auto& package : packageList) {
if (package == log4jJniPackageName) {
lookUpJniManager = true;
break;
}
}
endTime = clock();
if (lookUpJniManager && javaClass == nullptr) {
Tools::DbgPrint("suspicious package load : %s from %s by log4j \n",
name, loaderPath.c_str());
Tools::DbgPrint("time: %lf ms \n",
(double)(endTime - startTime) / CLOCKS_PER_SEC);
}
}
} // namespace Callback
auto Init(JavaVM* vm) -> jint {
jvmtiEnv* jvmti = nullptr;
jint res = vm->GetEnv((void**)&jvmti, JVMTI_VERSION_1_0);
if (res != JNI_OK || jvmti == nullptr) {
Tools::DbgPrint(
"ERROR: Unable to access JVMTI Version 1, 2 or higher\n");
return JNI_ERR;
}
// https://docs.oracle.com/javase/8/docs/platform/jvmti/jvmti.html#ClassFileLoadHook
jvmtiEventCallbacks callbacks = {0};
callbacks.ClassFileLoadHook = Callback::ClassFileLoadHook;
jvmtiCapabilities capabilities = {0};
capabilities.can_generate_all_class_hook_events = 1;
jvmti->AddCapabilities(&capabilities);
jvmti->SetEventCallbacks(&callbacks, sizeof(jvmtiEventCallbacks));
jvmti->SetEventNotificationMode(JVMTI_ENABLE,
JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, NULL);
Tools::DbgPrint("Agent::Init\n");
return 0;
}
} // namespace Agent
JNIEXPORT auto __stdcall Agent_OnAttach(JavaVM* vm, char* options,
void* reserved) -> jint {
return Agent::Init(vm);
}
JNIEXPORT auto __stdcall Agent_OnUnload(JavaVM* vm) -> void {
Tools::DbgPrint("Agent_OnUnload");
}
auto __stdcall DllMain(HMODULE hModule, DWORD ul_reason_for_call,
LPVOID lpReserved) -> bool {
switch (ul_reason_for_call) {
case DLL_PROCESS_ATTACH:
case DLL_THREAD_ATTACH:
case DLL_THREAD_DETACH:
case DLL_PROCESS_DETACH:
break;
}
return true;
}
测试
整了个sprinboot:
package com.example.sprinboottest;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.lang.management.ManagementFactory;
@RestController
public class webpage {
private static final Logger logger = LogManager.getLogger(webpage.class);
@RequestMapping("/")
public String demo(@RequestParam("a") String s){
logger.error("str={}",s);
return "Hello World!";
}
}
请注意sprinboot得设置iml文件把版本改成带漏洞的版本
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-core:2.14.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.14.1" level="project" />
然后启动sprinboot,结合之前的加载器附加进去,触发漏洞,看看效果:
看起来还不错哈…
项目github
https://github.com/huoji120/log4j_detect
如需授权、对文章有疑问或需删除稿件,请联系 FreeBuf 客服小蜜蜂(微信:freebee1024)