1 /*
2  * Copyright 2009-2017 Alibaba Cloud All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CryptoStreamBuf.h"
18 #include <algorithm>
19 #include <cstring>
20 
21 using namespace AlibabaCloud::OSS;
22 
23 
CryptoStreamBuf(std::iostream & stream,const std::shared_ptr<SymmetricCipher> & cipher,const ByteBuffer & key,const ByteBuffer & iv,const int skipCnt)24 CryptoStreamBuf::CryptoStreamBuf(std::iostream& stream,
25     const std::shared_ptr<SymmetricCipher>& cipher,
26     const ByteBuffer& key, const ByteBuffer& iv,
27     const int skipCnt) :
28     StreamBufProxy(stream),
29     cipher_(cipher),
30     encBufferCnt_(0),
31     encBufferOff_(0),
32     decBufferCnt_(0),
33     decBufferOff_(0),
34     key_(key),
35     iv_(iv),
36     initEncrypt(false),
37     initDecrypt(false),
38     skipCnt_(skipCnt)
39 {
40     StartPosForIV_ = StreamBufProxy::seekoff(std::streamoff(0), std::ios_base::cur, std::ios_base::in);
41 }
42 
~CryptoStreamBuf()43 CryptoStreamBuf::~CryptoStreamBuf()
44 {
45     //flush decBuffer when its size is BLK_SIZE
46     if (decBufferCnt_ > 0) {
47         unsigned char block[BLK_SIZE];
48         auto ret = cipher_->Decrypt(block, static_cast<int>(decBufferCnt_), decBuffer_, static_cast<int>(decBufferCnt_));
49         decBufferCnt_ = 0;
50         if (ret < 0) {
51             return;
52         }
53         xsputn_with_skip(reinterpret_cast<char *>(block), ret);
54     }
55 }
56 
xsgetn(char * _Ptr,std::streamsize _Count)57 std::streamsize CryptoStreamBuf::xsgetn(char * _Ptr, std::streamsize _Count)
58 {
59     const std::streamsize startCount = _Count;
60     std::streamsize readCnt;
61     unsigned char block[BLK_SIZE];
62 
63     //update iv base the pos
64     if (!initEncrypt) {
65         auto currPos = StreamBufProxy::seekoff(std::streamoff(0), std::ios_base::cur, std::ios_base::in);
66         currPos -= StartPosForIV_;
67         auto blkOff = currPos / BLK_SIZE;
68         auto blkIdx = currPos % BLK_SIZE;
69         auto iv = SymmetricCipher::IncCTRCounter(iv_, blkOff);
70         cipher_->EncryptInit(key_, iv);
71         encBufferCnt_ = 0;
72         encBufferOff_ = 0;
73         if (blkIdx > 0) {
74             StreamBufProxy::seekpos(blkOff * BLK_SIZE, std::ios_base::in);
75             readCnt = StreamBufProxy::xsgetn(reinterpret_cast<char *>(block), BLK_SIZE);
76             auto ret = cipher_->Encrypt(encBuffer_, static_cast<int>(readCnt), block, static_cast<int>(readCnt));
77             if (ret < 0) {
78                 return -1;
79             }
80             encBufferCnt_ = ret - blkIdx;
81             encBufferOff_ = blkIdx;
82         }
83         initEncrypt = true;
84     }
85 
86     //read from inner encBuffer_ first
87     readCnt = read_from_encrypted_buffer(_Ptr, _Count);
88     if (readCnt > 0) {
89         _Count -= readCnt;
90         _Ptr += readCnt;
91     }
92 
93     //read from streambuf by BLK_SIZE
94     while (_Count > 0) {
95         readCnt = StreamBufProxy::xsgetn(reinterpret_cast<char *>(block), BLK_SIZE);
96         if (readCnt <= 0)
97             break;
98 
99         if (_Count < readCnt) {
100             auto ret = cipher_->Encrypt(encBuffer_, static_cast<int>(readCnt), block, static_cast<int>(readCnt));
101             if (ret < 0) {
102                 return -1;
103             }
104             encBufferCnt_ = ret;
105             encBufferOff_ = 0;
106             break;
107         }
108         else {
109             auto ret = cipher_->Encrypt(reinterpret_cast<unsigned char *>(_Ptr), static_cast<int>(readCnt), block, static_cast<int>(readCnt));
110             if (ret < 0) {
111                 return -1;
112             }
113             _Count -= ret;
114             _Ptr += ret;
115         }
116     }
117 
118     //read from inner encBuffer_ again
119     readCnt = read_from_encrypted_buffer(_Ptr, _Count);
120     if (readCnt > 0) {
121         _Count -= readCnt;
122         _Ptr += readCnt;
123     }
124 
125     return startCount - _Count;
126 }
127 
read_from_encrypted_buffer(char * _Ptr,std::streamsize _Count)128 std::streamsize CryptoStreamBuf::read_from_encrypted_buffer(char * _Ptr, std::streamsize _Count)
129 {
130     const std::streamsize startCount = _Count;
131     if (_Count > 0 && encBufferCnt_ > 0) {
132         auto cnt = std::min(_Count, encBufferCnt_);
133         memcpy(_Ptr, encBuffer_ + encBufferOff_, static_cast<size_t>(cnt));
134         _Ptr += cnt;
135         _Count -= cnt;
136         encBufferCnt_ -= cnt;
137         encBufferOff_ += cnt;
138     }
139     return startCount - _Count;
140 }
141 
seekoff(off_type _Off,std::ios_base::seekdir _Way,std::ios_base::openmode _Mode)142 std::streampos CryptoStreamBuf::seekoff(off_type _Off, std::ios_base::seekdir _Way, std::ios_base::openmode _Mode)
143 {
144     if (_Mode & std::ios_base::in) {
145         initEncrypt = false;
146     }
147     if (_Mode & std::ios_base::out) {
148         initDecrypt = false;
149     }
150     return StreamBufProxy::seekoff(_Off, _Way, _Mode);
151 }
152 
seekpos(pos_type _Pos,std::ios_base::openmode _Mode)153 std::streampos CryptoStreamBuf::seekpos(pos_type _Pos, std::ios_base::openmode _Mode)
154 {
155     if (_Mode & std::ios_base::in) {
156         initEncrypt = false;
157     }
158     if (_Mode & std::ios_base::out) {
159         initDecrypt = false;
160     }
161     return StreamBufProxy::seekpos(_Pos, _Mode);
162 }
163 
xsputn(const char * _Ptr,std::streamsize _Count)164 std::streamsize CryptoStreamBuf::xsputn(const char *_Ptr, std::streamsize _Count)
165 {
166     const std::streamsize startCount = _Count;
167     unsigned char block[BLK_SIZE * 2];
168     std::streamsize writeCnt;
169     //update iv
170     if (!initDecrypt) {
171         cipher_->DecryptInit(key_, iv_);
172         decBufferCnt_ = 0;
173         decBufferOff_ = 0;
174         initDecrypt = true;
175     }
176 
177     //append to decBuffer first
178     if (decBufferCnt_ > 0) {
179         writeCnt = std::min(_Count, (BLK_SIZE - decBufferCnt_));
180         memcpy(decBuffer_ + decBufferOff_, _Ptr, static_cast<int>(writeCnt));
181         decBufferOff_ += writeCnt;
182         decBufferCnt_ += writeCnt;
183         _Ptr += writeCnt;
184         _Count -= writeCnt;
185     }
186 
187     //flush decBuffer when its size is BLK_SIZE
188     if (decBufferCnt_ == BLK_SIZE) {
189         auto ret = cipher_->Decrypt(block, static_cast<int>(BLK_SIZE), decBuffer_, static_cast<int>(BLK_SIZE));
190         if (ret < 0) {
191             return -1;
192         }
193         decBufferCnt_ = 0;
194         decBufferOff_ = 0;
195         writeCnt = xsputn_with_skip(reinterpret_cast<char *>(block), BLK_SIZE);
196         if (writeCnt != BLK_SIZE) {
197             //Todo Save decrypted data
198             return startCount - _Count;
199         }
200     }
201 
202     auto blkOff = _Count / BLK_SIZE;
203     auto blkIdx = _Count % BLK_SIZE;
204 
205     //decrypt by BLK_SIZE
206     for (auto i = std::streamsize(0); i < blkOff; i++) {
207         auto ret = cipher_->Decrypt(block, static_cast<int>(BLK_SIZE), reinterpret_cast<const unsigned char *>(_Ptr), static_cast<int>(BLK_SIZE));
208         if (ret < 0) {
209             return -1;
210         }
211         _Ptr += BLK_SIZE;
212         _Count -= BLK_SIZE;
213         writeCnt = xsputn_with_skip(reinterpret_cast<char *>(block), BLK_SIZE);
214         if (writeCnt != BLK_SIZE) {
215             //Todo Save decrypted data
216             return startCount - _Count;
217         }
218     }
219 
220     //save to decBuffer and decrypt next time
221     if (blkIdx > 0) {
222         memcpy(decBuffer_, _Ptr, static_cast<int>(blkIdx));
223         _Ptr += blkIdx;
224         _Count -= blkIdx;
225         decBufferCnt_ = blkIdx;
226         decBufferOff_ = blkIdx;
227     }
228 
229     return startCount - _Count;
230 }
231 
xsputn_with_skip(const char * _Ptr,std::streamsize _Count)232 std::streamsize CryptoStreamBuf::xsputn_with_skip(const char *_Ptr, std::streamsize _Count)
233 {
234     const std::streamsize startCount = _Count;
235     if (skipCnt_ > 0) {
236         auto min = std::min(skipCnt_, _Count);
237         skipCnt_ -= min;
238         _Count -= min;
239         _Ptr += min;
240     }
241 
242     if (_Count > 0) {
243         _Count -= StreamBufProxy::xsputn(_Ptr, _Count);
244     }
245     return startCount - _Count;
246 }