diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index b2ae5f74a7..303c3b7445 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -239,6 +239,88 @@ JNIEXPORT void JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Fin env->ReleaseByteArrayElements(shared_key_msg_input, shared_key_msg_bytes, 0); } +/////////////////////////////// Shared Key Gen Begin //////////////////////////////// + +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_GetPublicKey( + JNIEnv *env, jobject obj, jlong eid) { + (void)obj; + (void)eid; + + uint8_t* report_msg = NULL; + size_t report_msg_size = 0; + + oe_check_and_time("Get enclave public key", + ecall_get_public_key((oe_enclave_t*)eid, + &report_msg, + &report_msg_size)); + + // Allocate memory + jbyteArray report_msg_bytes = env->NewByteArray(report_msg_size); + env->SetByteArrayRegion(report_msg_bytes, 0, report_msg_size, reinterpret_cast(report_msg)); + + return report_msg_bytes; +} + +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_GetListEncrypted( + JNIEnv *env, jobject obj, jlong eid, jbyteArray shared_key_msg_input) { + (void)obj; + + jboolean if_copy = false; + jbyte *shared_key_msg_bytes = env->GetByteArrayElements(shared_key_msg_input, &if_copy); + uint32_t shared_key_msg_size = static_cast(env->GetArrayLength(shared_key_msg_input)); + + size_t report_msg_size = OE_SHARED_KEY_CIPHERTEXT_SIZE * (shared_key_msg_size / OE_PUBLIC_KEY_SIZE); + uint8_t* report_msg = new uint8_t[report_msg_size]; + + oe_check_and_time("Get List Encrypted", + ecall_get_list_encrypted((oe_enclave_t*)eid, + reinterpret_cast(shared_key_msg_bytes), + shared_key_msg_size, + report_msg, + report_msg_size)); + + // Allocate memory + jbyteArray report_msg_bytes = env->NewByteArray(report_msg_size); + env->SetByteArrayRegion(report_msg_bytes, 0, report_msg_size, reinterpret_cast(report_msg)); + + env->ReleaseByteArrayElements(shared_key_msg_input, (jbyte *) shared_key_msg_bytes, 0); + + delete[] report_msg; + + return report_msg_bytes; +} + +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_FinishSharedKey( + JNIEnv *env, jobject obj, jlong eid, jbyteArray shared_key_msg_input) { + (void)obj; + + jboolean if_copy = false; + + jbyte *shared_key_msg_bytes = env->GetByteArrayElements(shared_key_msg_input, &if_copy); + uint32_t shared_key_msg_size = static_cast(env->GetArrayLength(shared_key_msg_input)); + + size_t report_msg_size = SGX_AESGCM_KEY_SIZE; + uint8_t* report_msg = new uint8_t[report_msg_size]; + + oe_check_and_time("Finish attestation", + ecall_finish_shared_key((oe_enclave_t*)eid, + reinterpret_cast(shared_key_msg_bytes), + shared_key_msg_size, + report_msg, + report_msg_size)); + + // Allocate memory + jbyteArray report_msg_bytes = env->NewByteArray(report_msg_size); + env->SetByteArrayRegion(report_msg_bytes, 0, report_msg_size, reinterpret_cast(report_msg)); + + env->ReleaseByteArrayElements(shared_key_msg_input, shared_key_msg_bytes, 0); + delete[] report_msg; + + return report_msg_bytes; +} + +/////////////////////////////// Shared Key Gen End //////////////////////////////// + JNIEXPORT void JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_StopEnclave( JNIEnv *env, jobject obj, jlong eid) { (void)env; @@ -345,6 +427,72 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla return ciphertext; } +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ReEncryptUser( + JNIEnv *env, jobject obj, jlong eid, jbyteArray ciphertext, jstring username) { + (void)obj; + + uint32_t clength = (uint32_t)env->GetArrayLength(ciphertext); + + jboolean if_copy = false; + uint8_t *ciphertext_ptr = (uint8_t *)env->GetByteArrayElements(ciphertext, &if_copy); + + uint8_t *new_ciphertext_copy = nullptr; + jsize new_clength = clength; + + const char *username_str = env->GetStringUTFChars(username, nullptr); + + if (ciphertext_ptr == nullptr) { + ocall_throw("Encrypt: JNI failed to get input byte array."); + } else { + new_ciphertext_copy = new uint8_t[new_clength]; + + oe_check("Encrypt", ecall_re_encrypt_user((oe_enclave_t *)eid, ciphertext_ptr, clength, + new_ciphertext_copy, (uint32_t)new_clength, username_str)); + } + + jbyteArray new_ciphertext = env->NewByteArray(new_clength); + env->SetByteArrayRegion(new_ciphertext, 0, new_clength, (jbyte *)new_ciphertext_copy); + + env->ReleaseByteArrayElements(ciphertext, (jbyte *)ciphertext_ptr, 0); + + delete[] new_ciphertext_copy; + + return new_ciphertext; +} + + +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Decrypt( + JNIEnv *env, jobject obj, jlong eid, jbyteArray ciphertext) { + (void)obj; + + uint32_t clength = (uint32_t)env->GetArrayLength(ciphertext); + jboolean if_copy = false; + uint8_t *ciphertext_ptr = (uint8_t *)env->GetByteArrayElements(ciphertext, &if_copy); + + uint8_t *plaintext_copy = nullptr; + jsize plength = 0; + + if (ciphertext_ptr == nullptr) { + ocall_throw("Encrypt: JNI failed to get input byte array."); + } else { + plength = clength - SGX_AESGCM_IV_SIZE - SGX_AESGCM_MAC_SIZE; + plaintext_copy = new uint8_t[plength]; + + oe_check("Decrypt", ecall_decrypt((oe_enclave_t *)eid, ciphertext_ptr, clength, + plaintext_copy, (uint32_t)plength)); + } + + jbyteArray plaintext = env->NewByteArray(plength); + env->SetByteArrayRegion(plaintext, 0, plength, (jbyte *)plaintext_copy); + + env->ReleaseByteArrayElements(ciphertext, (jbyte *)ciphertext_ptr, 0); + + delete[] plaintext_copy; + + return plaintext; +} + + JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Sample( JNIEnv *env, jobject obj, jlong eid, jbyteArray input_rows) { (void)obj; diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index 3fd2d58e2f..e2a28ff53b 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -21,6 +21,9 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Encrypt( JNIEnv *, jobject, jlong, jbyteArray); +JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ReEncryptUser( + JNIEnv *, jobject, jlong, jbyteArray, jstring); + JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_Decrypt( JNIEnv *, jobject, jlong, jbyteArray); @@ -80,6 +83,15 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_GenerateReport(JNIEnv *, j JNIEXPORT void JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_FinishAttestation( JNIEnv *, jobject, jlong, jbyteArray); + JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_GetPublicKey( + JNIEnv *, jobject, jlong); + + JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_GetListEncrypted( + JNIEnv *, jobject, jlong, jbyteArray); + + JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_FinishSharedKey( + JNIEnv *, jobject, jlong, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/enclave/Enclave/Crypto.cpp b/src/enclave/Enclave/Crypto.cpp index fe7e240bf9..dbf7618955 100644 --- a/src/enclave/Enclave/Crypto.cpp +++ b/src/enclave/Enclave/Crypto.cpp @@ -1,9 +1,20 @@ +#include "Crypto.h" +#include "Random.h" + #include +#include +#include +#include #include "Crypto.h" #include "Random.h" #include "common.h" #include "util.h" +#include +#include + +// Set this number before creating the enclave +uint8_t num_clients = 1; /** * Symmetric key used to encrypt row data. This key is shared among the driver @@ -17,16 +28,48 @@ unsigned char shared_key[SGX_AESGCM_KEY_SIZE] = {0}; std::unique_ptr ks; +// may user name to public key +std::unordered_map> client_key_schedules; +std::unordered_map> client_keys; + void initKeySchedule() { + // Use shared key to init key schedule ks.reset(new KeySchedule(reinterpret_cast(shared_key), SGX_AESGCM_KEY_SIZE)); } +void initKeySchedule(char* username) { + std::string user(username); + std::unique_ptr user_ks; + unsigned char client_key[SGX_AESGCM_KEY_SIZE]; + + auto iter = client_keys.find(user); + // if (iter == client_keys.end()) { + // ocall_throw("No client key for user: %s", username); + // } else { + memcpy(client_key, (uint8_t*) iter->second.data(), SGX_AESGCM_KEY_SIZE); + // } + + user_ks.reset(new KeySchedule(reinterpret_cast(client_key), SGX_AESGCM_KEY_SIZE)); + client_key_schedules[user] = std::move(user_ks); +} + +void add_client_key(uint8_t *client_key_bytes, uint32_t client_key_size, char* username) { + if (client_key_size <= 0) { + throw std::runtime_error("Add client key failed: Invalid client key size"); + } + + std::vector user_private_key(client_key_bytes, client_key_bytes + client_key_size); + std::string user(username); + client_keys[user] = user_private_key; + + initKeySchedule(username); +} + void set_shared_key(uint8_t *shared_key_bytes, uint32_t shared_key_size) { if (shared_key_size <= 0) { throw std::runtime_error("Attempting to set a shared key with invalid key size."); } memcpy_s(shared_key, sizeof(shared_key), shared_key_bytes, shared_key_size); - initKeySchedule(); } @@ -46,13 +89,36 @@ void encrypt(uint8_t *plaintext, uint32_t plaintext_length, uint8_t *ciphertext) AesGcm cipher(ks.get(), reinterpret_cast(iv_ptr), SGX_AESGCM_IV_SIZE); cipher.encrypt(plaintext, plaintext_length, ciphertext_ptr, plaintext_length); memcpy(mac_ptr, cipher.tag().t, SGX_AESGCM_MAC_SIZE); + +} + +void encrypt_user(uint8_t *plaintext, uint32_t plaintext_length, uint8_t *ciphertext, const char * user_name) { + + for (auto& keypair : client_key_schedules) { + + if (strcmp(keypair.first.c_str(), user_name) == 0) { + uint8_t *iv_ptr = ciphertext; + uint8_t *ciphertext_ptr = ciphertext + SGX_AESGCM_IV_SIZE; + sgx_aes_gcm_128bit_tag_t *mac_ptr = + (sgx_aes_gcm_128bit_tag_t *)(ciphertext + SGX_AESGCM_IV_SIZE + plaintext_length); + mbedtls_read_rand(reinterpret_cast(iv_ptr), SGX_AESGCM_IV_SIZE); + + AesGcm cipher(keypair.second.get(), reinterpret_cast(iv_ptr), SGX_AESGCM_IV_SIZE); + cipher.encrypt(plaintext, plaintext_length, ciphertext_ptr, plaintext_length); + memcpy(mac_ptr, cipher.tag().t, SGX_AESGCM_MAC_SIZE); + return; + } + } + throw std::runtime_error("Couldn't find user and key"); } void decrypt(const uint8_t *ciphertext, uint32_t ciphertext_length, uint8_t *plaintext) { + if (!ks) { throw std::runtime_error("Cannot encrypt without a shared key. Ensure all " "enclaves have completed attestation."); } + uint32_t plaintext_length = dec_size(ciphertext_length); uint8_t *iv_ptr = (uint8_t *)ciphertext; @@ -63,7 +129,27 @@ void decrypt(const uint8_t *ciphertext, uint32_t ciphertext_length, uint8_t *pla AesGcm decipher(ks.get(), iv_ptr, SGX_AESGCM_IV_SIZE); decipher.decrypt(ciphertext_ptr, plaintext_length, plaintext, plaintext_length); if (memcmp(mac_ptr, decipher.tag().t, SGX_AESGCM_MAC_SIZE) != 0) { - printf("Decrypt: invalid mac\n"); + // Shared key doesn't work + // Perhaps we need to use a client key instead + int success = -1; + for (auto& keypair : client_key_schedules) { + std::vector print_key = client_keys[keypair.first]; + for(size_t i=0; i < print_key.size(); i++) + std::cout << print_key.at(i) << ' '; + std::cout << std::endl; + + AesGcm decipher(keypair.second.get(), iv_ptr, SGX_AESGCM_IV_SIZE); + decipher.decrypt(ciphertext_ptr, plaintext_length, plaintext, plaintext_length); + if (memcmp(mac_ptr, decipher.tag().t, SGX_AESGCM_MAC_SIZE) == 0) { + std::cout << "We found the proper key, of user " << keypair.first << std::endl; + success = 0; + break; + } + } + + if (success == -1) { + throw std::runtime_error("Couldn't decrypt -- proper key unknown\n"); + } } } diff --git a/src/enclave/Enclave/Crypto.h b/src/enclave/Enclave/Crypto.h index 821253711a..d37f7f0b5d 100644 --- a/src/enclave/Enclave/Crypto.h +++ b/src/enclave/Enclave/Crypto.h @@ -19,8 +19,12 @@ extern const sgx_ec256_public_t g_sp_pub_key; * Set the symmetric key used to encrypt row data using message 4 of the remote * attestation process. */ +void add_client_key(uint8_t *client_key_bytes, uint32_t client_key_size, char* username); + void set_shared_key(uint8_t *msg4, uint32_t msg4_size); +void add_client_key(uint8_t *client_key_bytes, uint32_t client_key_size, char* username); + /** * Encrypt the given plaintext using AES-GCM with a 128-bit key and write the * result to `ciphertext`. The encrypted data will be formatted as follows, @@ -35,6 +39,8 @@ void set_shared_key(uint8_t *msg4, uint32_t msg4_size); */ void encrypt(uint8_t *plaintext, uint32_t plaintext_length, uint8_t *ciphertext); +void encrypt_user(uint8_t *plaintext, uint32_t plaintext_length, uint8_t *ciphertext, const char * user_name); + /** * Decrypt the given ciphertext using AES-GCM with a 128-bit key and write the * result to `plaintext`. The encrypted data must be formatted as described in diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index ec9b6bcadb..713c76274f 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -1,9 +1,13 @@ #include "Enclave_t.h" +#include #include #include +#include "Random.h" + #include "Aggregate.h" +#include "Attestation.h" #include "BroadcastNestedLoopJoin.h" #include "Crypto.h" #include "Filter.h" @@ -22,6 +26,17 @@ #include #include +// needed for certificate +#include +#include +#include +#include +#include +#include + +// needed for openenclave evidences +#include + // This file contains definitions of the ecalls declared in Enclave.edl. Errors // originating within these ecalls are signaled by throwing a // std::runtime_error, which is caught at the top level of the ecall (i.e., @@ -46,6 +61,48 @@ void ecall_encrypt(uint8_t *plaintext, uint32_t plaintext_length, uint8_t *ciphe } } +void ecall_re_encrypt_user(uint8_t *ciphertext, uint32_t ciphertext_length, uint8_t *new_ciphertext, + uint32_t new_cipher_length, const char * user_name) { + // Guard against encrypting or overwriting enclave memory + assert(oe_is_outside_enclave(ciphertext, ciphertext_length) == 1); + assert(oe_is_outside_enclave(new_ciphertext, new_cipher_length) == 1); + __builtin_ia32_lfence(); + + (void) new_cipher_length; + + try { + // IV (12 bytes) + ciphertext + mac (16 bytes) + assert(ciphertext_length == new_cipher_length); + + uint32_t plength = ciphertext_length - SGX_AESGCM_IV_SIZE - SGX_AESGCM_MAC_SIZE; + uint8_t * plaintext = new uint8_t[plength]; + decrypt(ciphertext, ciphertext_length, plaintext); + encrypt_user(plaintext, plength, new_ciphertext, user_name); + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } +} + +void ecall_decrypt(uint8_t *ciphertext, uint32_t cipher_length, uint8_t *plaintext, + uint32_t plaintext_length) { + + // Guard against decrypting or overwriting enclave memory + assert(oe_is_outside_enclave(plaintext, plaintext_length) == 1); + assert(oe_is_outside_enclave(ciphertext, cipher_length) == 1); + __builtin_ia32_lfence(); + + try { + // IV (12 bytes) + ciphertext + mac (16 bytes) + assert(cipher_length >= plaintext_length + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE); + (void)cipher_length; + (void)plaintext_length; + decrypt(ciphertext, cipher_length, plaintext); + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } + +} + void ecall_project(uint8_t *condition, size_t condition_length, uint8_t *input_rows, size_t input_rows_length, uint8_t **output_rows, size_t *output_rows_length) { // Guard against operating on arbitrary enclave memory @@ -241,17 +298,41 @@ void ecall_limit_return_rows(uint64_t partition_id, uint8_t *limits, size_t limi static Crypto g_crypto; void ecall_finish_attestation(uint8_t *shared_key_msg_input, uint32_t shared_key_msg_size) { + (void) shared_key_msg_size; + try { + (void) shared_key_msg_size; oe_shared_key_msg_t *shared_key_msg = (oe_shared_key_msg_t *)shared_key_msg_input; uint8_t shared_key_plaintext[SGX_AESGCM_KEY_SIZE]; size_t shared_key_plaintext_size = sizeof(shared_key_plaintext); - bool ret = g_crypto.decrypt(shared_key_msg->shared_key_ciphertext, shared_key_msg_size, + bool ret = g_crypto.decrypt(shared_key_msg->shared_key_ciphertext, OE_SHARED_KEY_CIPHERTEXT_SIZE, shared_key_plaintext, &shared_key_plaintext_size); if (!ret) { ocall_throw("shared key decryption failed"); } - set_shared_key(shared_key_plaintext, shared_key_plaintext_size); + // Add verifySignatureFromCertificate from XGBoost + // Get name from certificate + unsigned char nameptr[50]; + size_t name_len; + int res; + mbedtls_x509_crt user_cert; + mbedtls_x509_crt_init(&user_cert); + if ((res = mbedtls_x509_crt_parse(&user_cert, (const unsigned char*) shared_key_msg->user_cert, shared_key_msg->user_cert_len)) != 0) { + // char tmp[50]; + // mbedtls_strerror(res, tmp, 50); + // std::cout << tmp << std::endl; + ocall_throw("Verification failed - could not read user certificate\n. mbedtls_x509_crt_parse returned"); + } + + mbedtls_x509_name subject_name = user_cert.subject; + mbedtls_asn1_buf name = subject_name.val; + strcpy((char*) nameptr, (const char*) name.p); + name_len = name.len; + std::string user_nam(nameptr, nameptr + name_len); + + add_client_key(shared_key_plaintext, shared_key_plaintext_size, (char*) user_nam.c_str()); + } catch (const std::runtime_error &e) { ocall_throw(e.what()); } @@ -316,3 +397,177 @@ void ecall_generate_report(uint8_t **report_msg_data, size_t *report_msg_data_si } oe_free_report(report); } + +//////////////////////////////////// Generate Shared Key Begin ////////////////////////////////////// + +static Attestation attestation(&g_crypto); + +void ecall_get_public_key(uint8_t **report_msg_data, + size_t* report_msg_data_size) { + +#ifndef SIMULATE + oe_uuid_t sgx_local_uuid = {OE_FORMAT_UUID_SGX_LOCAL_ATTESTATION}; + oe_uuid_t* format_id = &sgx_local_uuid; + + uint8_t* format_settings = NULL; + size_t format_settings_size = 0; + + if (!attestation.get_format_settings( + format_id, + &format_settings, + &format_settings_size)) { + ocall_throw("Unable to get enclave format settings"); + } +#endif + + uint8_t pem_public_key[512]; + size_t public_key_size = sizeof(pem_public_key); + uint8_t* evidence = nullptr; + size_t evidence_size = 0; + + g_crypto.retrieve_public_key(pem_public_key); + +#ifndef SIMULATE + if (attestation.generate_attestation_evidence( + format_id, + format_settings, + format_settings_size, + pem_public_key, + public_key_size, + &evidence, + &evidence_size) == false) { + ocall_throw("Unable to retrieve enclave evidence"); + } + + if (!attestation.attest_attestation_evidence(format_id, evidence, evidence_size, pem_public_key, public_key_size)) { + ocall_throw("Unable to verify FRESH attestation!"); + } +#endif + + // The report msg includes the public key, the size of the evidence, and the evidence itself + *report_msg_data_size = public_key_size + sizeof(evidence_size) + evidence_size; + *report_msg_data = (uint8_t*)oe_host_malloc(*report_msg_data_size); + + memcpy_s(*report_msg_data, public_key_size, pem_public_key, public_key_size); + memcpy_s(*report_msg_data + public_key_size, sizeof(size_t), &evidence_size, sizeof(evidence_size)); + + memcpy_s(*report_msg_data + public_key_size + sizeof(size_t), evidence_size, evidence, evidence_size); + +} + +void ecall_get_list_encrypted(uint8_t * pk_list, + uint32_t pk_list_size, + uint8_t * sk_list, + uint32_t sk_list_size) { + + // Guard against encrypting or overwriting enclave memory + assert(oe_is_outside_enclave(pk_list, pk_list_size) == 1); + assert(oe_is_outside_enclave(sk_list, sk_list_size) == 1); + __builtin_ia32_lfence(); + + (void) sk_list_size; + + try { + // Generate a random value used for key + // Size of shared key is 16 from ServiceProvider - LC_AESGCM_KEY_SIZE + // For now SGX_AESGCM_KEY_SIZE is also 16, so will just use that for now + + unsigned char secret_key[SGX_AESGCM_KEY_SIZE] = {0}; + mbedtls_read_rand(secret_key, SGX_AESGCM_KEY_SIZE); + + uint8_t public_key[OE_PUBLIC_KEY_SIZE] = {}; + uint8_t *pk_pointer = pk_list; + + unsigned char encrypted_sharedkey[OE_SHARED_KEY_CIPHERTEXT_SIZE]; + size_t encrypted_sharedkey_size = sizeof(encrypted_sharedkey); + + uint8_t *sk_pointer = sk_list; + + size_t evidence_size[1] = {}; + +#ifndef SIMULATE + oe_uuid_t sgx_local_uuid = {OE_FORMAT_UUID_SGX_LOCAL_ATTESTATION}; + oe_uuid_t* format_id = &sgx_local_uuid; + + uint8_t* format_settings = NULL; + size_t format_settings_size = 0; +#endif + + while (pk_pointer < pk_list + pk_list_size) { + +#ifndef SIMULATE + if (!attestation.get_format_settings( + format_id, + &format_settings, + &format_settings_size)) { + ocall_throw("Unable to get enclave format settings"); + } +#endif + + // Read public key, size of evidence, and evidence + memcpy_s(public_key, OE_PUBLIC_KEY_SIZE, pk_pointer, OE_PUBLIC_KEY_SIZE); + +#ifndef SIMULATE + memcpy_s(evidence_size, sizeof(evidence_size), pk_pointer + OE_PUBLIC_KEY_SIZE, sizeof(size_t)); + uint8_t evidence[evidence_size[0]] = {}; + memcpy_s(evidence, evidence_size[0], pk_pointer + OE_PUBLIC_KEY_SIZE + sizeof(size_t), evidence_size[0]); + + // Verify the provided public key is valid + if (!attestation.attest_attestation_evidence(format_id, evidence, evidence_size[0], public_key, sizeof(public_key))) { + std::cout << "get_list_encrypted - unable to verify attestation evidence" << std::endl; + ocall_throw("Unable to verify attestation evidence"); + } +#endif + + g_crypto.encrypt(public_key, + secret_key, + SGX_AESGCM_KEY_SIZE, + encrypted_sharedkey, + &encrypted_sharedkey_size); + memcpy_s(sk_pointer, OE_SHARED_KEY_CIPHERTEXT_SIZE, encrypted_sharedkey, OE_SHARED_KEY_CIPHERTEXT_SIZE); + + pk_pointer += OE_PUBLIC_KEY_SIZE + sizeof(size_t) + evidence_size[0]; + sk_pointer += OE_SHARED_KEY_CIPHERTEXT_SIZE; + } + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } + +} + +void ecall_finish_shared_key(uint8_t *sk_list, + uint32_t sk_list_size, + uint8_t *sk, + uint32_t sk_size) { + + (void) sk; + (void) sk_size; + + uint8_t *sk_pointer = sk_list; + + uint8_t secret_key[SGX_AESGCM_KEY_SIZE] = {0}; + size_t sk_length = sizeof(secret_key); + assert(sk_length == sk_size); + + while (sk_pointer < sk_list + sk_list_size) { + uint8_t encrypted_sharedkey[OE_SHARED_KEY_CIPHERTEXT_SIZE]; + size_t encrypted_sharedkey_size = sizeof(encrypted_sharedkey); + + memcpy_s(encrypted_sharedkey, encrypted_sharedkey_size, sk_pointer, OE_SHARED_KEY_CIPHERTEXT_SIZE); + + try { + bool ret = g_crypto.decrypt(encrypted_sharedkey, encrypted_sharedkey_size, secret_key, &sk_length); + if (ret) {break;} // Decryption was successful to obtain secret key + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } + + sk_pointer += OE_SHARED_KEY_CIPHERTEXT_SIZE; + } + + set_shared_key(secret_key, sk_size); + + memcpy(sk, secret_key, SGX_AESGCM_KEY_SIZE); +} + +//////////////////////////////////// Generate Shared Key End ////////////////////////////////////// diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 1789ff2b64..be13e8454b 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -1,3 +1,4 @@ + // -*- mode: c++ -*- /* Enclave.edl - Top EDL file. */ @@ -23,6 +24,15 @@ enclave { [user_check] uint8_t *plaintext, uint32_t length, [user_check] uint8_t *ciphertext, uint32_t cipher_length); + public void ecall_re_encrypt_user( + [user_check] uint8_t *plaintext, uint32_t length, + [user_check] uint8_t *ciphertext, uint32_t cipher_length, + [in, string] const char * user_name); + + public void ecall_decrypt( + [user_check] uint8_t *ciphertext, uint32_t cipher_length, + [user_check] uint8_t *plaintext, uint32_t plain_length); + public void ecall_sample( [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); @@ -90,6 +100,18 @@ enclave { public void ecall_finish_attestation( [in,size=msg4_size] uint8_t *msg4, uint32_t msg4_size); + + public void ecall_get_public_key( + [out] uint8_t** msg1, + [out] size_t* msg1_size); + + public void ecall_get_list_encrypted( + [user_check] uint8_t *pk_list, uint32_t length, + [user_check] uint8_t *sk_list, uint32_t cipher_length); + + public void ecall_finish_shared_key( + [in,size=msg4_size] uint8_t *msg4, uint32_t msg4_size, + [user_check] uint8_t *secret_key, uint32_t secret_length); }; untrusted { diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 41dfff4446..f45bf7f71c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -19,6 +19,8 @@ package edu.berkeley.cs.rise.opaque import java.io.File import java.io.FileNotFoundException +import java.io.IOException + import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.file.{Files, Paths} @@ -36,9 +38,11 @@ import com.google.flatbuffers.FlatBufferBuilder import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Add import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.expressions.And @@ -145,6 +149,7 @@ object Utils extends Logging { } private def jsonSerialize(x: Any): String = (x: @unchecked) match { + case x: Int => x.toString case x: Double => x.toString case x: Boolean => x.toString @@ -250,42 +255,108 @@ object Utils extends Logging { final val GCM_KEY_LENGTH = 32 final val GCM_TAG_LENGTH = 16 + // We do not trust the driver. Encryption and decryption done in enclave only. /** * Symmetric key used to encrypt row data. This key is securely sent to the enclaves if * attestation succeeds. For testing/benchmarking, we use a hardcoded key. For all other * cases, the driver SHOULD NOT be able to decrypt anything. */ - var sharedKey: Option[Array[Byte]] = None - - def encrypt(data: Array[Byte]): Array[Byte] = sharedKey match { - case Some(sharedKey) => - val random = SecureRandom.getInstance("SHA1PRNG") - val cipherKey = new SecretKeySpec(sharedKey, "AES") - val iv = new Array[Byte](GCM_IV_LENGTH) - random.nextBytes(iv) - val spec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) - val cipher = Cipher.getInstance("AES/GCM/NoPadding", "SunJCE") - cipher.init(Cipher.ENCRYPT_MODE, cipherKey, spec) - val cipherText = cipher.doFinal(data) - iv ++ cipherText - case None => - throw new OpaqueException("Cannot encrypt without sharedKey.") +// val sharedKey: Array[Byte] = Array.fill[Byte](GCM_KEY_LENGTH)(0) + val sharedKey: Array[Byte] = "01234567890123456789012345678901".getBytes + assert(sharedKey.size == GCM_KEY_LENGTH) + + def encrypt(data: Array[Byte]): Array[Byte] = { + + var numExecutors: Int = 1 + val sc = SparkSession.active.sparkContext + if (!sc.isLocal) { + numExecutors = sc.getConf.getInt("spark.executor.instances", -1) + } + + val rdd = sc.parallelize(Seq.fill(numExecutors) {()}, numExecutors) + + // Only one enclave needs to encrypt + val encryptedResults = rdd.context.parallelize(Array(data), 1) + .map { data => + // Only encrypt if enclave is attested and trusted + if (eid != 0L && attested) { + val enclave = new SGXEnclave() + enclave.Encrypt(eid, data) + } else { + throw new Exception("Enclaves not attested - will not encrypt") + } + }.first() + + return encryptedResults + } + + // Encryption with user string provided for key + def reEncryptUser(data: Array[Byte], user: String): Array[Byte] = { + + // Move encryption to the enclaves + var numExecutors: Int = 1 + val sc = SparkSession.active.sparkContext + if (!sc.isLocal) { + numExecutors = sc.getConf.getInt("spark.executor.instances", -1) + } + + val rdd = sc.parallelize(Seq.fill(numExecutors) {()}, numExecutors) + + // Only one enclave needs to encrypt + val encryptedResults = rdd.context.parallelize(Array(data), 1) + .map { data => + // Only encrypt if enclave is attested and trusted + if (eid != 0L && attested) { + val enclave = new SGXEnclave() + enclave.ReEncryptUser(eid, data, user) + } else { + throw new Exception("Enclaves not attested - will not encrypt") + } + }.first() + + return encryptedResults + } + + def decrypt(data: Array[Byte]): Array[Byte] = { + + var numExecutors: Int = 1 + val sc = SparkSession.active.sparkContext + if (!sc.isLocal) { + numExecutors = sc.getConf.getInt("spark.executor.instances", -1) + } + + val rdd = sc.parallelize(Seq.fill(numExecutors) {()}, numExecutors) + + // Only one enclave needs to encrypt + val encryptedResults = rdd.context.parallelize(Array(data), 1) + .map { data => + // Only encrypt if enclave is attested and trusted + if (eid != 0L && attested) { + val enclave = new SGXEnclave() + enclave.Decrypt(eid, data) + } else { + throw new Exception("Enclaves not attested - will not encrypt") + } + }.first() + + return encryptedResults } - def decrypt(data: Array[Byte]): Array[Byte] = sharedKey match { - case Some(sharedKey) => - val cipherKey = new SecretKeySpec(sharedKey, "AES") - val iv = data.take(GCM_IV_LENGTH) - val cipherText = data.drop(GCM_IV_LENGTH) - val cipher = Cipher.getInstance("AES/GCM/NoPadding", "SunJCE") - cipher.init(Cipher.DECRYPT_MODE, cipherKey, new GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)) - cipher.doFinal(cipherText) - case None => - throw new OpaqueException("Cannot decrypt without sharedKey.") + // This function should never be called from the driver + // We add this function to make the Decrypt UDF work + def decryptWorker(data: Array[Byte]): Array[Byte] = { + // Only decrypt if enclave is attested and trusted + if (eid != 0L && attested) { + val enclave = new SGXEnclave() + return enclave.Decrypt(eid, data) + } else { + throw new Exception("Enclaves not attested - will not decrypt") + } } var eid = 0L var attested: Boolean = false + var laAttested: Boolean = false // Initialize accumulator variables for tracking the state of attestation var acc_registered: Boolean = false val numEnclaves: LongAccumulator = new LongAccumulator @@ -355,6 +426,17 @@ object Utils extends Logging { eid } + // Helper function for printing. + // @source https://alvinalexander.com/source-code/scala-how-to-convert-array-bytes-to-hex-string/ + def convertBytesToHex(bytes: Option[Array[Byte]]): String = { + val sb = new StringBuilder + for (b <- bytes.get) { + sb.append(String.format("%02x", Byte.box(b))) + } + sb.toString + } + + // TODO: Create new laAttested variable to track LA, so that it doesn't happen constantly def generateReport(): (Long, Option[Array[Byte]]) = { this.synchronized { // Only generate report if the enclave has already been started, but not yet attested @@ -384,6 +466,46 @@ object Utils extends Logging { } } + def getEvidence(): (Long, Option[Array[Byte]]) = { + this.synchronized { + // Only generate evidence if the enclave has already been started AND attested + if (eid != 0L && attested && !laAttested) { + val enclave = new SGXEnclave() + val msg1 = enclave.GetPublicKey(eid) + (eid, Option(msg1)) + } else { + (eid, None) + } + } + } + + def getListEncrypted(evidences: Array[Byte]): Array[Byte] = { + this.synchronized { + if (eid != 0L && attested && !laAttested) { + val enclave = new SGXEnclave() + enclave.GetListEncrypted(eid, evidences) + } else { + // Return empty array + Array[Byte]() + } + } + } + + def finishSharedKey(msg3s: Map[Long, Array[Byte]]): Unit = { + this.synchronized { + val enclave = new SGXEnclave() + if (msg3s.contains(eid) && attested && !laAttested) { + val msg3 = msg3s(eid) + enclave.FinishSharedKey(eid, msg3s(eid)) + laAttested = true + } else { + // Return array of 0s. Also throw error + print("Failure to set shared key") +// Array.fill[Byte](GCM_KEY_LENGTH)(0) + } + } + } + def cleanup(spark: SparkSession) { RA.stopThread() spark.stop() @@ -436,8 +558,12 @@ object Utils extends Logging { def force(ds: Dataset[_]): Unit = { val rdd: RDD[_] = ds.queryExecution.executedPlan match { - case p: OpaqueOperatorExec => p.executeBlocked() - case p => p.execute() + case p: OpaqueOperatorExec => { + p.executeBlocked() + } + case p => { + p.execute() + } } rdd.foreach(x => {}) } @@ -840,7 +966,7 @@ object Utils extends Logging { def decryptScalar(ciphertext: String): Any = { val ciphertext_bytes = Base64.getDecoder().decode(ciphertext); - val plaintext = decrypt(ciphertext_bytes) + val plaintext = decryptWorker(ciphertext_bytes) val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plaintext)) val row = rows.rows(0) val field = row.fieldValues(0) @@ -1000,6 +1126,84 @@ object Utils extends Logging { }).flatten } + /** + * Runs post-verification on dataframe. If integrity post-verification passes, + * then the function reads the encrypted files in './tmp' folder, appropriately re-encrypts them + * under the client keys and then prints the results out. + * + * 1. Read the encrypted file from './tmp' folder + * 2. Parse the provided bytes as EncryptedBlocks + * 3. Obtain the rows + * 4. Decrypt the rows with enclave key. Re-encrypt them with client key + * 5. Print out results + */ + def postVerifyAndPrint(df: DataFrame, user: String): Unit = { + + val ciphers = postVerifyAndReturn(df, user) + + for (cipher <- ciphers) { + println(Base64.getEncoder().encodeToString(cipher)) + } + } + + def postVerifyAndReturn(df: DataFrame, user: String): Seq[Array[Byte]] = { + + // Placeholder for post-verification section when merged + if (false) { + throw new Exception("Post verification failed") + } + + var reEncryptedCiphers = Seq[Array[Byte]](); + try { + + // 1 + val dir = new File("tmp"); + if (!dir.exists) { + return reEncryptedCiphers + } + + for (file <- dir.listFiles()) { + + val bytes = Files.readAllBytes(file.toPath()) + val buf = ByteBuffer.wrap(bytes) + + // 2 + val encryptedBlocks = tuix.EncryptedBlocks.getRootAsEncryptedBlocks(buf) + for (i <- 0 until encryptedBlocks.blocksLength) yield { + val encryptedBlock = encryptedBlocks.blocks(i) + + // 3 + val ciphertextBuf = encryptedBlock.encRowsAsByteBuffer + val ciphertext = new Array[Byte](ciphertextBuf.remaining) + ciphertextBuf.get(ciphertext) + + // 4 + val cipher = reEncryptUser(ciphertext, user) + + reEncryptedCiphers = reEncryptedCiphers :+ cipher + + } + } + } catch { + case x: FileNotFoundException => + { + throw new FileNotFoundException("Exception: File missing") + } + + case x: IOException => + { + throw new IOException("Input/output Exception") + } + } + + reEncryptedCiphers + } + + /** + * Given the dataframe name, reads the dataframe encrypted blocks and encrypts/decrypts them accordingly + */ + + // Helper function for printing. def treeFold[BaseType <: TreeNode[BaseType], B]( tree: BaseType )(op: (Seq[B], BaseType) => B): B = { diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index 584b08f208..d859227d5c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -29,6 +29,7 @@ class SGXEnclave extends java.io.Serializable { @native def Filter(eid: Long, condition: Array[Byte], input: Array[Byte]): Array[Byte] @native def Encrypt(eid: Long, plaintext: Array[Byte]): Array[Byte] + @native def ReEncryptUser(eid: Long, ciphertext: Array[Byte], username: String): Array[Byte] @native def Decrypt(eid: Long, ciphertext: Array[Byte]): Array[Byte] @native def Sample(eid: Long, input: Array[Byte]): Array[Byte] @@ -84,4 +85,9 @@ class SGXEnclave extends java.io.Serializable { // Remote attestation, enclave side @native def GenerateReport(eid: Long): Array[Byte] @native def FinishAttestation(eid: Long, attResultInput: Array[Byte]): Unit + + // "Local attestation" to determine shared key, enclave side + @native def GetPublicKey(eid: Long): Array[Byte] + @native def GetListEncrypted(eid: Long, publicKeyList: Array[Byte]): Array[Byte] + @native def FinishSharedKey(eid: Long, encryptedKeyList: Array[Byte]): Array[Byte] } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 084551015e..599b29575b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -19,6 +19,10 @@ package edu.berkeley.cs.rise.opaque.execution import scala.collection.mutable.ArrayBuffer +import java.io.{File, IOException, FileNotFoundException} +import java.nio.file.Files +import java.nio.file.Paths + import edu.berkeley.cs.rise.opaque.Utils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -31,6 +35,9 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan +import edu.berkeley.cs.rise.opaque.tuix +import java.nio.ByteBuffer + trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet @@ -182,23 +189,99 @@ trait OpaqueOperatorExec extends SparkPlan { executeBlocked().collect } + /** + * Write encrypted results to file and return an empty rdd + * + * 1. Creates an empty directory for the dataframe (or if the directory already exists, removes all contents + * inside the directory) + * 2. For each block in the result: + * 2a. Create a file + * 2b. Write the byte contents of the block to the file + */ override def executeCollect(): Array[InternalRow] = { - collectEncrypted().flatMap { block => - Utils.decryptBlockFlatbuffers(block) + + val blocks = collectEncrypted() + + try { + + // 1 + val dir = new File("tmp"); + if (!dir.exists) { + dir.mkdir() + } else { + for(file <- dir.listFiles()) { + file.delete() + } + } + + // 2 + var i = 0 + for (block <- blocks) { + // 2a & b + val f = "tmp/tmp_" + i.toString + Files.write(Paths.get(f), block.bytes) + + // Increment i for next block + i += 1 + } + + } catch { + case x: FileNotFoundException => + { + throw new FileNotFoundException("Exception: File missing") + } + + case x: IOException => + { + throw new IOException("Input/output Exception") + } } + + sqlContext.sparkContext.emptyRDD.collect() } + /* + * executeTake can only take in the closest block increment + * This is because encrypted blocks must be saved to a file in block increments + * Any additional parsing must be done externally. + */ override def executeTake(n: Int): Array[InternalRow] = { + + + // Prepare directory for storing encrypted blocks + try { + val dir = new File("tmp"); + if (!dir.exists) { + dir.mkdir() + } else { + for(file <- dir.listFiles()) { + file.delete() + } + } + } catch { + case x: FileNotFoundException => + { + throw new FileNotFoundException("Exception: File missing") + } + + case x: IOException => + { + throw new IOException("Input/output Exception") + } + } + + // Original code for parsing through blocks if (n == 0) { return new Array[InternalRow](0) } val childRDD = executeBlocked() - val buf = new ArrayBuffer[InternalRow] + var buf = 0 val totalParts = childRDD.partitions.length var partsScanned = 0 - while (buf.size < n && partsScanned < totalParts) { + var i = 0 + while (buf < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L @@ -206,10 +289,10 @@ trait OpaqueOperatorExec extends SparkPlan { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. - if (buf.size == 0) { + if (buf == 0) { numPartsToTry = totalParts - 1 } else { - numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt + numPartsToTry = (1.5 * n * partsScanned / buf).toInt } } numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions @@ -220,19 +303,26 @@ trait OpaqueOperatorExec extends SparkPlan { sc.runJob(childRDD, (it: Iterator[Block]) => if (it.hasNext) Some(it.next()) else None, p) res.foreach { - case Some(block) => - buf ++= Utils.decryptBlockFlatbuffers(block) + case Some(block) => { + + val blockBuffer = ByteBuffer.wrap(block.bytes) + val encryptedBlocks = tuix.EncryptedBlocks.getRootAsEncryptedBlocks(blockBuffer) + buf = buf + encryptedBlocks.blocksLength + + val f = "tmp/tmp_" + i.toString + Files.write(Paths.get(f), block.bytes) + + // Increment i for next block + i += 1 + + } case None => } partsScanned += p.size } - if (buf.size > n) { - buf.take(n).toArray - } else { - buf.toArray - } + sqlContext.sparkContext.emptyRDD.collect() } } @@ -265,6 +355,7 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan) override def output: Seq[Attribute] = child.output override def executeBlocked(): RDD[Block] = { + val conditionSer = Utils.serializeFilterExpression(condition, child.output) val childRDD = child.asInstanceOf[OpaqueOperatorExec].executeBlocked() applyLoggingLevel(childRDD) { childRDD =>