-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RandomBufferGenerator protocol for WASI random_get
This protocol allows to inject a custom random number generator for `wasi_snapshot_preview1::random_get` function. The default implementation continues to use `swift_stdlib_random` function but users can provide their own implementation. Additionally, types conforming to `RandomNumberGenerator` can automatically conform to the new protocol. This is useful when we want fully deterministic behavior like build tools or tests.
- Loading branch information
1 parent
da13542
commit 99252ae
Showing
3 changed files
with
88 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import SwiftShims // For swift_stdlib_random | ||
|
||
/// A type that provides random bytes. | ||
/// | ||
/// This type is similar to `RandomNumberGenerator` in Swift standard library, | ||
/// but it provides a way to fill a buffer with random bytes instead of a single | ||
/// random number. | ||
public protocol RandomBufferGenerator { | ||
|
||
/// Fills the buffer with random bytes. | ||
/// | ||
/// - Parameter buffer: The destination buffer to fill with random bytes. | ||
mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>) | ||
} | ||
|
||
extension RandomBufferGenerator where Self: RandomNumberGenerator { | ||
public mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>) { | ||
// The buffer is filled with 8 bytes at once. | ||
let count = buffer.count / 8 | ||
for i in 0..<count { | ||
let random = self.next() | ||
withUnsafeBytes(of: random) { randomBytes in | ||
let startOffset = i * 8 | ||
let destination = UnsafeMutableBufferPointer(rebasing: buffer[startOffset..<(startOffset + 8)]) | ||
randomBytes.copyBytes(to: destination) | ||
} | ||
} | ||
|
||
// If the buffer size is not a multiple of 8, fill the remaining bytes. | ||
let remaining = buffer.count % 8 | ||
if remaining > 0 { | ||
let random = self.next() | ||
withUnsafeBytes(of: random) { randomBytes in | ||
let startOffset = count * 8 | ||
let destination = UnsafeMutableBufferPointer(rebasing: buffer[startOffset..<(startOffset + remaining)]) | ||
randomBytes.copyBytes(to: destination) | ||
} | ||
} | ||
} | ||
} | ||
|
||
extension SystemRandomNumberGenerator: RandomBufferGenerator { | ||
public mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>) { | ||
guard let baseAddress = buffer.baseAddress else { return } | ||
// Directly call underlying C function of SystemRandomNumberGenerator | ||
swift_stdlib_random(baseAddress, Int(buffer.count)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import XCTest | ||
|
||
@testable import WASI | ||
|
||
final class RandomBufferGeneratorTests: XCTestCase { | ||
struct DeterministicGenerator: RandomNumberGenerator, RandomBufferGenerator { | ||
var items: [UInt64] | ||
|
||
mutating func next() -> UInt64 { | ||
items.removeFirst() | ||
} | ||
} | ||
func testDefaultFill() { | ||
var generator = DeterministicGenerator(items: [ | ||
0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefbaddcafe | ||
]) | ||
for (bufferSize, expectedBytes): (Int, [UInt8]) in [ | ||
(10, [0xef, 0xcd, 0xab, 0x89, 0x67, 0x45, 0x23, 0x01, 0x10, 0x32]), | ||
(2, [0xfe, 0xca]), | ||
(0, []) | ||
] { | ||
var buffer: [UInt8] = Array(repeating: 0, count: bufferSize) | ||
buffer.withUnsafeMutableBufferPointer { | ||
generator.fill(buffer: $0) | ||
} | ||
let expected: [UInt8] | ||
#if _endian(little) | ||
expected = expectedBytes | ||
#else | ||
expected = Array(expectedBytes.reversed()) | ||
#endif | ||
XCTAssertEqual(buffer, expected) | ||
} | ||
} | ||
} |