2025-09-11 10:32:09 -04:00

342 lines
9.6 KiB
Rust

#[macro_use]
extern crate rocket;
extern crate bls12_381;
extern crate rand;
use crate::core::EncryptedMessage;
use bls12_381::pairing;
use bls12_381::G1Affine;
use bls12_381::G2Affine;
use bls12_381::Gt;
use bls12_381::Scalar;
use rocket::request;
use rocket::serde::json::Json;
use rocket::Request;
use rocket_okapi::settings::UrlObject;
use rocket_okapi::{openapi, openapi_get_routes, rapidoc::*, swagger_ui::*};
use std::time::Duration;
use ophe::core;
use ophe::core::{
PublicParameters, RatelimiterRequest, RatelimiterResponse, RatelimiterRotateRequest,
RatelimiterRotateResponse,
};
use ophe::shamir;
use ophe::utils::random_scalar;
use futures::future::join_all;
use rocket::State;
use rocket_okapi::okapi::schemars;
use rocket_okapi::okapi::schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use rocket::http::Status;
#[derive(Clone)]
struct RlState {
pp: PublicParameters,
url: String,
client: reqwest::Client,
s: Scalar,
}
struct OpheState {
pps: Vec<RlState>,
n: i64,
t: i64,
pk: Gt,
secret_key: Scalar,
}
#[derive(Serialize, Deserialize, JsonSchema)]
struct EncryptRequest {
username: String,
password: String,
/// # data
/// data to be encrypted
data: String,
}
#[serde_as]
#[derive(Serialize, Deserialize, JsonSchema)]
pub struct DecryptRequest {
username: String,
password: String,
ciphertext: EncryptedMessage,
}
#[serde_as]
#[derive(Serialize, Deserialize, JsonSchema)]
pub struct RotateRequest {
ciphertext: EncryptedMessage,
}
#[catch(422)]
fn serialize_failed(_req: &Request) -> String {
format!("Malformed Request")
}
#[rocket::main]
async fn main() {
let cryptoservice_urls = vec!["http://localhost:8081"]; //,"http://localhost:9002","http://localhost:9003"];
let n = cryptoservice_urls.len();
let t = n;
let rl_key = ophe::utils::random_scalar();
// we now also have server keys
let s_key = ophe::utils::random_scalar();
let rl_public_key = pairing(&G1Affine::generator(), &G2Affine::generator()) * rl_key;
println!("Generated public key: {:?}", rl_public_key);
let keys = ophe::shamir::gen_shares_scalar(rl_key, n as i64, t as i64);
let pps = keys
.iter()
.zip(cryptoservice_urls)
.map(|(k, url)| async move {
let client = reqwest::ClientBuilder::new()
.tcp_keepalive(Some(Duration::from_secs(60)))
.build()
.unwrap();
let set_key_request = core::SetKeyHelper { key: k.clone() };
let res = client
.post(&format!("{}/set_key", url))
.json(&set_key_request)
.send()
.await
.map_err(|x| format!("Cryptoservice not reachable: {}", x))
.unwrap()
.text()
.await
.unwrap();
let pp: PublicParameters = serde_json::from_str(&res)
.map_err(|x| format!("Cryptoservice not reachable: {}", x))
.unwrap();
RlState {
pp,
url: url.to_string(),
client,
s: Scalar::zero(),
}
})
.collect::<Vec<_>>();
let pps = join_all(pps).await;
let pp_keys = pps
.iter()
.map(|x| x.pp.ratelimiter_public_key)
.collect::<Vec<_>>();
let pp_key = ophe::shamir::recover_shares(&pp_keys, n as i64);
assert_eq!(pp_key, rl_public_key, "Public keys do not match");
let o_state = OpheState {
pps,
n: n as i64,
t: t as i64,
pk: pp_key,
secret_key: s_key,
};
println!("Received public parameters from crytoservice");
let launch_result = rocket::build()
.mount("/", openapi_get_routes![encrypt, decrypt, rotate])
.manage(o_state)
.mount(
"/swagger-ui/",
make_swagger_ui(&SwaggerUIConfig {
url: "../openapi.json".to_owned(),
..Default::default()
}),
)
.mount(
"/rapidoc/",
make_rapidoc(&RapiDocConfig {
general: GeneralConfig {
spec_urls: vec![UrlObject::new("General", "../openapi.json")],
..Default::default()
},
hide_show: HideShowConfig {
allow_spec_url_load: false,
allow_spec_file_load: false,
..Default::default()
},
..Default::default()
}),
)
.register("/", catchers![serialize_failed])
.launch()
.await;
match launch_result {
Ok(_) => println!("Rocket shut down gracefully."),
Err(err) => println!("Rocket had an error: {}", err),
};
}
#[openapi()]
#[post("/encrypt", format = "json", data = "<request>")]
async fn encrypt(
request: Json<EncryptRequest>,
o_state: &State<OpheState>,
) -> Result<Json<EncryptedMessage>, (Status, Json<String>)> {
let (ss, request1) = core::phe_init(&request.username, &request.password);
let mut msg = request.data.clone().into_bytes();
if msg.len() > 64 {
return Err((Status::BadRequest, Json("Data too long.".to_string())));
}
// pad to 64 bytes
msg.resize(64, 0);
let pps = o_state.pps.iter().map(|x| x.pp).collect::<Vec<_>>();
let responses = get_ratelimiter_reponses(&request1, &o_state.pps)
.await
.map_err(|x| {
(
Status::InternalServerError,
Json("Decryption failed: ".to_string() + &x),
)
})?;
let ciphertext =
core::phe_enc_finish_t(&msg, &pps, &responses, &ss, o_state.n, &o_state.secret_key)
.map_err(|x| {
(
Status::InternalServerError,
Json("Decryption failed.".to_string() + &x),
)
})?;
Ok(Json(ciphertext))
}
#[openapi()]
#[post("/decrypt", format = "json", data = "<request>")]
async fn decrypt(
request: Json<DecryptRequest>,
o_state: &State<OpheState>,
) -> Result<Json<String>, (Status, Json<String>)> {
let (ss, request1) =
core::phe_init_decrypt(&request.username, &request.password, &request.ciphertext.n);
let responses = get_ratelimiter_reponses(&request1, &o_state.pps)
.await
.map_err(|x| {
(
Status::InternalServerError,
Json("Decryption failed: ".to_string() + &x),
)
})?;
let pps = o_state.pps.iter().map(|x| x.pp).collect::<Vec<_>>();
let res = core::phe_dec_finish_t(
&request.ciphertext,
&pps,
&responses,
&ss,
o_state.n,
&o_state.secret_key,
)
.map_err(|x| {
(
Status::InternalServerError,
Json("Decryption failed: ".to_string() + &x),
)
})?;
Ok(Json(
String::from_utf8(res)
.map_err(|_x| {
(
Status::InternalServerError,
Json("Decryption failed. Utf8".to_string()),
)
})?
.trim_end_matches(char::from(0))
.to_string(),
))
}
#[openapi()]
#[post("/rotate", format = "json", data = "<request>")]
async fn rotate(
request: Json<RotateRequest>,
o_state: &State<OpheState>,
) -> Result<Json<EncryptedMessage>, (Status, Json<String>)> {
let sk_new = random_scalar();
let shares = shamir::gen_shares_scalar(Scalar::zero(), o_state.n, o_state.t);
let ct_new = request.ciphertext.clone();
let responses = get_ratelimiter_rotate_reponses(&shares, &o_state.pps)
.await
.map_err(|x| {
(
Status::InternalServerError,
Json("Rotation failed: ".to_string() + &x),
)
})?;
Ok(Json(ct_new))
}
async fn get_ratelimiter_reponses(
request: &RatelimiterRequest,
urls: &Vec<RlState>,
) -> Result<Vec<RatelimiterResponse>, String> {
let responses = urls.iter().map(|x| async move {
let res = x
.client
.post(&format!("{}/phe_help", x.url))
.json(&request)
.send()
.await
.map_err(|_x| format!("Cryptoserice unreachable."))?
.text()
.await
.map_err(|_x| format!("Cryptoserice unreachable."))?;
let rs =
serde_json::from_str(&res).map_err(|_x| format!("Invalid cryptoservice response."))?;
Ok::<_, String>(rs)
});
join_all(responses).await.into_iter().flatten().collect()
}
async fn get_ratelimiter_rotate_reponses(
shares: &Vec<Scalar>,
urls: &Vec<RlState>,
) -> Result<Vec<RatelimiterRotateResponse>, String> {
// insert shares to ratelimiter state
let mut requests = Vec::new();
for (index, item) in urls.iter().enumerate() {
let mut tmp_state: RlState = item.clone();
tmp_state.s = shares[index];
requests.push(tmp_state);
}
let responses = requests.iter().map(|x| async move {
let request = RatelimiterRotateRequest { s: x.s };
let res = x
.client
.post(&format!("{}/rotate_key", x.url))
.json(&request)
.send()
.await
.map_err(|_x| format!("Cryptoserice unreachable."))?
.text()
.await
.map_err(|_x| format!("Cryptoserice unreachable."))?;
let rs =
serde_json::from_str(&res).map_err(|_x| format!("Invalid cryptoservice response."))?;
Ok::<_, String>(rs)
});
join_all(responses).await.into_iter().flatten().collect()
}