Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 78 additions & 49 deletions crates/cervo-cli/src/commands/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use anyhow::{bail, Result};
use cervo::asset::AssetData;
use cervo::core::prelude::{Batcher, Inferer, InfererExt, State};
use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker};
use cervo::core::epsilon::EpsilonInjectorWrapper;
use cervo::core::model::{BaseCase, Model, ModelWrapper};
use cervo::core::prelude::{Batcher, Inferer, State};
use cervo::core::recurrent::{RecurrentInfo, RecurrentTrackerWrapper};
use clap::Parser;
use clap::ValueEnum;
use serde::Serialize;
Expand Down Expand Up @@ -222,85 +224,112 @@ pub fn build_inputs_from_desc(
.collect()
}

fn do_run(mut inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
let shapes = inferer.input_shapes().to_vec();
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
for id in 0..batch_size {
inferer.begin_agent(id as u64);
}
let res = execute_load_metrics(batch_size, observations, config.count, &mut inferer)?;
for id in 0..batch_size {
inferer.end_agent(id as u64);
fn do_run(
wrapper: impl ModelWrapper,
inferer: impl Inferer + 'static,
config: &Args,
) -> Result<Vec<Record>> {
let mut model = Model::new(wrapper, Box::new(inferer) as Box<dyn Inferer>);

let mut records = Vec::with_capacity(config.batch_sizes.len());
for batch_size in config.batch_sizes.clone() {
let mut reader = File::open(&config.file)?;
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
} else {
match config.file.extension().and_then(|ext| ext.to_str()) {
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
Some(other) => bail!("unknown file type {:?}", other),
None => bail!("missing file extension {:?}", config.file),
}
};

model = model
.with_new_policy(Box::new(inferer) as Box<dyn Inferer>)
.map_err(|(_, e)| e)?;

let shapes = model.input_shapes().to_vec();
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
for id in 0..batch_size {
model.begin_agent(id as u64);
}
let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?;

// Print Text
if matches!(config.output, OutputFormat::Text) {
println!(
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
res.batch_size, res.mean, res.stddev, res.total,
);
}

records.push(res);
for id in 0..batch_size {
model.end_agent(id as u64);
}
}

Ok(res)
Ok(records)
}

fn run_apply_epsilon_config(
inferer: impl Inferer,
batch_size: usize,
wrapper: impl ModelWrapper,
inferer: impl Inferer + 'static,
config: &Args,
) -> Result<Record> {
) -> Result<Vec<Record>> {
if let Some(epsilon) = config.with_epsilon.as_ref() {
let inferer = inferer.with_default_epsilon(epsilon)?;
do_run(inferer, batch_size, config)
let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?;
do_run(wrapper, inferer, config)
} else {
do_run(inferer, batch_size, config)
do_run(wrapper, inferer, config)
}
}

fn run_apply_recurrent(inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
fn run_apply_recurrent(
wrapper: impl ModelWrapper,
inferer: impl Inferer + 'static,
config: &Args,
) -> Result<Vec<Record>> {
if let Some(recurrent) = config.recurrent.as_ref() {
if matches!(recurrent, RecurrentConfig::None) {
run_apply_epsilon_config(inferer, batch_size, config)
run_apply_epsilon_config(wrapper, inferer, config)
} else {
let inferer = match recurrent {
let wrapper = match recurrent {
RecurrentConfig::None => unreachable!(),
RecurrentConfig::Auto => RecurrentTracker::wrap(inferer),
RecurrentConfig::Auto => RecurrentTrackerWrapper::wrap(wrapper, &inferer),
RecurrentConfig::Mapped(map) => {
let infos = map
.iter()
.cloned()
.map(|(inkey, outkey)| RecurrentInfo { inkey, outkey })
.collect::<Vec<_>>();
RecurrentTracker::new(inferer, infos)
RecurrentTrackerWrapper::new(wrapper, &inferer, infos)
}
}?;

run_apply_epsilon_config(inferer, batch_size, config)
run_apply_epsilon_config(wrapper, inferer, config)
}
} else {
run_apply_epsilon_config(inferer, batch_size, config)
run_apply_epsilon_config(wrapper, inferer, config)
}
}

pub(super) fn run(config: Args) -> Result<()> {
let mut records: Vec<Record> = Vec::new();
for batch_size in config.batch_sizes.clone() {
let mut reader = File::open(&config.file)?;
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
} else {
match config.file.extension().and_then(|ext| ext.to_str()) {
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
Some(other) => bail!("unknown file type {:?}", other),
None => bail!("missing file extension {:?}", config.file),
}
};

let record = run_apply_recurrent(inferer, batch_size, &config)?;

// Print Text
if matches!(config.output, OutputFormat::Text) {
println!(
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
record.batch_size, record.mean, record.stddev, record.total,
);
let mut reader = File::open(&config.file)?;
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
cervo::nnef::builder(&mut reader).build_basic()?
} else {
match config.file.extension().and_then(|ext| ext.to_str()) {
Some("onnx") => cervo::onnx::builder(&mut reader).build_basic()?,
Some("crvo") => AssetData::deserialize(&mut reader)?.load_basic()?,
Some(other) => bail!("unknown file type {:?}", other),
None => bail!("missing file extension {:?}", config.file),
}
};

let records = run_apply_recurrent(BaseCase, inferer, &config)?;

records.push(record);
}
// Print JSON
if matches!(config.output, OutputFormat::Json) {
let json = serde_json::to_string_pretty(&records)?;
Expand Down
10 changes: 1 addition & 9 deletions crates/cervo-cli/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,7 @@ pub(super) fn run(config: Args) -> Result<()> {

let elapsed = if let Some(epsilon) = config.with_epsilon.as_ref() {
let inferer = inferer.with_default_epsilon(epsilon)?;
// TODO[TSolberg]: Issue #31.
let shapes = inferer
.raw_input_shapes()
.iter()
.filter(|(k, _)| k.as_str() != epsilon)
.cloned()
.collect::<Vec<_>>();

let observations = build_inputs_from_desc(config.batch_size as u64, &shapes);
let observations = build_inputs_from_desc(config.batch_size as u64, inferer.input_shapes());

if config.print_input {
print_input(&observations);
Expand Down
127 changes: 113 additions & 14 deletions crates/cervo-core/src/epsilon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model.

use std::cell::RefCell;

use crate::{batcher::ScratchPadView, inferer::Inferer};
use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::ModelWrapper};
use anyhow::{bail, Result};
use perchance::PerchanceContext;
use rand::thread_rng;
Expand Down Expand Up @@ -112,6 +112,13 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
}
}

struct EpsilonInjectorState<NG: NoiseGenerator> {
count: usize,
index: usize,
generator: NG,

inputs: Vec<(String, Vec<usize>)>,
}
/// The [`EpsilonInjector`] wraps an inferer to add noise values as one of the input data points. This is useful for
/// continuous action policies where you might have trained your agent to follow a stochastic policy trained with the
/// reparametrization trick.
Expand All @@ -120,11 +127,8 @@ impl NoiseGenerator for HighQualityNoiseGenerator {
/// wrapper.
pub struct EpsilonInjector<T: Inferer, NG: NoiseGenerator = HighQualityNoiseGenerator> {
inner: T,
count: usize,
index: usize,
generator: NG,

inputs: Vec<(String, Vec<usize>)>,
state: EpsilonInjectorState<NG>,
}

impl<T> EpsilonInjector<T, HighQualityNoiseGenerator>
Expand Down Expand Up @@ -169,11 +173,12 @@ where

Ok(Self {
inner: inferer,
index,
count,
generator,

inputs,
state: EpsilonInjectorState {
index,
count,
generator,
inputs,
},
})
}
}
Expand All @@ -188,15 +193,15 @@ where
}

fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
let total_count = self.count * batch.len();
let output = batch.input_slot_mut(self.index);
self.generator.generate(total_count, output);
let total_count = self.state.count * batch.len();
let output = batch.input_slot_mut(self.state.index);
self.state.generator.generate(total_count, output);

self.inner.infer_raw(batch)
}

fn input_shapes(&self) -> &[(String, Vec<usize>)] {
&self.inputs
&self.state.inputs
}

fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
Expand All @@ -215,3 +220,97 @@ where
self.inner.end_agent(id);
}
}

pub struct EpsilonInjectorWrapper<Inner: ModelWrapper, NG: NoiseGenerator> {
inner: Inner,
state: EpsilonInjectorState<NG>,
}

impl<Inner: ModelWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
/// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
///
/// This function will use [`HighQualityNoiseGenerator`] as the noise source.
///
/// # Errors
///
/// Will return an error if the provided key doesn't match an input on the model.
pub fn wrap(
inner: Inner,
inferer: &dyn Inferer,
key: &str,
) -> Result<EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator>> {
Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key)
}
}

impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
where
Inner: ModelWrapper,
NG: NoiseGenerator,
{
/// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
///
/// # Errors
///
/// Will return an error if the provided key doesn't match an input on the model.
pub fn with_generator(
inner: Inner,
inferer: &dyn Inferer,
generator: NG,
key: &str,
) -> Result<Self> {
let inputs = inferer.input_shapes();

let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
Some((index, (_, shape))) => (index, shape.iter().product()),
None => bail!("model has no input key {:?}", key),
};

let inputs = inputs
.iter()
.filter(|(k, _)| *k != key)
.map(|(k, v)| (k.to_owned(), v.to_owned()))
.collect::<Vec<_>>();

Ok(Self {
inner,
state: EpsilonInjectorState {
index,
count,
generator,
inputs,
},
})
}
}

impl<Inner, NG> ModelWrapper for EpsilonInjectorWrapper<Inner, NG>
where
Inner: ModelWrapper,
NG: NoiseGenerator,
{
fn invoke(&self, inferer: &impl Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
self.inner.invoke(inferer, batch)?;
let total_count = self.state.count * batch.len();
let output = batch.input_slot_mut(self.state.index);
self.state.generator.generate(total_count, output);

self.inner.invoke(inferer, batch)
}

fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.state.inputs.as_ref()
}

fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.inner.output_shapes(inferer)
}

fn begin_agent(&self, id: u64) {
self.inner.begin_agent(id);
}

fn end_agent(&self, id: u64) {
self.inner.end_agent(id);
}
}
2 changes: 2 additions & 0 deletions crates/cervo-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub use tract_hir;
pub mod batcher;
pub mod epsilon;
pub mod inferer;
pub mod model;
mod model_api;
pub mod recurrent;

Expand All @@ -29,6 +30,7 @@ pub mod prelude {
InfererProvider, MemoizingDynamicInferer, Response, State,
};

pub use super::model::ModelWrapper;
pub use super::model_api::ModelApi;
pub use super::recurrent::{RecurrentInfo, RecurrentTracker};
}
Loading
Loading