#[macro_use] extern crate rocket; extern crate bls12_381; extern crate rand; use crate::core::EncryptedMessage; use bls12_381::pairing; use bls12_381::G1Affine; use bls12_381::G2Affine; use bls12_381::Gt; use bls12_381::Scalar; use rocket::request; use rocket::serde::json::Json; use rocket::Request; use rocket_okapi::settings::UrlObject; use rocket_okapi::{openapi, openapi_get_routes, rapidoc::*, swagger_ui::*}; use std::time::Duration; use ophe::core; use ophe::core::{ PublicParameters, RatelimiterRequest, RatelimiterResponse, RatelimiterRotateRequest, RatelimiterRotateResponse, }; use ophe::shamir; use ophe::utils::random_scalar; use futures::future::join_all; use rocket::State; use rocket_okapi::okapi::schemars; use rocket_okapi::okapi::schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use rocket::http::Status; #[derive(Clone)] struct RlState { pp: PublicParameters, url: String, client: reqwest::Client, s: Scalar, } struct OpheState { pps: Vec, n: i64, t: i64, pk: Gt, secret_key: Scalar, } #[derive(Serialize, Deserialize, JsonSchema)] struct EncryptRequest { username: String, password: String, /// # data /// data to be encrypted data: String, } #[serde_as] #[derive(Serialize, Deserialize, JsonSchema)] pub struct DecryptRequest { username: String, password: String, ciphertext: EncryptedMessage, } #[serde_as] #[derive(Serialize, Deserialize, JsonSchema)] pub struct RotateRequest { ciphertext: EncryptedMessage, } #[catch(422)] fn serialize_failed(_req: &Request) -> String { format!("Malformed Request") } #[rocket::main] async fn main() { let cryptoservice_urls = vec!["http://localhost:8081"]; //,"http://localhost:9002","http://localhost:9003"]; let n = cryptoservice_urls.len(); let t = n; let rl_key = ophe::utils::random_scalar(); // we now also have server keys let s_key = ophe::utils::random_scalar(); let rl_public_key = pairing(&G1Affine::generator(), &G2Affine::generator()) * rl_key; println!("Generated public key: {:?}", rl_public_key); let keys = ophe::shamir::gen_shares_scalar(rl_key, n as i64, t as i64); let pps = keys .iter() .zip(cryptoservice_urls) .map(|(k, url)| async move { let client = reqwest::ClientBuilder::new() .tcp_keepalive(Some(Duration::from_secs(60))) .build() .unwrap(); let set_key_request = core::SetKeyHelper { key: k.clone() }; let res = client .post(&format!("{}/set_key", url)) .json(&set_key_request) .send() .await .map_err(|x| format!("Cryptoservice not reachable: {}", x)) .unwrap() .text() .await .unwrap(); let pp: PublicParameters = serde_json::from_str(&res) .map_err(|x| format!("Cryptoservice not reachable: {}", x)) .unwrap(); RlState { pp, url: url.to_string(), client, s: Scalar::zero(), } }) .collect::>(); let pps = join_all(pps).await; let pp_keys = pps .iter() .map(|x| x.pp.ratelimiter_public_key) .collect::>(); let pp_key = ophe::shamir::recover_shares(&pp_keys, n as i64); assert_eq!(pp_key, rl_public_key, "Public keys do not match"); let o_state = OpheState { pps, n: n as i64, t: t as i64, pk: pp_key, secret_key: s_key, }; println!("Received public parameters from crytoservice"); let launch_result = rocket::build() .mount("/", openapi_get_routes![encrypt, decrypt, rotate]) .manage(o_state) .mount( "/swagger-ui/", make_swagger_ui(&SwaggerUIConfig { url: "../openapi.json".to_owned(), ..Default::default() }), ) .mount( "/rapidoc/", make_rapidoc(&RapiDocConfig { general: GeneralConfig { spec_urls: vec![UrlObject::new("General", "../openapi.json")], ..Default::default() }, hide_show: HideShowConfig { allow_spec_url_load: false, allow_spec_file_load: false, ..Default::default() }, ..Default::default() }), ) .register("/", catchers![serialize_failed]) .launch() .await; match launch_result { Ok(_) => println!("Rocket shut down gracefully."), Err(err) => println!("Rocket had an error: {}", err), }; } #[openapi()] #[post("/encrypt", format = "json", data = "")] async fn encrypt( request: Json, o_state: &State, ) -> Result, (Status, Json)> { let (ss, request1) = core::phe_init(&request.username, &request.password); let mut msg = request.data.clone().into_bytes(); if msg.len() > 64 { return Err((Status::BadRequest, Json("Data too long.".to_string()))); } // pad to 64 bytes msg.resize(64, 0); let pps = o_state.pps.iter().map(|x| x.pp).collect::>(); let responses = get_ratelimiter_reponses(&request1, &o_state.pps) .await .map_err(|x| { ( Status::InternalServerError, Json("Decryption failed: ".to_string() + &x), ) })?; let ciphertext = core::phe_enc_finish_t(&msg, &pps, &responses, &ss, o_state.n, &o_state.secret_key) .map_err(|x| { ( Status::InternalServerError, Json("Decryption failed.".to_string() + &x), ) })?; Ok(Json(ciphertext)) } #[openapi()] #[post("/decrypt", format = "json", data = "")] async fn decrypt( request: Json, o_state: &State, ) -> Result, (Status, Json)> { let (ss, request1) = core::phe_init_decrypt(&request.username, &request.password, &request.ciphertext.n); let responses = get_ratelimiter_reponses(&request1, &o_state.pps) .await .map_err(|x| { ( Status::InternalServerError, Json("Decryption failed: ".to_string() + &x), ) })?; let pps = o_state.pps.iter().map(|x| x.pp).collect::>(); let res = core::phe_dec_finish_t( &request.ciphertext, &pps, &responses, &ss, o_state.n, &o_state.secret_key, ) .map_err(|x| { ( Status::InternalServerError, Json("Decryption failed: ".to_string() + &x), ) })?; Ok(Json( String::from_utf8(res) .map_err(|_x| { ( Status::InternalServerError, Json("Decryption failed. Utf8".to_string()), ) })? .trim_end_matches(char::from(0)) .to_string(), )) } #[openapi()] #[post("/rotate", format = "json", data = "")] async fn rotate( request: Json, o_state: &State, ) -> Result, (Status, Json)> { let sk_new = random_scalar(); let shares = shamir::gen_shares_scalar(Scalar::zero(), o_state.n, o_state.t); let ct_new = request.ciphertext.clone(); let responses = get_ratelimiter_rotate_reponses(&shares, &o_state.pps) .await .map_err(|x| { ( Status::InternalServerError, Json("Rotation failed: ".to_string() + &x), ) })?; Ok(Json(ct_new)) } async fn get_ratelimiter_reponses( request: &RatelimiterRequest, urls: &Vec, ) -> Result, String> { let responses = urls.iter().map(|x| async move { let res = x .client .post(&format!("{}/phe_help", x.url)) .json(&request) .send() .await .map_err(|_x| format!("Cryptoserice unreachable."))? .text() .await .map_err(|_x| format!("Cryptoserice unreachable."))?; let rs = serde_json::from_str(&res).map_err(|_x| format!("Invalid cryptoservice response."))?; Ok::<_, String>(rs) }); join_all(responses).await.into_iter().flatten().collect() } async fn get_ratelimiter_rotate_reponses( shares: &Vec, urls: &Vec, ) -> Result, String> { // insert shares to ratelimiter state let mut requests = Vec::new(); for (index, item) in urls.iter().enumerate() { let mut tmp_state: RlState = item.clone(); tmp_state.s = shares[index]; requests.push(tmp_state); } let responses = requests.iter().map(|x| async move { let request = RatelimiterRotateRequest { s: x.s }; let res = x .client .post(&format!("{}/rotate_key", x.url)) .json(&request) .send() .await .map_err(|_x| format!("Cryptoserice unreachable."))? .text() .await .map_err(|_x| format!("Cryptoserice unreachable."))?; let rs = serde_json::from_str(&res).map_err(|_x| format!("Invalid cryptoservice response."))?; Ok::<_, String>(rs) }); join_all(responses).await.into_iter().flatten().collect() }