001/*
002 *  Licensed to the Apache Software Foundation (ASF) under one
003 *  or more contributor license agreements.  See the NOTICE file
004 *  distributed with this work for additional information
005 *  regarding copyright ownership.  The ASF licenses this file
006 *  to you under the Apache License, Version 2.0 (the
007 *  "License"); you may not use this file except in compliance
008 *  with the License.  You may obtain a copy of the License at
009 *  
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *  
012 *  Unless required by applicable law or agreed to in writing,
013 *  software distributed under the License is distributed on an
014 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 *  KIND, either express or implied.  See the License for the
016 *  specific language governing permissions and limitations
017 *  under the License. 
018 *  
019 */
020package org.apache.directory.server.kerberos.shared.crypto.encryption;
021
022
023import java.security.GeneralSecurityException;
024import java.security.MessageDigest;
025import java.security.spec.AlgorithmParameterSpec;
026
027import javax.crypto.Cipher;
028import javax.crypto.Mac;
029import javax.crypto.SecretKey;
030import javax.crypto.spec.IvParameterSpec;
031import javax.crypto.spec.SecretKeySpec;
032
033import org.apache.directory.api.util.Strings;
034import org.apache.directory.server.kerberos.shared.crypto.checksum.ChecksumEngine;
035import org.apache.directory.shared.kerberos.codec.types.EncryptionType;
036import org.apache.directory.shared.kerberos.components.EncryptedData;
037import org.apache.directory.shared.kerberos.components.EncryptionKey;
038import org.apache.directory.shared.kerberos.crypto.checksum.ChecksumType;
039import org.apache.directory.shared.kerberos.exceptions.ErrorType;
040import org.apache.directory.shared.kerberos.exceptions.KerberosException;
041
042
043/**
044 * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a>
045 */
046public class Des3CbcSha1KdEncryption extends EncryptionEngine implements ChecksumEngine
047{
048    private static final byte[] iv = new byte[]
049        { ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00,
050            ( byte ) 0x00 };
051
052
053    public EncryptionType getEncryptionType()
054    {
055        return EncryptionType.DES3_CBC_SHA1_KD;
056    }
057
058
059    public int getConfounderLength()
060    {
061        return 8;
062    }
063
064
065    public int getChecksumLength()
066    {
067        return 20;
068    }
069
070
071    public ChecksumType checksumType()
072    {
073        return ChecksumType.HMAC_SHA1_DES3_KD;
074    }
075
076
077    public byte[] calculateChecksum( byte[] data, byte[] key, KeyUsage usage )
078    {
079        byte[] kc = deriveKey( key, getUsageKc( usage ), 64, 168 );
080
081        return processChecksum( data, kc );
082    }
083
084
085    public byte[] calculateIntegrity( byte[] data, byte[] key, KeyUsage usage )
086    {
087        byte[] ki = deriveKey( key, getUsageKi( usage ), 64, 168 );
088
089        return processChecksum( data, ki );
090    }
091
092
093    public byte[] getDecryptedData( EncryptionKey key, EncryptedData data, KeyUsage usage ) throws KerberosException
094    {
095        byte[] ke = deriveKey( key.getKeyValue(), getUsageKe( usage ), 64, 168 );
096
097        byte[] encryptedData = data.getCipher();
098
099        // extract the old checksum
100        byte[] oldChecksum = new byte[getChecksumLength()];
101        System
102            .arraycopy( encryptedData, encryptedData.length - getChecksumLength(), oldChecksum, 0, oldChecksum.length );
103
104        // remove trailing checksum
105        encryptedData = removeTrailingBytes( encryptedData, 0, getChecksumLength() );
106
107        // decrypt the data
108        byte[] decryptedData = decrypt( encryptedData, ke );
109
110        // remove leading confounder
111        byte[] withoutConfounder = removeLeadingBytes( decryptedData, getConfounderLength(), 0 );
112
113        // calculate a new checksum
114        byte[] newChecksum = calculateIntegrity( decryptedData, key.getKeyValue(), usage );
115
116        // compare checksums
117        if ( !MessageDigest.isEqual( oldChecksum, newChecksum ) )
118        {
119            throw new KerberosException( ErrorType.KRB_AP_ERR_BAD_INTEGRITY );
120        }
121
122        return withoutConfounder;
123    }
124
125
126    public EncryptedData getEncryptedData( EncryptionKey key, byte[] plainText, KeyUsage usage )
127    {
128        byte[] ke = deriveKey( key.getKeyValue(), getUsageKe( usage ), 64, 168 );
129
130        // build the ciphertext structure
131        byte[] conFounder = getRandomBytes( getConfounderLength() );
132        byte[] paddedPlainText = padString( plainText );
133        byte[] dataBytes = concatenateBytes( conFounder, paddedPlainText );
134        byte[] checksumBytes = calculateIntegrity( dataBytes, key.getKeyValue(), usage );
135        byte[] encryptedData = encrypt( dataBytes, ke );
136        byte[] cipherText = concatenateBytes( encryptedData, checksumBytes );
137
138        return new EncryptedData( getEncryptionType(), key.getKeyVersion(), cipherText );
139    }
140
141
142    public byte[] encrypt( byte[] plainText, byte[] keyBytes )
143    {
144        return processCipher( true, plainText, keyBytes );
145    }
146
147
148    public byte[] decrypt( byte[] cipherText, byte[] keyBytes )
149    {
150        return processCipher( false, cipherText, keyBytes );
151    }
152
153
154    /**
155     * Derived Key = DK(Base Key, Well-Known Constant)
156     * DK(Key, Constant) = random-to-key(DR(Key, Constant))
157     * DR(Key, Constant) = k-truncate(E(Key, Constant, initial-cipher-state))
158     * 
159     * @param baseKey The base key to derive
160     * @param usage The key usage
161     * @param n The number of resulting bytes
162     * @param k The number of bytes
163     * @return The derived key
164     */
165    protected byte[] deriveKey( byte[] baseKey, byte[] usage, int n, int k )
166    {
167        byte[] result = deriveRandom( baseKey, usage, n, k );
168        result = randomToKey( result );
169
170        return result;
171    }
172
173
174    protected byte[] randomToKey( byte[] seed )
175    {
176        int kBytes = 24;
177        byte[] result = new byte[kBytes];
178
179        byte[] fillingKey = Strings.EMPTY_BYTES;
180
181        int pos = 0;
182
183        for ( int i = 0; i < kBytes; i++ )
184        {
185            if ( pos < fillingKey.length )
186            {
187                result[i] = fillingKey[pos];
188                pos++;
189            }
190            else
191            {
192                fillingKey = getBitGroup( seed, i / 8 );
193                fillingKey = setParity( fillingKey );
194                pos = 0;
195                result[i] = fillingKey[pos];
196                pos++;
197            }
198        }
199
200        return result;
201    }
202
203
204    protected byte[] getBitGroup( byte[] seed, int group )
205    {
206        int srcPos = group * 7;
207
208        byte[] result = new byte[7];
209
210        System.arraycopy( seed, srcPos, result, 0, 7 );
211
212        return result;
213    }
214
215
216    protected byte[] setParity( byte[] in )
217    {
218        byte[] expandedIn = new byte[8];
219
220        System.arraycopy( in, 0, expandedIn, 0, in.length );
221
222        setBit( expandedIn, 62, getBit( in, 7 ) );
223        setBit( expandedIn, 61, getBit( in, 15 ) );
224        setBit( expandedIn, 60, getBit( in, 23 ) );
225        setBit( expandedIn, 59, getBit( in, 31 ) );
226        setBit( expandedIn, 58, getBit( in, 39 ) );
227        setBit( expandedIn, 57, getBit( in, 47 ) );
228        setBit( expandedIn, 56, getBit( in, 55 ) );
229
230        byte[] out = new byte[8];
231
232        int bitCount = 0;
233        int index = 0;
234
235        for ( int i = 0; i < 64; i++ )
236        {
237            if ( ( i + 1 ) % 8 == 0 )
238            {
239                if ( bitCount % 2 == 0 )
240                {
241                    setBit( out, i, 1 );
242                }
243
244                index++;
245                bitCount = 0;
246            }
247            else
248            {
249                int val = getBit( expandedIn, index );
250                boolean bit = val > 0;
251
252                if ( bit )
253                {
254                    setBit( out, i, val );
255                    bitCount++;
256                }
257
258                index++;
259            }
260        }
261
262        return out;
263    }
264
265
266    private byte[] processCipher( boolean isEncrypt, byte[] data, byte[] keyBytes )
267    {
268        try
269        {
270            Cipher cipher = Cipher.getInstance( "DESede/CBC/NoPadding" );
271            SecretKey key = new SecretKeySpec( keyBytes, "DESede" );
272
273            AlgorithmParameterSpec paramSpec = new IvParameterSpec( iv );
274
275            if ( isEncrypt )
276            {
277                cipher.init( Cipher.ENCRYPT_MODE, key, paramSpec );
278            }
279            else
280            {
281                cipher.init( Cipher.DECRYPT_MODE, key, paramSpec );
282            }
283
284            return cipher.doFinal( data );
285        }
286        catch ( GeneralSecurityException nsae )
287        {
288            nsae.printStackTrace();
289            return null;
290        }
291    }
292
293
294    private byte[] processChecksum( byte[] data, byte[] key )
295    {
296        try
297        {
298            SecretKey sk = new SecretKeySpec( key, "DESede" );
299
300            Mac mac = Mac.getInstance( "HmacSHA1" );
301            mac.init( sk );
302
303            return mac.doFinal( data );
304        }
305        catch ( GeneralSecurityException nsae )
306        {
307            nsae.printStackTrace();
308            return null;
309        }
310    }
311}