Skip to content

Commit

Permalink
separated buffer logic
Browse files Browse the repository at this point in the history
  • Loading branch information
SkBlaz committed Apr 8, 2024
1 parent b24b617 commit 445b6a0
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 277 deletions.
53 changes: 28 additions & 25 deletions src/block_ffm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::model_instance;
use crate::optimizer;
use crate::port_buffer;
use crate::port_buffer::PortBuffer;
use crate::regressor;
use crate::quantization;
use crate::regressor;
use crate::regressor::{BlockCache, FFM_CONTRA_BUF_LEN};

const FFM_STACK_BUF_LEN: usize = 170393;
Expand Down Expand Up @@ -458,8 +458,11 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
contra_fields,
features_present,
ffm,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockFFMCache, executing forward pass without cache");
} = next_cache
else {
log::warn!(
"Unable to downcast cache to BlockFFMCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -667,15 +670,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockFFMCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockFFMCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::FFM {
contra_fields,
features_present,
ffm,
} = next_cache else {
} = next_cache
else {
log::warn!("Unable to downcast cache to BlockFFMCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -829,32 +835,29 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {

let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
if use_quantization {
let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
}
}
block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter, false)?;
Ok(())
}

fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
}
}

block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader, false)?;
Ok(())
}
Expand All @@ -877,18 +880,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
.downcast_mut::<BlockFFM<optimizer::OptimizerSGD>>()
.unwrap();

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?;
}
}
block_helpers::skip_weights_from_buf::<OptimizerData<L>>(
self.ffm_weights_len as usize,
input_bufreader,
Expand Down
4 changes: 2 additions & 2 deletions src/block_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro_rules! assert_epsilon {
pub fn read_weights_from_buf<L>(
weights: &mut Vec<L>,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
return Err("Loading weights to unallocated weighs buffer".to_string())?;
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn skip_weights_from_buf<L>(
pub fn write_weights_to_buf<L>(
weights: &Vec<L>,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
assert!(false);
Expand Down
24 changes: 11 additions & 13 deletions src/block_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
return;
};

let BlockCache::LR {
lr,
combo_indexes,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, executing forward pass without cache");
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!(
"Unable to downcast cache to BlockLRCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -222,14 +221,13 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockLRCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockLRCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::LR {
lr,
combo_indexes
} = next_cache else {
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -263,15 +261,15 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)
}

fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)
}
Expand All @@ -280,7 +278,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
6 changes: 3 additions & 3 deletions src/block_neural.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
block_helpers::write_weights_to_buf(&self.weights_optimizer, output_bufwriter, false)?;
Expand All @@ -440,7 +440,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
block_helpers::read_weights_from_buf(&mut self.weights_optimizer, input_bufreader, false)?;
Expand All @@ -466,7 +466,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
1 change: 0 additions & 1 deletion src/feature_transform_implementations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::feature_transform_executor::{
use crate::feature_transform_parser;
use crate::vwmap::{NamespaceDescriptor, NamespaceFormat, NamespaceType};


// -------------------------------------------------------------------
// TransformerBinner - A basic binner
// It can take any function as a binning function f32 -> f32. Then output is rounded to integer
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub mod quantization;
pub mod block_ffm;
pub mod block_helpers;
pub mod block_loss_functions;
Expand All @@ -7,6 +6,7 @@ pub mod block_misc;
pub mod block_neural;
pub mod block_normalize;
pub mod block_relu;
pub mod buffer_handler;
pub mod cache;
pub mod cmdline;
pub mod feature_buffer;
Expand All @@ -22,15 +22,16 @@ pub mod optimizer;
pub mod parser;
pub mod persistence;
pub mod port_buffer;
pub mod quantization;
pub mod radix_tree;
pub mod regressor;
pub mod serving;
pub mod version;
pub mod vwmap;

extern crate blas;
extern crate intel_mkl_src;
extern crate half;
extern crate intel_mkl_src;

use crate::feature_buffer::FeatureBufferTranslator;
use crate::multithread_helpers::BoxedRegressorTrait;
Expand Down
Loading

0 comments on commit 445b6a0

Please sign in to comment.