diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..4a9e747 Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index 2ff9510..1fea0be 100644 Binary files a/README.md and b/README.md differ diff --git a/__pycache__/api.cpython-313.pyc b/__pycache__/api.cpython-313.pyc new file mode 100644 index 0000000..7008c9d Binary files /dev/null and b/__pycache__/api.cpython-313.pyc differ diff --git a/api.py b/api.py new file mode 100644 index 0000000..67d2377 --- /dev/null +++ b/api.py @@ -0,0 +1 @@ +API_KEY_ASSEMBLY = "24fd187ded9f4d13a249086d89670ce0" \ No newline at end of file diff --git a/asr.py b/asr.py new file mode 100644 index 0000000..eb2e3e6 --- /dev/null +++ b/asr.py @@ -0,0 +1,165 @@ +import pyaudio +import websockets +import asyncio +import base64 +import json +from api import API_KEY_ASSEMBLY +from custom_interfaces.srv import GetTranscript + +import rclpy +from rclpy.node import Node + + +FRAMES_PER_BUFFER = 3200 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +RATE = 16000 + + +class ASRService(Node): + + def __init__(self): + super().__init__('asr_service') + self.srv = self.create_service(GetTranscript, 'get_transcript', self.asr_callback) + self.get_logger().info('ASR Service initialized') + + def asr_callback(self, request, response): + """Service callback to perform ASR""" + self.get_logger().info('Incoming request: duration=%d seconds' % request.duration) + + try: + transcript = asyncio.run(self.run_asr(request.duration)) + response.transcript = transcript + response.success = True + except Exception as e: + self.get_logger().error(f'ASR Error: {str(e)}') + response.transcript = "" + response.success = False + + return response + + async def run_asr(self, duration_seconds): + """Run ASR for specified duration or until silence""" + p = pyaudio.PyAudio() + + device_index = None + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if info.get('maxInputChannels') > 0: + device_index = i + self.get_logger().info(f"Using input device {i}: {info.get('name')}") + break + + if device_index is None: + raise OSError("No microphone input device found!") + + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + input_device_index=device_index, + frames_per_buffer=FRAMES_PER_BUFFER + ) + + URL = "wss://api.assemblyai.com/v2/realtime/ws?sample_rate=16000" + + try: + async with websockets.connect( + URL, + ping_timeout=20, + ping_interval=5, + extra_headers={"Authorization": API_KEY_ASSEMBLY} + ) as _ws: + await asyncio.sleep(0.1) + session_begins = await _ws.recv() + self.get_logger().info('ASR session started') + + transcripts = [] + stop_event = asyncio.Event() + last_transcript_time = asyncio.get_event_loop().time() + silence_timeout = 3 + start_time = asyncio.get_event_loop().time() + + async def send(): + while not stop_event.is_set(): + try: + data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + data = base64.b64encode(data).decode("utf-8") + json_data = json.dumps({"audio_data": data}) + await _ws.send(json_data) + except websockets.exceptions.ConnectionClosedError as e: + self.get_logger().error(f'Connection error: {e}') + break + except Exception as e: + self.get_logger().error(f'Send error: {e}') + break + await asyncio.sleep(0.01) + + async def receive(): + nonlocal last_transcript_time + while not stop_event.is_set(): + try: + result_str = await _ws.recv() + result = json.loads(result_str) + prompt = result.get("text") + if prompt and result.get("message_type") == "FinalTranscript": + self.get_logger().info(f'Transcript: {prompt}') + transcripts.append(prompt) + last_transcript_time = asyncio.get_event_loop().time() + except websockets.exceptions.ConnectionClosedError as e: + self.get_logger().error(f'Connection error: {e}') + break + except Exception as e: + self.get_logger().error(f'Receive error: {e}') + break + + async def check_stop_conditions(): + while not stop_event.is_set(): + current_time = asyncio.get_event_loop().time() + + # Check if duration exceeded + if current_time - start_time > duration_seconds: + self.get_logger().info('Duration limit reached') + stop_event.set() + break + + # Check if silence timeout + if current_time - last_transcript_time > silence_timeout: + self.get_logger().info('Silence detected') + stop_event.set() + break + + await asyncio.sleep(0.5) + + try: + send_task = asyncio.create_task(send()) + receive_task = asyncio.create_task(receive()) + stop_task = asyncio.create_task(check_stop_conditions()) + + await asyncio.gather(send_task, receive_task, stop_task) + except asyncio.CancelledError: + pass + finally: + stop_event.set() + + finally: + stream.stop_stream() + stream.close() + p.terminate() + + return " ".join(transcripts) + + +def main(): + rclpy.init() + + asr_service = ASRService() + + rclpy.spin(asr_service) + + rclpy.shutdown() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/asr_node.py b/asr_node.py new file mode 100644 index 0000000..42b64fa --- /dev/null +++ b/asr_node.py @@ -0,0 +1,161 @@ +import pyaudio +import websockets +import asyncio +import base64 +import json +from api import API_KEY_ASSEMBLY + +import rclpy +from rclpy.node import Node +from std_msgs.msg import String, Empty +import threading + + +FRAMES_PER_BUFFER = 3200 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +RATE = 16000 + + +class ASRPublisher(Node): + + def __init__(self): + super().__init__('asr_publisher') + + # Create publisher for transcripts + self.transcript_publisher = self.create_publisher(String, 'asr_node', 10) + + # Create publisher for emergency stop + self.emergency_publisher = self.create_publisher(Empty, 'emergency', 10) + + # Parameters for emergency stop keywords + self.declare_parameter('emergency_keywords', ['stop', 'halt', 'emergency']) + self.emergency_keywords = self.get_parameter('emergency_keywords').value + + self.get_logger().info('ASR Publisher Node initialized') + self.get_logger().info(f'Publishing to topic: voice_command') + self.get_logger().info(f'Emergency stop topic: emergency') + self.get_logger().info(f'Emergency keywords: {self.emergency_keywords}') + + # Start ASR in separate thread + self.asr_thread = threading.Thread(target=self.run_asr_thread, daemon=True) + self.asr_thread.start() + + def run_asr_thread(self): + """Run ASR in a separate thread""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.continuous_asr()) + except Exception as e: + self.get_logger().error(f'ASR thread error: {str(e)}') + finally: + loop.close() + + async def continuous_asr(self): + """Continuously run ASR and publish transcripts""" + p = pyaudio.PyAudio() + + # Find input device + device_index = None + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if info.get('maxInputChannels') > 0: + device_index = i + self.get_logger().info(f"Using input device {i}: {info.get('name')}") + break + + if device_index is None: + raise OSError("No microphone input device found!") + + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + input_device_index=device_index, + frames_per_buffer=FRAMES_PER_BUFFER + ) + + URL = "wss://api.assemblyai.com/v2/realtime/ws?sample_rate=16000" + + while rclpy.ok(): + try: + async with websockets.connect( + URL, + ping_timeout=20, + ping_interval=5, + extra_headers={"Authorization": API_KEY_ASSEMBLY} + ) as _ws: + await asyncio.sleep(0.1) + session_begins = await _ws.recv() + self.get_logger().info('ASR session started, listening...') + + stop_event = asyncio.Event() + + async def send(): + while not stop_event.is_set() and rclpy.ok(): + try: + data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + data = base64.b64encode(data).decode("utf-8") + json_data = json.dumps({"audio_data": data}) + await _ws.send(json_data) + except Exception as e: + self.get_logger().error(f'Send error: {e}') + break + await asyncio.sleep(0.01) + + async def receive(): + while not stop_event.is_set() and rclpy.ok(): + try: + result_str = await _ws.recv() + result = json.loads(result_str) + prompt = result.get("text") + if prompt and result.get("message_type") == "FinalTranscript": + self.get_logger().info(f'Transcript: {prompt}') + + # Publish immediately + msg = String() + msg.data = prompt + self.transcript_publisher.publish(msg) + self.get_logger().info(f'Published: "{prompt}"') + except Exception as e: + self.get_logger().error(f'Receive error: {e}') + break + + try: + send_task = asyncio.create_task(send()) + receive_task = asyncio.create_task(receive()) + + await asyncio.gather(send_task, receive_task) + except asyncio.CancelledError: + pass + finally: + stop_event.set() + + except Exception as e: + self.get_logger().error(f'Connection error: {e}') + self.get_logger().info('Reconnecting in 3 seconds...') + await asyncio.sleep(3) + + stream.stop_stream() + stream.close() + p.terminate() + + +def main(): + rclpy.init() + + asr_publisher = ASRPublisher() + + try: + rclpy.spin(asr_publisher) + except KeyboardInterrupt: + pass + finally: + asr_publisher.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/harvard.wav b/harvard.wav new file mode 100644 index 0000000..b05ec79 Binary files /dev/null and b/harvard.wav differ diff --git a/harvard.wav.zip b/harvard.wav.zip new file mode 100644 index 0000000..70c4b43 Binary files /dev/null and b/harvard.wav.zip differ diff --git a/vcs.py b/vcs.py new file mode 100644 index 0000000..d1e40d1 --- /dev/null +++ b/vcs.py @@ -0,0 +1,54 @@ +import requests +import sys +from api import API_KEY_ASSEMBLY + +#upload +upload_endpoint = "https://api.assemblyai.com/v2/upload" +transcript_endpoint = "https://api.assemblyai.com/v2/transcript" +headers = {'authorization': API_KEY_ASSEMBLY} +filename = sys.argv[1] + +def upload(filename): + def read_file(filename, chunk_size=5242880): + with open(filename, 'rb') as _file: + while True: + data = _file.read(chunk_size) + if not data: + break + yield data + + upload_response = requests.post(upload_endpoint, + headers=headers, + data=read_file(filename)) + + audio_url = upload_response.json()['upload_url'] + return audio_url + +# transribe +def transcribe(audio_url): + transcript_request = {"audio_url": audio_url} + transcript_response = requests.post(transcript_endpoint, json=transcript_request, headers=headers) + job_id = transcript_response.json()['id'] + return job_id + +# poll +def poll(transcript_id): + polling_endpoint = transcript_endpoint + '/' + transcript_id + polling_response = requests.get(polling_endpoint, headers=headers) + return polling_response.json() + +def get_transcription_result_url(audio_url): + transcript_id = transcribe(audio_url) + while True: + data = poll(transcript_id) + if data['status'] == 'completed': + return data, None + elif data['status'] == 'completed': + return data, data['error'] + +audio_url = upload(filename) +data, error = get_transcription_result_url(audio_url) + +print(data['text']) + +# save transcript