"""
ciphers.py

By Kirby Urner, Oregon Curriculum Network
Major reference:  http://www.inetarena.com/~pdx4d/ocn/crypto0.html

Last revision:  May 04, 2001

"""

from string import uppercase
from random import shuffle, randint
from xreadlines import xreadlines
from primes import bigppr, pptest
from operator import mul
from binascii import hexlify, unhexlify

"""
Set domain = range(...) if you want to work with
integers instead of uppercase letters -- or experiment
with other domains
"""

#================= Permutations =======================


_domain = list(uppercase)

def mkcode():
    """
    Uniquely pair each element in _domain with another
    (returns a bijection from _domain to _domain --
    as a dictionary)
    """
    target = _domain[:]
    shuffle(target)
    dict = {}
    for i,j in zip(_domain, target):
       dict[i] = j
    return dict

def encrypt(plaintext,secretkey):
    """
    substitute shuffled elements as per secretkey, or
    leave as is if not in the key -- designed to work
    with _domain = list(uppercase)
    """
    ciphertext = ""
    keys = secretkey.keys()
    for i in plaintext.upper():
	if not i in keys:
	   ciphertext += i
        else:
   	   ciphertext += secretkey[i]
    return ciphertext


def decrypt(ciphertext,secretkey):
    """
    decryption = encryption with reversed
    dictionary in this simple substitution scheme
    """
    reversedict = {}
    for i,j in secretkey.items():
       reversedict[j]=i
    return encrypt(ciphertext,reversedict)

def mkcycles(secretkey):
    """
    A dictionary of letter pairs is isomorphic to a
    list of disjoint cycles -- disjoint means no letters
    in common -- where each cycle pairs a letter with
    the one following, with the last pairing with the
    first e.g. {'a':'b','b':'c','c':'a'} -> [(a,b,c)]

    If a letter pairs with itself, it gets left out of
    the tuples.  E.g. the identity dictionary, pairing
    every letter with itself, returns [].
    
    This function returns a list of such cycles-tuples
    given any dictionary of letter pairs
    """
    cycles = []
    dict = {}
    dict.update(secretkey)
    
    while len(dict)>0:        
        cycle = [dict.keys()[0]]
        while 1:
            next = dict[cycle[-1]]
            del(dict[cycle[-1]])            
            if next==cycle[0]:
                break
            cycle.append(next)            
        if len(cycle)>1:    
            cycles.append(tuple(cycle))

    return cycles

def mkdict(cycles):
    """
    This function is the inverse of mkcycles -- given
    a list of tuple-cycles, it returns the corresponding
    dictionary.  Note that it fills in letters which pair
    with themselves
    """
    allelem = _domain[:]    
    dict = {}
    if len(cycles)>0:
       for cycle in cycles:
           for j in range(len(cycle)-1):
               dict[cycle[j]] = cycle[j+1]
               allelem.remove(cycle[j])            
           dict[cycle[-1]] = cycle[0]
           allelem.remove(cycle[-1])
    for k in allelem:
        dict[k]=k
    return dict

class P:
    """
    Permutations:  these objects multiply with each other,
    return an inverse, may be raised to an integer power
    """

    def __init__(self,cycles=None,dict=None):
        """
        Accept permutation in cyclic notation, or as a
        substitution dictionary.  A random permutation
        is returned if neither is provided
        """
        if cycles==None and dict==None:
            self.dict = mkcode()
        elif cycles != None: self.dict = mkdict(cycles)
        else: self.dict = dict

    def __mul__(self,other):
        newdict = {}
        for i in self.dict.keys():
            newdict[i] = other.dict[self.dict[i]]
        return P(dict=newdict)

    def __call__(self,other):
        newdict = {}
        for i in other.dict.keys():
            newdict[i] = self.dict[other.dict[i]]
        return P(dict=newdict)
        
    def __div__(self,other):
        return self * other.inv()
    
    def inv(self):
        newdict = {}
        for i,j in self.dict.items():
           newdict[j]=i
        return P(dict=newdict)

    def __eq__(self,other):
        return self.dict==other.dict
    
    def __pow__(self,n):
        new = P([])
        if n==0: return new
        for i in range(abs(n)):
            new = self * new
        if n<0:
            new = new.inv()
        return new

    def ord(self):
        """
        returns the exponent n of p such that p**n
        = [] (identity element)
        """        
        return reduce(lcm, [1]+map(len, mkcycles(self.dict)))
                
    def __repr__(self):
        return "Permutation: " + str(mkcycles(self.dict))

#================= Residue Classes  =====================

    
class R:
        
    def __init__(self,val,n):
        self.v = val%n
        self.n = n
        
    def __mul__(self,m):
        """
        a*b = (a*b) mod n
        """
        return R(self.v * m.v, self.n)

    def __div__(self,m):
        return self * m.inv()
    
    def __add__(self,m):
        """
        a+b = (a+b) mod n
        """
        return R(self.v + m.v, self.n)
    
    def __sub__(self,m):        
        return self.__add__(-m)
    
    def __neg__(self):
        return R(-self.v, self.n)
    
    def __pow__(self,n):
        new = R(1,self.n)
        if n==0: return new
        new = R(pow(self.v,abs(n),self.n),self.n)
        if n<0:
            new = new.inv()
        return new

    def ord(self):
        """
        Exponent of g: power of g > 0 that equals 1
        """
        i = 1
        while not (self**i).v == 1:
           i += 1
        return i    

    def inv(self):
        return pow(self,phi(self.n)-1)
        
    def __lt__(self,k):
        return self.v < k.v

    def __gt__(self,k):
        return self.v > k.v

    def __eq__(self,k):
        return self.v == k.v

    def __repr__(self):
        return str(self.v)

def gcd(a,b):
    """Return greatest common divisor using Euclid's Algorithm."""
    while b:      
	a, b = b, a % b
    return a

def lcm(a,b):
    """
    Return lowest common multiple."""
    return (a*b)/gcd(a,b)

def bingcd(a,b):
    """Extended version of Euclid's Algorithm (binary GCD)
    Returns (m,n,gcd) such that  m*a + n*b = gcd(a,b)"""
    g,u,v = [b,a],[1,0],[0,1]
    while g[1]<>0:
        y = g[0]/g[1]
        g[0],g[1] = g[1],g[0]%g[1]
        u[0],u[1] = u[1],u[0] - y*u[1]
        v[0],v[1] = v[1],v[0] - y*v[1]
    m = v[0]%b
    gcd = (m*a)%b
    n = (gcd - m*a)/b
    return (m,n,gcd)

def inverse(a,b):
    """If gcd(a,b)=1, then inverse(a,b)*a mod b = 1,
    otherwise, if gcd(a,b)!=1, return 0

    Useful in RSA encryption, for finding d such that
    e*d mod totient(n) = 1"""
    inva,n,gcd = bingcd(a,b)
    return (gcd==1)*inva

def relprimes(n):
    """
    List integers 0 < i < n, such that i,n are coprime
    """
    return [x for x in range(1,n) if gcd(n,x)==1]

def phi(n):
    """
    Number of integers 0 < i < n coprime to n
    """
    return len(relprimes(n))

    
class Rgroup:

    def __init__(self,modulus,elements=None):
        self.modulus = modulus
        if elements==None:
            self.elements = [R(x,modulus) for x in relprimes(modulus)]
        else:
            self.elements = elements
        self.totient = len(self.elements)
        
    def table(self,op="*"):
        """
        Outputs an operation table for the group, using whatever
        operation: * / + -
        """
        if op in "+-":  elems = self.elements + [R(0,self.modulus)]
        else: elems = self.elements[:]
        elems.sort()
        print " "
        head = " "+op + "  "+(len(elems)*" %2s") % tuple(elems)
        print head
        print  "   " + len(head)*"-"    
        for i in elems:
           vals = [eval("i"+op+"x") for x in elems]
           seg1 = "%2s| " % i
           seg2 = (len(vals)*" %2s") % tuple(vals)
           print seg1 + seg2
        print " "

    def exp(self,i):
        return [self[i]**x for x in range(1,self[i].ord()+1)]

    def powers(self):
        cycles = []
        elems = self.elements[:]
        elems.sort()        
        print " "
        head = "**  "+((len(elems)+1)*" %2s") \
               % tuple(range(len(elems)+1))
        print head
        print  "   " + len(head)*"-"
        for i in elems:
            vals = [i**x for x in range(len(elems)+1)]
            seg1 = "%2s| " % i
            seg2 = (len(vals)*" %2s") % tuple(vals)
            print seg1 + seg2
        print " "
        
    def __getitem__(self,k):
        return self.elements[k]
    
#================= Enigma-style Secret Key machine ============

    
class Number:

    def __init__(self,columns):
        self.columns = tuple(columns)
        bases = []
        for i in self.columns:
            bases.append(i.n)
        self.bases = bases
        
    def __add__(self,input):
        carry = 0
        columns = [0]*len(self.columns)
        # evaluate from right to left

        for i in range(len(self.columns)-1,-1,-1):
            columns[i]  = R(self.columns[i].v + input[i] + carry, self.bases[i])
            carry = (self.columns[i].v + input[i] + carry)/self.bases[i]
        return Number(columns)

    def __len__(self):
        return len(self.columns)
    
    def __repr__(self):
        return str(self.columns)

class Enigma:

    def __init__(self,rotors):
        self.rotors=list(rotors)
        columns = []
        for rotor in rotors:
            columns.append( R(0,rotor.ord()) )        
        self.counter = Number(columns)
        self.permute = P([])

    def click(self,direction):
        before = self.counter.columns[:]
        self.counter = self.counter + ([0]*(len(before)-1)+[1]) # add 1

        after = self.counter.columns[:]
        for i in range(len(before)):
            if before[i].v<>after[i].v:  # compare counter values

                if direction==1:                    
                    self.permute = self.rotors[i] * self.permute
                else:
                    self.permute = self.permute * (self.rotors[i].inv())
                
    def encrypt(self,plaintext):
        ciphertext = ""
        keys = self.permute.dict.keys()
        for i in plaintext.upper():
       	    self.click(1)
    	    if not i in keys:
	        ciphertext += i
            else:
   	        ciphertext += self.permute.dict[i]
        return ciphertext

    def reset(self):
        self.permute = P([])
        for i in range(len(self.rotors)):
            self.counter.columns[i].v = 0
        
    def decrypt(self,ciphertext):
        plaintext = ""
        keys = self.permute.dict.keys()
        for i in ciphertext.upper():
            self.click(-1)
            if not i in keys:
                plaintext += i
            else:
                plaintext += self.permute.dict[i]
        return plaintext

    def __repr__(self):
        return "Enigma-class object: rotors of order %s" % (self.counter.bases)
    
def encipher(infile,outfile,enigma):
    """
    Apply an enigma object to some input file,
    writing an enciphered output file
    """
    f1 = open(infile,'r')
    f2 = open(outfile,'w')
    for i in xreadlines(f1):
        f2.write(enigma.encrypt(i))
    f1.close()
    f2.close()

def decipher(infile,outfile,enigma):
    """
    Apply an enigma object to some input file,
    writing a deciphered output file
    """    
    f1 = open(infile,'r')
    f2 = open(outfile,'w')
    enigma.reset()
    for i in xreadlines(f1):
        f2.write(enigma.decrypt(i))
    f1.close()
    f2.close()

def permutations(n):
    """
    return all unique permuations of 0...n (= n! elements)
    """
    columns = [R(x,n) for x in range(n)]
    counter = Number(columns)
    nbperms = reduce(mul,range(1,n+1))
    perms = []
    i = 0
    while i < nbperms:
        dismiss = 0
        strval = ''.join(map(str,counter.columns))
        for c in strval:
            if strval.count(c)>1:
                dismiss = 1
        if not dismiss:
            perms.append(strval)
            i += 1
        counter += ([0]*(len(counter)-1)+[1])
    return perms

def sign(perm):
    """
    return the sign of a permutation
    = -1 if odd number of transpositions
    =  1 if even
    """
    trans = 0
    for k in perm[:-1]:
        for j in perm[perm.index(k)+1:]:
            if int(k)>int(j):
                trans += 1
    if trans%2==0: return 1
    else: return -1
            
#================= RSA public key encryption ============

    
def mknum(phrase):
    return eval('0x'+hexlify(phrase)+'L')

def mkphrase(num):
    return unhexlify(hex(num)[2:-1])

def getpr(digits):
    """
    Getting a good p,q for an RSA modulus (p*q), using a recipe on
    pg. 405 of Volume 2 of The Art of Computer Programming by
    Donald Knuth -- (except this module defaults to smaller
    primes than recommended, and a pseudo-random function, for
    purposes of pedagogy).
    """
    while 1:
       p1 = bigppr(digits)       
       if (p1-1)%3>0:
           break
    while 1:
       p2 = bigppr(digits/2)
       if (p2-1)%3>0:
           break
    k = p2 + 1
    trials = 1
    while 1:
       trials += 1
       if trials == 1000:
           break
       if (k%3 == p1%3) and (k%2==0):
          cand = k*p1 + 1
          if pptest(cand)>0.99:
              print "OK!"
              break
       k += 2
    return cand

def rsasetup():    
    p = getpr(30)
    q = getpr(40)
    n = p*q
    totient = (p-1)*(q-1)
    d = inverse(3,totient)
    return (n,d)

def checkrsa(d,n):
    print "".join(["m mod n = %2s : m**3 mod n = %10s : m**(3*d) mod n = %2s\n" \
            % (x,pow(x,3,n),pow(x,3*d,n)) for x in \
            [randint(2,100) for y in range(10)]])
    
def rsaencrypt(message,pubkey):
    """
    Using default enciphering exponent of 3, mod some pubkey
    = p*q (see rsasetup).  Note that the full definition of the
    RSA doesn't insist on using 3 for the encrypting exponent.
    Any e such that gcd(e,totient)=1 will do, but best if not
    too big -- the prime number 65537 is a popular value, for
    technical reasons. However, in this module, we're following
    Knuth's example (pg. 404) and going with 3.
    """
    m = mknum(message)
    return pow(m,3,pubkey)

def rsadecrypt(c,privkey,n):
    """
    (3*d) = k*totient + 1 (k some integer>0), since
    d is the inverse of 3.  Since c = m**3 mod pubkey
    as per rsaencrypt:
    c**d mod pubkey  =  m**(k*totient + 1) mod pubkey
                     =  m**(k*totient) * m mod pubkey
                     =  1**k * m mod pubkey (Euler's Theorem)
                     =  m mod pubkey
    """
    m = pow(c,privkey,n)
    return mkphrase(m)

# code highlighted using py2html.py version 0.8