Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SHA256 hash #93

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## Added

- `TimestampRequest` now accepts setting the hash algorithm to `SHA256` (in addition to `SHA512`)
([93](https://github.com/trailofbits/rfc3161-client/pull/93))

## [0.1.2] - 2024-12-11

### Changed
Expand Down
50 changes: 46 additions & 4 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,21 +623,63 @@ pub(crate) fn parse_timestamp_request(
Ok(TimeStampReq { raw: raw.into() })
}

struct HashInfo<'a> {
params: cryptography_x509::common::AlgorithmParameters<'a>,
hash_fn: fn(&[u8]) -> Vec<u8>,
}

fn detect_hash_algorithm<'a>(
py: Python<'a>,
hash_algorithm: Option<pyo3::Bound<'a, pyo3::PyAny>>,
) -> PyResult<HashInfo<'a>> {
let name = if hash_algorithm.is_none() {
"SHA512".to_string()
} else {
let algorithm = hash_algorithm.unwrap();
if !algorithm.is_instance(&crate::util::HASH_ALGORITHM.get(py)?)? {
return Err(pyo3::exceptions::PyValueError::new_err(
"invalid hash algorithm",
));
}
let name_str = algorithm
.getattr(pyo3::intern!(py, "name"))?
.extract::<pyo3::pybacked::PyBackedStr>()?;
name_str.to_string()
};

match name.as_str() {
"SHA256" => Ok(HashInfo {
params: cryptography_x509::common::AlgorithmParameters::Sha256(Some(())),
hash_fn: |data| sha2::Sha256::digest(data).to_vec(),
}),
"SHA512" => Ok(HashInfo {
params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())),
hash_fn: |data| sha2::Sha512::digest(data).to_vec(),
}),
_ => Err(pyo3::exceptions::PyValueError::new_err(format!(
"unsupported hash algorithm {:?}",
name
))),
}
}

#[pyo3::pyfunction]
#[pyo3(signature = (data, nonce, cert))]
#[pyo3(signature = (data, nonce, cert, hash_algorithm=None))]
pub(crate) fn create_timestamp_request(
Comment on lines +667 to 668
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, nonblocking: it's unfortunate that the default here is None instead of HashAlgorithm.SHA512, but I don't think PyO3 would allow us to set the latter anyways.

Not a blocker.

py: pyo3::Python<'_>,
data: pyo3::Py<pyo3::types::PyBytes>,
nonce: bool,
cert: bool,
hash_algorithm: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> PyResult<TimeStampReq> {
let data_bytes = data.as_bytes(py);
let hash = sha2::Sha512::digest(data_bytes);
let hash_info = detect_hash_algorithm(py, hash_algorithm)?;

let data_bytes = data.as_bytes(py);
let hash = (hash_info.hash_fn)(data_bytes);
let message_imprint = tsp_asn1::tsp::MessageImprint {
hash_algorithm: cryptography_x509::common::AlgorithmIdentifier {
oid: asn1::DefinedByMarker::marker(),
params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())),
params: hash_info.params,
},
hashed_message: hash.as_slice(),
};
Expand Down
3 changes: 3 additions & 0 deletions rust/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ pub static NAME: LazyPyImport = LazyPyImport::new("cryptography.x509", &["Name"]
pub static DIRECTORY_NAME: LazyPyImport =
LazyPyImport::new("cryptography.x509", &["DirectoryName"]);

pub static HASH_ALGORITHM: LazyPyImport =
LazyPyImport::new("rfc3161_client.base", &["HashAlgorithm"]);

pub fn generate_random_bytes_for_asn1_biguint() -> Vec<u8> {
let mut rng = rand::thread_rng();
let nonce_random: u64 = rng.gen_range(0..u64::MAX);
Expand Down
2 changes: 2 additions & 0 deletions src/rfc3161_client/_rust/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rfc3161_client.tsp import TimeStampRequest, TimeStampResponse
from rfc3161_client.base import HashAlgorithm

class PyMessageImprint: ...

Expand All @@ -18,6 +19,7 @@ def create_timestamp_request(
data: bytes,
nonce: bool,
cert: bool,
hash_algorithm: HashAlgorithm | None = None,
) -> TimeStampRequest: ...


Expand Down
2 changes: 2 additions & 0 deletions src/rfc3161_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class HashAlgorithm(enum.Enum):
"""Hash algorithms."""

SHA256 = "SHA256"
SHA512 = "SHA512"


Expand Down Expand Up @@ -83,6 +84,7 @@ def build(self) -> TimeStampRequest:
data=self._data,
nonce=self._nonce,
cert=self._cert_req,
hash_algorithm=self._algorithm,
)


Expand Down
4 changes: 4 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import cryptography.x509

SHA256_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.1")
SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3")
DarkaMaul marked this conversation as resolved.
Show resolved Hide resolved
32 changes: 21 additions & 11 deletions test/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import cryptography.x509
import pytest
from cryptography.hazmat.primitives import hashes

from rfc3161_client.base import HashAlgorithm, TimestampRequestBuilder

SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3")
from .common import SHA256_OID, SHA512_OID


class TestRequestBuilder:
Expand All @@ -18,13 +17,6 @@ def test_succeeds(self):
assert request.nonce is not None
assert request.policy is None

message_imprint = request.message_imprint
assert message_imprint.hash_algorithm == SHA512_OID

digest = hashes.Hash(hashes.SHA512())
digest.update(message)
assert digest.finalize() == message_imprint.message

def test_data(self):
with pytest.raises(ValueError):
TimestampRequestBuilder().build()
Expand All @@ -35,15 +27,33 @@ def test_data(self):
with pytest.raises(ValueError, match="once"):
TimestampRequestBuilder().data(b"hello").data(b"world")

def test_set_algorithm(self):
def test_algorithm_sha256(self):
message = b"random-message"
request = (
TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA256).build()
)
assert request.message_imprint.hash_algorithm == SHA256_OID

digest = hashes.Hash(hashes.SHA256())
digest.update(message)
assert digest.finalize() == request.message_imprint.message

def test_algorithm_sha512(self):
message = b"random-message"
request = (
TimestampRequestBuilder().hash_algorithm(HashAlgorithm.SHA512).data(b"hello").build()
TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA512).build()
)
assert request.message_imprint.hash_algorithm == SHA512_OID

digest = hashes.Hash(hashes.SHA512())
digest.update(message)
assert digest.finalize() == request.message_imprint.message

def test_set_algorithm(self):
with pytest.raises(TypeError):
TimestampRequestBuilder().hash_algorithm("invalid hash algorihtm")

# Default hash algorithm
request = TimestampRequestBuilder().data(b"hello").build()
assert request.message_imprint.hash_algorithm == SHA512_OID

Expand Down
21 changes: 21 additions & 0 deletions test/test_rust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from rfc3161_client._rust import create_timestamp_request
from rfc3161_client.base import HashAlgorithm

from .common import SHA256_OID, SHA512_OID


def test_create_timestamp_request():
request = create_timestamp_request(
data=b"hello", nonce=True, cert=False, hash_algorithm=HashAlgorithm.SHA512
)

assert request.message_imprint.hash_algorithm == SHA512_OID

# Optional parameter
request = create_timestamp_request(data=b"hello", nonce=True, cert=True)
assert request.message_imprint.hash_algorithm == SHA512_OID

request = create_timestamp_request(
data=b"hello", nonce=True, cert=True, hash_algorithm=HashAlgorithm.SHA256
)
assert request.message_imprint.hash_algorithm == SHA256_OID
Loading