-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoding_2bit.py
More file actions
106 lines (89 loc) · 3.34 KB
/
encoding_2bit.py
File metadata and controls
106 lines (89 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
from dataclasses import dataclass
from code_mapping import DECODE_MAPPING, ENCODE_MAPPING, ENCODING_TO_BASES, Encoding
from generic_encoding import (
DecodingError,
EncodedQuality,
EncodedSequence,
EncodingError,
bits_to_bytes,
bytes_to_bits,
)
ENCODING = Encoding.BIT2_ATCG
TAG_BIT2 = "01"
HEADER_PAD_LENGTH = 2
HEADER_LENGTH = len(TAG_BIT2) + HEADER_PAD_LENGTH
##
## From LEFT to RIGHT
## First bit (leftmost): KETO (KETO=1; AMINO=0)
## KETO=GT / AMINO=AC
##
## Second bit: PYRIMIDINE (PYRIMIDINE=1; PURINE=0)
## PYRIMIDINE=CT / PURINE=AG
##
def encode_2bit_sequence(sequence: str) -> bytes:
"""
Layout (BIT2):
[2b TAG=01][2b HEADER_PAD][0..6 pad bits of 0][2-bit symbols...]
HEADER_PAD encodes the number (0..6) of zero bits added to the header
to ensure proper byte alignment.
"""
sequence = sequence.upper().replace("\n", "").replace("\r", "")
if invalid_bases := set(sequence).difference(ENCODING_TO_BASES[ENCODING]):
raise EncodingError(
f"Unsupported symbols in sequence ({sorted(invalid_bases)})",
encoding=ENCODING.value,
)
mapping = ENCODE_MAPPING[ENCODING]
data_bits = "".join(mapping[base] for base in sequence) # 2-bit symbols
# Compute how many *2-bit pairs* needed to reach next byte boundary
length_before_padding = HEADER_LENGTH + len(data_bits)
remainders = length_before_padding % 8
pad_bits = (8 - remainders) if remainders else 0
header_pad = format(pad_bits // 2, "02b")
header = TAG_BIT2 + header_pad
bitstring = header + ("0" * pad_bits) + data_bits
return bits_to_bytes(bitstring)
def decode_2bit_sequence(encoded_bytes: bytes) -> str:
bits = bytes_to_bits(encoded_bytes)
if bits[:2] != TAG_BIT2:
raise DecodingError(
f"Wrong tag in header (found {bits[:2]}, expected {TAG_BIT2})",
encoding=ENCODING.value,
)
pad_length = int(bits[2:4], 2) * 2
if bits[HEADER_LENGTH : HEADER_LENGTH + pad_length].strip("0"):
raise DecodingError(
(
"Non-zero padding bits found in header"
f" (expected '{pad_length * '0'}', found"
f" '{bits[HEADER_LENGTH : HEADER_LENGTH + pad_length]}')."
),
encoding=ENCODING.value,
)
mapping = DECODE_MAPPING[ENCODING]
if len(seq_bits := bits[4 + pad_length :]) % 2 != 0:
raise DecodingError(
(
"bitstring length after header is not divisible by 2"
f" (found length {len(seq_bits)} % 2 = {len(seq_bits) % 2})."
),
encoding=ENCODING.value,
)
return "".join(
mapping[bits[i : i + 2]] for i in range(4 + pad_length, len(bits), 2)
)
@dataclass
class Encoded2bitSequence(EncodedSequence):
"""
Represents a DNA sequence with its encoding, quality scores, and header information.
"""
encoded_sequence: bytes # The encoded sequence
encoded_quality: EncodedQuality | None = None # Quality scores as bytes (optional)
header: str | None = None # Header information (optional)
@staticmethod
def encode_sequence(sequence: str) -> bytes:
return encode_2bit_sequence(sequence)
@staticmethod
def decode_sequence(encoded_sequence: bytes) -> str:
return decode_2bit_sequence(encoded_sequence)