Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions momentum/character/skeleton_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ void SkeletonStateT<T>::set(
set(referenceSkeleton, computeDeriv);
}

template <typename T>
void SkeletonStateT<T>::set(
const JointParametersT<T>& parameters,
const Skeleton& referenceSkeleton,
const JointSet& needXform,
const JointSet& needDeriv) {
MT_PROFILE_FUNCTION();
this->jointParameters = parameters;
this->jointState.resize(referenceSkeleton.joints.size());
set(referenceSkeleton, needXform, needDeriv);
}

template <typename T>
void SkeletonStateT<T>::set(const Skeleton& referenceSkeleton, bool computeDeriv) {
MT_PROFILE_FUNCTION();
Expand Down Expand Up @@ -127,6 +139,60 @@ void SkeletonStateT<T>::set(const Skeleton& referenceSkeleton, bool computeDeriv
MT_CHECK(jointState.size() == numJoints, "{} is not {}", jointState.size(), numJoints);
}

template <typename T>
void SkeletonStateT<T>::set(
const Skeleton& referenceSkeleton,
const JointSet& needXform,
const JointSet& needDeriv) {
MT_PROFILE_FUNCTION();
// get input joints
const JointListT<T>& joints = ::momentum::cast<T>(referenceSkeleton.joints);

// initialize array size variables
const size_t numJoints = joints.size();

// ensure that all variables are valid
MT_CHECK(
jointParameters.size() == gsl::narrow<Eigen::Index>(numJoints * kParametersPerJoint),
"Unexpected joint parameter size. Expected '{}' (# of joints '{}' X kParametersPerJoint '{}') but got '{}'.",
numJoints * kParametersPerJoint,
numJoints,
kParametersPerJoint,
jointParameters.size());

// go over all joint elements and calculate Transformation
for (size_t jointID = 0; jointID < numJoints; jointID++) {
if (!needXform.test(jointID)) {
continue;
}

const Eigen::Index parameterOffset = jointID * kParametersPerJoint;

// some reference for quick access
const JointT<T>& joint = joints[jointID];

// set joint-state based on parameters
// IMPORTANT: this all assumes that parent joints always appear before their children in the
// joint list, so their joint state will already be calculated when processing the children
if (joint.parent == kInvalidIndex) {
jointState[jointID].set(
joint,
jointParameters.v.template middleRows<7>(parameterOffset),
nullptr,
needDeriv.test(jointID));
} else {
jointState[jointID].set(
joint,
jointParameters.v.template middleRows<7>(parameterOffset),
&jointState[joint.parent],
needDeriv.test(jointID));
}
}

// ensure arrays are valid
MT_CHECK(jointState.size() == numJoints, "{} is not {}", jointState.size(), numJoints);
}

template <typename T>
TransformListT<T> SkeletonStateT<T>::toTransforms() const {
TransformListT<T> result;
Expand Down
22 changes: 22 additions & 0 deletions momentum/character/skeleton_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ struct SkeletonStateT {
const Skeleton& referenceSkeleton,
bool computeDeriv = true);

/// Updates the skeleton state with per-joint control over transforms and derivatives
///
/// This overload allows selective computation of transforms and derivatives on a per-joint
/// basis, which can significantly reduce computation when only a subset of joints is needed.
///
/// @param jointParameters New joint parameters for all joints
/// @param referenceSkeleton The skeleton structure defining joint hierarchy
/// @param needXform Bitset indicating which joints need their transforms computed
/// @param needDeriv Bitset indicating which joints need derivative information computed
void set(
const JointParametersT<T>& jointParameters,
const Skeleton& referenceSkeleton,
const JointSet& needXform,
const JointSet& needDeriv);

/// Updates the skeleton state from another skeleton state with a different scalar type
///
/// @tparam T2 Source scalar type
Expand Down Expand Up @@ -138,6 +153,13 @@ struct SkeletonStateT {
/// @param computeDeriv Whether to compute derivative information for the joints
void set(const Skeleton& referenceSkeleton, bool computeDeriv);

/// Updates the joint states with per-joint control over transforms and derivatives
///
/// @param referenceSkeleton The skeleton structure defining joint hierarchy
/// @param needXform Bitset indicating which joints need their transforms computed
/// @param needDeriv Bitset indicating which joints need derivative information computed
void set(const Skeleton& referenceSkeleton, const JointSet& needXform, const JointSet& needDeriv);

/// Copies joint states from another skeleton state with a different scalar type
///
/// @tparam T2 Source scalar type
Expand Down
29 changes: 29 additions & 0 deletions momentum/character_solver/collision_error_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,35 @@ void CollisionErrorFunctionT<T>::computeBroadPhase(const SkeletonStateT<T>& stat
bvh_.setBoundingBoxes(bvh_.getPrimitives());
}

template <typename T>
JointSet CollisionErrorFunctionT<T>::getAffectedJoints() const {
JointSet activeJoints;

// For each collision geom, count how many other geoms it can't collide with.
// If a geom cannot collide with all other n-1 geoms, it's inactive.
// On the other hand, any active geom has a chance of being included in the error function.
const size_t numGeoms = collisionGeometry_.size();
if (numGeoms == 0) {
return activeJoints;
}

Eigen::VectorXi excludeCount = Eigen::VectorXi::Zero(numGeoms);
for (const auto& pair : excludingPairIds_) {
excludeCount[pair.first]++;
excludeCount[pair.second]++;
}

for (size_t iGeom = 0; iGeom < numGeoms; ++iGeom) {
if (static_cast<size_t>(excludeCount[iGeom]) + 1 < numGeoms) {
const auto parent = collisionGeometry_[iGeom].parent;
if (parent != kInvalidIndex) {
activeJoints.set(parent);
}
}
}
return activeJoints;
}

template <typename T>
std::vector<Vector2i> CollisionErrorFunctionT<T>::getCollisionPairs() const {
std::vector<Vector2i> collidingPairs;
Expand Down
2 changes: 2 additions & 0 deletions momentum/character_solver/collision_error_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class CollisionErrorFunctionT : public SkeletonErrorFunctionT<T> {

[[nodiscard]] std::vector<Vector2i> getCollisionPairs() const;

[[nodiscard]] JointSet getAffectedJoints() const final;

protected:
void updateCollisionPairs();

Expand Down
12 changes: 12 additions & 0 deletions momentum/character_solver/skeleton_error_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ class SkeletonErrorFunctionT {
enabledParameters_ = ps;
}

/// Returns the set of joints that are directly used in this error function
///
/// Override this to provide an accurate list of affected joints (e.g., constrained joints
/// in a position error function). The solver uses this information to optimize which
/// joint transforms and derivatives need to be computed.
/// By default, returns all joints set (conservative).
[[nodiscard]] virtual JointSet getAffectedJoints() const {
JointSet allJoints;
allJoints.set();
return allJoints;
}

virtual double getError(
const ModelParametersT<T>& /* params */,
const SkeletonStateT<T>& /* state */,
Expand Down
97 changes: 91 additions & 6 deletions momentum/character_solver/skeleton_solver_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ SkeletonSolverFunctionT<T>::SkeletonSolverFunctionT(
state_ = std::make_unique<SkeletonStateT<T>>(parameterTransform_.zero(), character_.skeleton);
activeJointParams_ = parameterTransform_.activeJointParams;
meshState_ = std::make_unique<MeshStateT<T>>();
enabledParameters_.set();

for (auto& errf : errorFunctions) {
addErrorFunction(std::move(errf));
Expand All @@ -52,18 +53,102 @@ void SkeletonSolverFunctionT<T>::setEnabledParameters(const ParameterSet& ps) {
}
// set the enabled joints based on the parameter set
activeJointParams_ = parameterTransform_.computeActiveJointParams(ps);
enabledParameters_ = ps;

// give data to helper functions
// give data to helper functions (propagate immediately, not deferred to initialize())
for (auto&& solvable : errorFunctions_) {
solvable->setActiveJoints(activeJointParams_);
solvable->setEnabledParameters(ps);
}
}

template <typename T>
void SkeletonSolverFunctionT<T>::initialize() {
// Gather affected joints from all error functions
const size_t numJoints = character_.skeleton.joints.size();
JointSet affectedJoints;
for (const auto& func : errorFunctions_) {
affectedJoints |= func->getAffectedJoints();
}

// Check if all joints are affected — if so, skip all the expensive bitset work
bool allAffected = true;
for (size_t i = 0; i < numJoints; ++i) {
if (!affectedJoints.test(i)) {
allAffected = false;
break;
}
}

if (allAffected) {
allJointsActive_ = true;
return;
}

// Compute active joints from active joint parameters (at the joint level, not parameter level)
solverActiveJoints_.reset();
for (size_t iJoint = 0; iJoint < numJoints; ++iJoint) {
for (size_t jParam = 0; jParam < kParametersPerJoint; ++jParam) {
if (activeJointParams_[iJoint * kParametersPerJoint + jParam]) {
solverActiveJoints_.set(iJoint);
break;
}
}
}

// We need joint derivatives for all *active* ancestors of an affected joint.
// We need joint transformations for all ancestors of an affected joint, active or not.
activeJointDeriv_.reset();
activeJointXform_.reset();
for (size_t iJoint = 0; iJoint < numJoints; ++iJoint) {
if (!affectedJoints.test(iJoint)) {
continue;
}

// The order of joints ensures parents come before their children. So when an activeJointXform
// is set, we can be sure that its ancestors have been set.
size_t jointIndex = iJoint;
while (jointIndex != kInvalidIndex && !activeJointXform_.test(jointIndex)) {
activeJointXform_.set(jointIndex);
if (solverActiveJoints_.test(jointIndex)) {
activeJointDeriv_.set(jointIndex);
}
jointIndex = character_.skeleton.joints[jointIndex].parent;
}
}

// Check if all joints ended up active anyway
allJointsActive_ = true;
for (size_t i = 0; i < numJoints; ++i) {
if (!activeJointXform_.test(i) || !activeJointDeriv_.test(i)) {
allJointsActive_ = false;
break;
}
}
}

template <typename T>
void SkeletonSolverFunctionT<T>::updateSkeletonStateSelective(
const Eigen::VectorX<T>& parameters,
bool computeDeriv) {
// Selective path: only compute transforms/derivatives for active joints
if (computeDeriv) {
state_->set(
parameterTransform_.apply(parameters),
character_.skeleton,
activeJointXform_,
activeJointDeriv_);
} else {
static const JointSet emptySet;
state_->set(
parameterTransform_.apply(parameters), character_.skeleton, activeJointXform_, emptySet);
}
}

template <typename T>
double SkeletonSolverFunctionT<T>::getError(const Eigen::VectorX<T>& parameters) {
// update the state according to the transformed parameters
state_->set(parameterTransform_.apply(parameters), character_.skeleton, false);
// update the state according to the transformed parameters (no derivatives needed for error-only)
updateSkeletonState(parameters, false);

// Update mesh state if needed
if (needsMeshState()) {
Expand All @@ -87,7 +172,7 @@ double SkeletonSolverFunctionT<T>::getGradient(
const Eigen::VectorX<T>& parameters,
Eigen::VectorX<T>& gradient) {
// update the state according to the transformed parameters
state_->set(parameterTransform_.apply(parameters), character_.skeleton);
updateSkeletonState(parameters, true);

// Update mesh state if needed
if (needsMeshState()) {
Expand Down Expand Up @@ -121,7 +206,7 @@ double SkeletonSolverFunctionT<T>::getSolverDerivatives(
// update the state according to the transformed parameters
{
MT_PROFILE_EVENT("UpdateState");
state_->set(parameterTransform_.apply(parameters), character_.skeleton);
updateSkeletonState(parameters, true);
}

// Update mesh state if needed
Expand Down Expand Up @@ -204,7 +289,7 @@ void SkeletonSolverFunctionT<T>::initializeJacobianComputation(
// Update the state according to the transformed parameters
{
MT_PROFILE_EVENT("Initialize - update state");
state_->set(parameterTransform_.apply(parameters), character_.skeleton);
updateSkeletonState(parameters, true);
}

// Update mesh state if needed
Expand Down
30 changes: 30 additions & 0 deletions momentum/character_solver/skeleton_solver_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <momentum/character/character.h>
#include <momentum/character/fwd.h>
#include <momentum/character/parameter_transform.h>
#include <momentum/character/skeleton_state.h>
#include <momentum/character/types.h>
#include <momentum/character_solver/fwd.h>
#include <momentum/solver/solver_function.h>
Expand Down Expand Up @@ -50,6 +51,7 @@ class SkeletonSolverFunctionT : public SolverFunctionT<T> {

void updateParameters(Eigen::VectorX<T>& parameters, const Eigen::VectorX<T>& delta) final;
void setEnabledParameters(const ParameterSet& ps) final;
void initialize() override;

void addErrorFunction(std::shared_ptr<SkeletonErrorFunctionT<T>> solvable);
void clearErrorFunctions();
Expand Down Expand Up @@ -86,7 +88,35 @@ class SkeletonSolverFunctionT : public SolverFunctionT<T> {
bool needsMeshState_;
VectorX<bool> activeJointParams_;

/// True when all skeleton joints need both transforms and derivatives.
/// Enables a fast path that bypasses per-joint bitset checks.
/// Defaults to true so that before initialize() is called, the baseline code path is used.
bool allJointsActive_{true};

std::vector<std::shared_ptr<SkeletonErrorFunctionT<T>>> errorFunctions_;

/// Updates the skeleton state; inlined so the compiler can eliminate the branch
/// in the common case where allJointsActive_ is true.
void updateSkeletonState(const Eigen::VectorX<T>& parameters, bool computeDeriv) {
if (allJointsActive_) {
// Fast path: identical to baseline — no per-joint bitset checks.
state_->set(parameterTransform_.apply(parameters), character_.skeleton, computeDeriv);
} else {
updateSkeletonStateSelective(parameters, computeDeriv);
}
}

/// Slow path for selective joint computation; outlined to keep the fast path lean.
void updateSkeletonStateSelective(const Eigen::VectorX<T>& parameters, bool computeDeriv);

/// Joints that need transforms in any solver computation; computed once at initialization
JointSet activeJointXform_;

/// Joints that need derivatives in any solver computation; computed once at initialization
JointSet activeJointDeriv_;

JointSet solverActiveJoints_;
ParameterSet enabledParameters_;
};

} // namespace momentum
3 changes: 3 additions & 0 deletions momentum/math/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ using ColorArray = Matrix3Xb;
constexpr size_t kMaxModelParams = 2048; // at most 2048 parameters per frame
using ParameterSet = std::bitset<kMaxModelParams>;

constexpr size_t kMaxJoints = 1024; // at most 1024 joints in the system
using JointSet = std::bitset<kMaxJoints>;

/// @brief A utility struct that facilitates the deduction of a `std::span` type from a given type.
///
/// This utility is particularly useful when a function accepts a `std::span<Vector3<T>>` as an
Expand Down
1 change: 1 addition & 0 deletions momentum/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ double SolverT<T>::solve(Eigen::VectorX<T>& params) {
lastError_ = std::numeric_limits<decltype(lastError_)>::max();

initializeSolver();
solverFunction_->initialize();

iteration_ = 0;
for (; iteration_ < maxIterations_; iteration_++) {
Expand Down
Loading
Loading