get_model_latents_states¶
- behavenet.plotting.arhmm_utils.get_model_latents_states(hparams, version, sess_idx=0, return_samples=0, cond_sampling=False, dtype='test', dtypes=['train', 'val', 'test'], rng_seed=0)[source]¶
Return arhmm defined in
hparamswith associated latents and states.Can also return sampled latents and states.
- Parameters:
hparams (
dict) – needs to contain enough information to specify an arhmmversion (
strorint) – test tube model version (can be ‘best’)sess_idx (
int, optional) – session index into data generatorreturn_samples (
int, optional) – number of trials to sample from modelcond_sampling (
bool, optional) – ifTruereturn samples conditioned on most likely state sequence; else return unconditioned samplesdtype (
str, optional) – trial type to use for conditonal sampling; ‘train’ | ‘val’ | ‘test’dtypes (
array-like, optional) – trial types for which to collect latents and statesrng_seed (
int, optional) – random number generator seed to control sampling
- Returns:
‘model’ (
ssm.HMMobject)’latents’ (
dict): latents from train, val and test trials’states’ (
dict): states from train, val and test trials’trial_idxs’ (
dict): trial indices from train, val and test trials’latents_gen’ (
list)’states_gen’ (
list)
- Return type:
dict