get_model_latents_states

behavenet.plotting.arhmm_utils.get_model_latents_states(hparams, version, sess_idx=0, return_samples=0, cond_sampling=False, dtype='test', dtypes=['train', 'val', 'test'], rng_seed=0)[source]

Return arhmm defined in hparams with associated latents and states.

Can also return sampled latents and states.

Parameters:
  • hparams (dict) – needs to contain enough information to specify an arhmm

  • version (str or int) – test tube model version (can be ‘best’)

  • sess_idx (int, optional) – session index into data generator

  • return_samples (int, optional) – number of trials to sample from model

  • cond_sampling (bool, optional) – if True return samples conditioned on most likely state sequence; else return unconditioned samples

  • dtype (str, optional) – trial type to use for conditonal sampling; ‘train’ | ‘val’ | ‘test’

  • dtypes (array-like, optional) – trial types for which to collect latents and states

  • rng_seed (int, optional) – random number generator seed to control sampling

Returns:

  • ‘model’ (ssm.HMM object)

  • ’latents’ (dict): latents from train, val and test trials

  • ’states’ (dict): states from train, val and test trials

  • ’trial_idxs’ (dict): trial indices from train, val and test trials

  • ’latents_gen’ (list)

  • ’states_gen’ (list)

Return type:

dict