Skip to content

Commit bd495d4

Browse files
authored
Merge pull request #876 from padix-key/rust-alphabet
Port alphabet codec to Rust
2 parents 40da2cf + c0e31af commit bd495d4

File tree

6 files changed

+200
-161
lines changed

6 files changed

+200
-161
lines changed

src/biotite/sequence/alphabet.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import string
1616
from numbers import Integral
1717
import numpy as np
18-
from biotite.sequence.codec import decode_to_chars, encode_chars, map_sequence_code
18+
from biotite.rust.sequence import AlphabetCodec
1919

2020

2121
class Alphabet(object):
@@ -319,6 +319,7 @@ def __init__(self, symbols):
319319
self._symbols = np.frombuffer(
320320
np.array(self._symbols, dtype="|S1"), dtype=np.ubyte
321321
)
322+
self._codec = AlphabetCodec(self._symbols)
322323

323324
def __repr__(self):
324325
"""Represent LetterAlphabet as a string for debugging."""
@@ -378,7 +379,7 @@ def encode_multiple(self, symbols, dtype=None):
378379
symbols = np.frombuffer(
379380
np.array(list(symbols), dtype="|S1"), dtype=np.ubyte
380381
)
381-
return encode_chars(alphabet=self._symbols, symbols=symbols)
382+
return self._codec.encode(symbols)
382383

383384
def decode_multiple(self, code, as_bytes=False):
384385
"""
@@ -403,7 +404,7 @@ def decode_multiple(self, code, as_bytes=False):
403404
if not isinstance(code, np.ndarray):
404405
code = np.array(code, dtype=np.uint8)
405406
code = code.astype(np.uint8, copy=False)
406-
symbols = decode_to_chars(alphabet=self._symbols, code=code)
407+
symbols = self._codec.decode(code)
407408
# Symbols must be convverted from 'np.ubyte' to '|S1'
408409
symbols = np.frombuffer(symbols, dtype="|S1")
409410
if not as_bytes:
@@ -497,9 +498,7 @@ def __getitem__(self, code):
497498
):
498499
code = np.array(code, dtype=np.uint64)
499500
if self._necessary_mapping:
500-
mapped_code = np.empty(len(code), dtype=self._mapper.dtype)
501-
map_sequence_code(self._mapper, code, mapped_code)
502-
return mapped_code
501+
return self._mapper[code]
503502
else:
504503
return code
505504

src/biotite/sequence/codec.pyx

Lines changed: 0 additions & 155 deletions
This file was deleted.

src/rust/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use pyo3::prelude::*;
22
use pyo3::types::PyDict;
33

4+
mod sequence;
45
mod structure;
56
pub mod util;
67

@@ -21,6 +22,7 @@ fn add_subpackage(
2122

2223
#[pymodule]
2324
fn rust(module: &Bound<'_, PyModule>) -> PyResult<()> {
25+
add_subpackage(module, &sequence::module(module)?, "biotite.rust.sequence")?;
2426
add_subpackage(
2527
module,
2628
&structure::module(module)?,

src/rust/sequence/codec.rs

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
use numpy::ndarray::Array1;
2+
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1};
3+
use pyo3::prelude::*;
4+
use pyo3::types::PyTuple;
5+
6+
/// A codec for encoding/decoding between ASCII symbols and integer codes,
7+
/// based on a given alphabet of allowed symbols.
8+
#[pyclass(module = "biotite.rust.sequence")]
9+
pub struct AlphabetCodec {
10+
/// Maps ASCII byte value -> symbol code. `illegal_code` for unmapped symbols.
11+
symbol_to_code: [u8; 256],
12+
/// The alphabet symbols (as bytes), indexed by code.
13+
code_to_symbol: Vec<u8>,
14+
/// The sentinel code that marks an illegal/unmapped symbol.
15+
illegal_code: u8,
16+
}
17+
18+
#[pymethods]
19+
impl AlphabetCodec {
20+
/// Create a new codec from the alphabet's symbols.
21+
///
22+
/// This class can only be used if the symbols are ASCII characters.
23+
///
24+
/// Parameters
25+
/// ----------
26+
/// symbols : ndarray, dtype=uint8
27+
/// The ASCII characters as bytes.
28+
/// The index of each symbol becomes its code.
29+
#[new]
30+
fn new(symbols: PyReadonlyArray1<u8>) -> PyResult<Self> {
31+
let symbols = symbols.as_slice()?;
32+
if symbols.len() > 255 {
33+
return Err(pyo3::exceptions::PyValueError::new_err(
34+
"Alphabet must have at most 255 symbols",
35+
));
36+
}
37+
// As the symbol code `symbols.len()` is always illegal,
38+
// it can be used later to check for invalid input symbols
39+
let illegal_code = symbols.len() as u8;
40+
// An array based map that maps from symbol to code
41+
// Since the maximum value of a char is 256
42+
// the size of the map is known at compile time
43+
// Initially fill the map with the illegal symbol
44+
// Consequently, the map will later return the illegal symbol
45+
// when indexed with a character that is not part of the alphabet
46+
let mut symbol_to_code = [illegal_code; 256];
47+
for (i, &symbol) in symbols.iter().enumerate() {
48+
symbol_to_code[symbol as usize] = i as u8;
49+
}
50+
Ok(AlphabetCodec {
51+
symbol_to_code,
52+
code_to_symbol: symbols.to_vec(),
53+
illegal_code,
54+
})
55+
}
56+
57+
/// Encode an array of ASCII symbols into an array of symbol codes.
58+
///
59+
/// Parameters
60+
/// ----------
61+
/// symbols : ndarray, dtype=uint8
62+
/// The symbols (as bytes) to encode.
63+
///
64+
/// Returns
65+
/// -------
66+
/// code : ndarray, dtype=uint8
67+
/// The encoded symbol codes.
68+
///
69+
/// Raises
70+
/// ------
71+
/// AlphabetError
72+
/// If any symbol is not in the alphabet.
73+
fn encode<'py>(
74+
&self,
75+
py: Python<'py>,
76+
symbols: PyReadonlyArray1<u8>,
77+
) -> PyResult<Bound<'py, PyArray1<u8>>> {
78+
let symbols = symbols.as_slice()?;
79+
let mut code = Array1::<u8>::uninit(symbols.len());
80+
81+
for (&sym, out_code) in symbols.iter().zip(
82+
code.as_slice_mut()
83+
.expect("Array not contiguous")
84+
.iter_mut(),
85+
) {
86+
let c = self.symbol_to_code[sym as usize];
87+
if c == self.illegal_code {
88+
let alphabet_error = py
89+
.import("biotite.sequence.alphabet")?
90+
.getattr("AlphabetError")?;
91+
return Err(PyErr::from_value(alphabet_error.call1((format!(
92+
"Symbol {} is not in the alphabet",
93+
repr_char(sym)
94+
),))?));
95+
}
96+
out_code.write(c);
97+
}
98+
// SAFETY: All n elements have been written above
99+
let code = unsafe { code.assume_init() };
100+
Ok(code.into_pyarray(py))
101+
}
102+
103+
/// Decode an array of symbol codes into an array of ASCII symbols.
104+
///
105+
/// Parameters
106+
/// ----------
107+
/// code : ndarray, dtype=uint8
108+
/// The symbol codes to decode.
109+
///
110+
/// Returns
111+
/// -------
112+
/// symbols : ndarray, dtype=uint8
113+
/// The decoded symbols as bytes.
114+
///
115+
/// Raises
116+
/// ------
117+
/// AlphabetError
118+
/// If any code is not valid in the alphabet.
119+
fn decode<'py>(
120+
&self,
121+
py: Python<'py>,
122+
code: PyReadonlyArray1<u8>,
123+
) -> PyResult<Bound<'py, PyArray1<u8>>> {
124+
let code = code.as_slice()?;
125+
let mut symbols = Array1::<u8>::uninit(code.len());
126+
127+
for (&c, out_symbol) in code.iter().zip(
128+
symbols
129+
.as_slice_mut()
130+
.expect("Array not contiguous")
131+
.iter_mut(),
132+
) {
133+
if (c as usize) >= self.code_to_symbol.len() {
134+
let alphabet_error = py
135+
.import("biotite.sequence.alphabet")?
136+
.getattr("AlphabetError")?;
137+
return Err(PyErr::from_value(
138+
alphabet_error.call1((format!("'{}' is not a valid code", c),))?,
139+
));
140+
}
141+
out_symbol.write(self.code_to_symbol[c as usize]);
142+
}
143+
// SAFETY: All n elements have been written above
144+
let symbols = unsafe { symbols.assume_init() };
145+
Ok(symbols.into_pyarray(py))
146+
}
147+
148+
fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
149+
let cls = py
150+
.import("biotite.rust.sequence")?
151+
.getattr("AlphabetCodec")?;
152+
let code_to_symbol = self.code_to_symbol.clone().into_pyarray(py);
153+
let args = PyTuple::new(py, [code_to_symbol.into_any()])?;
154+
PyTuple::new(py, [cls.unbind(), args.into_any().unbind()])
155+
}
156+
}
157+
158+
/// Format a byte value as a Python-style repr for error messages.
159+
fn repr_char(byte: u8) -> String {
160+
let c = byte as char;
161+
if c.is_ascii_graphic() || c == ' ' {
162+
format!("'{}'", c)
163+
} else {
164+
format!("'\\x{:02x}'", byte)
165+
}
166+
}

src/rust/sequence/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use pyo3::prelude::*;
2+
3+
pub mod codec;
4+
5+
pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyModule>> {
6+
let module = PyModule::new(parent_module.py(), "sequence")?;
7+
module.add_class::<codec::AlphabetCodec>()?;
8+
Ok(module)
9+
}

0 commit comments

Comments
 (0)