# python3 code for finding CSS codes of arbitray code distance using random search in the CPC formalism
# written by Nicholas Chancellor
# based on the methods of arXiv:1611.08012 by N. Chancellor, A. Kissenger, J. Roffe, S. Zohren, and D. Horsman
# feel free to resuse/modify but please attribute the source and cite our paper in any published work

import numpy as np
import numpy.random as r
import scipy as scp
import scipy.misc as misc


def find_codes(n,k,nErr,*args,**kwargs): # function for finding [[n,k,2*nErr+1]] codes
	
	if 'nCheck' in kwargs: # number of random codes to checks
		nCheck=kwargs['nCheck']
	else:
		nCheck=1000 # default value if none specified
	if 'nKeep' in kwargs: # maximum number of successful codes to keep
		nKeep=kwargs['nKeep']
	else:
		nKeep=10 # default value if none specified

	MbCell=[None]*nKeep # list of bit checks matricies for successful codes
	MpCell=[None]*nKeep # list of phase checks for successful codes
	McCell=[None]*nKeep # list of cross checks for successful codes

	nSuccess=0 # number of codes successfully found

	binConvert=2**np.arange(n-k) # vector which will be useful later for converting to binary

	for iRand in range(nCheck): # loop for generating random code matricies
		Mb=np.round(r.rand(k,n-k)) # random bit check matrix
		Mp=np.round(r.rand(k,n-k)) # random phase check matrix
		Mc=np.round(r.rand(n-k,n-k)) # random cross check matrix
		Mc=np.triu(Mc) # make cross check matrix upper triangular
		Mc=Mc-np.diag(np.diag(Mc)) # make cross check matrix strictly upper triangular
		nSyndrome=0
		for nErr1 in range(0,nErr+1):
			nSyndrome=nSyndrome+misc.comb(2*n,nErr1) # number of unique syndromes which need to be checked
		syndromeList=np.zeros([int(n-k),int(nSyndrome)]) # list for storing syndromes for all error patterns with up nErr errors
		iSyndrome=1; # leave first vector in list blank to signify case of no error
	
		for inErr in range(1,nErr+1): # loop over number of errors
			errList=np.zeros([2*n,1])
			errList[range(inErr)]=1; # configuration of inErr errors which corresponds to the smallest possible binary number
			for iConfig in range(int(misc.comb(2*n,inErr))):
				# syndromes for bit flip errors on data qubits
				bitErrVec=np.transpose(errList[range(k)]) # vector listing bit errors
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.dot(bitErrVec,Mb) # calculate syndrome
				# syndromes for phase errors on data qubits
				phaseErrVec=np.transpose(errList[range(k,2*k)]) # vector listing phase errors
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.dot(phaseErrVec,Mp) # calculate syndrome
				# syndromes for bit flip errors on parity check qubits
				bitErrVecPar=np.transpose(errList[range(2*k,k+n)]) # vector listing bit errors
				syndromeList[:,iSyndrome]=syndromeList[:,iSyndrome]+bitErrVecPar # calculate syndrome
		    # syndromes for phase errors on parity check qubits
				phaseErrVecPar=np.transpose(errList[range(k+n,len(errList))]) # vector listing phase errors
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.mod(np.dot(phaseErrVecPar,(Mc+np.transpose(Mc)+np.dot(np.transpose(Mp),Mb))),2); # calculate syndrome
				syndromeList[:,iSyndrome]=np.mod(syndromeList[:,iSyndrome],2); # errors add mod 2
				iSyndrome=iSyndrome+1 # increment for storing in list
				for iIncr in range(2*n-1): # increment error list
					if errList[iIncr]==1 and errList[iIncr+1]==0: # if error can be moved up by one slot
						errList[iIncr]=0
						errList[iIncr+1]=1 # move error up
						nErrLess=(errList[range(iIncr)]!=0).sum()
						errList[range(iIncr)]=np.zeros([iIncr,1])
						errList[range(nErrLess)]=1 # bits below moved bit reset to lowest value
						break  # break look after successfully incrementing
	
		syndromeNums=np.dot(binConvert,syndromeList) # treat as binary numbers and convert to decimal for easier comparison
		if len(np.unique(syndromeNums))==len(syndromeNums): # check if syndromes are unique
			nSuccess=nSuccess+1 # another code successfully found!
			if (nSuccess-1)<nKeep: # only keep so many codes to avoid memory issues
				MbCell[nSuccess-1]=Mb.T # store bit check matrix, transpose is performed to agree with convention in paper
				MpCell[nSuccess-1]=Mp.T # store phase check matrix, transpose is performed to agree with convention in paper
				McCell[nSuccess-1]=Mc # store cross check matrix
	if ('matrix_save' in kwargs) and kwargs['matrix_save']==True: # save code matricies if set to true
		np.save('MbCell',MbCell)
		np.save('MpCell',MpCell)
		np.save('McCell',McCell)
	print(str(nSuccess)+' out of '+str(nCheck)+' succeeded')
	return(MbCell,MpCell,McCell)

def find_codes_includeY(n,k,nErr,*args,**kwargs): # function for finding [[n,k,2*nErr+1]] codes including the possibility of Y errors
	
	if 'nCheck' in kwargs: # number of random codes to checks
		nCheck=kwargs['nCheck']
	else:
		nCheck=1000 # default value if none specified
	if 'nKeep' in kwargs: # maximum number of successful codes to keep
		nKeep=kwargs['nKeep']
	else:
		nKeep=10 # default value if none specified

	MbCell=[None]*nKeep # list of bit checks matricies for successful codes
	MpCell=[None]*nKeep # list of phase checks for successful codes
	McCell=[None]*nKeep # list of cross checks for successful codes

	nSuccess=0 # number of codes successfully found

	binConvert=2**np.arange(n-k) # vector which will be useful later for converting to binary

	for iRand in range(nCheck): # loop for generating random code matricies
		Mb=np.round(r.rand(k,n-k)) # random bit check matrix
		Mp=np.round(r.rand(k,n-k)) # random phase check matrix
		Mc=np.round(r.rand(n-k,n-k)) # random cross check matrix
		Mc=np.triu(Mc) # make cross check matrix upper triangular
		Mc=Mc-np.diag(np.diag(Mc)) # make cross check matrix strictly upper triangular
		nSyndrome=0
		for nErr1 in range(0,nErr+1):
			nSyndrome=nSyndrome+misc.comb(3*n,nErr1) # number of unique syndromes which need to be checked
		syndromeList=np.zeros([int(n-k),int(nSyndrome)]) # list for storing syndromes for all error patterns with up nErr errors
		iSyndrome=1; # leave first vector in list blank to signify case of no error
	
		for inErr in range(1,nErr+1): # loop over number of errors
			errList=np.zeros([3*n,1])
			errList[range(inErr)]=1; # configuration of inErr errors which corresponds to the smallest possible binary number
			# errList lists all bit errors, with data qubits first followed by parity check qubits, than phase errors, then Y errors
			for iConfig in range(int(misc.comb(3*n,inErr))):
				# syndromes for bit flip errors on data qubits
				bitErrVec=np.transpose(errList[range(k)])+np.transpose(errList[range(2*n,2*n+k)]) # vector listing bit errors incuding those caused by Pauli Y
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.dot(bitErrVec,Mb) # calculate syndrome
				# syndromes for phase errors on data qubits
				phaseErrVec=np.transpose(errList[range(n,n+k)])+np.transpose(errList[range(2*n,2*n+k)]) # vector listing phase errors incuding those caused by Pauli Y
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.dot(phaseErrVec,Mp) # calculate syndrome
				# syndromes for bit flip errors on parity check qubits
				bitErrVecPar=np.transpose(errList[range(k,n)])+np.transpose(errList[range(2*n+k,3*n)]) # vector listing bit errors incuding those caused by Pauli Y
				syndromeList[:,iSyndrome]=syndromeList[:,iSyndrome]+bitErrVecPar # calculate syndrome
		    # syndromes for phase errors on parity check qubits
				phaseErrVecPar=np.transpose(errList[range(k+n,2*n)])+np.transpose(errList[range(2*n+k,3*n)]) # vector listing phase errors incuding those caused by Pauli Y
				syndromeList[:,iSyndrome]=np.transpose(syndromeList[:,iSyndrome])+np.mod(np.dot(phaseErrVecPar,(Mc+np.transpose(Mc)+np.dot(np.transpose(Mp),Mb))),2); # calculate syndrome
				syndromeList[:,iSyndrome]=np.mod(syndromeList[:,iSyndrome],2); # errors add mod 2
				if any((errList[range(n)]*errList[range(n,2*n)])>0) or any((errList[range(n)]*errList[range(2*n,3*n)])>0) or any((errList[range(n,2*n)]*errList[range(2*n,3*n)])>0):
					# if error pattern has already been created a simpler way, then syndrome will be the same
					syndromeList[:,iSyndrome]=np.NaN # NaN is always treated as unique, even compared to other NaN values
				iSyndrome=iSyndrome+1 # increment for storing in list
				for iIncr in range(3*n-1): # increment error list
					if errList[iIncr]==1 and errList[iIncr+1]==0: # if error can be moved up by one slot
						errList[iIncr]=0
						errList[iIncr+1]=1 # move error up
						nErrLess=(errList[range(iIncr)]!=0).sum()
						errList[range(iIncr)]=np.zeros([iIncr,1])
						errList[range(nErrLess)]=1 # bits below moved bit reset to lowest value
						break  # break loop after successfully incrementing
	
		syndromeNums=np.dot(binConvert,syndromeList) # treat as binary numbers and convert to decimal for easier comparison
		if len(np.unique(syndromeNums))==len(syndromeNums): # check if syndromes are unique
			nSuccess=nSuccess+1 # another code successfully found!
			if (nSuccess-1)<nKeep: # only keep so many codes to avoid memory issues
				MbCell[nSuccess-1]=Mb.T # store bit check matrix, transpose to agree with convention in paper
				MpCell[nSuccess-1]=Mp.T # store phase check matrix, transpose to agree with convention in paper
				McCell[nSuccess-1]=Mc # store cross check matrix
	if ('matrix_save' in kwargs) and kwargs['matrix_save']==True: # save code matricies if set to true
		np.save('MbCellY',MbCell)
		np.save('MpCellY',MpCell)
		np.save('McCellY',McCell)
	print(str(nSuccess)+' out of '+str(nCheck)+' succeeded')
	return(MbCell,MpCell,McCell)

def CPC_mats_2_stabilizers(Mb,Mp,Mc,*args,**kwargs): # converts CPC matricies to latex formatted stabilizer tables, and save latex formatted versions if desired
	# Mb, Mp, and Mc are bit phase and cross check matricies written in the format given in arXiv:1611.08012
	# saveName is an optional parameter giving the name of the text file where the stabalizer matrix is saved
	Mb=Mb.T # convention in paper uses transpose of what we use in this code
	Mp=Mp.T # convention in paper uses transpose of what we use in this code

	k=Mb.shape[0] # number of logical qubits

	n=Mb.shape[0]+Mb.shape[1] # number of total qubits

	strCellLines=[None]*(n-k) # list for storing lines of the latex array
	strCellLinesDisplay=[None]*(n-k) # list for storing lines of the display array

	indirectProp=np.dot(np.transpose(Mp),Mb) # indirectly propagated phase information

	for i in range(n-k): # iterate over stabalizers
		strCellChars=[None]*n # list for storing (X, Z, Y or 1) elements of stabalizer row
		numZmultList=np.zeros(n) # number of times a Z stabalizer element is found on a given qubit
		numXmultList=np.zeros(n) # number of times an X stabalizer element is found on a given qubit
		# apply matrix formula to create stabalizers
		# Z stabalizers
		numZmultList[k+i]=numZmultList[k+i]+1 # bit information of measured qubit
		numZmultList[range(k)]=numZmultList[range(k)]+Mb[:,i] # bit information from measured qubits
		# X stabalizers
		numXmultList[range(k)]=numXmultList[range(k)]+Mp[:,i]; # phase information from measured qubits
		numXmultList[range(k,len(numXmultList))]=numXmultList[range(k,len(numXmultList))]+Mc[:,i]+np.transpose(Mc[i,:]) # phase information propagated by cross checks
		numXmultList[range(k,len(numXmultList))]=numXmultList[range(k,len(numXmultList))]+indirectProp[:,i] # phase information propagated indriectly
		# write stabalizer table  
		for iWrite in range(n):
			if (numZmultList[iWrite]%2)==0 and (numXmultList[iWrite]%2)==0: # if there are neither X nor Z stabalizers
				strCellChars[iWrite]='1'
			elif numZmultList[iWrite]%2==1 and numXmultList[iWrite]%2==0: # if there is only a Z stabalizer
				strCellChars[iWrite]='Z'
			elif numZmultList[iWrite]%2==0 and numXmultList[iWrite]%2==1: # if there is only an X stabalizer
				strCellChars[iWrite]='X'
			elif numZmultList[iWrite]%2==1 and numXmultList[iWrite]%2==1: # if there are both X and Z stabalizers
				strCellChars[iWrite]='Y' # X and Z combine to form Y
		strCellLines[i]=' & '.join(strCellChars) # latex formatted characters for line
		strCellLinesDisplay[i]=' '.join(strCellChars) # latex formatted characters for line


	latex_output='\\\\\n'.join(strCellLines) # combine lines to make total latex array
	display_output='\n'.join(strCellLinesDisplay) # combine lines to make display version of table

	if 'saveName' in kwargs: # write latex array to file if file name provided
		np.savetxt(kwargs['saveName'],[latex_output],fmt='%s')
	print(display_output)

def convert_to_latex_array(arr,filename): # function for converting integer arrays, such as CPC matricies to latex formatted arrays
	
	# get array dimensions
	arraySize1=arr.shape[0]
	arraySize2=arr.shape[1]

	strCellLines=[None]*arraySize1 # lines of the text file to be returned

	for i1 in range(arraySize1): # iterate through lines
		strCellChars=[None]*arraySize2 # each element in a line
		for i2 in range(arraySize2): # iterate through numbers on line
			strCellChars[i2]=str(int(arr[i1,i2])) # convert to in then string
		strCellLines[i1]=' & '.join(strCellChars) # latex formatted line

	latex_output=' \\\\\n'.join(strCellLines) # latex formatted array

	np.savetxt(filename,[latex_output],fmt='%s') # save latex formatted array
