@@ -68,6 +68,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
6868 ISemaphoreGroups public groups;
6969 mapping (address account = > uint256 groupId ) public groupMapping;
7070 mapping (address account = > uint8 threshold ) public thresholds;
71+ mapping (address account = > uint8 count ) public memberCount;
7172
7273 // smart account -> hash(call(params)) -> valid proof count
7374 mapping (address account = > mapping (bytes32 txHash = > ExtCallCount callDataCount )) public
@@ -133,6 +134,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
133134
134135 // Add members to the group
135136 semaphore.addMembers (groupId, cmts);
137+ memberCount[account] = uint8 (cmts.length );
136138
137139 emit ModuleInitialized (account);
138140 }
@@ -143,6 +145,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
143145 delete thresholds[account];
144146 delete groupMapping[account];
145147 delete acctSeqNum[account];
148+ delete memberCount[account];
146149
147150 //TODO: what is a good way to delete entries associated with `acctTxCount[account]`,
148151 // The following line will make the compiler fail.
@@ -151,15 +154,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
151154 emit ModuleUninitialized (account);
152155 }
153156
154- function memberCount (address account ) public view returns (uint8 cnt ) {
155- // account doesn't belong to a semaphore group. We return 0
156- if (thresholds[account] == 0 ) return 0 ;
157- cnt = uint8 (groups.getMerkleTreeSize (groupMapping[account]));
158- }
159-
160157 function setThreshold (uint8 newThreshold ) external moduleInstalled {
161158 address account = msg .sender ;
162- if (newThreshold == 0 || newThreshold > memberCount ( account) ) {
159+ if (newThreshold == 0 || newThreshold > memberCount[ account] ) {
163160 revert InvalidThreshold (account);
164161 }
165162
@@ -171,14 +168,16 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
171168 address account = msg .sender ;
172169 uint256 groupId = groupMapping[account];
173170
174- if (memberCount ( account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached (account);
171+ if (memberCount[ account] + cmts.length > MAX_MEMBERS) revert MaxMemberReached (account);
175172
176173 for (uint256 i = 0 ; i < cmts.length ; ++ i) {
177174 if (cmts[i] == uint256 (0 )) revert InvalidCommitment (account);
178175 if (groups.hasMember (groupId, cmts[i])) revert IsMemberAlready (account, cmts[i]);
179176 }
180177
181178 semaphore.addMembers (groupId, cmts);
179+ memberCount[account] += uint8 (cmts.length );
180+
182181 emit AddedMembers (account, cmts.length );
183182 }
184183
@@ -191,12 +190,13 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
191190 {
192191 address account = msg .sender ;
193192
194- if (memberCount ( account) == thresholds[account]) revert MemberCntReachesThreshold (account);
193+ if (memberCount[ account] == thresholds[account]) revert MemberCntReachesThreshold (account);
195194
196195 uint256 groupId = groupMapping[account];
197196 if (! groups.hasMember (groupId, cmt)) revert MemberNotExists (account, cmt);
198197
199198 semaphore.removeMember (groupId, cmt, merkleProofSiblings);
199+ memberCount[account] -= 1 ;
200200
201201 emit RemovedMember (account, cmt);
202202 }
@@ -347,7 +347,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
347347 uint256 cmt = Identity.getCommitment (pubKey);
348348 if (! groups.hasMember (groupId, cmt)) revert MemberNotExists (account, cmt);
349349
350- // We don't allow call to other contract .
350+ // We don't allow call to other contracts .
351351 address targetAddr = address (bytes20 (userOp.callData[100 :120 ]));
352352 if (targetAddr != address (this )) revert NonValidatorCallBanned (targetAddr, address (this ));
353353
@@ -356,10 +356,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
356356 bytes memory valAndCallData = userOp.callData[120 :];
357357 bytes4 funcSel = bytes4 (LibBytes.slice (valAndCallData, 32 , 36 ));
358358
359- // Allow only these few types on function calls to pass, and reject all other on-chain
360- // calls. They must be executed via `executeTx()` function .
359+ // We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
360+ // and reject the rest .
361361 if (_isAllowedSelector (funcSel)) return VALIDATION_SUCCESS;
362-
363362 revert NonAllowedSelector (account, funcSel);
364363 }
365364
0 commit comments