001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 * 019 */ 020package org.apache.directory.server.kerberos.shared.crypto.encryption; 021 022 023/** 024 * An implementation of the n-fold algorithm, as required by RFC 3961, 025 * "Encryption and Checksum Specifications for Kerberos 5." 026 * 027 * "To n-fold a number X, replicate the input value to a length that 028 * is the least common multiple of n and the length of X. Before 029 * each repetition, the input is rotated to the right by 13 bit 030 * positions. The successive n-bit chunks are added together using 031 * 1's-complement addition (that is, with end-around carry) to yield 032 * a n-bit result." 033 * 034 * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a> 035 */ 036public class NFold 037{ 038 /** 039 * N-fold the data n times. 040 * 041 * @param n The number of times to n-fold the data. 042 * @param data The data to n-fold. 043 * @return The n-folded data. 044 */ 045 public static byte[] nFold( int n, byte[] data ) 046 { 047 int k = data.length * 8; 048 int lcm = getLcm( n, k ); 049 int replicate = lcm / k; 050 byte[] sumBytes = new byte[lcm / 8]; 051 052 for ( int i = 0; i < replicate; i++ ) 053 { 054 int rotation = 13 * i; 055 056 byte[] temp = rotateRight( data, data.length * 8, rotation ); 057 058 for ( int j = 0; j < temp.length; j++ ) 059 { 060 sumBytes[j + i * temp.length] = temp[j]; 061 } 062 } 063 064 byte[] sum = new byte[n / 8]; 065 byte[] nfold = new byte[n / 8]; 066 067 for ( int m = 0; m < lcm / n; m++ ) 068 { 069 for ( int o = 0; o < n / 8; o++ ) 070 { 071 sum[o] = sumBytes[o + ( m * n / 8 )]; 072 } 073 074 nfold = sum( nfold, sum, nfold.length * 8 ); 075 076 } 077 078 return nfold; 079 } 080 081 082 /** 083 * For 2 numbers, return the least-common multiple. 084 * 085 * @param n1 The first number. 086 * @param n2 The second number. 087 * @return The least-common multiple. 088 */ 089 protected static int getLcm( int n1, int n2 ) 090 { 091 int temp; 092 int product; 093 094 product = n1 * n2; 095 096 do 097 { 098 if ( n1 < n2 ) 099 { 100 temp = n1; 101 n1 = n2; 102 n2 = temp; 103 } 104 n1 = n1 % n2; 105 } 106 while ( n1 != 0 ); 107 108 return product / n2; 109 } 110 111 112 /** 113 * Right-rotate the given byte array. 114 * 115 * @param in The byte array to right-rotate. 116 * @param len The length of the byte array to rotate. 117 * @param step The number of positions to rotate the byte array. 118 * @return The right-rotated byte array. 119 */ 120 private static byte[] rotateRight( byte[] in, int len, int step ) 121 { 122 int numOfBytes = ( len - 1 ) / 8 + 1; 123 byte[] out = new byte[numOfBytes]; 124 125 for ( int i = 0; i < len; i++ ) 126 { 127 int val = getBit( in, i ); 128 setBit( out, ( i + step ) % len, val ); 129 } 130 return out; 131 } 132 133 134 /** 135 * Perform one's complement addition (addition with end-around carry). Note 136 * that for purposes of n-folding, we do not actually complement the 137 * result of the addition. 138 * 139 * @param n1 The first number. 140 * @param n2 The second number. 141 * @param len The length of the byte arrays to sum. 142 * @return The sum with end-around carry. 143 */ 144 protected static byte[] sum( byte[] n1, byte[] n2, int len ) 145 { 146 int numOfBytes = ( len - 1 ) / 8 + 1; 147 byte[] out = new byte[numOfBytes]; 148 int carry = 0; 149 150 for ( int i = len - 1; i > -1; i-- ) 151 { 152 int n1b = getBit( n1, i ); 153 int n2b = getBit( n2, i ); 154 155 int sum = n1b + n2b + carry; 156 157 if ( sum == 0 || sum == 1 ) 158 { 159 setBit( out, i, sum ); 160 carry = 0; 161 } 162 else if ( sum == 2 ) 163 { 164 carry = 1; 165 } 166 else if ( sum == 3 ) 167 { 168 setBit( out, i, 1 ); 169 carry = 1; 170 } 171 } 172 173 if ( carry == 1 ) 174 { 175 byte[] carryArray = new byte[n1.length]; 176 carryArray[carryArray.length - 1] = 1; 177 out = sum( out, carryArray, n1.length * 8 ); 178 } 179 180 return out; 181 } 182 183 184 /** 185 * Get a bit from a byte array at a given position. 186 * 187 * @param data The data to get the bit from. 188 * @param pos The position to get the bit at. 189 * @return The value of the bit. 190 */ 191 private static int getBit( byte[] data, int pos ) 192 { 193 int posByte = pos / 8; 194 int posBit = pos % 8; 195 196 byte valByte = data[posByte]; 197 198 return valByte >> ( 8 - ( posBit + 1 ) ) & 0x0001; 199 } 200 201 202 /** 203 * Set a bit in a byte array at a given position. 204 * 205 * @param data The data to set the bit in. 206 * @param pos The position of the bit to set. 207 * @param The value to set the bit to. 208 */ 209 private static void setBit( byte[] data, int pos, int val ) 210 { 211 int posByte = pos / 8; 212 int posBit = pos % 8; 213 byte oldByte = data[posByte]; 214 oldByte = ( byte ) ( ( ( 0xFF7F >> posBit ) & oldByte ) & 0x00FF ); 215 byte newByte = ( byte ) ( ( val << ( 8 - ( posBit + 1 ) ) ) | oldByte ); 216 data[posByte] = newByte; 217 } 218}