Commit ab60654e authored by Lubomir Bulej's avatar Lubomir Bulej

DiSL-Agent: Fixed problem with partial reads of the response.

DiSL-Agent: Revamped message creation and handling.
DiSL-Agent: Various cleanups.
parent 2bfcd547
......@@ -35,7 +35,7 @@ LINK_SHARED=$(LINK.c) -shared -o $@
COMMON_FLAGS=-fPIC -std=gnu99
# Options that help find errors
COMMON_FLAGS+= -W -Wall -Wextra -Wno-unused-parameter
COMMON_FLAGS+= -W -Wall -Wextra -Wno-unused-parameter -lpthread
CFLAGS += $(COMMON_FLAGS)
......@@ -44,7 +44,7 @@ CFLAGS += -I$(JAVA_HOME)/include -I$(JAVA_HOME)/include/linux
# add debugging output
ifneq ($(DEBUG),)
CFLAGS += -DDEBUG
CFLAGS += -DDEBUG -g
else
CFLAGS += -DNDEBUG -O3
endif
......
......@@ -6,9 +6,9 @@
/**
* Reports an actual general error. This function implements the slow path of
* check_error(). It prints the given error message and exits with an error code
* indicating a general error.
* Reports an error and terminates the program. This function implements the
* slow path of check_error(). It prints the given error message and exits with
* an error code indicating a general error.
*/
void
report_error (const char * message) {
......@@ -18,17 +18,15 @@ report_error (const char * message) {
/**
* Reports an actual standard library error. This function implements the slow path
* of check_std_error(). It prints the given error message along with the error
* message provided by the standard library, and exits with an error code indicating
* failure in standard library call.
* Reports a standard library error and terminates the program. This function
* implements the slow path of check_std_error(). It prints the given error
* message along with the error message provided by the standard library and
* exits with an error code indicating failure in standard library call.
*/
void
report_std_error (const char * message) {
char msgbuf [1024];
snprintf (msgbuf, sizeof (msgbuf), "%s%s", ERROR_PREFIX, message);
perror (msgbuf);
fprintf (stderr, "%s%s\n", ERROR_PREFIX, message);
perror ("cause");
exit (ERROR_STD);
}
......
......@@ -10,7 +10,7 @@
#define ERROR_STD 10002
#define ERROR_JVMTI 10003
#define ERROR_PREFIX "DiSL agent error: "
#define ERROR_PREFIX "DiSL-agent error: "
void report_error (const char * message);
......@@ -25,8 +25,8 @@ void report_std_error (const char * message);
/**
* Checks for a general error by testing the results of an error condition.
* Reports a general error and terminates the program if the condition is true.
* Reports a general error and terminates the program if the provided
* error condition is true.
*/
inline static void
check_error (bool error, const char * message) {
......@@ -37,13 +37,12 @@ check_error (bool error, const char * message) {
/**
* Checks the return value of a standard library call for error. Reports a standard
* library error and terminates the program if the the returned value matches the
* specified error value.
* Reports a standard library error and terminates the program if the provided
* error condition is true.
*/
inline static void
check_std_error (int retval, int errorval, const char * message) {
if (retval == errorval) {
check_std_error (bool error, const char * message) {
if (error) {
report_std_error (message);
}
}
......
......@@ -17,12 +17,12 @@
static void
__connection_init (struct connection * connection, const int sockfd) {
connection->sockfd = sockfd;
list_init (& connection->cp_link);
list_init (&connection->cp_link);
#ifdef DEBUG
connection->sent_bytes = 0;
connection->recv_bytes = 0;
#endif /* DEBUG */
#endif
}
......@@ -36,17 +36,17 @@ struct connection * connection_open (struct addrinfo * addr) {
// sender side and create a wrapper object for the connection.
//
int sockfd = socket(addr->ai_family, SOCK_STREAM, 0);
check_std_error(sockfd, -1, "failed to create socket");
check_std_error (sockfd < 0, "failed to create socket");
int connect_result = connect(sockfd, addr->ai_addr, addr->ai_addrlen);
check_std_error(connect_result, -1, "failed to connect to server");
check_std_error (connect_result < 0, "failed to connect to server");
int tcp_nodelay = 1;
int sso_result = setsockopt (
sockfd, IPPROTO_TCP, TCP_NODELAY,
&tcp_nodelay, sizeof (tcp_nodelay)
);
check_std_error(sso_result, -1, "failed to enable TCP_NODELAY");
check_std_error (sso_result < 0, "failed to enable TCP_NODELAY");
//
......@@ -67,11 +67,11 @@ connection_close (struct connection * connection) {
assert (connection != NULL);
#if DEBUG
fprintf (
stderr, "socket %d: sent bytes %llu, recv bytes %llu\n",
printf (
"debug: socket %d: sent bytes %llu, recv bytes %llu\n",
connection->sockfd, connection->sent_bytes, connection->recv_bytes
);
#endif /* DEBUG */
#endif
close (connection->sockfd);
free (connection);
......@@ -79,101 +79,103 @@ connection_close (struct connection * connection) {
//
static void
__socket_send (const int sockfd, const void * buf, const ssize_t len) {
typedef ssize_t (* xfer_fn) (int sockfd, void * buf, size_t len, int flags);
inline static ssize_t
__socket_xfer (xfer_fn xfer, const int sockfd, const void * buf, const ssize_t len) {
unsigned char * buf_tail = (unsigned char *) buf;
size_t remaining = len;
while (remaining > 0) {
int sent = send (sockfd, buf_tail, remaining, 0);
check_std_error (sent, -1, "error sending data to server");
ssize_t xferred = xfer (sockfd, buf_tail, remaining, 0);
if (xferred < 0) {
return -remaining;
}
remaining -= sent;
buf_tail += sent;
remaining -= xferred;
buf_tail += xferred;
}
return len;
}
/**
* Sends data into the given connection.
* Sends data into the given connection. Does not return until all provided
* data has been sent.
*/
void
ssize_t
connection_send (struct connection * connection, const void * buf, const ssize_t len) {
assert (connection != NULL);
assert (buf != NULL);
assert (len >= 0);
__socket_send (connection->sockfd, buf, len);
ssize_t sent = __socket_xfer ((xfer_fn) send, connection->sockfd, buf, len);
check_std_error (sent < 0, "error sending data to server");
#ifdef DEBUG
connection->sent_bytes += len;
#endif /* DEBUG*/
connection->sent_bytes += sent;
#endif
return sent;
}
/**
* Sends vectored data into the given connection.
* Sends vectored data into the given connection. May send less data than requested.
*/
void
ssize_t
connection_send_iov (struct connection * connection, const struct iovec * iov, int iovcnt) {
assert (connection != NULL);
assert (iov != NULL);
ssize_t written = writev (connection->sockfd, iov, iovcnt);
check_std_error (written, -1, "error sending data to server");
ssize_t sent = writev (connection->sockfd, iov, iovcnt);
check_std_error (sent < 0, "error sending data to server");
#ifdef DEBUG
connection->sent_bytes += written;
#endif /* DEBUG */
}
//
static void
__socket_recv (const int sockfd, void * buf, ssize_t len) {
unsigned char * buf_tail = (unsigned char *) buf;
ssize_t remaining = len;
connection->sent_bytes += sent;
#endif
while (remaining > 0) {
int received = recv (sockfd, buf_tail, remaining, 0);
check_std_error(received, -1, "error receiving data from server");
remaining -= received;
buf_tail += received;
}
check_error (remaining < 0, "received more data than expected");
return sent;
}
//
/**
* Receives a predefined amount of data from the given connection.
* Receives a predefined amount of data from the given connection. Does not return
* until all requested data has been received.
*/
void
ssize_t
connection_recv (struct connection * connection, void * buf, const ssize_t len) {
assert (connection != NULL);
assert (buf != NULL);
assert (len >= 0);
__socket_recv (connection->sockfd, buf, len);
ssize_t received = __socket_xfer ((xfer_fn) recv, connection->sockfd, buf, len);
check_std_error (received < 0, "error receiving data from server");
#ifdef DEBUG
connection->recv_bytes += len;
#endif /* DEBUG */
connection->recv_bytes += received;
#endif
return received;
}
/**
* Receives vectored data from the given connection.
* Receives vectored data from the given connection. May receive less data than requested.
*/
void
ssize_t
connection_recv_iov (struct connection * connection, const struct iovec * iov, int iovcnt) {
assert (connection != NULL);
assert (iov != NULL);
ssize_t read = readv (connection->sockfd, iov, iovcnt);
check_std_error (read, -1, "error receiving data from server");
ssize_t received = readv (connection->sockfd, iov, iovcnt);
check_std_error (received < 0, "error receiving data from server");
#ifdef DEBUG
connection->recv_bytes += read;
#endif /* DEBUG */
connection->recv_bytes += received;
#endif
return received;
}
......@@ -28,9 +28,10 @@ struct connection {
struct connection * connection_open (struct addrinfo * addr);
void connection_close (struct connection * connection);
void connection_send (struct connection * connection, const void * buf, const ssize_t len);
void connection_send_iov (struct connection * connection, const struct iovec * iov, int iovcnt);
void connection_recv (struct connection * connection, void * buf, const ssize_t len);
void connection_recv_iov (struct connection * connection, const struct iovec * iov, int iovcnt);
ssize_t connection_send (struct connection * connection, const void * buf, const ssize_t len);
ssize_t connection_send_iov (struct connection * connection, const struct iovec * iov, int iovcnt);
ssize_t connection_recv (struct connection * connection, void * buf, const ssize_t len);
ssize_t connection_recv_iov (struct connection * connection, const struct iovec * iov, int iovcnt);
#endif /* _CONNECTION_H_ */
......@@ -15,8 +15,8 @@ connection_pool_init (struct connection_pool * cp, struct addrinfo * endpoint) {
assert (endpoint != NULL);
cp->connections_count = 0;
list_init (& cp->free_connections);
list_init (& cp->busy_connections);
list_init (&cp->free_connections);
list_init (&cp->busy_connections);
cp->endpoint = endpoint;
cp->after_open_hook = NULL;
......@@ -52,9 +52,9 @@ connection_pool_get_connection (struct connection_pool * cp) {
// Grab the first available connection and return. If there is no connection
// available, create a new one and add it to the busy connection list.
//
if (!list_is_empty (& cp->free_connections)) {
struct list * item = list_remove_after (& cp->free_connections);
list_insert_after (item, & cp->busy_connections);
if (!list_is_empty (&cp->free_connections)) {
struct list * item = list_remove_after (&cp->free_connections);
list_insert_after (item, &cp->busy_connections);
return list_item (item, struct connection, cp_link);
} else {
......@@ -63,8 +63,11 @@ connection_pool_get_connection (struct connection_pool * cp) {
cp->after_open_hook (connection);
}
list_insert_after (& connection->cp_link, & cp->busy_connections);
list_insert_after (&connection->cp_link, &cp->busy_connections);
cp->connections_count++;
#ifdef DEBUG
printf ("[new connection, %d in total] ", cp->connections_count);
#endif
return connection;
}
}
......@@ -85,8 +88,8 @@ connection_pool_put_connection (
// Move the connection from the list of busy connections to
// the list of available connections.
//
struct list * item = list_remove (& connection->cp_link);
list_insert_after (item, & cp->free_connections);
struct list * item = list_remove (&connection->cp_link);
list_insert_after (item, &cp->free_connections);
}
//
......@@ -114,15 +117,15 @@ connection_pool_close (struct connection_pool * cp) {
assert (cp != NULL);
#ifdef DEBUG
fprintf (
stderr, "connection pool %s: max connections %d\n",
printf (
"debug: connection pool for %s: max connections %d\n",
cp->endpoint->ai_canonname, cp->connections_count
);
#endif /* DEBUG */
#endif
list_destroy (& cp->free_connections, __connection_destructor, (void *) cp);
if (!list_is_empty (& cp->busy_connections)) {
list_destroy (&cp->free_connections, __connection_destructor, (void *) cp);
if (!list_is_empty (&cp->busy_connections)) {
fprintf (stderr, "warning: closing %d active connections", cp->connections_count);
list_destroy (& cp->busy_connections, __connection_destructor, (void *) cp);
list_destroy (&cp->busy_connections, __connection_destructor, (void *) cp);
}
}
......@@ -8,8 +8,8 @@
#include <pthread.h>
#include <jvmti.h>
#include <jni.h>
#include <jvmti.h>
#include "common.h"
#include "jvmtiutil.h"
......@@ -27,8 +27,11 @@
// AGENT CONFIG
// ****************************************************************************
#define DISL_HOST_DEFAULT "localhost"
#define DISL_PORT_DEFAULT "11217"
#define DISLSERVER_HOST "dislserver.host"
#define DISLSERVER_HOST_DEFAULT "localhost"
#define DISLSERVER_PORT "dislserver.port"
#define DISLSERVER_PORT_DEFAULT "11217"
#define DISL_BYPASS "disl.bypass"
#define DISL_BYPASS_DEFAULT "dynamic"
......@@ -130,34 +133,65 @@ __calc_code_flags (struct config * config, bool jvm_is_booting) {
/**
* Sends the given class to the remote server for instrumentation. Returns
* the response from the server that contains the instrumented class.
* Sends the given class to the remote server for instrumentation. If the
* server modified the class, updates the provided class definition structure
* and returns true. Otherwise, the structure is left unmodified and false
* is returned.
*/
static struct message
static bool
__instrument_class (
jint request_flags, const char * classname,
const unsigned char * classcode, jint classcode_size
jint request_flags, const char * class_name,
jvmtiClassDefinition * class_def
) {
//
// Acquire a connection, put the class data into the message and
// send it to the server, wait for the response, and release the
// Put the class data into a request message, acquire a connection and
// send the it to the server. Receive the response and release the
// connection again.
//
struct message request = {
.message_flags = request_flags,
.control_size = (class_name != NULL) ? strlen (class_name) : 0,
.classcode_size = class_def->class_byte_count,
.control = (unsigned char *) class_name,
.classcode = class_def->class_bytes,
};
//
struct connection * conn = network_acquire_connection ();
message_send (conn, &request);
struct message response;
message_recv (conn, &response);
network_release_connection (conn);
//
// Check if error occurred on the server.
// The control field of the response contains the error message.
//
if (response.control_size > 0) {
fprintf (
stderr,
"%sinstrumentation server error:\n%s\n",
ERROR_PREFIX, response.control
);
// TODO: This would do just with a thread-local message
struct message request = create_message (
request_flags,
(const unsigned char *) classname, strlen (classname),
classcode, classcode_size
);
send_message (conn, &request);
struct message result = recv_message (conn);
exit (ERROR_SERVER);
}
//
// Update the class definition and signal modified class if
// any class code has been returned. If not, the class has
// not been modified.
//
if (response.classcode_size > 0) {
class_def->class_byte_count = response.classcode_size;
class_def->class_bytes = response.classcode;
return true;
network_release_connection (conn);
return result;
} else {
return false;
}
}
......@@ -166,60 +200,71 @@ jvmti_callback_class_file_load (
jvmtiEnv * jvmti, JNIEnv * jni,
jclass class_being_redefined, jobject loader,
const char * class_name, jobject protection_domain,
jint class_data_len, const unsigned char * class_data,
jint * new_class_data_len, unsigned char ** new_class_data
jint class_byte_count, const unsigned char * class_bytes,
jint * new_class_byte_count, unsigned char ** new_class_bytes
) {
assert (jvmti != NULL);
#ifdef DEBUG
if (class_name != NULL) {
printf ("Instrumenting class %s\n", class_name);
} else {
printf ("Instrumenting unknown class\n");
}
printf (
"debug: instrumenting class %s, %d bytes at %p\n",
(class_name != NULL) ? class_name : "<unknown>",
class_byte_count, class_bytes
);
#endif
// skip instrumentation of the bypass check class
if (strcmp (class_name, BPC_CLASS_NAME) == 0) {
//
// Avoid instrumenting the bypass check class.
//
if (class_name != NULL && (strcmp (class_name, BPC_CLASS_NAME) == 0)) {
#ifdef DEBUG
printf ("Skipping class %s\n", class_name);
printf ("debug: skipping bypass check class (%s)\n", class_name);
#endif
return;
}
// ask the server to instrument the class
struct message instrclass = __instrument_class (
agent_code_flags, class_name, class_data, class_data_len
);
//
// Instrument the class and if changed by the server, provide the
// code to the JVM in its own memory.
//
jvmtiClassDefinition class_def = {
.class_byte_count = class_byte_count,
.class_bytes = class_bytes,
};
// error on the server
if (instrclass.control_size > 0) {
// classname contains the error message
fprintf(stderr, "%sError occurred in the remote instrumentation server\n", ERROR_PREFIX);
fprintf(stderr, " Reason: %s\n", instrclass.control);
exit (ERROR_SERVER);
}
bool class_changed = __instrument_class (
agent_code_flags, class_name, &class_def
);
// instrumented class received (0 - means no instrumentation done)
if(instrclass.classcode_size > 0) {
// give to JVM the instrumented class
unsigned char * new_class_space;
if (class_changed) {
unsigned char * jvm_class_bytes;
jvmtiError error = (*jvmti)->Allocate (
jvmti, (jlong) class_def.class_byte_count, &jvm_class_bytes
);
check_jvmti_error (
jvmti, error,
"failed to allocate memory for the instrumented class"
);
// let JVMTI to allocate the mem for the new class
jvmtiError err = (*jvmti)->Allocate (jvmti, (jlong) instrclass.classcode_size, & new_class_space);
check_jvmti_error (jvmti, err, "Cannot allocate memory for the instrumented class");
//
memcpy (new_class_space, instrclass.classcode, instrclass.classcode_size);
memcpy (jvm_class_bytes, class_def.class_bytes, class_def.class_byte_count);
free ((void *) class_def.class_bytes);
// set the newly instrumented class + len
*(new_class_data_len) = instrclass.classcode_size;
*(new_class_data) = new_class_space;
*new_class_byte_count = class_def.class_byte_count;
*new_class_bytes = jvm_class_bytes;
// free memory
free_message (&instrclass);
#ifdef DEBUG
printf (
"debug: class redefined, %d bytes at %p\n",
class_def.class_byte_count, jvm_class_bytes
);
#endif
}
#ifdef DEBUG
printf("Instrumentation done\n");
printf ("debug: instrumentation done\n");
#endif
}
......@@ -230,10 +275,14 @@ jvmti_callback_class_file_load (
static void JNICALL
jvmti_callback_vm_init (jvmtiEnv * jvmti, JNIEnv * jni, jthread thread) {
#ifdef DEBUG
printf ("debug: the VM has been initialized\n");
#endif
//
// Update code flags to reflect that the VM has stopped booting.
//
agent_code_flags = __calc_code_flags (& agent_config, false);
agent_code_flags = __calc_code_flags (&agent_config, false);
//
// Redefine the bypass check class. If dynamic bypass is required, use
......@@ -243,14 +292,14 @@ jvmti_callback_vm_init (jvmtiEnv * jvmti, JNIEnv * jni, jthread thread) {
jvmtiClassDefinition * bpc_classdef;
if (agent_config.bypass_mode == BYPASS_MODE_DYNAMIC) {
#ifdef DEBUG
fprintf (stderr, "vm_init: redefining BypassCheck for dynamic bypass\n");
printf ("debug: redefining BypassCheck for dynamic bypass\n");
#endif
bpc_classdef = & bpc_dynamic_classdef;
bpc_classdef = &bpc_dynamic_classdef;
} else {
#ifdef DEBUG
fprintf (stderr, "vm_init: redefining BypassCheck to disable bypass\n");
printf ("debug: redefining BypassCheck to disable bypass\n");
#endif
bpc_classdef = & bpc_never_classdef;
bpc_classdef = &bpc_never_classdef;
}