Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,23 @@ mod tests {

use super::*;
use defs::ContentType;
use tempfile::tempdir;
use tempfile::{TempDir, tempdir};

// Helper function to create a test database
fn create_test_db() -> VectorDb {
fn create_test_db() -> (VectorDb, TempDir) {
let temp_dir = tempdir().unwrap();
let config = DbConfig {
storage_type: StorageType::RocksDb,
index_type: IndexType::Flat,
data_path: temp_dir.path().to_path_buf(),
dimension: 3,
};
init_api(config).unwrap()
(init_api(config).unwrap(), temp_dir)
}

#[test]
fn test_insert_and_get() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();
let vector = vec![1.0, 2.0, 3.0];
let payload = Payload {
content_type: ContentType::Text,
Expand All @@ -209,7 +209,7 @@ mod tests {

#[test]
fn test_dimension_mismatch() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![1.0, 2.0];
let payload = defs::Payload {
Expand All @@ -228,7 +228,7 @@ mod tests {

#[test]
fn test_delete() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();
let vector = vec![1.0, 2.0, 3.0];
let payload = Payload {
content_type: ContentType::Text,
Expand All @@ -251,7 +251,7 @@ mod tests {

#[test]
fn test_search() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();

// Insert some points
let vectors = vec![
Expand Down Expand Up @@ -280,7 +280,7 @@ mod tests {

#[test]
fn test_search_limit() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();

// Insert 5 points
let mut ids = Vec::new();
Expand All @@ -307,7 +307,7 @@ mod tests {

#[test]
fn test_empty_database() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();

// Get non-existent point
assert!(db.get(Uuid::new_v4()).unwrap().is_none());
Expand All @@ -319,7 +319,7 @@ mod tests {

#[test]
fn test_list_vectors() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();
// insert some points
let mut ids = Vec::new();
for i in 0..10 {
Expand Down Expand Up @@ -350,7 +350,7 @@ mod tests {

#[test]
fn test_build_index() {
let db = create_test_db();
let (db, _temp_dir) = create_test_db();

// insert some points
for i in 0..10 {
Expand Down
16 changes: 8 additions & 8 deletions crates/grpc/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use index::IndexType;
use std::net::SocketAddr;
use std::sync::Arc;
use storage::StorageType;
use tempfile::tempdir;
use tempfile::{TempDir, tempdir};
use tonic::transport::Channel;

// Inspired from https://github.com/hyperium/tonic/discussions/924#discussioncomment-9854088
Expand All @@ -22,7 +22,7 @@ fn append_test_auth_header<T>(request: &mut tonic::Request<T>, token: &str) {
.insert(AUTHORIZATION_HEADER_KEY, auth_value.parse().unwrap());
}

async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error>> {
async fn start_test_server() -> Result<(SocketAddr, TempDir), Box<dyn std::error::Error>> {
// using a temporary directory for db datapath
let temp_dir = tempdir().unwrap();

Expand Down Expand Up @@ -50,7 +50,7 @@ async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error>> {
.inspect_err(|err| panic!("Could not start test server : {:?}", err));
});

Ok(listener_addr)
Ok((listener_addr, temp_dir))
}

async fn create_test_client(
Expand All @@ -65,7 +65,7 @@ async fn create_test_client(

#[tokio::test]
async fn test_grpc_server_start() {
let server_addr = start_test_server().await.unwrap();
let (server_addr, _temp_dir) = start_test_server().await.unwrap();
let mut client = create_test_client(server_addr).await.unwrap();

// insert a test vector
Expand All @@ -85,7 +85,7 @@ async fn test_grpc_server_start() {

#[tokio::test]
async fn test_insert_vector_rpc() {
let server_addr = start_test_server().await.unwrap();
let (server_addr, _temp_dir) = start_test_server().await.unwrap();
let mut client = create_test_client(server_addr).await.unwrap();

// insert a test vector
Expand Down Expand Up @@ -133,7 +133,7 @@ async fn test_insert_vector_rpc() {

#[tokio::test]
async fn test_delete_vector_rpc() {
let server_addr = start_test_server().await.unwrap();
let (server_addr, _temp_dir) = start_test_server().await.unwrap();
let mut client = create_test_client(server_addr).await.unwrap();

// insert a test vector
Expand Down Expand Up @@ -175,7 +175,7 @@ async fn test_delete_vector_rpc() {

#[tokio::test]
async fn test_search_vector_rpc() {
let server_addr = start_test_server().await.unwrap();
let (server_addr, _temp_dir) = start_test_server().await.unwrap();
let mut client = create_test_client(server_addr).await.unwrap();

// insert a test vector
Expand Down Expand Up @@ -221,7 +221,7 @@ async fn test_search_vector_rpc() {

#[tokio::test]
async fn test_unauthorized_rpc() {
let server_addr = start_test_server().await.unwrap();
let (server_addr, _temp_dir) = start_test_server().await.unwrap();
let mut client = create_test_client(server_addr).await.unwrap();

// insert a test vector
Expand Down
1 change: 0 additions & 1 deletion crates/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ prost.workspace = true
serde.workspace = true
serde_json.workspace = true
storage.workspace = true
tempfile.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tonic.workspace = true
Expand Down
7 changes: 1 addition & 6 deletions crates/server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::fs;
use std::net::SocketAddr;
use std::path::PathBuf;
use storage::StorageType;
use tempfile::tempdir;
use tracing::{Level, event};

const DEFAULT_HTTP_PORT: &str = "3000";
Expand Down Expand Up @@ -146,11 +145,7 @@ impl ServerConfig {
fs::create_dir_all(&path).map_err(|_| ConfigError::InvalidDataPath)?;
path
} else {
let tempbuf = tempdir()
.map_err(|e| ConfigError::IoError(e.to_string()))?
.path()
.to_path_buf()
.join("vectordb");
let tempbuf = env::temp_dir().join("vectordb");
fs::create_dir_all(&tempbuf).map_err(|e| ConfigError::IoError(e.to_string()))?;
event!(
Level::WARN,
Expand Down
38 changes: 12 additions & 26 deletions crates/storage/src/rocks_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,24 @@ mod tests {
use defs::ContentType;
use uuid::Uuid;

use tempfile::tempdir;
use tempfile::{TempDir, tempdir};

fn create_test_db() -> (RocksDbStorage, String) {
fn create_test_db() -> (RocksDbStorage, TempDir) {
let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_str().unwrap().to_string();

let db = RocksDbStorage::new(temp_dir_path.clone()).expect("Failed to create RocksDB");
(db, temp_dir_path)
let db = RocksDbStorage::new(temp_dir.path()).expect("Failed to create RocksDB");
(db, temp_dir)
}

#[test]
fn test_new_rocksdb_storage() {
let (db, path) = create_test_db();
assert_eq!(db.get_current_path(), PathBuf::from(path.clone()));
std::fs::remove_dir_all(path).unwrap_or_default();
let (db, temp_dir) = create_test_db();
assert_eq!(db.get_current_path(), PathBuf::from(temp_dir.path()));
}

#[test]
fn test_insert_and_get_vector() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();
let vector = Some(vec![0.1, 0.2, 0.3]);
let payload = Some(Payload {
Expand All @@ -190,13 +188,11 @@ mod tests {
assert!(db.insert_point(id, vector.clone(), payload).is_ok());
let result = db.get_vector(id).unwrap();
assert_eq!(result, vector);

std::fs::remove_dir_all(path).unwrap_or_default();
}

#[test]
fn test_insert_and_get_payload() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();
let payload = Some(Payload {
content_type: ContentType::Text,
Expand All @@ -212,13 +208,11 @@ mod tests {
content: "Test".to_string(),
});
assert_eq!(result, expected);

std::fs::remove_dir_all(path).unwrap_or_default();
}

#[test]
fn test_contains_point() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();
let payload = Some(Payload {
content_type: ContentType::Text,
Expand All @@ -231,13 +225,11 @@ mod tests {
db.insert_point(id, vector, payload).unwrap();

assert!(db.contains_point(id).unwrap());

std::fs::remove_dir_all(path).unwrap_or_default();
}

#[test]
fn test_delete_point() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();
let payload = Some(Payload {
content_type: ContentType::Text,
Expand All @@ -255,27 +247,21 @@ mod tests {
assert!(!db.contains_point(id).unwrap());
assert_eq!(db.get_vector(id).unwrap(), None);
assert_eq!(db.get_payload(id).unwrap(), None);

std::fs::remove_dir_all(path).unwrap_or_default();
}

#[test]
fn test_get_nonexistent_vector() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();

assert_eq!(db.get_vector(id).unwrap(), None);

std::fs::remove_dir_all(path).unwrap_or_default();
}

#[test]
fn test_get_nonexistent_payload() {
let (db, path) = create_test_db();
let (db, _temp_dir) = create_test_db();
let id = Uuid::new_v4();

assert_eq!(db.get_payload(id).unwrap(), None);

std::fs::remove_dir_all(path).unwrap_or_default();
}
}
Loading