001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *  
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *  
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License. 
018 *  
019 */
020package org.apache.directory.server.kerberos.shared.crypto.encryption;
021
022
023/**
024 * An implementation of the n-fold algorithm, as required by RFC 3961,
025 * "Encryption and Checksum Specifications for Kerberos 5."
026 * 
027 * "To n-fold a number X, replicate the input value to a length that
028 * is the least common multiple of n and the length of X.  Before
029 * each repetition, the input is rotated to the right by 13 bit
030 * positions.  The successive n-bit chunks are added together using
031 * 1's-complement addition (that is, with end-around carry) to yield
032 * a n-bit result."
033 * 
034 * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a>
035 */
036public class NFold
037{
038    /**
039     * N-fold the data n times.
040     * 
041     * @param n The number of times to n-fold the data.
042     * @param data The data to n-fold.
043     * @return The n-folded data.
044     */
045    public static byte[] nFold( int n, byte[] data )
046    {
047        int k = data.length * 8;
048        int lcm = getLcm( n, k );
049        int replicate = lcm / k;
050        byte[] sumBytes = new byte[lcm / 8];
051
052        for ( int i = 0; i < replicate; i++ )
053        {
054            int rotation = 13 * i;
055
056            byte[] temp = rotateRight( data, data.length * 8, rotation );
057
058            for ( int j = 0; j < temp.length; j++ )
059            {
060                sumBytes[j + i * temp.length] = temp[j];
061            }
062        }
063
064        byte[] sum = new byte[n / 8];
065        byte[] nfold = new byte[n / 8];
066
067        for ( int m = 0; m < lcm / n; m++ )
068        {
069            for ( int o = 0; o < n / 8; o++ )
070            {
071                sum[o] = sumBytes[o + ( m * n / 8 )];
072            }
073
074            nfold = sum( nfold, sum, nfold.length * 8 );
075
076        }
077
078        return nfold;
079    }
080
081
082    /**
083     * For 2 numbers, return the least-common multiple.
084     *
085     * @param n1 The first number.
086     * @param n2 The second number.
087     * @return The least-common multiple.
088     */
089    protected static int getLcm( int n1, int n2 )
090    {
091        int temp;
092        int product;
093
094        product = n1 * n2;
095
096        do
097        {
098            if ( n1 < n2 )
099            {
100                temp = n1;
101                n1 = n2;
102                n2 = temp;
103            }
104            n1 = n1 % n2;
105        }
106        while ( n1 != 0 );
107
108        return product / n2;
109    }
110
111
112    /**
113     * Right-rotate the given byte array.
114     *
115     * @param in The byte array to right-rotate.
116     * @param len The length of the byte array to rotate.
117     * @param step The number of positions to rotate the byte array.
118     * @return The right-rotated byte array.
119     */
120    private static byte[] rotateRight( byte[] in, int len, int step )
121    {
122        int numOfBytes = ( len - 1 ) / 8 + 1;
123        byte[] out = new byte[numOfBytes];
124
125        for ( int i = 0; i < len; i++ )
126        {
127            int val = getBit( in, i );
128            setBit( out, ( i + step ) % len, val );
129        }
130        return out;
131    }
132
133
134    /**
135     * Perform one's complement addition (addition with end-around carry).  Note
136     * that for purposes of n-folding, we do not actually complement the
137     * result of the addition.
138     * 
139     * @param n1 The first number.
140     * @param n2 The second number.
141     * @param len The length of the byte arrays to sum.
142     * @return The sum with end-around carry.
143     */
144    protected static byte[] sum( byte[] n1, byte[] n2, int len )
145    {
146        int numOfBytes = ( len - 1 ) / 8 + 1;
147        byte[] out = new byte[numOfBytes];
148        int carry = 0;
149
150        for ( int i = len - 1; i > -1; i-- )
151        {
152            int n1b = getBit( n1, i );
153            int n2b = getBit( n2, i );
154
155            int sum = n1b + n2b + carry;
156
157            if ( sum == 0 || sum == 1 )
158            {
159                setBit( out, i, sum );
160                carry = 0;
161            }
162            else if ( sum == 2 )
163            {
164                carry = 1;
165            }
166            else if ( sum == 3 )
167            {
168                setBit( out, i, 1 );
169                carry = 1;
170            }
171        }
172
173        if ( carry == 1 )
174        {
175            byte[] carryArray = new byte[n1.length];
176            carryArray[carryArray.length - 1] = 1;
177            out = sum( out, carryArray, n1.length * 8 );
178        }
179
180        return out;
181    }
182
183
184    /**
185     * Get a bit from a byte array at a given position.
186     *
187     * @param data The data to get the bit from.
188     * @param pos The position to get the bit at.
189     * @return The value of the bit.
190     */
191    private static int getBit( byte[] data, int pos )
192    {
193        int posByte = pos / 8;
194        int posBit = pos % 8;
195
196        byte valByte = data[posByte];
197        
198        return valByte >> ( 8 - ( posBit + 1 ) ) & 0x0001;
199    }
200
201
202    /**
203     * Set a bit in a byte array at a given position.
204     *
205     * @param data The data to set the bit in.
206     * @param pos The position of the bit to set.
207     * @param The value to set the bit to.
208     */
209    private static void setBit( byte[] data, int pos, int val )
210    {
211        int posByte = pos / 8;
212        int posBit = pos % 8;
213        byte oldByte = data[posByte];
214        oldByte = ( byte ) ( ( ( 0xFF7F >> posBit ) & oldByte ) & 0x00FF );
215        byte newByte = ( byte ) ( ( val << ( 8 - ( posBit + 1 ) ) ) | oldByte );
216        data[posByte] = newByte;
217    }
218}