Skip to content

Commit cd14087

Browse files
committed
Clean up more
Change-Id: I419e86af4a04a4cdbc93c0bb897ecbc1612d25dd
1 parent 84f78cd commit cd14087

9 files changed

Lines changed: 45 additions & 204 deletions

File tree

oak_private_memory/BUILD

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,6 @@ rust_library(
133133
],
134134
)
135135

136-
rust_library(
137-
name = "session",
138-
srcs = ["src/session.rs"],
139-
deps = [
140-
"@oak//oak_proto_rust",
141-
"@oak//oak_session",
142-
"@oak//oak_session/tls/rust:oak_session_tls",
143-
"@oak_crates_index//:anyhow",
144-
"@oak_crates_index//:prost",
145-
],
146-
)
147-
148136
rust_library(
149137
name = "session_binder",
150138
srcs = ["src/session_binder.rs"],
@@ -176,7 +164,6 @@ rust_library(
176164
":external_db_client",
177165
":log",
178166
":metrics",
179-
":session",
180167
":session_binder",
181168
"//app",
182169
"//database",

oak_private_memory/app/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ rust_library(
4040
"//:external_db_client",
4141
"//:log",
4242
"//:metrics",
43-
"//:session",
4443
"//database",
4544
"//proto:sealed_memory_rust_proto",
4645
"//proto/grpc:sealed_memory_grpc_proto",

oak_private_memory/app/service.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use sealed_memory_grpc_proto::oak::private_memory::sealed_memory_service_server:
3030
SealedMemoryService, SealedMemoryServiceServer,
3131
};
3232
use sealed_memory_rust_proto::{oak::private_memory::TlsSessionFrame, prelude::v1::*};
33-
use session::TlsEncryptedSession;
3433
use tokio::{net::TcpListener, sync::mpsc};
3534
use tokio_stream::{Stream, StreamExt, wrappers::TcpListenerStream};
3635

@@ -200,7 +199,7 @@ impl OakSessionHandler {
200199
/// data.
201200
struct TlsSessionHandler {
202201
metrics: Arc<metrics::Metrics>,
203-
session: TlsEncryptedSession,
202+
session: oak_session_tls::OakSessionTls,
204203
application_handler: SealedMemorySessionHandler,
205204
}
206205

@@ -214,7 +213,7 @@ impl TlsSessionHandler {
214213
) -> Self {
215214
Self {
216215
metrics: metrics.clone(),
217-
session: TlsEncryptedSession::new(tls_session),
216+
session: tls_session,
218217
application_handler: SealedMemorySessionHandler::new(
219218
metrics.clone(),
220219
persistence_tx.clone(),
@@ -229,10 +228,9 @@ impl TlsSessionHandler {
229228
pub async fn handle_app_request(&mut self, encrypted_request: &[u8]) -> tonic::Result<Vec<u8>> {
230229
self.metrics.inc_requests(RequestMetricName::total());
231230

232-
let decrypted_request = self
233-
.session
234-
.decrypt(encrypted_request)
235-
.into_invalid_argument("failed to decrypt TLS request")?;
231+
let decrypted_request = self.session.decrypt(encrypted_request).map_err(|e| {
232+
tonic::Status::invalid_argument(format!("failed to decrypt TLS request: {e}"))
233+
})?;
236234

237235
if decrypted_request.is_empty() {
238236
// This can happen if the TLS frame only contained handshake data
@@ -246,10 +244,9 @@ impl TlsSessionHandler {
246244
self.metrics.inc_failures(RequestMetricName::total());
247245
Err(e)
248246
}
249-
Ok(plaintext_response) => self
250-
.session
251-
.encrypt(&plaintext_response)
252-
.into_internal_error("failed to encrypt TLS response"),
247+
Ok(plaintext_response) => self.session.encrypt(&plaintext_response).map_err(|e| {
248+
tonic::Status::internal(format!("failed to encrypt TLS response: {e}"))
249+
}),
253250
}
254251
}
255252
}

oak_private_memory/src/client.rs

Lines changed: 34 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,16 @@ impl PrivateMemoryClient {
202202

203203
sealed_memory_response.response.ok_or_else(|| anyhow!("empty response"))
204204
}
205+
}
206+
207+
#[async_trait]
208+
pub trait PrivateMemoryAppClient {
209+
async fn invoke(
210+
&mut self,
211+
request: sealed_memory_request::Request,
212+
) -> Result<sealed_memory_response::Response>;
205213

206-
pub async fn register_user(
214+
async fn register_user(
207215
&mut self,
208216
pm_uid: &str,
209217
kek: &[u8],
@@ -221,11 +229,7 @@ impl PrivateMemoryClient {
221229
}
222230
}
223231

224-
pub async fn key_sync(
225-
&mut self,
226-
pm_uid: &str,
227-
kek: &[u8],
228-
) -> Result<key_sync_response::Status> {
232+
async fn key_sync(&mut self, pm_uid: &str, kek: &[u8]) -> Result<key_sync_response::Status> {
229233
let request =
230234
KeySyncRequest { pm_uid: pm_uid.to_string(), key_encryption_key: kek.to_vec() };
231235
let response = self.invoke(sealed_memory_request::Request::KeySyncRequest(request)).await?;
@@ -235,15 +239,15 @@ impl PrivateMemoryClient {
235239
}
236240
}
237241

238-
pub async fn add_memory(&mut self, memory: Memory) -> Result<AddMemoryResponse> {
242+
async fn add_memory(&mut self, memory: Memory) -> Result<AddMemoryResponse> {
239243
let request = AddMemoryRequest { memory: Some(memory) };
240244
let response =
241245
self.invoke(sealed_memory_request::Request::AddMemoryRequest(request)).await?;
242246
expect_response_type!(response, sealed_memory_response::Response::AddMemoryResponse)
243247
}
244248

245249
#[allow(deprecated)]
246-
pub async fn get_memories(
250+
async fn get_memories(
247251
&mut self,
248252
tag: &str,
249253
page_size: i32,
@@ -261,7 +265,7 @@ impl PrivateMemoryClient {
261265
expect_response_type!(response, sealed_memory_response::Response::GetMemoriesResponse)
262266
}
263267

264-
pub async fn get_memory_by_id(
268+
async fn get_memory_by_id(
265269
&mut self,
266270
id: &str,
267271
result_mask: Option<ResultMask>,
@@ -272,7 +276,7 @@ impl PrivateMemoryClient {
272276
expect_response_type!(response, sealed_memory_response::Response::GetMemoryByIdResponse)
273277
}
274278

275-
pub async fn search_memory(
279+
async fn search_memory(
276280
&mut self,
277281
query: SearchMemoryQuery,
278282
page_size: i32,
@@ -292,21 +296,21 @@ impl PrivateMemoryClient {
292296
expect_response_type!(response, sealed_memory_response::Response::SearchMemoryResponse)
293297
}
294298

295-
pub async fn delete_memory(&mut self, ids: Vec<String>) -> Result<DeleteMemoryResponse> {
299+
async fn delete_memory(&mut self, ids: Vec<String>) -> Result<DeleteMemoryResponse> {
296300
let request = DeleteMemoryRequest { ids };
297301
let response =
298302
self.invoke(sealed_memory_request::Request::DeleteMemoryRequest(request)).await?;
299303
expect_response_type!(response, sealed_memory_response::Response::DeleteMemoryResponse)
300304
}
301305

302-
pub async fn reset_memory(&mut self) -> Result<ResetMemoryResponse> {
306+
async fn reset_memory(&mut self) -> Result<ResetMemoryResponse> {
303307
let request = ResetMemoryRequest::default();
304308
let response =
305309
self.invoke(sealed_memory_request::Request::ResetMemoryRequest(request)).await?;
306310
expect_response_type!(response, sealed_memory_response::Response::ResetMemoryResponse)
307311
}
308312

309-
pub async fn get_memories_by_id(
313+
async fn get_memories_by_id(
310314
&mut self,
311315
ids: Vec<String>,
312316
result_mask: Option<ResultMask>,
@@ -317,7 +321,7 @@ impl PrivateMemoryClient {
317321
expect_response_type!(response, sealed_memory_response::Response::GetMemoriesByIdResponse)
318322
}
319323

320-
pub async fn get_database_metrics(&mut self) -> Result<GetDatabaseMetricsResponse> {
324+
async fn get_database_metrics(&mut self) -> Result<GetDatabaseMetricsResponse> {
321325
let request = GetDatabaseMetricsRequest::default();
322326
let response =
323327
self.invoke(sealed_memory_request::Request::GetDatabaseMetricsRequest(request)).await?;
@@ -328,6 +332,15 @@ impl PrivateMemoryClient {
328332
}
329333
}
330334

335+
#[async_trait]
336+
impl PrivateMemoryAppClient for PrivateMemoryClient {
337+
async fn invoke(
338+
&mut self,
339+
request: sealed_memory_request::Request,
340+
) -> Result<sealed_memory_response::Response> {
341+
self.invoke(request).await
342+
}
343+
}
331344
// ---------------------------------------------------------------------------
332345
// TLS Client Support
333346
// ---------------------------------------------------------------------------
@@ -460,109 +473,14 @@ impl PrivateMemoryTlsClient {
460473

461474
sealed_memory_response.response.ok_or_else(|| anyhow!("empty response"))
462475
}
476+
}
463477

464-
pub async fn register_user(
465-
&mut self,
466-
pm_uid: &str,
467-
kek: &[u8],
468-
) -> Result<user_registration_response::Status> {
469-
let request = UserRegistrationRequest {
470-
pm_uid: pm_uid.to_string(),
471-
key_encryption_key: kek.to_vec(),
472-
boot_strap_info: Some(KeyDerivationInfo::default()),
473-
};
474-
let response =
475-
self.invoke(sealed_memory_request::Request::UserRegistrationRequest(request)).await?;
476-
match response {
477-
sealed_memory_response::Response::UserRegistrationResponse(resp) => Ok(resp.status()),
478-
_ => Err(anyhow!("unexpected response type for user registration")),
479-
}
480-
}
481-
482-
pub async fn key_sync(
483-
&mut self,
484-
pm_uid: &str,
485-
kek: &[u8],
486-
) -> Result<key_sync_response::Status> {
487-
let request =
488-
KeySyncRequest { pm_uid: pm_uid.to_string(), key_encryption_key: kek.to_vec() };
489-
let response = self.invoke(sealed_memory_request::Request::KeySyncRequest(request)).await?;
490-
match response {
491-
sealed_memory_response::Response::KeySyncResponse(resp) => Ok(resp.status()),
492-
_ => Err(anyhow!("unexpected response type for key sync")),
493-
}
494-
}
495-
496-
pub async fn add_memory(&mut self, memory: Memory) -> Result<AddMemoryResponse> {
497-
let request = AddMemoryRequest { memory: Some(memory) };
498-
let response =
499-
self.invoke(sealed_memory_request::Request::AddMemoryRequest(request)).await?;
500-
expect_response_type!(response, sealed_memory_response::Response::AddMemoryResponse)
501-
}
502-
503-
pub async fn search_memory(
504-
&mut self,
505-
query: SearchMemoryQuery,
506-
page_size: i32,
507-
result_mask: Option<ResultMask>,
508-
page_token: &str,
509-
keep_all_llm_views: bool,
510-
) -> Result<SearchMemoryResponse> {
511-
let request = SearchMemoryRequest {
512-
query: Some(query),
513-
page_size,
514-
result_mask,
515-
page_token: page_token.to_string(),
516-
keep_all_llm_views,
517-
};
518-
let response =
519-
self.invoke(sealed_memory_request::Request::SearchMemoryRequest(request)).await?;
520-
expect_response_type!(response, sealed_memory_response::Response::SearchMemoryResponse)
521-
}
522-
523-
pub async fn delete_memory(&mut self, ids: Vec<String>) -> Result<DeleteMemoryResponse> {
524-
let request = DeleteMemoryRequest { ids };
525-
let response =
526-
self.invoke(sealed_memory_request::Request::DeleteMemoryRequest(request)).await?;
527-
expect_response_type!(response, sealed_memory_response::Response::DeleteMemoryResponse)
528-
}
529-
530-
pub async fn reset_memory(&mut self) -> Result<ResetMemoryResponse> {
531-
let request = ResetMemoryRequest::default();
532-
let response =
533-
self.invoke(sealed_memory_request::Request::ResetMemoryRequest(request)).await?;
534-
expect_response_type!(response, sealed_memory_response::Response::ResetMemoryResponse)
535-
}
536-
537-
pub async fn get_memory_by_id(
538-
&mut self,
539-
id: &str,
540-
result_mask: Option<ResultMask>,
541-
) -> Result<GetMemoryByIdResponse> {
542-
let request = GetMemoryByIdRequest { id: id.to_string(), result_mask };
543-
let response =
544-
self.invoke(sealed_memory_request::Request::GetMemoryByIdRequest(request)).await?;
545-
expect_response_type!(response, sealed_memory_response::Response::GetMemoryByIdResponse)
546-
}
547-
548-
pub async fn get_memories_by_id(
478+
#[async_trait]
479+
impl PrivateMemoryAppClient for PrivateMemoryTlsClient {
480+
async fn invoke(
549481
&mut self,
550-
ids: Vec<String>,
551-
result_mask: Option<ResultMask>,
552-
) -> Result<GetMemoriesByIdResponse> {
553-
let request = GetMemoriesByIdRequest { ids, result_mask };
554-
let response =
555-
self.invoke(sealed_memory_request::Request::GetMemoriesByIdRequest(request)).await?;
556-
expect_response_type!(response, sealed_memory_response::Response::GetMemoriesByIdResponse)
557-
}
558-
559-
pub async fn get_database_metrics(&mut self) -> Result<GetDatabaseMetricsResponse> {
560-
let request = GetDatabaseMetricsRequest::default();
561-
let response =
562-
self.invoke(sealed_memory_request::Request::GetDatabaseMetricsRequest(request)).await?;
563-
expect_response_type!(
564-
response,
565-
sealed_memory_response::Response::GetDatabaseMetricsResponse
566-
)
482+
request: sealed_memory_request::Request,
483+
) -> Result<sealed_memory_response::Response> {
484+
self.invoke(request).await
567485
}
568486
}

oak_private_memory/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,4 @@ pub use encryption;
1818
pub use external_db_client;
1919
pub use log;
2020
pub use metrics;
21-
pub use session;
2221
pub use session_binder;

oak_private_memory/src/session.rs

Lines changed: 0 additions & 59 deletions
This file was deleted.

oak_private_memory/test/client_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use attestation_testing::{
2020
DUMMY_ATTESTATION_ID, DummySessionBindingVerifierProvider, RejectingVerifier,
2121
dummy_client_session_config,
2222
};
23-
use client::PrivateMemoryClient;
23+
use client::{PrivateMemoryAppClient, PrivateMemoryClient};
2424
use oak_session::{attestation::AttestationType, config::SessionConfig, handshake::HandshakeType};
2525
use private_memory_test_utils::{start_server, system_time_to_timestamp};
2626
use sealed_memory_rust_proto::{

0 commit comments

Comments
 (0)