I’m tring to perform the spectral unmixing technique with the MCMC algorithm. But I encountered a problem while I ran it. This is the part of the code I’m running:
import pymc3 as pm
import theano.tensor as tt
import arviz as az
"""
Perform the spectral unmixing through a MCMC algorithm
"""
### SHOW AVAILABLE END MEMBERS
for l in range(len(TOT_names)):
print('Index: '+str(l)+', Compound: '+str(TOT_names[l]))
### ENDMEMBERS
my_arr = []
lst = list(map(int, input("Which end-members (enter comma separated values): ").split(",")))
for k in lst:
my_arr.append(EM[k][i:j])
ENDM = pm.floatX(np.array(my_arr))
print('ENDMEMBERS ARRAY: ', ENDM.shape)
### DATA MATRIX
MATRIX = JM0340_matrix
### DEFINE MCMC MODEL
with pm.Model() as model:
# Prior distributions for the abundances
ABUNDANCES = pm.Dirichlet('abundances', a=np.ones(len(lst)))
print('ABUNDANCES ARRAY: ', ABUNDANCES)
# Constraint on non-negativity of abundances
pm.Potential('AB_POS_CONSTRAINT', pm.math.switch(pm.math.sum(pm.math.maximum(ABUNDANCES, 0)) - pm.math.sum(ABUNDANCES) < 0, -np.inf, 0))
# Constraint on the sum of abundances
AB_SUM = pm.Deterministic('ab_sum', pm.math.sum(ABUNDANCES))
pm.Potential('AB_SUM_CONSTRAINT', pm.math.switch(tt.abs_(AB_SUM - 1) > 1e-3, -np.inf, 0))
# Compute modeled spectra
MODELED_SPECTRA = pm.Deterministic('modeled_spectra', pm.math.dot(ABUNDANCES, ENDM))
# Likelihood distribution for the data
DATA = pm.Normal('data', mu=MODELED_SPECTRA, sd=1, observed=MATRIX) #sd standard deviation
# Sampling Define MCMC model
TRACE = pm.sample(draws=400, tune=210, chains=8, cores=8, step=pm.Metropolis(), return_inferencedata=True)
# Extract the trace of the abundances
ABUND = TRACE['abundances'].mean(axis=0)
print('ABUNDANCES ARRAY: ', ABUND.shape)
And I got this error:
81 # Extract the trace of the abundances
---> 82 ABUND = TRACE['abundances'].mean(axis=0)
83 print('ABUNDANCES ARRAY: ', ABUND.shape)
File ~\anaconda3\lib\site-packages\arviz\data\inference_data.py:236, in InferenceData.__getitem__(self, key)
234 """Get item by key."""
235 if key not in self._groups_all:
--> 236 raise KeyError(key)
237 return getattr(self, key)
KeyError: 'abundances'
I tried with NUTS and with Metropolis and it is the same. I got the same error.
I’m not able to understand what is wrong.
Read more here: Source link