Commit d7591837 authored by Lubomir Bulej's avatar Lubomir Bulej
Browse files

Cleanup session initiation in legacy DiSL agent

This comprises several changes:
- send_client_message now uses a reference to ClientMessage
- neither session_start nor session_end gets to use the config
  structure directly, they just get the bits they need
- there is a new method session_send_instrumentation, which
  loads the jars and sends them to the server, instead of this
  being handled in session_start
- the jars are loaded directly into the common buffer
parent aa71b9fb
......@@ -52,7 +52,7 @@ struct config {
// Instrumentation jars.
size_t jar_count;
char ** instrumentation_jars;
char ** jar_paths;
// Code generation.
enum bypass_mode bypass_mode;
......
......@@ -203,7 +203,7 @@ __instrument_class (
//
struct connection * conn = network_acquire_connection ();
send_client_message (message, conn);
send_client_message (&message, conn);
ServerMessage * resp = receive_server_message (conn);
network_release_connection (conn);
......@@ -578,7 +578,7 @@ jvmti_callback_vm_init (jvmtiEnv * jvmti, JNIEnv * jni, jthread thread) {
// Update flags to reflect that the VM has stopped booting.
//
jvm_is_initialized = true;
agent_code_flags = __calc_code_flags (&agent_config, false);
agent_code_flags = __calc_code_flags (&agent_config, false /* jvm_is_booting */);
struct thread_info info = INIT_THREAD_INFO;
rdexec {
......@@ -662,7 +662,7 @@ __parse_instrumentations (char * list, struct config * config) {
char ** jar_paths = split_string (list, LIST_SEPARATOR, &jar_count);
config->jar_count = jar_count;
config->instrumentation_jars = jar_paths;
config->jar_paths = jar_paths;
}
static void
......@@ -723,6 +723,10 @@ __configure_from_properties (jvmtiEnv * jvmti, struct config * config) {
rdaprefix ("force superclass: %d\n", config->force_superclass);
rdaprefix ("force interfaces: %d\n", config->force_interfaces);
rdaprefix ("runtime debug: %d\n", config->debug);
rdaprefix ("instrumentation jars: %zu\n", config->jar_count);
for (size_t i = 0; i < config->jar_count; i++) {
rdaprefix ("\t%s\n", config->jar_paths[i]);
}
}
}
......@@ -852,18 +856,22 @@ Agent_OnLoad (JavaVM * jvm, char * options, void * reserved) {
__jvmti_enable_events (jvmti, events, sizeof_array (events));
// configure agent and init connections
// Configure agent and init globals.
__configure_from_options (options, &agent_config);
__configure_from_properties (jvmti, &agent_config);
jvm_is_started = false;
jvm_is_initialized = false;
agent_code_flags = __calc_code_flags (&agent_config, true);
agent_code_flags = __calc_code_flags (&agent_config, true /* jvm_is_booting */);
rdaprintf ("agent loaded, initializing connections\n");
// Start session and send instrumentation to the server.
rdaprintf ("agent initialized, starting server session\n");
network_init (agent_config.server_host, agent_config.server_port);
agent_config.session_id = session_start (true /* disl */, false /* shvm */);
rdaprintf ("received session id %d, sending instrumentation\n", agent_config.session_id);
session_send_instrumentation (agent_config.session_id, agent_config.jar_count, agent_config.jar_paths);
session_start (&agent_config);
return 0;
}
......@@ -879,6 +887,6 @@ Agent_OnUnload (JavaVM * jvm) {
//
// Just close all the connections.
//
// session_end ();
session_end (agent_config.session_id);
network_fini ();
}
......@@ -21,17 +21,17 @@
* Send a client message over a network connection.
*/
void
send_client_message (const ClientMessage message, struct connection * restrict conn) {
size_t send_size = client_message__get_packed_size (&message);
send_client_message (const ClientMessage * message, struct connection * restrict conn) {
size_t send_size = client_message__get_packed_size (message);
void * buffer = malloc (send_size);
check_error (buffer == NULL, "failed to allocate buffer for client message");
client_message__pack (&message, buffer);
client_message__pack (message, buffer);
message_send (conn, buffer, send_size);
free (buffer);
}
/**
* Receive a server message from a network connection.
*/
......@@ -47,57 +47,27 @@ receive_server_message (struct connection * restrict conn) {
return (response);
}
/*
* Instrumentation jar structure
*/
struct inst_jar {
char * name;
void * buffer;
size_t filesize;
};
/*
* Get filesize
*/
static long
get_filesize (const char * filename) {
struct stat buffer;
int result = stat (filename, &buffer);
check_std_error (result != 0, "failed to determine size of %s", filename);
return buffer.st_size;
static off_t
__get_file_size (const char * restrict name) {
struct stat statbuf;
int result = stat (name, &statbuf);
check_std_error (result != 0, "failed to determine size of %s", name);
return statbuf.st_size;
}
static void
__load_file (const char * name, void ** buffer, size_t * size) {
static size_t
__load_file (const char * name, void * restrict buffer, const size_t size) {
FILE * file = fopen (name, "r");
check_std_error (file == NULL, "failed to open %s", name);
size_t file_size = (size_t) get_filesize (name);
void * file_data = malloc (file_size);
check_std_error (file_data == NULL, "failed to allocate buffer for %s", name);
size_t bytes_read = fread (file_data, 1, file_size, file);
check_std_error (bytes_read != file_size, "failed to load %s", name);
size_t bytes_read = fread (buffer, 1, size, file);
check_std_error (bytes_read != size, "failed to load %s", name);
fclose (file);
*buffer = file_data;
*size = file_size;
}
/*
* Load the files to the buffers.
* @return
*/
static void
load_files_to_buffers (struct inst_jar * jars, const size_t count) {
// Load all instrumentation jars or fail.
for (size_t i = 0; i < count; ++i) {
struct inst_jar * jar = &(jars [i]);
__load_file (jar->name, &jar->buffer, &jar->filesize);
}
return bytes_read;
}
......@@ -107,122 +77,139 @@ load_files_to_buffers (struct inst_jar * jars, const size_t count) {
* for jar sizes and jar data.
*/
static InstrumentationDelivery
instrumentationDeliveryInit (struct config * config) {
size_t jarCount = config->jar_count;
InstrumentationDelivery_create (const size_t jar_count, char ** jar_names) {
// First load all the jar files into a contiguous buffer.
int32_t * jar_sizes = malloc (jar_count * sizeof (int32_t));
check_std_error (jar_sizes == NULL, "failed to allocate memory for instrumentation sizes");
size_t jar_sizes_total = 0;
for (size_t i = 0; i < jar_count; i++) {
off_t size = __get_file_size (jar_names [i]);
jar_sizes [i] = (int32_t) size;
jar_sizes_total += size;
}
InstrumentationDelivery delivery = INSTRUMENTATION_DELIVERY__INIT;
delivery.n_sizes = jarCount;
delivery.sizes = malloc (jarCount * sizeof (int32_t));
assert (jar_sizes_total > 0);
struct inst_jar * jars = malloc (jarCount * sizeof (struct inst_jar));
for (size_t i = 0; i < jarCount; ++i) {
jars [i].name = config->instrumentation_jars [i];
void * jar_data = malloc (jar_sizes_total);
check_std_error (jar_data == NULL, "failed to allocate memory for instrumentation data");
void * jar_buffer = jar_data;
for (size_t i = 0; i < jar_count; ++i) {
jar_buffer += __load_file (jar_names [i], jar_buffer, jar_sizes [i]);
}
// Load files and ensure that all of them exist.
load_files_to_buffers (jars, jarCount);
// Initialize the instrumentation message.
InstrumentationDelivery delivery = INSTRUMENTATION_DELIVERY__INIT;
delivery.n_sizes = jar_count;
delivery.sizes = jar_sizes;
delivery.instrumentation.len = jar_sizes_total;
delivery.instrumentation.data = jar_data;
return delivery;
}
size_t total_size = 0;
for (size_t i = 0; i < jarCount; ++i) {
delivery.sizes [i] = (int32_t) jars [i].filesize;
total_size += delivery.sizes [i];
}
assert (total_size > 0);
static void
InstrumentationDelivery_destroy (InstrumentationDelivery * delivery) {
assert (delivery != NULL);
delivery.instrumentation.len = total_size;
delivery.instrumentation.data = malloc (total_size);
if (delivery->sizes != NULL) {
free (delivery->sizes);
delivery->sizes = NULL;
}
void * buffer = delivery.instrumentation.data;
for (size_t i = 0; i < jarCount; ++i) {
memcpy (buffer, jars [i].buffer, jars [i].filesize);
buffer += jars [i].filesize;
free (jars [i].buffer);
if (delivery->instrumentation.data != NULL) {
free (delivery->instrumentation.data);
delivery->instrumentation.data = NULL;
}
}
free (jars);
return delivery;
}
int32_t
session_start (bool require_disl, bool require_shvm) {
// Create a session init request.
SessionInitRequest req = SESSION_INIT_REQUEST__INIT;
req.require_disl = require_disl;
req.require_shvm = require_shvm;
void
session_start (struct config * config) {
int32_t sid = 0;
ClientMessage message = CLIENT_MESSAGE__INIT;
message.request_case = CLIENT_MESSAGE__REQUEST_SESSION_INIT_REQUEST;
message.session_init_request = &req;
// Send the request and await response.
struct connection * conn = network_acquire_connection ();
{
/* Create, pack and send the session init request */
SessionInitRequest req = SESSION_INIT_REQUEST__INIT;
req.require_disl = true;
req.require_shvm = false;
ClientMessage message = CLIENT_MESSAGE__INIT;
message.request_case = CLIENT_MESSAGE__REQUEST_SESSION_INIT_REQUEST;
message.session_init_request = &req;
send_client_message (message, conn);
/* Receive the response and set the session_id */
ServerMessage * response = receive_server_message (conn);
switch (response->response_case) {
case SERVER_MESSAGE__RESPONSE_ERROR:
warn (response->error->message);
network_release_connection (conn);
network_fini ();
exit (1);
case SERVER_MESSAGE__RESPONSE_SESSION_INIT_RESPONSE:
sid = response->session_init_response->session_id;
assert (sid != 0);
break;
default:
warn ("Wrong response to the session init request");
network_release_connection (conn);
network_fini ();
exit (1);
}
send_client_message (&message, conn);
ServerMessage * response = receive_server_message (conn);
network_release_connection (conn);
// Check the response and return the session_id.
if (SERVER_MESSAGE__RESPONSE_SESSION_INIT_RESPONSE == response->response_case) {
int32_t session_id = response->session_init_response->session_id;
server_message__free_unpacked (response, NULL);
check_error (session_id == 0, "received invalid session id");
return session_id;
}
{
InstrumentationDelivery delivery = instrumentationDeliveryInit (config);
// Hard failure.
const char * error_message = "invalid response to the session init request";
if (SERVER_MESSAGE__RESPONSE_ERROR == response->response_case) {
error_message = response->error->message;
}
warn (error_message);
network_fini ();
exit (1);
}
ClientMessage message = CLIENT_MESSAGE__INIT;
message.request_case = CLIENT_MESSAGE__REQUEST_INSTRUMENTATION_DELIVERY;
message.session_id = sid;
message.instrumentation_delivery = &delivery;
send_client_message (message, conn);
void
session_send_instrumentation (const int32_t session_id, const size_t jar_count, char ** jar_paths) {
// Create message with instrumentation.
InstrumentationDelivery delivery = InstrumentationDelivery_create (jar_count, jar_paths);
ServerMessage * response = receive_server_message (conn);
ClientMessage message = CLIENT_MESSAGE__INIT;
message.request_case = CLIENT_MESSAGE__REQUEST_INSTRUMENTATION_DELIVERY;
message.session_id = session_id;
message.instrumentation_delivery = &delivery;
assert (response->response_case == SERVER_MESSAGE__RESPONSE_INSTRUMENTATION_ACCEPT_CONFIRMATION);
assert (response->instrumentation_accept_confirmation->instrumentation_accepted);
// Send the request and await response.
struct connection * conn = network_acquire_connection ();
send_client_message (&message, conn);
server_message__free_unpacked (response, NULL);
}
ServerMessage * response = receive_server_message (conn);
network_release_connection (conn);
config->session_id = sid;
// Check the response.
assert (SERVER_MESSAGE__RESPONSE_INSTRUMENTATION_ACCEPT_CONFIRMATION == response->response_case);
check_error (
!response->instrumentation_accept_confirmation->instrumentation_accepted,
"instrumentation was not accepted by the server"
);
// Release messages.
server_message__free_unpacked (response, NULL);
InstrumentationDelivery_destroy (&delivery);
}
/**
* Send a session-closing message to the server.
*/
void
session_end (struct config * config) {
struct connection * conn = network_acquire_connection ();
session_end(int32_t session_id) {
// Create session close request.
CloseConnection cc = CLOSE_CONNECTION__INIT;
cc.reason = CLOSE_CONNECTION__CLOSE_REASON__FINISHED;
ClientMessage message = CLIENT_MESSAGE__INIT;
message.request_case = CLIENT_MESSAGE__REQUEST_CLOSE_CONNECTION;
message.session_id = config->session_id;
message.session_id = session_id;
message.close_connection = &cc;
send_client_message (message, conn);
// Just send the message, there is no response.
struct connection * conn = network_acquire_connection ();
send_client_message (&message, conn);
network_release_connection (conn);
}
......@@ -8,11 +8,10 @@
#include "msgchannel.h"
#include "connection.h"
#include "config.h"
#include "protocol/main.pb-c.h"
void send_client_message (const ClientMessage message, struct connection * conn);
void send_client_message (const ClientMessage * message, struct connection * conn);
ServerMessage * receive_server_message (struct connection * conn);
......@@ -21,11 +20,16 @@ ServerMessage * receive_server_message (struct connection * conn);
*
* @return Obtained session id.
*/
void session_start (struct config * config);
int32_t session_start (const bool require_disl, const bool require_shvm);
/**
* Announce the end of the session to the server.
*/
void session_end (struct config * config);
void session_end (const int32_t session_id);
/**
* Sends instrumentation jars to the server.
*/
void session_send_instrumentation (const int32_t session_id, const size_t jar_count, char ** jar_paths);
#endif // _SESSIONS_H
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment