Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions backend/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,40 @@ def get_people_by_ids(uid: str, person_ids: list[str]):
return all_people


def add_shared_person_to_user(target_uid: str, source_uid: str, name: str, embedding: list, profile_url: str = None) -> None:
"""
Add a shared person entry to target user's `shared_people` subcollection.

The document id will be the source_uid to make revoke simple.
"""
user_ref = db.collection('users').document(target_uid)
shared_ref = user_ref.collection('shared_people').document(source_uid)
data = {
'id': source_uid,
'source_uid': source_uid,
'name': name,
'speaker_embedding': embedding,
'profile_url': profile_url,
}
shared_ref.set(data, merge=True)


def remove_shared_person_from_user(target_uid: str, source_uid: str) -> bool:
"""Remove a previously shared person/profile (by source_uid) from target user's shared_people."""
shared_ref = db.collection('users').document(target_uid).collection('shared_people').document(source_uid)
if not shared_ref.get().exists:
return False
shared_ref.delete()
return True


def get_shared_people(target_uid: str) -> list:
"""Return list of shared people documents for the target user."""
shared_ref = db.collection('users').document(target_uid).collection('shared_people')
docs = shared_ref.stream()
return [d.to_dict() for d in docs]


def update_person(uid: str, person_id: str, name: str):
person_ref = db.collection('users').document(uid).collection('people').document(person_id)
person_ref.update({'name': name})
Expand Down
71 changes: 71 additions & 0 deletions backend/routers/speech_profile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
from typing import Optional

import av
Expand Down Expand Up @@ -26,6 +27,8 @@
get_user_has_speech_profile,
)
from utils.stt.vad import apply_vad_for_speech_profile
from database.users import add_shared_person_to_user, remove_shared_person_from_user, get_shared_people, get_person
from database import redis_db

router = APIRouter()

Expand Down Expand Up @@ -101,3 +104,71 @@ def get_extra_speech_profile_samples(person_id: Optional[str] = None, uid: str =
if person_id:
return get_user_person_speech_samples(uid, person_id)
return get_additional_profile_recordings(uid)


@router.post('/v3/speech-profile/share', tags=['v3'])
def share_speech_profile_with_user(target_uid: str, name: str = None, source_person_id: str = None, uid: str = Depends(auth.get_current_user_uid)):
"""Share the caller's speech profile (embedding+metadata) with another user (target_uid).

- This will read the caller's profile embedding stored under their people or main profile
and add a shared_people doc under the target user's account keyed by the caller uid.
- For now we add a minimal record: source_uid, name, speaker_embedding (if exists), profile_url
"""
# Load specified person embedding if provided, else try to pick a reasonable source
person_doc = None
embedding = []
profile_url = get_profile_audio_if_exists(uid, download=False)

if source_person_id:
person_doc = get_person(uid, source_person_id)
else:
try:
from database.users import get_people

all_people = get_people(uid)
if all_people:
person_doc = all_people[0]
except Exception as e:
print(f"Failed to get people for user {uid} when sharing profile: {e}")
person_doc = None

if person_doc and person_doc.get('speaker_embedding'):
embedding = person_doc.get('speaker_embedding')

person_name = name or (person_doc.get('name') if person_doc else 'Unknown')
add_shared_person_to_user(target_uid, uid, person_name, embedding or [], profile_url)

# Notify target user's active sessions via Redis pubsub
try:
channel = f'users:{target_uid}:shared_profiles'
payload = {
'action': 'add',
'source_uid': uid,
'name': person_name,
'speaker_embedding': embedding or [],
}
redis_db.r.publish(channel, json.dumps(payload))
except Exception as e:
print(f"Failed to publish 'add' shared profile notification to Redis for user {target_uid}: {e}")

return {'status': 'ok'}


@router.post('/v3/speech-profile/revoke', tags=['v3'])
def revoke_speech_profile_from_user(target_uid: str, uid: str = Depends(auth.get_current_user_uid)):
"""Revoke a previously shared speech profile from target_uid (remove shared doc)."""
success = remove_shared_person_from_user(target_uid, uid)
try:
channel = f'users:{target_uid}:shared_profiles'
payload = {'action': 'remove', 'source_uid': uid}
redis_db.r.publish(channel, json.dumps(payload))
except Exception as e:
print(f"Failed to publish 'remove' shared profile notification to Redis for user {target_uid}: {e}")
return {'status': 'ok' if success else 'not_found'}


@router.get('/v3/speech-profile/shared', tags=['v3'])
def list_shared_profiles(uid: str = Depends(auth.get_current_user_uid)):
"""List profiles shared with the current user."""
shared = get_shared_people(uid)
return {'shared': shared}
72 changes: 70 additions & 2 deletions backend/routers/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,10 +1329,26 @@ async def speaker_identification_task():
nonlocal websocket_active, speaker_to_person_map
nonlocal person_embeddings_cache, audio_ring_buffer

def _load_shared_people_into_cache(uid: str, cache: dict):
try:
shared = user_db.get_shared_people(uid)
for s in shared:
sid = s.get('id')
emb = s.get('speaker_embedding')
name = s.get('name') or f"Shared-{sid}"
if emb:
ns_id = f"shared_{sid}"
cache[ns_id] = {
'embedding': np.array(emb, dtype=np.float32).reshape(1, -1),
'name': name,
}
except Exception as e:
print(f"Speaker ID: failed loading shared people: {e}", uid, session_id)

if not speaker_id_enabled:
return

# Load person embeddings
# Load person embeddings (own people + profiles shared with this user)
try:
people = user_db.get_people(uid)
for person in people:
Expand All @@ -1342,7 +1358,11 @@ async def speaker_identification_task():
'embedding': np.array(emb, dtype=np.float32).reshape(1, -1),
'name': person['name'],
}
print(f"Speaker ID: loaded {len(person_embeddings_cache)} person embeddings", uid, session_id)

# Also load shared people (profiles other users shared with this user)
_load_shared_people_into_cache(uid, person_embeddings_cache)

print(f"Speaker ID: loaded {len(person_embeddings_cache)} person embeddings (including shared)", uid, session_id)
except Exception as e:
print(f"Speaker ID: failed to load embeddings: {e}", uid, session_id)
return
Expand All @@ -1351,6 +1371,50 @@ async def speaker_identification_task():
print("Speaker ID: no stored embeddings, task disabled", uid, session_id)
return

# Start a background listener for shared profile pubsub updates so we can refresh in real-time
async def _shared_profiles_listener():
nonlocal person_embeddings_cache
try:
import aioredis

redis_url = f"redis://{os.getenv('REDIS_DB_HOST')}:{os.getenv('REDIS_DB_PORT') or 6379}"
sub = await aioredis.create_redis(redis_url)
ch, = await sub.subscribe(f'users:{uid}:shared_profiles')
while websocket_active:
msg = await ch.get(encoding='utf-8')
if not msg:
await asyncio.sleep(0.1)
continue
try:
payload = json.loads(msg)
action = payload.get('action')
source_uid = payload.get('source_uid')
if action == 'add':
# Use embedding from payload directly to avoid DB roundtrip
emb = payload.get('speaker_embedding')
source_uid = payload.get('source_uid')
name = payload.get('name') or f"Shared-{source_uid}"
if emb and source_uid:
try:
ns_id = f"shared_{source_uid}"
person_embeddings_cache[ns_id] = {
'embedding': np.array(emb, dtype=np.float32).reshape(1, -1),
'name': name,
}
except Exception as e:
print(f"Failed to load embedding from payload for user {uid}, source {source_uid}: {e}")
elif action == 'remove':
ns_id = f"shared_{source_uid}"
if ns_id in person_embeddings_cache:
del person_embeddings_cache[ns_id]
except Exception as e:
print(f"Error processing shared profile message for user {uid}: {e}. Payload: {msg}")
continue
except Exception as e:
print(f"Shared profiles listener error: {e}", uid, session_id)

shared_listener_task = asyncio.create_task(_shared_profiles_listener())

# Consume loop
while websocket_active:
try:
Expand All @@ -1369,6 +1433,10 @@ async def speaker_identification_task():
asyncio.create_task(_match_speaker_embedding(speaker_id, seg))

print("Speaker ID task ended", uid, session_id)
try:
shared_listener_task.cancel()
except Exception:
pass

async def _match_speaker_embedding(speaker_id: int, segment: dict):
"""Extract audio from ring buffer and match against stored embeddings."""
Expand Down
38 changes: 38 additions & 0 deletions backend/scripts/SMOKE_TEST.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Smoke test: share/revoke speech profile

Purpose
- Quick verification that share and revoke endpoints work and that Redis pubsub notifications are emitted.

Prerequisites
- A staging backend accessible via environment variable API_BASE.
- User A (sharer) and user B (recipient) tokens (`A_TOKEN`, `B_TOKEN`) and the UID of user B (`B_UID`).
- (Optional) Redis connection if you want to observe pubsub events.
- `aioredis` must be available in the runtime for `/v4/listen` sessions to subscribe to pubsub (add to requirements if needed).
- Note: share payload now includes `speaker_embedding` so listeners can update caches without a DB roundtrip.

How to run

Set environment variables (example):

```bash
export API_BASE=https://staging-api.example.com
export A_TOKEN=<A bearer token>
export B_TOKEN=<B bearer token>
export B_UID=<uid of B>
export SOURCE_PERSON_ID=<optional person id owned by A>
# optional: redis info
export REDIS_HOST=redis.example.com
export REDIS_PORT=6379
export REDIS_PASSWORD=<password if needed>

python3 backend/scripts/smoke_shared_profile_test.py
```

Expected
- The share endpoint returns status 200 and the response body `{"status":"ok"}`.
- If Redis is reachable, a pubsub message should be observed on `users:{B_UID}:shared_profiles` with action `add`.
- The subsequent GET /v3/speech-profile/shared as B shows the shared doc.
- After revoke, GET /v3/speech-profile/shared does not include the shared doc and a `remove` event is published.

Notes
- This test does not exercise the `/v4/listen` real-time matching, which requires streaming audio and a running STT/embedding stack. Use a separate test to stream audio and confirm `SpeakerLabelSuggestionEvent` arrives in a live session.
90 changes: 90 additions & 0 deletions backend/scripts/smoke_shared_profile_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Smoke test for sharing/revoking speech profiles.

Run this from a machine with network access to your staging backend and Redis.
It performs:
1) POST /v3/speech-profile/share as user A to share a person with user B
2) GET /v3/speech-profile/shared as user B to verify the doc exists
3) (optional) subscribe to Redis channel users:{B}:shared_profiles to observe live pubsub
4) POST /v3/speech-profile/revoke as user A to revoke
5) GET /v3/speech-profile/shared as user B to verify removal

Environment variables required:
- API_BASE (e.g. https://staging-api.omi.me)
- A_TOKEN (Bearer token for sharing user A)
- B_TOKEN (Bearer token for target user B)
- B_UID (uid of target user B)
- SOURCE_PERSON_ID (optional person id owned by user A to share)
- REDIS_HOST, REDIS_PORT, REDIS_PASSWORD (optional, for listening to pubsub)

This script does not create persons or embeddings. Ensure user A has a person with a stored speaker_embedding if you want matching to work.
"""

import os
import time
import json
import requests

API_BASE = os.getenv('API_BASE')
A_TOKEN = os.getenv('A_TOKEN')
B_TOKEN = os.getenv('B_TOKEN')
B_UID = os.getenv('B_UID')
SOURCE_PERSON_ID = os.getenv('SOURCE_PERSON_ID')

if not API_BASE or not A_TOKEN or not B_TOKEN or not B_UID:
print('Missing required environment variables. See header of this script for details.')
exit(1)

headers_a = {'Authorization': f'Bearer {A_TOKEN}'}
headers_b = {'Authorization': f'Bearer {B_TOKEN}'}

print('1) Share profile from A -> B')
params = {'target_uid': B_UID}
if SOURCE_PERSON_ID:
params['source_person_id'] = SOURCE_PERSON_ID

resp = requests.post(f'{API_BASE}/v3/speech-profile/share', headers=headers_a, params=params)
print('share status:', resp.status_code, resp.text)

# optionally listen for Redis pubsub message
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT') or 6379)
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')

if REDIS_HOST:
try:
import redis
print('Connecting to Redis to subscribe to pubsub channel...')
r = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
p = r.pubsub()
channel = f'users:{B_UID}:shared_profiles'
p.subscribe(channel)
print('Subscribed to', channel, '- waiting up to 5 seconds for a message...')
msg = None
start = time.time()
while time.time() - start < 5:
m = p.get_message()
if m and m.get('type') == 'message':
msg = m['data']
break
time.sleep(0.2)
print('pubsub message:', msg)
except Exception as e:
print('Redis listen failed:', str(e))
else:
print('REDIS_HOST not set; skipping pubsub listening')

print('\n2) As B, list shared profiles')
resp = requests.get(f'{API_BASE}/v3/speech-profile/shared', headers=headers_b)
print('list-shared status:', resp.status_code, resp.text)

print('\n3) Revoke from A -> B')
params = {'target_uid': B_UID}
resp = requests.post(f'{API_BASE}/v3/speech-profile/revoke', headers=headers_a, params=params)
print('revoke status:', resp.status_code, resp.text)

print('\n4) As B, list shared profiles (post-revoke)')
resp = requests.get(f'{API_BASE}/v3/speech-profile/shared', headers=headers_b)
print('list-shared (after revoke) status:', resp.status_code, resp.text)

print('\nSmoke test finished.')
33 changes: 33 additions & 0 deletions backend/tests/unit/test_speech_profile_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient
from backend.routers import speech_profile as sp


client = TestClient(sp.router)


@patch('backend.routers.speech_profile.get_profile_audio_if_exists')
@patch('backend.routers.speech_profile.add_shared_person_to_user')
@patch('backend.routers.speech_profile.redis_db')
def test_share_endpoint_calls_add_and_publishes(mock_redis, mock_add_shared, mock_get_profile):
mock_get_profile.return_value = 'http://example.com/profile.wav'
mock_redis.r = MagicMock()

# Simulate auth dependency by passing uid param via query
response = client.post('/v3/speech-profile/share?target_uid=target&name=Bob&uid=caller')
assert response.status_code == 200
mock_add_shared.assert_called()
mock_redis.r.publish.assert_called()


@patch('backend.routers.speech_profile.remove_shared_person_from_user')
@patch('backend.routers.speech_profile.redis_db')
def test_revoke_endpoint_calls_remove_and_publishes(mock_redis, mock_remove):
mock_remove.return_value = True
mock_redis.r = MagicMock()
response = client.post('/v3/speech-profile/revoke?target_uid=target&uid=caller')
assert response.status_code == 200
mock_remove.assert_called()
mock_redis.r.publish.assert_called()
Loading