Skip to content

Commit

Permalink
Merge branch 'develop' into adjust-test
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo authored Oct 18, 2024
2 parents fffced9 + b250174 commit b495b50
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 29 deletions.
26 changes: 15 additions & 11 deletions nexus/bin/qdens
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ class SingleDensity(QDBase):
self.check_type('grid' ,grid ,'tuple/list/ndarray',(tuple,list,np.ndarray))
self.check_type('structure',structure,'Structure' ,Structure)
self.check_type('extension',extension,'string' ,str)
self.check_type('density_cell' ,grid ,'tuple/list/ndarray',(tuple,list,np.ndarray))
self.check_type('density_corner',grid ,'tuple/list/ndarray',(tuple,list,np.ndarray))
self.check_type('density_cell' ,grid,'tuple/list/ndarray',(tuple,list,np.ndarray))
self.check_type('density_corner',grid,'tuple/list/ndarray',(tuple,list,np.ndarray))
# process other inputs
if mean is not None:
self.mean = np.array(mean,dtype=float)
Expand All @@ -335,19 +335,12 @@ class SingleDensity(QDBase):
self._error('inputted density_corner must have length 3, received {0}'.format(density_corner))
#end if
self.density_corner = np.array(density_corner,dtype=float)
else:
self.density_corner = np.zeros(3) # If density_corner is not given, use [0, 0, 0] as corner
#end if
if density_cell is not None:
if len(np.array(density_cell).ravel())!=9:
self._error('inputted density_cell must have length 9, received {0}'.format(density_cell))
#end if
self.density_cell = np.array(density_cell,dtype=float)
else:
# If density_cell is not given, use the structure cell
s = self.structure.copy()
s.change_units('A')
self.density_cell = s.axes
#end if
if data is not None:
self.analyze(data)
Expand Down Expand Up @@ -1128,12 +1121,21 @@ class QMCDensityProcessor(QDBase):
self.error('could not identify grid data for spin density named "{0}"\bin QMCPACK input file: {1}'.format(name,opt.input))
#end if
if 'cell' in sd:
self.error('Currently, the cell keyword is not supported due to a bug: See https://github.com/QMCPACK/qmcpack/issues/5201')
density_cell = sd.cell.copy()
density_cell = convert(density_cell, 'B', 'A')
else:
# If density_cell is not given, use the structure cell
dens_struct = structure.copy()
dens_struct.change_units('A')
density_cell = dens_struct.axes
#end if
if 'corner' in sd:
self.error('Currently, the corner keyword is not supported due to a bug: See https://github.com/QMCPACK/qmcpack/issues/5201')
density_corner = sd.corner.copy()
density_corner = convert(density_corner, 'B', 'A')
else:
density_corner = np.zeros(3) # If density_corner is not given, use [0, 0, 0] as corner
#end if
elif name not in grids and isinstance(xml,density_xml):
sd = xml
Expand Down Expand Up @@ -1203,6 +1205,7 @@ class QMCDensityProcessor(QDBase):

# --density_cell option
if opt.density_cell is not None:
self.error('Currently, the cell keyword is not supported due to a bug: See https://github.com/QMCPACK/qmcpack/issues/5201')
density_cell = input_list(opt.density_cell)
try:
density_cell = np.array(density_cell,dtype=float)
Expand All @@ -1220,11 +1223,12 @@ class QMCDensityProcessor(QDBase):

# --density_corner option
if opt.density_corner is not None:
self.error('Currently, the corner keyword is not supported due to a bug: See https://github.com/QMCPACK/qmcpack/issues/5201')
density_corner = input_list(opt.density_corner)
try:
density_corner = np.array(density_corner,dtype=int)
density_corner = np.array(density_corner,dtype=float)
except:
self.error('--density_corner input misformatted\nexpected a list of integers\nreceived: {0}'.format(density_corner))
self.error('--density_corner input misformatted\nexpected a list of floats\nreceived: {0}'.format(density_corner))
#end try
if len(density_corner)!=3:
self.error('--density_corner input misformatted\nexpected 3 elements\nreceived {0} elements with values: {1}'.format(len(density_corner),density_corner))
Expand Down
26 changes: 16 additions & 10 deletions nexus/lib/qmcpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,16 +931,22 @@ def incorporate_result(self,result_name,result,sim):
J3 = optwf.get('J3')
if J3 is not None:
corr = J3.get('correlation')
j3_ids = []
for j3_term in corr:
j3_id = j3_term.coefficients.id
j3_ids.append(j3_id)
#end for
for j3_id in j3_ids:
if 'ud' in j3_id:
delattr(corr, j3_id)
#end if
#end for
if hasattr(corr, 'coefficients'):
# For single-species systems, the data structure changes.
# In this case, the only J3 term should be 'uu'.
# Otherwise, the user might be trying to do something strange.
assert 'uu' in corr.coefficients.id, 'Only uu J3 terms are allowed in SOC calculations.'
else:
j3_ids = []
for j3_term in corr:
j3_id = j3_term.coefficients.id
j3_ids.append(j3_id)
#end for
for j3_id in j3_ids:
if 'ud' in j3_id:
delattr(corr, j3_id)
#end if
#end for
#end if
#end if
def process_jastrow(wf):
Expand Down
16 changes: 8 additions & 8 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
for (const VirtualParticleSet& VP : vp_list)
mw_nVP += VP.getTotalNum();

const size_t packed_size = nw * sizeof(ValueType*) + mw_nVP * (6 * sizeof(TT) + sizeof(int));
const size_t packed_size = nw * sizeof(ValueType*) + mw_nVP * (6 * sizeof(ST) + sizeof(int));
det_ratios_buffer_H2D.resize(packed_size);

// pack invRow_ptr_list to det_ratios_buffer_H2D
Expand All @@ -297,9 +297,9 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
ptr_buffer[iw] = invRow_ptr_list[iw];

// pack particle positions
auto* pos_ptr = reinterpret_cast<TT*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*));
auto* pos_ptr = reinterpret_cast<ST*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*));
auto* ref_id_ptr =
reinterpret_cast<int*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(TT));
reinterpret_cast<int*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(ST));
size_t iVP = 0;
for (size_t iw = 0; iw < nw; iw++)
{
Expand Down Expand Up @@ -353,14 +353,14 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS

auto* restrict offload_scratch_iat_ptr = offload_scratch_ptr + spline_padded_size * iat;
auto* restrict psi_iat_ptr = results_scratch_ptr + sposet_padded_size * iat;
auto* ref_id_ptr = reinterpret_cast<int*>(buffer_H2D_ptr + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(TT));
auto* ref_id_ptr = reinterpret_cast<int*>(buffer_H2D_ptr + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(ST));
auto* restrict psiinv_ptr = reinterpret_cast<const ValueType**>(buffer_H2D_ptr)[ref_id_ptr[iat]];
auto* restrict pos_scratch = reinterpret_cast<TT*>(buffer_H2D_ptr + nw * sizeof(ValueType*));
auto* restrict pos_scratch = reinterpret_cast<ST*>(buffer_H2D_ptr + nw * sizeof(ValueType*));

int ix, iy, iz;
ST a[4], b[4], c[4];
spline2::computeLocationAndFractional(spline_ptr, ST(pos_scratch[iat * 6 + 3]), ST(pos_scratch[iat * 6 + 4]),
ST(pos_scratch[iat * 6 + 5]), ix, iy, iz, a, b, c);
spline2::computeLocationAndFractional(spline_ptr, pos_scratch[iat * 6 + 3], pos_scratch[iat * 6 + 4],
pos_scratch[iat * 6 + 5], ix, iy, iz, a, b, c);

PRAGMA_OFFLOAD("omp parallel for")
for (int index = 0; index < last - first; index++)
Expand All @@ -370,7 +370,7 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
const size_t last_cplx = omptarget::min(last / 2, num_complex_splines);
PRAGMA_OFFLOAD("omp parallel for")
for (int index = first_cplx; index < last_cplx; index++)
C2R::assign_v(ST(pos_scratch[iat * 6]), ST(pos_scratch[iat * 6 + 1]), ST(pos_scratch[iat * 6 + 2]),
C2R::assign_v(pos_scratch[iat * 6], pos_scratch[iat * 6 + 1], pos_scratch[iat * 6 + 2],
psi_iat_ptr, offload_scratch_iat_ptr, myKcart_ptr, myKcart_padded_size, first_spo_local,
nComplexBands_local, index);

Expand Down

0 comments on commit b495b50

Please sign in to comment.