Skip to content

Commit

Permalink
Merge pull request #54 from mkkellogg/feature/optimized-sort-v1
Browse files Browse the repository at this point in the history
Optimized sort V1
  • Loading branch information
mkkellogg authored Nov 18, 2023
2 parents 10e88f8 + 9fc8ba5 commit 79d9ae9
Show file tree
Hide file tree
Showing 11 changed files with 451 additions and 196 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"type": "git",
"url": "https://github.com/mkkellogg/GaussianSplat3D"
},
"version": "1.0.0",
"version": "0.1.1",
"description": "Three.js-based 3D Gaussian splat viewer",
"main": "src/index.js",
"author": "Mark Kellogg",
Expand Down
54 changes: 27 additions & 27 deletions src/PlyParser.js
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,24 @@ export class PlyParser {

console.log('Total valid splats: ', validVertexes.length, 'out of', splatCount);

const positionsForBucketCalcs = [];
const centersForBucketCalcs = [];
for (let row = 0; row < validVertexes.length; row++) {
rawVertex = validVertexes[row];
positionsForBucketCalcs.push([rawVertex.x, rawVertex.y, rawVertex.z]);
centersForBucketCalcs.push([rawVertex.x, rawVertex.y, rawVertex.z]);
}
const buckets = this.computeBuckets(positionsForBucketCalcs);
const buckets = this.computeBuckets(centersForBucketCalcs);

const paddedSplatCount = buckets.length * SplatBufferBucketSize;
const headerSize = SplatBuffer.HeaderSizeBytes;
const header = new Uint8Array(new ArrayBuffer(headerSize));
header[3] = compressionLevel;
(new Uint32Array(header.buffer, 4, 1))[0] = paddedSplatCount;

let bytesPerPosition = SplatBuffer.CompressionLevels[compressionLevel].BytesPerPosition;
let bytesPerCenter = SplatBuffer.CompressionLevels[compressionLevel].BytesPerCenter;
let bytesPerScale = SplatBuffer.CompressionLevels[compressionLevel].BytesPerScale;
let bytesPerColor = SplatBuffer.CompressionLevels[compressionLevel].BytesPerColor;
let bytesPerRotation = SplatBuffer.CompressionLevels[compressionLevel].BytesPerRotation;
const positionBuffer = new ArrayBuffer(bytesPerPosition * paddedSplatCount);
const centerBuffer = new ArrayBuffer(bytesPerCenter * paddedSplatCount);
const scaleBuffer = new ArrayBuffer(bytesPerScale * paddedSplatCount);
const colorBuffer = new ArrayBuffer(bytesPerColor * paddedSplatCount);
const rotationBuffer = new ArrayBuffer(bytesPerRotation * paddedSplatCount);
Expand All @@ -204,7 +204,7 @@ export class PlyParser {
rawVertex = validVertexes[row];

if (compressionLevel === 0) {
const position = new Float32Array(positionBuffer, outSplatIndex * bytesPerPosition, 3);
const center = new Float32Array(centerBuffer, outSplatIndex * bytesPerCenter, 3);
const scales = new Float32Array(scaleBuffer, outSplatIndex * bytesPerScale, 3);
const rot = new Float32Array(rotationBuffer, outSplatIndex * bytesPerRotation, 4);
if (propertyTypes['scale_0']) {
Expand All @@ -216,9 +216,9 @@ export class PlyParser {
scales.set([0.01, 0.01, 0.01]);
rot.set([1.0, 0.0, 0.0, 0.0]);
}
position.set([rawVertex.x, rawVertex.y, rawVertex.z]);
center.set([rawVertex.x, rawVertex.y, rawVertex.z]);
} else {
const position = new Uint16Array(positionBuffer, outSplatIndex * bytesPerPosition, 3);
const center = new Uint16Array(centerBuffer, outSplatIndex * bytesPerCenter, 3);
const scales = new Uint16Array(scaleBuffer, outSplatIndex * bytesPerScale, 3);
const rot = new Uint16Array(rotationBuffer, outSplatIndex * bytesPerRotation, 4);
const thf = THREE.DataUtils.toHalfFloat.bind(THREE.DataUtils);
Expand All @@ -238,7 +238,7 @@ export class PlyParser {
bucketCenterDelta.y = clamp(bucketCenterDelta.y, 0, doubleCompressionScaleRange);
bucketCenterDelta.z = Math.round(bucketCenterDelta.z * compressionScaleFactor) + compressionScaleRange;
bucketCenterDelta.z = clamp(bucketCenterDelta.z, 0, doubleCompressionScaleRange);
position.set([bucketCenterDelta.x, bucketCenterDelta.y, bucketCenterDelta.z]);
center.set([bucketCenterDelta.x, bucketCenterDelta.y, bucketCenterDelta.z]);
}

const rgba = new Uint8ClampedArray(colorBuffer, outSplatIndex * bytesPerColor, 4);
Expand Down Expand Up @@ -269,7 +269,7 @@ export class PlyParser {

const bytesPerBucket = 12;
const bucketsSize = bytesPerBucket * buckets.length;
const splatDataBufferSize = positionBuffer.byteLength + scaleBuffer.byteLength +
const splatDataBufferSize = centerBuffer.byteLength + scaleBuffer.byteLength +
colorBuffer.byteLength + rotationBuffer.byteLength;

const headerArrayUint32 = new Uint32Array(header.buffer);
Expand All @@ -286,11 +286,11 @@ export class PlyParser {

const unifiedBuffer = new ArrayBuffer(unifiedBufferSize);
new Uint8Array(unifiedBuffer, 0, headerSize).set(header);
new Uint8Array(unifiedBuffer, headerSize, positionBuffer.byteLength).set(new Uint8Array(positionBuffer));
new Uint8Array(unifiedBuffer, headerSize + positionBuffer.byteLength, scaleBuffer.byteLength).set(new Uint8Array(scaleBuffer));
new Uint8Array(unifiedBuffer, headerSize + positionBuffer.byteLength + scaleBuffer.byteLength,
new Uint8Array(unifiedBuffer, headerSize, centerBuffer.byteLength).set(new Uint8Array(centerBuffer));
new Uint8Array(unifiedBuffer, headerSize + centerBuffer.byteLength, scaleBuffer.byteLength).set(new Uint8Array(scaleBuffer));
new Uint8Array(unifiedBuffer, headerSize + centerBuffer.byteLength + scaleBuffer.byteLength,
colorBuffer.byteLength).set(new Uint8Array(colorBuffer));
new Uint8Array(unifiedBuffer, headerSize + positionBuffer.byteLength + scaleBuffer.byteLength + colorBuffer.byteLength,
new Uint8Array(unifiedBuffer, headerSize + centerBuffer.byteLength + scaleBuffer.byteLength + colorBuffer.byteLength,
rotationBuffer.byteLength).set(new Uint8Array(rotationBuffer));

if (compressionLevel > 0) {
Expand All @@ -314,23 +314,23 @@ export class PlyParser {
return splatBuffer;
}

computeBuckets(positions) {
computeBuckets(centers) {
const blockSize = SplatBufferBucketBlockSize;
const halfBlockSize = blockSize / 2.0;
const splatCount = positions.length;
const splatCount = centers.length;

const min = new THREE.Vector3();
const max = new THREE.Vector3();

// ignore the first splat since it's the invalid designator
for (let i = 1; i < splatCount; i++) {
const position = positions[i];
if (i === 0 || position[0] < min.x) min.x = position[0];
if (i === 0 || position[0] > max.x) max.x = position[0];
if (i === 0 || position[1] < min.y) min.y = position[1];
if (i === 0 || position[1] > max.y) max.y = position[1];
if (i === 0 || position[2] < min.z) min.z = position[2];
if (i === 0 || position[2] > max.z) max.z = position[2];
const center = centers[i];
if (i === 0 || center[0] < min.x) min.x = center[0];
if (i === 0 || center[0] > max.x) max.x = center[0];
if (i === 0 || center[1] < min.y) min.y = center[1];
if (i === 0 || center[1] > max.y) max.y = center[1];
if (i === 0 || center[2] < min.z) min.z = center[2];
if (i === 0 || center[2] > max.z) max.z = center[2];
}

const dimensions = new THREE.Vector3().copy(max).sub(min);
Expand All @@ -343,10 +343,10 @@ export class PlyParser {

// ignore the first splat since it's the invalid designator
for (let i = 1; i < splatCount; i++) {
const position = positions[i];
const xBlock = Math.ceil((position[0] - min.x) / blockSize);
const yBlock = Math.ceil((position[1] - min.y) / blockSize);
const zBlock = Math.ceil((position[2] - min.z) / blockSize);
const center = centers[i];
const xBlock = Math.ceil((center[0] - min.x) / blockSize);
const yBlock = Math.ceil((center[1] - min.y) / blockSize);
const zBlock = Math.ceil((center[2] - min.z) / blockSize);

blockCenter.x = (xBlock - 1) * blockSize + min.x + halfBlockSize;
blockCenter.y = (yBlock - 1) * blockSize + min.y + halfBlockSize;
Expand Down
76 changes: 38 additions & 38 deletions src/SplatBuffer.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ let tbf;

export class SplatBuffer {

static PositionComponentCount = 3;
static CenterComponentCount = 3;
static ScaleComponentCount = 3;
static RotationComponentCount = 4;
static ColorComponentCount = 4;

static CompressionLevels = {
0: {
BytesPerPosition: 12,
BytesPerCenter: 12,
BytesPerScale: 12,
BytesPerColor: 4,
BytesPerRotation: 16,
ScaleRange: 1
},
1: {
BytesPerPosition: 6,
BytesPerCenter: 6,
BytesPerScale: 6,
BytesPerColor: 4,
BytesPerRotation: 8,
Expand Down Expand Up @@ -62,12 +62,12 @@ export class SplatBuffer {
this.splatBufferData = new ArrayBuffer(dataBufferSizeBytes);
new Uint8Array(this.splatBufferData).set(new Uint8Array(bufferData, SplatBuffer.HeaderSizeBytes, dataBufferSizeBytes));

this.bytesPerPosition = SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerPosition;
this.bytesPerCenter = SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerCenter;
this.bytesPerScale = SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerScale;
this.bytesPerColor = SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerColor;
this.bytesPerRotation = SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerRotation;

this.bytesPerSplat = this.bytesPerPosition + this.bytesPerScale + this.bytesPerColor + this.bytesPerRotation;
this.bytesPerSplat = this.bytesPerCenter + this.bytesPerScale + this.bytesPerColor + this.bytesPerRotation;

fbf = this.fbf.bind(this);
tbf = this.tbf.bind(this);
Expand All @@ -77,13 +77,13 @@ export class SplatBuffer {

linkBufferArrays() {
let FloatArray = (this.compressionLevel === 0) ? Float32Array : Uint16Array;
this.positionArray = new FloatArray(this.splatBufferData, 0, this.splatCount * SplatBuffer.PositionComponentCount);
this.scaleArray = new FloatArray(this.splatBufferData, this.bytesPerPosition * this.splatCount,
this.centerArray = new FloatArray(this.splatBufferData, 0, this.splatCount * SplatBuffer.CenterComponentCount);
this.scaleArray = new FloatArray(this.splatBufferData, this.bytesPerCenter * this.splatCount,
this.splatCount * SplatBuffer.ScaleComponentCount);
this.colorArray = new Uint8Array(this.splatBufferData, (this.bytesPerPosition + this.bytesPerScale) * this.splatCount,
this.colorArray = new Uint8Array(this.splatBufferData, (this.bytesPerCenter + this.bytesPerScale) * this.splatCount,
this.splatCount * SplatBuffer.ColorComponentCount);
this.rotationArray = new FloatArray(this.splatBufferData,
(this.bytesPerPosition + this.bytesPerScale + this.bytesPerColor) * this.splatCount,
(this.bytesPerCenter + this.bytesPerScale + this.bytesPerColor) * this.splatCount,
this.splatCount * SplatBuffer.RotationComponentCount);
this.bucketsBase = this.splatCount * this.bytesPerSplat;
}
Expand Down Expand Up @@ -112,41 +112,41 @@ export class SplatBuffer {
return this.splatBufferData;
}

getPosition(index, outPosition = new THREE.Vector3()) {
getCenter(index, outCenter = new THREE.Vector3()) {
let bucket = [0, 0, 0];
const positionBase = index * SplatBuffer.PositionComponentCount;
const centerBase = index * SplatBuffer.CenterComponentCount;
if (this.compressionLevel > 0) {
const sf = this.compressionScaleFactor;
const sr = this.compressionScaleRange;
const bucketIndex = Math.floor(index / this.bucketSize);
bucket = new Float32Array(this.splatBufferData, this.bucketsBase + bucketIndex * this.bytesPerBucket, 3);
outPosition.x = (this.positionArray[positionBase] - sr) * sf + bucket[0];
outPosition.y = (this.positionArray[positionBase + 1] - sr) * sf + bucket[1];
outPosition.z = (this.positionArray[positionBase + 2] - sr) * sf + bucket[2];
outCenter.x = (this.centerArray[centerBase] - sr) * sf + bucket[0];
outCenter.y = (this.centerArray[centerBase + 1] - sr) * sf + bucket[1];
outCenter.z = (this.centerArray[centerBase + 2] - sr) * sf + bucket[2];
} else {
outPosition.x = this.positionArray[positionBase];
outPosition.y = this.positionArray[positionBase + 1];
outPosition.z = this.positionArray[positionBase + 2];
outCenter.x = this.centerArray[centerBase];
outCenter.y = this.centerArray[centerBase + 1];
outCenter.z = this.centerArray[centerBase + 2];
}
return outPosition;
return outCenter;
}

setPosition(index, position) {
setCenter(index, center) {
let bucket = [0, 0, 0];
const positionBase = index * SplatBuffer.PositionComponentCount;
const centerBase = index * SplatBuffer.CenterComponentCount;
if (this.compressionLevel > 0) {
const sf = 1.0 / this.compressionScaleFactor;
const sr = this.compressionScaleRange;
const maxR = sr * 2 + 1;
const bucketIndex = Math.floor(index / this.bucketSize);
bucket = new Float32Array(this.splatBufferData, this.bucketsBase + bucketIndex * this.bytesPerBucket, 3);
this.positionArray[positionBase] = clamp(Math.round((position.x - bucket[0]) * sf) + sr, 0, maxR);
this.positionArray[positionBase + 1] = clamp(Math.round((position.y - bucket[1]) * sf) + sr, 0, maxR);
this.positionArray[positionBase + 2] = clamp(Math.round((position.z - bucket[2]) * sf) + sr, 0, maxR);
this.centerArray[centerBase] = clamp(Math.round((center.x - bucket[0]) * sf) + sr, 0, maxR);
this.centerArray[centerBase + 1] = clamp(Math.round((center.y - bucket[1]) * sf) + sr, 0, maxR);
this.centerArray[centerBase + 2] = clamp(Math.round((center.z - bucket[2]) * sf) + sr, 0, maxR);
} else {
this.positionArray[positionBase] = position.x;
this.positionArray[positionBase + 1] = position.y;
this.positionArray[positionBase + 2] = position.z;
this.centerArray[centerBase] = center.x;
this.centerArray[centerBase + 1] = center.y;
this.centerArray[centerBase + 2] = center.z;
}
}

Expand Down Expand Up @@ -232,23 +232,23 @@ export class SplatBuffer {
}
}

fillPositionArray(outPositionArray) {
fillCenterArray(outCenterArray) {
const splatCount = this.splatCount;
let bucket = [0, 0, 0];
for (let i = 0; i < splatCount; i++) {
const positionBase = i * SplatBuffer.PositionComponentCount;
const centerBase = i * SplatBuffer.CenterComponentCount;
if (this.compressionLevel > 0) {
const bucketIndex = Math.floor(i / this.bucketSize);
bucket = new Float32Array(this.splatBufferData, this.bucketsBase + bucketIndex * this.bytesPerBucket, 3);
const sf = this.compressionScaleFactor;
const sr = this.compressionScaleRange;
outPositionArray[positionBase] = (this.positionArray[positionBase] - sr) * sf + bucket[0];
outPositionArray[positionBase + 1] = (this.positionArray[positionBase + 1] - sr) * sf + bucket[1];
outPositionArray[positionBase + 2] = (this.positionArray[positionBase + 2] - sr) * sf + bucket[2];
outCenterArray[centerBase] = (this.centerArray[centerBase] - sr) * sf + bucket[0];
outCenterArray[centerBase + 1] = (this.centerArray[centerBase + 1] - sr) * sf + bucket[1];
outCenterArray[centerBase + 2] = (this.centerArray[centerBase + 2] - sr) * sf + bucket[2];
} else {
outPositionArray[positionBase] = this.positionArray[positionBase];
outPositionArray[positionBase + 1] = this.positionArray[positionBase + 1];
outPositionArray[positionBase + 2] = this.positionArray[positionBase + 2];
outCenterArray[centerBase] = this.centerArray[centerBase];
outCenterArray[centerBase + 1] = this.centerArray[centerBase + 1];
outCenterArray[centerBase + 2] = this.centerArray[centerBase + 2];
}
}
}
Expand Down Expand Up @@ -289,10 +289,10 @@ export class SplatBuffer {

swapVertices(indexA, indexB) {

this.getPosition(indexA, tempVector3A);
this.getPosition(indexB, tempVector3B);
this.setPosition(indexB, tempVector3A);
this.setPosition(indexA, tempVector3B);
this.getCenter(indexA, tempVector3A);
this.getCenter(indexB, tempVector3B);
this.setCenter(indexB, tempVector3A);
this.setCenter(indexA, tempVector3B);

this.getScale(indexA, tempVector3A);
this.getScale(indexB, tempVector3B);
Expand Down
Loading

0 comments on commit 79d9ae9

Please sign in to comment.